##// END OF EJS Templates
Backport PR #5166: remove mktemp usage...
MinRK -
Show More
@@ -1,348 +1,351 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Paging capabilities for IPython.core
3 Paging capabilities for IPython.core
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez
8 * Fernando Perez
9
9
10 Notes
10 Notes
11 -----
11 -----
12
12
13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
14 rid of that dependency, we could move it there.
14 rid of that dependency, we could move it there.
15 -----
15 -----
16 """
16 """
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Copyright (C) 2008-2011 The IPython Development Team
19 # Copyright (C) 2008-2011 The IPython Development Team
20 #
20 #
21 # Distributed under the terms of the BSD License. The full license is in
21 # Distributed under the terms of the BSD License. The full license is in
22 # the file COPYING, distributed as part of this software.
22 # the file COPYING, distributed as part of this software.
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24
24
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Imports
26 # Imports
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 from __future__ import print_function
28 from __future__ import print_function
29
29
30 import os
30 import os
31 import re
31 import re
32 import sys
32 import sys
33 import tempfile
33 import tempfile
34
34
35 from io import UnsupportedOperation
35 from io import UnsupportedOperation
36
36
37 from IPython import get_ipython
37 from IPython import get_ipython
38 from IPython.core.error import TryNext
38 from IPython.core.error import TryNext
39 from IPython.utils.data import chop
39 from IPython.utils.data import chop
40 from IPython.utils import io
40 from IPython.utils import io
41 from IPython.utils.process import system
41 from IPython.utils.process import system
42 from IPython.utils.terminal import get_terminal_size
42 from IPython.utils.terminal import get_terminal_size
43 from IPython.utils import py3compat
43 from IPython.utils import py3compat
44
44
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Classes and functions
47 # Classes and functions
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50 esc_re = re.compile(r"(\x1b[^m]+m)")
50 esc_re = re.compile(r"(\x1b[^m]+m)")
51
51
52 def page_dumb(strng, start=0, screen_lines=25):
52 def page_dumb(strng, start=0, screen_lines=25):
53 """Very dumb 'pager' in Python, for when nothing else works.
53 """Very dumb 'pager' in Python, for when nothing else works.
54
54
55 Only moves forward, same interface as page(), except for pager_cmd and
55 Only moves forward, same interface as page(), except for pager_cmd and
56 mode."""
56 mode."""
57
57
58 out_ln = strng.splitlines()[start:]
58 out_ln = strng.splitlines()[start:]
59 screens = chop(out_ln,screen_lines-1)
59 screens = chop(out_ln,screen_lines-1)
60 if len(screens) == 1:
60 if len(screens) == 1:
61 print(os.linesep.join(screens[0]), file=io.stdout)
61 print(os.linesep.join(screens[0]), file=io.stdout)
62 else:
62 else:
63 last_escape = ""
63 last_escape = ""
64 for scr in screens[0:-1]:
64 for scr in screens[0:-1]:
65 hunk = os.linesep.join(scr)
65 hunk = os.linesep.join(scr)
66 print(last_escape + hunk, file=io.stdout)
66 print(last_escape + hunk, file=io.stdout)
67 if not page_more():
67 if not page_more():
68 return
68 return
69 esc_list = esc_re.findall(hunk)
69 esc_list = esc_re.findall(hunk)
70 if len(esc_list) > 0:
70 if len(esc_list) > 0:
71 last_escape = esc_list[-1]
71 last_escape = esc_list[-1]
72 print(last_escape + os.linesep.join(screens[-1]), file=io.stdout)
72 print(last_escape + os.linesep.join(screens[-1]), file=io.stdout)
73
73
74 def _detect_screen_size(screen_lines_def):
74 def _detect_screen_size(screen_lines_def):
75 """Attempt to work out the number of lines on the screen.
75 """Attempt to work out the number of lines on the screen.
76
76
77 This is called by page(). It can raise an error (e.g. when run in the
77 This is called by page(). It can raise an error (e.g. when run in the
78 test suite), so it's separated out so it can easily be called in a try block.
78 test suite), so it's separated out so it can easily be called in a try block.
79 """
79 """
80 TERM = os.environ.get('TERM',None)
80 TERM = os.environ.get('TERM',None)
81 if not((TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5'):
81 if not((TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5'):
82 # curses causes problems on many terminals other than xterm, and
82 # curses causes problems on many terminals other than xterm, and
83 # some termios calls lock up on Sun OS5.
83 # some termios calls lock up on Sun OS5.
84 return screen_lines_def
84 return screen_lines_def
85
85
86 try:
86 try:
87 import termios
87 import termios
88 import curses
88 import curses
89 except ImportError:
89 except ImportError:
90 return screen_lines_def
90 return screen_lines_def
91
91
92 # There is a bug in curses, where *sometimes* it fails to properly
92 # There is a bug in curses, where *sometimes* it fails to properly
93 # initialize, and then after the endwin() call is made, the
93 # initialize, and then after the endwin() call is made, the
94 # terminal is left in an unusable state. Rather than trying to
94 # terminal is left in an unusable state. Rather than trying to
95 # check everytime for this (by requesting and comparing termios
95 # check everytime for this (by requesting and comparing termios
96 # flags each time), we just save the initial terminal state and
96 # flags each time), we just save the initial terminal state and
97 # unconditionally reset it every time. It's cheaper than making
97 # unconditionally reset it every time. It's cheaper than making
98 # the checks.
98 # the checks.
99 term_flags = termios.tcgetattr(sys.stdout)
99 term_flags = termios.tcgetattr(sys.stdout)
100
100
101 # Curses modifies the stdout buffer size by default, which messes
101 # Curses modifies the stdout buffer size by default, which messes
102 # up Python's normal stdout buffering. This would manifest itself
102 # up Python's normal stdout buffering. This would manifest itself
103 # to IPython users as delayed printing on stdout after having used
103 # to IPython users as delayed printing on stdout after having used
104 # the pager.
104 # the pager.
105 #
105 #
106 # We can prevent this by manually setting the NCURSES_NO_SETBUF
106 # We can prevent this by manually setting the NCURSES_NO_SETBUF
107 # environment variable. For more details, see:
107 # environment variable. For more details, see:
108 # http://bugs.python.org/issue10144
108 # http://bugs.python.org/issue10144
109 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
109 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
110 os.environ['NCURSES_NO_SETBUF'] = ''
110 os.environ['NCURSES_NO_SETBUF'] = ''
111
111
112 # Proceed with curses initialization
112 # Proceed with curses initialization
113 try:
113 try:
114 scr = curses.initscr()
114 scr = curses.initscr()
115 except AttributeError:
115 except AttributeError:
116 # Curses on Solaris may not be complete, so we can't use it there
116 # Curses on Solaris may not be complete, so we can't use it there
117 return screen_lines_def
117 return screen_lines_def
118
118
119 screen_lines_real,screen_cols = scr.getmaxyx()
119 screen_lines_real,screen_cols = scr.getmaxyx()
120 curses.endwin()
120 curses.endwin()
121
121
122 # Restore environment
122 # Restore environment
123 if NCURSES_NO_SETBUF is None:
123 if NCURSES_NO_SETBUF is None:
124 del os.environ['NCURSES_NO_SETBUF']
124 del os.environ['NCURSES_NO_SETBUF']
125 else:
125 else:
126 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
126 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
127
127
128 # Restore terminal state in case endwin() didn't.
128 # Restore terminal state in case endwin() didn't.
129 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
129 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
130 # Now we have what we needed: the screen size in rows/columns
130 # Now we have what we needed: the screen size in rows/columns
131 return screen_lines_real
131 return screen_lines_real
132 #print '***Screen size:',screen_lines_real,'lines x',\
132 #print '***Screen size:',screen_lines_real,'lines x',\
133 #screen_cols,'columns.' # dbg
133 #screen_cols,'columns.' # dbg
134
134
135 def page(strng, start=0, screen_lines=0, pager_cmd=None):
135 def page(strng, start=0, screen_lines=0, pager_cmd=None):
136 """Print a string, piping through a pager after a certain length.
136 """Print a string, piping through a pager after a certain length.
137
137
138 The screen_lines parameter specifies the number of *usable* lines of your
138 The screen_lines parameter specifies the number of *usable* lines of your
139 terminal screen (total lines minus lines you need to reserve to show other
139 terminal screen (total lines minus lines you need to reserve to show other
140 information).
140 information).
141
141
142 If you set screen_lines to a number <=0, page() will try to auto-determine
142 If you set screen_lines to a number <=0, page() will try to auto-determine
143 your screen size and will only use up to (screen_size+screen_lines) for
143 your screen size and will only use up to (screen_size+screen_lines) for
144 printing, paging after that. That is, if you want auto-detection but need
144 printing, paging after that. That is, if you want auto-detection but need
145 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
145 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
146 auto-detection without any lines reserved simply use screen_lines = 0.
146 auto-detection without any lines reserved simply use screen_lines = 0.
147
147
148 If a string won't fit in the allowed lines, it is sent through the
148 If a string won't fit in the allowed lines, it is sent through the
149 specified pager command. If none given, look for PAGER in the environment,
149 specified pager command. If none given, look for PAGER in the environment,
150 and ultimately default to less.
150 and ultimately default to less.
151
151
152 If no system pager works, the string is sent through a 'dumb pager'
152 If no system pager works, the string is sent through a 'dumb pager'
153 written in python, very simplistic.
153 written in python, very simplistic.
154 """
154 """
155
155
156 # Some routines may auto-compute start offsets incorrectly and pass a
156 # Some routines may auto-compute start offsets incorrectly and pass a
157 # negative value. Offset to 0 for robustness.
157 # negative value. Offset to 0 for robustness.
158 start = max(0, start)
158 start = max(0, start)
159
159
160 # first, try the hook
160 # first, try the hook
161 ip = get_ipython()
161 ip = get_ipython()
162 if ip:
162 if ip:
163 try:
163 try:
164 ip.hooks.show_in_pager(strng)
164 ip.hooks.show_in_pager(strng)
165 return
165 return
166 except TryNext:
166 except TryNext:
167 pass
167 pass
168
168
169 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
169 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
170 TERM = os.environ.get('TERM','dumb')
170 TERM = os.environ.get('TERM','dumb')
171 if TERM in ['dumb','emacs'] and os.name != 'nt':
171 if TERM in ['dumb','emacs'] and os.name != 'nt':
172 print(strng)
172 print(strng)
173 return
173 return
174 # chop off the topmost part of the string we don't want to see
174 # chop off the topmost part of the string we don't want to see
175 str_lines = strng.splitlines()[start:]
175 str_lines = strng.splitlines()[start:]
176 str_toprint = os.linesep.join(str_lines)
176 str_toprint = os.linesep.join(str_lines)
177 num_newlines = len(str_lines)
177 num_newlines = len(str_lines)
178 len_str = len(str_toprint)
178 len_str = len(str_toprint)
179
179
180 # Dumb heuristics to guesstimate number of on-screen lines the string
180 # Dumb heuristics to guesstimate number of on-screen lines the string
181 # takes. Very basic, but good enough for docstrings in reasonable
181 # takes. Very basic, but good enough for docstrings in reasonable
182 # terminals. If someone later feels like refining it, it's not hard.
182 # terminals. If someone later feels like refining it, it's not hard.
183 numlines = max(num_newlines,int(len_str/80)+1)
183 numlines = max(num_newlines,int(len_str/80)+1)
184
184
185 screen_lines_def = get_terminal_size()[1]
185 screen_lines_def = get_terminal_size()[1]
186
186
187 # auto-determine screen size
187 # auto-determine screen size
188 if screen_lines <= 0:
188 if screen_lines <= 0:
189 try:
189 try:
190 screen_lines += _detect_screen_size(screen_lines_def)
190 screen_lines += _detect_screen_size(screen_lines_def)
191 except (TypeError, UnsupportedOperation):
191 except (TypeError, UnsupportedOperation):
192 print(str_toprint, file=io.stdout)
192 print(str_toprint, file=io.stdout)
193 return
193 return
194
194
195 #print 'numlines',numlines,'screenlines',screen_lines # dbg
195 #print 'numlines',numlines,'screenlines',screen_lines # dbg
196 if numlines <= screen_lines :
196 if numlines <= screen_lines :
197 #print '*** normal print' # dbg
197 #print '*** normal print' # dbg
198 print(str_toprint, file=io.stdout)
198 print(str_toprint, file=io.stdout)
199 else:
199 else:
200 # Try to open pager and default to internal one if that fails.
200 # Try to open pager and default to internal one if that fails.
201 # All failure modes are tagged as 'retval=1', to match the return
201 # All failure modes are tagged as 'retval=1', to match the return
202 # value of a failed system command. If any intermediate attempt
202 # value of a failed system command. If any intermediate attempt
203 # sets retval to 1, at the end we resort to our own page_dumb() pager.
203 # sets retval to 1, at the end we resort to our own page_dumb() pager.
204 pager_cmd = get_pager_cmd(pager_cmd)
204 pager_cmd = get_pager_cmd(pager_cmd)
205 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
205 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
206 if os.name == 'nt':
206 if os.name == 'nt':
207 if pager_cmd.startswith('type'):
207 if pager_cmd.startswith('type'):
208 # The default WinXP 'type' command is failing on complex strings.
208 # The default WinXP 'type' command is failing on complex strings.
209 retval = 1
209 retval = 1
210 else:
210 else:
211 tmpname = tempfile.mktemp('.txt')
211 fd, tmpname = tempfile.mkstemp('.txt')
212 tmpfile = open(tmpname,'wt')
212 try:
213 os.close(fd)
214 with open(tmpname, 'wt') as tmpfile:
213 tmpfile.write(strng)
215 tmpfile.write(strng)
214 tmpfile.close()
215 cmd = "%s < %s" % (pager_cmd,tmpname)
216 cmd = "%s < %s" % (pager_cmd, tmpname)
217 # tmpfile needs to be closed for windows
216 if os.system(cmd):
218 if os.system(cmd):
217 retval = 1
219 retval = 1
218 else:
220 else:
219 retval = None
221 retval = None
222 finally:
220 os.remove(tmpname)
223 os.remove(tmpname)
221 else:
224 else:
222 try:
225 try:
223 retval = None
226 retval = None
224 # if I use popen4, things hang. No idea why.
227 # if I use popen4, things hang. No idea why.
225 #pager,shell_out = os.popen4(pager_cmd)
228 #pager,shell_out = os.popen4(pager_cmd)
226 pager = os.popen(pager_cmd, 'w')
229 pager = os.popen(pager_cmd, 'w')
227 try:
230 try:
228 pager_encoding = pager.encoding or sys.stdout.encoding
231 pager_encoding = pager.encoding or sys.stdout.encoding
229 pager.write(py3compat.cast_bytes_py2(
232 pager.write(py3compat.cast_bytes_py2(
230 strng, encoding=pager_encoding))
233 strng, encoding=pager_encoding))
231 finally:
234 finally:
232 retval = pager.close()
235 retval = pager.close()
233 except IOError as msg: # broken pipe when user quits
236 except IOError as msg: # broken pipe when user quits
234 if msg.args == (32, 'Broken pipe'):
237 if msg.args == (32, 'Broken pipe'):
235 retval = None
238 retval = None
236 else:
239 else:
237 retval = 1
240 retval = 1
238 except OSError:
241 except OSError:
239 # Other strange problems, sometimes seen in Win2k/cygwin
242 # Other strange problems, sometimes seen in Win2k/cygwin
240 retval = 1
243 retval = 1
241 if retval is not None:
244 if retval is not None:
242 page_dumb(strng,screen_lines=screen_lines)
245 page_dumb(strng,screen_lines=screen_lines)
243
246
244
247
245 def page_file(fname, start=0, pager_cmd=None):
248 def page_file(fname, start=0, pager_cmd=None):
246 """Page a file, using an optional pager command and starting line.
249 """Page a file, using an optional pager command and starting line.
247 """
250 """
248
251
249 pager_cmd = get_pager_cmd(pager_cmd)
252 pager_cmd = get_pager_cmd(pager_cmd)
250 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
253 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
251
254
252 try:
255 try:
253 if os.environ['TERM'] in ['emacs','dumb']:
256 if os.environ['TERM'] in ['emacs','dumb']:
254 raise EnvironmentError
257 raise EnvironmentError
255 system(pager_cmd + ' ' + fname)
258 system(pager_cmd + ' ' + fname)
256 except:
259 except:
257 try:
260 try:
258 if start > 0:
261 if start > 0:
259 start -= 1
262 start -= 1
260 page(open(fname).read(),start)
263 page(open(fname).read(),start)
261 except:
264 except:
262 print('Unable to show file',repr(fname))
265 print('Unable to show file',repr(fname))
263
266
264
267
265 def get_pager_cmd(pager_cmd=None):
268 def get_pager_cmd(pager_cmd=None):
266 """Return a pager command.
269 """Return a pager command.
267
270
268 Makes some attempts at finding an OS-correct one.
271 Makes some attempts at finding an OS-correct one.
269 """
272 """
270 if os.name == 'posix':
273 if os.name == 'posix':
271 default_pager_cmd = 'less -r' # -r for color control sequences
274 default_pager_cmd = 'less -r' # -r for color control sequences
272 elif os.name in ['nt','dos']:
275 elif os.name in ['nt','dos']:
273 default_pager_cmd = 'type'
276 default_pager_cmd = 'type'
274
277
275 if pager_cmd is None:
278 if pager_cmd is None:
276 try:
279 try:
277 pager_cmd = os.environ['PAGER']
280 pager_cmd = os.environ['PAGER']
278 except:
281 except:
279 pager_cmd = default_pager_cmd
282 pager_cmd = default_pager_cmd
280 return pager_cmd
283 return pager_cmd
281
284
282
285
283 def get_pager_start(pager, start):
286 def get_pager_start(pager, start):
284 """Return the string for paging files with an offset.
287 """Return the string for paging files with an offset.
285
288
286 This is the '+N' argument which less and more (under Unix) accept.
289 This is the '+N' argument which less and more (under Unix) accept.
287 """
290 """
288
291
289 if pager in ['less','more']:
292 if pager in ['less','more']:
290 if start:
293 if start:
291 start_string = '+' + str(start)
294 start_string = '+' + str(start)
292 else:
295 else:
293 start_string = ''
296 start_string = ''
294 else:
297 else:
295 start_string = ''
298 start_string = ''
296 return start_string
299 return start_string
297
300
298
301
299 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
302 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
300 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
303 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
301 import msvcrt
304 import msvcrt
302 def page_more():
305 def page_more():
303 """ Smart pausing between pages
306 """ Smart pausing between pages
304
307
305 @return: True if need print more lines, False if quit
308 @return: True if need print more lines, False if quit
306 """
309 """
307 io.stdout.write('---Return to continue, q to quit--- ')
310 io.stdout.write('---Return to continue, q to quit--- ')
308 ans = msvcrt.getch()
311 ans = msvcrt.getch()
309 if ans in ("q", "Q"):
312 if ans in ("q", "Q"):
310 result = False
313 result = False
311 else:
314 else:
312 result = True
315 result = True
313 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
316 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
314 return result
317 return result
315 else:
318 else:
316 def page_more():
319 def page_more():
317 ans = raw_input('---Return to continue, q to quit--- ')
320 ans = raw_input('---Return to continue, q to quit--- ')
318 if ans.lower().startswith('q'):
321 if ans.lower().startswith('q'):
319 return False
322 return False
320 else:
323 else:
321 return True
324 return True
322
325
323
326
324 def snip_print(str,width = 75,print_full = 0,header = ''):
327 def snip_print(str,width = 75,print_full = 0,header = ''):
325 """Print a string snipping the midsection to fit in width.
328 """Print a string snipping the midsection to fit in width.
326
329
327 print_full: mode control:
330 print_full: mode control:
328 - 0: only snip long strings
331 - 0: only snip long strings
329 - 1: send to page() directly.
332 - 1: send to page() directly.
330 - 2: snip long strings and ask for full length viewing with page()
333 - 2: snip long strings and ask for full length viewing with page()
331 Return 1 if snipping was necessary, 0 otherwise."""
334 Return 1 if snipping was necessary, 0 otherwise."""
332
335
333 if print_full == 1:
336 if print_full == 1:
334 page(header+str)
337 page(header+str)
335 return 0
338 return 0
336
339
337 print(header, end=' ')
340 print(header, end=' ')
338 if len(str) < width:
341 if len(str) < width:
339 print(str)
342 print(str)
340 snip = 0
343 snip = 0
341 else:
344 else:
342 whalf = int((width -5)/2)
345 whalf = int((width -5)/2)
343 print(str[:whalf] + ' <...> ' + str[-whalf:])
346 print(str[:whalf] + ' <...> ' + str[-whalf:])
344 snip = 1
347 snip = 1
345 if snip and print_full == 2:
348 if snip and print_full == 2:
346 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
349 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
347 page(str)
350 page(str)
348 return snip
351 return snip
@@ -1,556 +1,557 b''
1 """Utilities for connecting to kernels
1 """Utilities for connecting to kernels
2
2
3 Authors:
3 Authors:
4
4
5 * Min Ragan-Kelley
5 * Min Ragan-Kelley
6
6
7 """
7 """
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2013 The IPython Development Team
10 # Copyright (C) 2013 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 from __future__ import absolute_import
20 from __future__ import absolute_import
21
21
22 import glob
22 import glob
23 import json
23 import json
24 import os
24 import os
25 import socket
25 import socket
26 import sys
26 import sys
27 from getpass import getpass
27 from getpass import getpass
28 from subprocess import Popen, PIPE
28 from subprocess import Popen, PIPE
29 import tempfile
29 import tempfile
30
30
31 import zmq
31 import zmq
32
32
33 # external imports
33 # external imports
34 from IPython.external.ssh import tunnel
34 from IPython.external.ssh import tunnel
35
35
36 # IPython imports
36 # IPython imports
37 from IPython.config import Configurable
37 from IPython.config import Configurable
38 from IPython.core.profiledir import ProfileDir
38 from IPython.core.profiledir import ProfileDir
39 from IPython.utils.localinterfaces import LOCALHOST
39 from IPython.utils.localinterfaces import LOCALHOST
40 from IPython.utils.path import filefind, get_ipython_dir
40 from IPython.utils.path import filefind, get_ipython_dir
41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str, cast_bytes_py2
41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str, cast_bytes_py2
42 from IPython.utils.traitlets import (
42 from IPython.utils.traitlets import (
43 Bool, Integer, Unicode, CaselessStrEnum,
43 Bool, Integer, Unicode, CaselessStrEnum,
44 )
44 )
45
45
46
46
47 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
48 # Working with Connection Files
48 # Working with Connection Files
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50
50
51 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
51 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
52 control_port=0, ip=LOCALHOST, key=b'', transport='tcp',
52 control_port=0, ip=LOCALHOST, key=b'', transport='tcp',
53 signature_scheme='hmac-sha256',
53 signature_scheme='hmac-sha256',
54 ):
54 ):
55 """Generates a JSON config file, including the selection of random ports.
55 """Generates a JSON config file, including the selection of random ports.
56
56
57 Parameters
57 Parameters
58 ----------
58 ----------
59
59
60 fname : unicode
60 fname : unicode
61 The path to the file to write
61 The path to the file to write
62
62
63 shell_port : int, optional
63 shell_port : int, optional
64 The port to use for ROUTER (shell) channel.
64 The port to use for ROUTER (shell) channel.
65
65
66 iopub_port : int, optional
66 iopub_port : int, optional
67 The port to use for the SUB channel.
67 The port to use for the SUB channel.
68
68
69 stdin_port : int, optional
69 stdin_port : int, optional
70 The port to use for the ROUTER (raw input) channel.
70 The port to use for the ROUTER (raw input) channel.
71
71
72 control_port : int, optional
72 control_port : int, optional
73 The port to use for the ROUTER (control) channel.
73 The port to use for the ROUTER (control) channel.
74
74
75 hb_port : int, optional
75 hb_port : int, optional
76 The port to use for the heartbeat REP channel.
76 The port to use for the heartbeat REP channel.
77
77
78 ip : str, optional
78 ip : str, optional
79 The ip address the kernel will bind to.
79 The ip address the kernel will bind to.
80
80
81 key : str, optional
81 key : str, optional
82 The Session key used for message authentication.
82 The Session key used for message authentication.
83
83
84 signature_scheme : str, optional
84 signature_scheme : str, optional
85 The scheme used for message authentication.
85 The scheme used for message authentication.
86 This has the form 'digest-hash', where 'digest'
86 This has the form 'digest-hash', where 'digest'
87 is the scheme used for digests, and 'hash' is the name of the hash function
87 is the scheme used for digests, and 'hash' is the name of the hash function
88 used by the digest scheme.
88 used by the digest scheme.
89 Currently, 'hmac' is the only supported digest scheme,
89 Currently, 'hmac' is the only supported digest scheme,
90 and 'sha256' is the default hash function.
90 and 'sha256' is the default hash function.
91
91
92 """
92 """
93 # default to temporary connector file
93 # default to temporary connector file
94 if not fname:
94 if not fname:
95 fname = tempfile.mktemp('.json')
95 fd, fname = tempfile.mkstemp('.json')
96 os.close(fd)
96
97
97 # Find open ports as necessary.
98 # Find open ports as necessary.
98
99
99 ports = []
100 ports = []
100 ports_needed = int(shell_port <= 0) + \
101 ports_needed = int(shell_port <= 0) + \
101 int(iopub_port <= 0) + \
102 int(iopub_port <= 0) + \
102 int(stdin_port <= 0) + \
103 int(stdin_port <= 0) + \
103 int(control_port <= 0) + \
104 int(control_port <= 0) + \
104 int(hb_port <= 0)
105 int(hb_port <= 0)
105 if transport == 'tcp':
106 if transport == 'tcp':
106 for i in range(ports_needed):
107 for i in range(ports_needed):
107 sock = socket.socket()
108 sock = socket.socket()
108 sock.bind(('', 0))
109 sock.bind(('', 0))
109 ports.append(sock)
110 ports.append(sock)
110 for i, sock in enumerate(ports):
111 for i, sock in enumerate(ports):
111 port = sock.getsockname()[1]
112 port = sock.getsockname()[1]
112 sock.close()
113 sock.close()
113 ports[i] = port
114 ports[i] = port
114 else:
115 else:
115 N = 1
116 N = 1
116 for i in range(ports_needed):
117 for i in range(ports_needed):
117 while os.path.exists("%s-%s" % (ip, str(N))):
118 while os.path.exists("%s-%s" % (ip, str(N))):
118 N += 1
119 N += 1
119 ports.append(N)
120 ports.append(N)
120 N += 1
121 N += 1
121 if shell_port <= 0:
122 if shell_port <= 0:
122 shell_port = ports.pop(0)
123 shell_port = ports.pop(0)
123 if iopub_port <= 0:
124 if iopub_port <= 0:
124 iopub_port = ports.pop(0)
125 iopub_port = ports.pop(0)
125 if stdin_port <= 0:
126 if stdin_port <= 0:
126 stdin_port = ports.pop(0)
127 stdin_port = ports.pop(0)
127 if control_port <= 0:
128 if control_port <= 0:
128 control_port = ports.pop(0)
129 control_port = ports.pop(0)
129 if hb_port <= 0:
130 if hb_port <= 0:
130 hb_port = ports.pop(0)
131 hb_port = ports.pop(0)
131
132
132 cfg = dict( shell_port=shell_port,
133 cfg = dict( shell_port=shell_port,
133 iopub_port=iopub_port,
134 iopub_port=iopub_port,
134 stdin_port=stdin_port,
135 stdin_port=stdin_port,
135 control_port=control_port,
136 control_port=control_port,
136 hb_port=hb_port,
137 hb_port=hb_port,
137 )
138 )
138 cfg['ip'] = ip
139 cfg['ip'] = ip
139 cfg['key'] = bytes_to_str(key)
140 cfg['key'] = bytes_to_str(key)
140 cfg['transport'] = transport
141 cfg['transport'] = transport
141 cfg['signature_scheme'] = signature_scheme
142 cfg['signature_scheme'] = signature_scheme
142
143
143 with open(fname, 'w') as f:
144 with open(fname, 'w') as f:
144 f.write(json.dumps(cfg, indent=2))
145 f.write(json.dumps(cfg, indent=2))
145
146
146 return fname, cfg
147 return fname, cfg
147
148
148
149
149 def get_connection_file(app=None):
150 def get_connection_file(app=None):
150 """Return the path to the connection file of an app
151 """Return the path to the connection file of an app
151
152
152 Parameters
153 Parameters
153 ----------
154 ----------
154 app : IPKernelApp instance [optional]
155 app : IPKernelApp instance [optional]
155 If unspecified, the currently running app will be used
156 If unspecified, the currently running app will be used
156 """
157 """
157 if app is None:
158 if app is None:
158 from IPython.kernel.zmq.kernelapp import IPKernelApp
159 from IPython.kernel.zmq.kernelapp import IPKernelApp
159 if not IPKernelApp.initialized():
160 if not IPKernelApp.initialized():
160 raise RuntimeError("app not specified, and not in a running Kernel")
161 raise RuntimeError("app not specified, and not in a running Kernel")
161
162
162 app = IPKernelApp.instance()
163 app = IPKernelApp.instance()
163 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
164 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
164
165
165
166
166 def find_connection_file(filename, profile=None):
167 def find_connection_file(filename, profile=None):
167 """find a connection file, and return its absolute path.
168 """find a connection file, and return its absolute path.
168
169
169 The current working directory and the profile's security
170 The current working directory and the profile's security
170 directory will be searched for the file if it is not given by
171 directory will be searched for the file if it is not given by
171 absolute path.
172 absolute path.
172
173
173 If profile is unspecified, then the current running application's
174 If profile is unspecified, then the current running application's
174 profile will be used, or 'default', if not run from IPython.
175 profile will be used, or 'default', if not run from IPython.
175
176
176 If the argument does not match an existing file, it will be interpreted as a
177 If the argument does not match an existing file, it will be interpreted as a
177 fileglob, and the matching file in the profile's security dir with
178 fileglob, and the matching file in the profile's security dir with
178 the latest access time will be used.
179 the latest access time will be used.
179
180
180 Parameters
181 Parameters
181 ----------
182 ----------
182 filename : str
183 filename : str
183 The connection file or fileglob to search for.
184 The connection file or fileglob to search for.
184 profile : str [optional]
185 profile : str [optional]
185 The name of the profile to use when searching for the connection file,
186 The name of the profile to use when searching for the connection file,
186 if different from the current IPython session or 'default'.
187 if different from the current IPython session or 'default'.
187
188
188 Returns
189 Returns
189 -------
190 -------
190 str : The absolute path of the connection file.
191 str : The absolute path of the connection file.
191 """
192 """
192 from IPython.core.application import BaseIPythonApplication as IPApp
193 from IPython.core.application import BaseIPythonApplication as IPApp
193 try:
194 try:
194 # quick check for absolute path, before going through logic
195 # quick check for absolute path, before going through logic
195 return filefind(filename)
196 return filefind(filename)
196 except IOError:
197 except IOError:
197 pass
198 pass
198
199
199 if profile is None:
200 if profile is None:
200 # profile unspecified, check if running from an IPython app
201 # profile unspecified, check if running from an IPython app
201 if IPApp.initialized():
202 if IPApp.initialized():
202 app = IPApp.instance()
203 app = IPApp.instance()
203 profile_dir = app.profile_dir
204 profile_dir = app.profile_dir
204 else:
205 else:
205 # not running in IPython, use default profile
206 # not running in IPython, use default profile
206 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
207 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
207 else:
208 else:
208 # find profiledir by profile name:
209 # find profiledir by profile name:
209 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
210 security_dir = profile_dir.security_dir
211 security_dir = profile_dir.security_dir
211
212
212 try:
213 try:
213 # first, try explicit name
214 # first, try explicit name
214 return filefind(filename, ['.', security_dir])
215 return filefind(filename, ['.', security_dir])
215 except IOError:
216 except IOError:
216 pass
217 pass
217
218
218 # not found by full name
219 # not found by full name
219
220
220 if '*' in filename:
221 if '*' in filename:
221 # given as a glob already
222 # given as a glob already
222 pat = filename
223 pat = filename
223 else:
224 else:
224 # accept any substring match
225 # accept any substring match
225 pat = '*%s*' % filename
226 pat = '*%s*' % filename
226 matches = glob.glob( os.path.join(security_dir, pat) )
227 matches = glob.glob( os.path.join(security_dir, pat) )
227 if not matches:
228 if not matches:
228 raise IOError("Could not find %r in %r" % (filename, security_dir))
229 raise IOError("Could not find %r in %r" % (filename, security_dir))
229 elif len(matches) == 1:
230 elif len(matches) == 1:
230 return matches[0]
231 return matches[0]
231 else:
232 else:
232 # get most recent match, by access time:
233 # get most recent match, by access time:
233 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
234 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
234
235
235
236
236 def get_connection_info(connection_file=None, unpack=False, profile=None):
237 def get_connection_info(connection_file=None, unpack=False, profile=None):
237 """Return the connection information for the current Kernel.
238 """Return the connection information for the current Kernel.
238
239
239 Parameters
240 Parameters
240 ----------
241 ----------
241 connection_file : str [optional]
242 connection_file : str [optional]
242 The connection file to be used. Can be given by absolute path, or
243 The connection file to be used. Can be given by absolute path, or
243 IPython will search in the security directory of a given profile.
244 IPython will search in the security directory of a given profile.
244 If run from IPython,
245 If run from IPython,
245
246
246 If unspecified, the connection file for the currently running
247 If unspecified, the connection file for the currently running
247 IPython Kernel will be used, which is only allowed from inside a kernel.
248 IPython Kernel will be used, which is only allowed from inside a kernel.
248 unpack : bool [default: False]
249 unpack : bool [default: False]
249 if True, return the unpacked dict, otherwise just the string contents
250 if True, return the unpacked dict, otherwise just the string contents
250 of the file.
251 of the file.
251 profile : str [optional]
252 profile : str [optional]
252 The name of the profile to use when searching for the connection file,
253 The name of the profile to use when searching for the connection file,
253 if different from the current IPython session or 'default'.
254 if different from the current IPython session or 'default'.
254
255
255
256
256 Returns
257 Returns
257 -------
258 -------
258 The connection dictionary of the current kernel, as string or dict,
259 The connection dictionary of the current kernel, as string or dict,
259 depending on `unpack`.
260 depending on `unpack`.
260 """
261 """
261 if connection_file is None:
262 if connection_file is None:
262 # get connection file from current kernel
263 # get connection file from current kernel
263 cf = get_connection_file()
264 cf = get_connection_file()
264 else:
265 else:
265 # connection file specified, allow shortnames:
266 # connection file specified, allow shortnames:
266 cf = find_connection_file(connection_file, profile=profile)
267 cf = find_connection_file(connection_file, profile=profile)
267
268
268 with open(cf) as f:
269 with open(cf) as f:
269 info = f.read()
270 info = f.read()
270
271
271 if unpack:
272 if unpack:
272 info = json.loads(info)
273 info = json.loads(info)
273 # ensure key is bytes:
274 # ensure key is bytes:
274 info['key'] = str_to_bytes(info.get('key', ''))
275 info['key'] = str_to_bytes(info.get('key', ''))
275 return info
276 return info
276
277
277
278
278 def connect_qtconsole(connection_file=None, argv=None, profile=None):
279 def connect_qtconsole(connection_file=None, argv=None, profile=None):
279 """Connect a qtconsole to the current kernel.
280 """Connect a qtconsole to the current kernel.
280
281
281 This is useful for connecting a second qtconsole to a kernel, or to a
282 This is useful for connecting a second qtconsole to a kernel, or to a
282 local notebook.
283 local notebook.
283
284
284 Parameters
285 Parameters
285 ----------
286 ----------
286 connection_file : str [optional]
287 connection_file : str [optional]
287 The connection file to be used. Can be given by absolute path, or
288 The connection file to be used. Can be given by absolute path, or
288 IPython will search in the security directory of a given profile.
289 IPython will search in the security directory of a given profile.
289 If run from IPython,
290 If run from IPython,
290
291
291 If unspecified, the connection file for the currently running
292 If unspecified, the connection file for the currently running
292 IPython Kernel will be used, which is only allowed from inside a kernel.
293 IPython Kernel will be used, which is only allowed from inside a kernel.
293 argv : list [optional]
294 argv : list [optional]
294 Any extra args to be passed to the console.
295 Any extra args to be passed to the console.
295 profile : str [optional]
296 profile : str [optional]
296 The name of the profile to use when searching for the connection file,
297 The name of the profile to use when searching for the connection file,
297 if different from the current IPython session or 'default'.
298 if different from the current IPython session or 'default'.
298
299
299
300
300 Returns
301 Returns
301 -------
302 -------
302 subprocess.Popen instance running the qtconsole frontend
303 subprocess.Popen instance running the qtconsole frontend
303 """
304 """
304 argv = [] if argv is None else argv
305 argv = [] if argv is None else argv
305
306
306 if connection_file is None:
307 if connection_file is None:
307 # get connection file from current kernel
308 # get connection file from current kernel
308 cf = get_connection_file()
309 cf = get_connection_file()
309 else:
310 else:
310 cf = find_connection_file(connection_file, profile=profile)
311 cf = find_connection_file(connection_file, profile=profile)
311
312
312 cmd = ';'.join([
313 cmd = ';'.join([
313 "from IPython.qt.console import qtconsoleapp",
314 "from IPython.qt.console import qtconsoleapp",
314 "qtconsoleapp.main()"
315 "qtconsoleapp.main()"
315 ])
316 ])
316
317
317 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
318 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
318 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
319 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
319 )
320 )
320
321
321
322
322 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
323 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
323 """tunnel connections to a kernel via ssh
324 """tunnel connections to a kernel via ssh
324
325
325 This will open four SSH tunnels from localhost on this machine to the
326 This will open four SSH tunnels from localhost on this machine to the
326 ports associated with the kernel. They can be either direct
327 ports associated with the kernel. They can be either direct
327 localhost-localhost tunnels, or if an intermediate server is necessary,
328 localhost-localhost tunnels, or if an intermediate server is necessary,
328 the kernel must be listening on a public IP.
329 the kernel must be listening on a public IP.
329
330
330 Parameters
331 Parameters
331 ----------
332 ----------
332 connection_info : dict or str (path)
333 connection_info : dict or str (path)
333 Either a connection dict, or the path to a JSON connection file
334 Either a connection dict, or the path to a JSON connection file
334 sshserver : str
335 sshserver : str
335 The ssh sever to use to tunnel to the kernel. Can be a full
336 The ssh sever to use to tunnel to the kernel. Can be a full
336 `user@server:port` string. ssh config aliases are respected.
337 `user@server:port` string. ssh config aliases are respected.
337 sshkey : str [optional]
338 sshkey : str [optional]
338 Path to file containing ssh key to use for authentication.
339 Path to file containing ssh key to use for authentication.
339 Only necessary if your ssh config does not already associate
340 Only necessary if your ssh config does not already associate
340 a keyfile with the host.
341 a keyfile with the host.
341
342
342 Returns
343 Returns
343 -------
344 -------
344
345
345 (shell, iopub, stdin, hb) : ints
346 (shell, iopub, stdin, hb) : ints
346 The four ports on localhost that have been forwarded to the kernel.
347 The four ports on localhost that have been forwarded to the kernel.
347 """
348 """
348 if isinstance(connection_info, basestring):
349 if isinstance(connection_info, basestring):
349 # it's a path, unpack it
350 # it's a path, unpack it
350 with open(connection_info) as f:
351 with open(connection_info) as f:
351 connection_info = json.loads(f.read())
352 connection_info = json.loads(f.read())
352
353
353 cf = connection_info
354 cf = connection_info
354
355
355 lports = tunnel.select_random_ports(4)
356 lports = tunnel.select_random_ports(4)
356 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
357 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
357
358
358 remote_ip = cf['ip']
359 remote_ip = cf['ip']
359
360
360 if tunnel.try_passwordless_ssh(sshserver, sshkey):
361 if tunnel.try_passwordless_ssh(sshserver, sshkey):
361 password=False
362 password=False
362 else:
363 else:
363 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
364 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
364
365
365 for lp,rp in zip(lports, rports):
366 for lp,rp in zip(lports, rports):
366 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
367 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
367
368
368 return tuple(lports)
369 return tuple(lports)
369
370
370
371
371 #-----------------------------------------------------------------------------
372 #-----------------------------------------------------------------------------
372 # Mixin for classes that work with connection files
373 # Mixin for classes that work with connection files
373 #-----------------------------------------------------------------------------
374 #-----------------------------------------------------------------------------
374
375
375 channel_socket_types = {
376 channel_socket_types = {
376 'hb' : zmq.REQ,
377 'hb' : zmq.REQ,
377 'shell' : zmq.DEALER,
378 'shell' : zmq.DEALER,
378 'iopub' : zmq.SUB,
379 'iopub' : zmq.SUB,
379 'stdin' : zmq.DEALER,
380 'stdin' : zmq.DEALER,
380 'control': zmq.DEALER,
381 'control': zmq.DEALER,
381 }
382 }
382
383
383 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
384 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
384
385
385 class ConnectionFileMixin(Configurable):
386 class ConnectionFileMixin(Configurable):
386 """Mixin for configurable classes that work with connection files"""
387 """Mixin for configurable classes that work with connection files"""
387
388
388 # The addresses for the communication channels
389 # The addresses for the communication channels
389 connection_file = Unicode('')
390 connection_file = Unicode('')
390 _connection_file_written = Bool(False)
391 _connection_file_written = Bool(False)
391
392
392 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
393 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
393
394
394 ip = Unicode(LOCALHOST, config=True,
395 ip = Unicode(LOCALHOST, config=True,
395 help="""Set the kernel\'s IP address [default localhost].
396 help="""Set the kernel\'s IP address [default localhost].
396 If the IP address is something other than localhost, then
397 If the IP address is something other than localhost, then
397 Consoles on other machines will be able to connect
398 Consoles on other machines will be able to connect
398 to the Kernel, so be careful!"""
399 to the Kernel, so be careful!"""
399 )
400 )
400
401
401 def _ip_default(self):
402 def _ip_default(self):
402 if self.transport == 'ipc':
403 if self.transport == 'ipc':
403 if self.connection_file:
404 if self.connection_file:
404 return os.path.splitext(self.connection_file)[0] + '-ipc'
405 return os.path.splitext(self.connection_file)[0] + '-ipc'
405 else:
406 else:
406 return 'kernel-ipc'
407 return 'kernel-ipc'
407 else:
408 else:
408 return LOCALHOST
409 return LOCALHOST
409
410
410 def _ip_changed(self, name, old, new):
411 def _ip_changed(self, name, old, new):
411 if new == '*':
412 if new == '*':
412 self.ip = '0.0.0.0'
413 self.ip = '0.0.0.0'
413
414
414 # protected traits
415 # protected traits
415
416
416 shell_port = Integer(0)
417 shell_port = Integer(0)
417 iopub_port = Integer(0)
418 iopub_port = Integer(0)
418 stdin_port = Integer(0)
419 stdin_port = Integer(0)
419 control_port = Integer(0)
420 control_port = Integer(0)
420 hb_port = Integer(0)
421 hb_port = Integer(0)
421
422
422 @property
423 @property
423 def ports(self):
424 def ports(self):
424 return [ getattr(self, name) for name in port_names ]
425 return [ getattr(self, name) for name in port_names ]
425
426
426 #--------------------------------------------------------------------------
427 #--------------------------------------------------------------------------
427 # Connection and ipc file management
428 # Connection and ipc file management
428 #--------------------------------------------------------------------------
429 #--------------------------------------------------------------------------
429
430
430 def get_connection_info(self):
431 def get_connection_info(self):
431 """return the connection info as a dict"""
432 """return the connection info as a dict"""
432 return dict(
433 return dict(
433 transport=self.transport,
434 transport=self.transport,
434 ip=self.ip,
435 ip=self.ip,
435 shell_port=self.shell_port,
436 shell_port=self.shell_port,
436 iopub_port=self.iopub_port,
437 iopub_port=self.iopub_port,
437 stdin_port=self.stdin_port,
438 stdin_port=self.stdin_port,
438 hb_port=self.hb_port,
439 hb_port=self.hb_port,
439 control_port=self.control_port,
440 control_port=self.control_port,
440 signature_scheme=self.session.signature_scheme,
441 signature_scheme=self.session.signature_scheme,
441 key=self.session.key,
442 key=self.session.key,
442 )
443 )
443
444
444 def cleanup_connection_file(self):
445 def cleanup_connection_file(self):
445 """Cleanup connection file *if we wrote it*
446 """Cleanup connection file *if we wrote it*
446
447
447 Will not raise if the connection file was already removed somehow.
448 Will not raise if the connection file was already removed somehow.
448 """
449 """
449 if self._connection_file_written:
450 if self._connection_file_written:
450 # cleanup connection files on full shutdown of kernel we started
451 # cleanup connection files on full shutdown of kernel we started
451 self._connection_file_written = False
452 self._connection_file_written = False
452 try:
453 try:
453 os.remove(self.connection_file)
454 os.remove(self.connection_file)
454 except (IOError, OSError, AttributeError):
455 except (IOError, OSError, AttributeError):
455 pass
456 pass
456
457
457 def cleanup_ipc_files(self):
458 def cleanup_ipc_files(self):
458 """Cleanup ipc files if we wrote them."""
459 """Cleanup ipc files if we wrote them."""
459 if self.transport != 'ipc':
460 if self.transport != 'ipc':
460 return
461 return
461 for port in self.ports:
462 for port in self.ports:
462 ipcfile = "%s-%i" % (self.ip, port)
463 ipcfile = "%s-%i" % (self.ip, port)
463 try:
464 try:
464 os.remove(ipcfile)
465 os.remove(ipcfile)
465 except (IOError, OSError):
466 except (IOError, OSError):
466 pass
467 pass
467
468
468 def write_connection_file(self):
469 def write_connection_file(self):
469 """Write connection info to JSON dict in self.connection_file."""
470 """Write connection info to JSON dict in self.connection_file."""
470 if self._connection_file_written:
471 if self._connection_file_written:
471 return
472 return
472
473
473 self.connection_file, cfg = write_connection_file(self.connection_file,
474 self.connection_file, cfg = write_connection_file(self.connection_file,
474 transport=self.transport, ip=self.ip, key=self.session.key,
475 transport=self.transport, ip=self.ip, key=self.session.key,
475 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
476 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
476 shell_port=self.shell_port, hb_port=self.hb_port,
477 shell_port=self.shell_port, hb_port=self.hb_port,
477 control_port=self.control_port,
478 control_port=self.control_port,
478 signature_scheme=self.session.signature_scheme,
479 signature_scheme=self.session.signature_scheme,
479 )
480 )
480 # write_connection_file also sets default ports:
481 # write_connection_file also sets default ports:
481 for name in port_names:
482 for name in port_names:
482 setattr(self, name, cfg[name])
483 setattr(self, name, cfg[name])
483
484
484 self._connection_file_written = True
485 self._connection_file_written = True
485
486
486 def load_connection_file(self):
487 def load_connection_file(self):
487 """Load connection info from JSON dict in self.connection_file."""
488 """Load connection info from JSON dict in self.connection_file."""
488 with open(self.connection_file) as f:
489 with open(self.connection_file) as f:
489 cfg = json.loads(f.read())
490 cfg = json.loads(f.read())
490
491
491 self.transport = cfg.get('transport', 'tcp')
492 self.transport = cfg.get('transport', 'tcp')
492 self.ip = cfg['ip']
493 self.ip = cfg['ip']
493 for name in port_names:
494 for name in port_names:
494 setattr(self, name, cfg[name])
495 setattr(self, name, cfg[name])
495 if 'key' in cfg:
496 if 'key' in cfg:
496 self.session.key = str_to_bytes(cfg['key'])
497 self.session.key = str_to_bytes(cfg['key'])
497 if cfg.get('signature_scheme'):
498 if cfg.get('signature_scheme'):
498 self.session.signature_scheme = cfg['signature_scheme']
499 self.session.signature_scheme = cfg['signature_scheme']
499
500
500 #--------------------------------------------------------------------------
501 #--------------------------------------------------------------------------
501 # Creating connected sockets
502 # Creating connected sockets
502 #--------------------------------------------------------------------------
503 #--------------------------------------------------------------------------
503
504
504 def _make_url(self, channel):
505 def _make_url(self, channel):
505 """Make a ZeroMQ URL for a given channel."""
506 """Make a ZeroMQ URL for a given channel."""
506 transport = self.transport
507 transport = self.transport
507 ip = self.ip
508 ip = self.ip
508 port = getattr(self, '%s_port' % channel)
509 port = getattr(self, '%s_port' % channel)
509
510
510 if transport == 'tcp':
511 if transport == 'tcp':
511 return "tcp://%s:%i" % (ip, port)
512 return "tcp://%s:%i" % (ip, port)
512 else:
513 else:
513 return "%s://%s-%s" % (transport, ip, port)
514 return "%s://%s-%s" % (transport, ip, port)
514
515
515 def _create_connected_socket(self, channel, identity=None):
516 def _create_connected_socket(self, channel, identity=None):
516 """Create a zmq Socket and connect it to the kernel."""
517 """Create a zmq Socket and connect it to the kernel."""
517 url = self._make_url(channel)
518 url = self._make_url(channel)
518 socket_type = channel_socket_types[channel]
519 socket_type = channel_socket_types[channel]
519 self.log.info("Connecting to: %s" % url)
520 self.log.info("Connecting to: %s" % url)
520 sock = self.context.socket(socket_type)
521 sock = self.context.socket(socket_type)
521 if identity:
522 if identity:
522 sock.identity = identity
523 sock.identity = identity
523 sock.connect(url)
524 sock.connect(url)
524 return sock
525 return sock
525
526
526 def connect_iopub(self, identity=None):
527 def connect_iopub(self, identity=None):
527 """return zmq Socket connected to the IOPub channel"""
528 """return zmq Socket connected to the IOPub channel"""
528 sock = self._create_connected_socket('iopub', identity=identity)
529 sock = self._create_connected_socket('iopub', identity=identity)
529 sock.setsockopt(zmq.SUBSCRIBE, b'')
530 sock.setsockopt(zmq.SUBSCRIBE, b'')
530 return sock
531 return sock
531
532
532 def connect_shell(self, identity=None):
533 def connect_shell(self, identity=None):
533 """return zmq Socket connected to the Shell channel"""
534 """return zmq Socket connected to the Shell channel"""
534 return self._create_connected_socket('shell', identity=identity)
535 return self._create_connected_socket('shell', identity=identity)
535
536
536 def connect_stdin(self, identity=None):
537 def connect_stdin(self, identity=None):
537 """return zmq Socket connected to the StdIn channel"""
538 """return zmq Socket connected to the StdIn channel"""
538 return self._create_connected_socket('stdin', identity=identity)
539 return self._create_connected_socket('stdin', identity=identity)
539
540
540 def connect_hb(self, identity=None):
541 def connect_hb(self, identity=None):
541 """return zmq Socket connected to the Heartbeat channel"""
542 """return zmq Socket connected to the Heartbeat channel"""
542 return self._create_connected_socket('hb', identity=identity)
543 return self._create_connected_socket('hb', identity=identity)
543
544
544 def connect_control(self, identity=None):
545 def connect_control(self, identity=None):
545 """return zmq Socket connected to the Heartbeat channel"""
546 """return zmq Socket connected to the Heartbeat channel"""
546 return self._create_connected_socket('control', identity=identity)
547 return self._create_connected_socket('control', identity=identity)
547
548
548
549
549 __all__ = [
550 __all__ = [
550 'write_connection_file',
551 'write_connection_file',
551 'get_connection_file',
552 'get_connection_file',
552 'find_connection_file',
553 'find_connection_file',
553 'get_connection_info',
554 'get_connection_info',
554 'connect_qtconsole',
555 'connect_qtconsole',
555 'tunnel_to_kernel',
556 'tunnel_to_kernel',
556 ]
557 ]
@@ -1,542 +1,541 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
24
23
25 import zmq
24 import zmq
26
25
27 from IPython import parallel
26 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
28 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
32
31
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
34
33
35 def setup():
34 def setup():
36 add_engines(4, total=True)
35 add_engines(4, total=True)
37
36
38 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
39
38
40 def test_ids(self):
39 def test_ids(self):
41 n = len(self.client.ids)
40 n = len(self.client.ids)
42 self.add_engines(2)
41 self.add_engines(2)
43 self.assertEqual(len(self.client.ids), n+2)
42 self.assertEqual(len(self.client.ids), n+2)
44
43
45 def test_view_indexing(self):
44 def test_view_indexing(self):
46 """test index access for views"""
45 """test index access for views"""
47 self.minimum_engines(4)
46 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
48 v = self.client[:]
50 self.assertEqual(v.targets, targets)
49 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
50 t = self.client.ids[2]
52 v = self.client[t]
51 v = self.client[t]
53 self.assertTrue(isinstance(v, DirectView))
52 self.assertTrue(isinstance(v, DirectView))
54 self.assertEqual(v.targets, t)
53 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
56 v = self.client[t]
55 v = self.client[t]
57 self.assertTrue(isinstance(v, DirectView))
56 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, t)
57 self.assertEqual(v.targets, t)
59 v = self.client[::2]
58 v = self.client[::2]
60 self.assertTrue(isinstance(v, DirectView))
59 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[::2])
60 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
61 v = self.client[1::3]
63 self.assertTrue(isinstance(v, DirectView))
62 self.assertTrue(isinstance(v, DirectView))
64 self.assertEqual(v.targets, targets[1::3])
63 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
64 v = self.client[:-3]
66 self.assertTrue(isinstance(v, DirectView))
65 self.assertTrue(isinstance(v, DirectView))
67 self.assertEqual(v.targets, targets[:-3])
66 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
67 v = self.client[-1]
69 self.assertTrue(isinstance(v, DirectView))
68 self.assertTrue(isinstance(v, DirectView))
70 self.assertEqual(v.targets, targets[-1])
69 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
72
71
73 def test_lbview_targets(self):
72 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
76 self.assertEqual(v.targets, None)
75 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
78 self.assertEqual(v.targets, [self.client.ids[-1]])
77 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
80 self.assertEqual(v.targets, None)
79 self.assertEqual(v.targets, None)
81
80
82 def test_dview_targets(self):
81 def test_dview_targets(self):
83 """test direct_view targets"""
82 """test direct_view targets"""
84 v = self.client.direct_view()
83 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
84 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
85 v = self.client.direct_view('all')
87 self.assertEqual(v.targets, 'all')
86 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
87 v = self.client.direct_view(-1)
89 self.assertEqual(v.targets, self.client.ids[-1])
88 self.assertEqual(v.targets, self.client.ids[-1])
90
89
91 def test_lazy_all_targets(self):
90 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
91 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
92 v = self.client.direct_view()
94 self.assertEqual(v.targets, 'all')
93 self.assertEqual(v.targets, 'all')
95
94
96 def double(x):
95 def double(x):
97 return x*2
96 return x*2
98 seq = range(100)
97 seq = range(100)
99 ref = [ double(x) for x in seq ]
98 ref = [ double(x) for x in seq ]
100
99
101 # add some engines, which should be used
100 # add some engines, which should be used
102 self.add_engines(1)
101 self.add_engines(1)
103 n1 = len(self.client.ids)
102 n1 = len(self.client.ids)
104
103
105 # simple apply
104 # simple apply
106 r = v.apply_sync(lambda : 1)
105 r = v.apply_sync(lambda : 1)
107 self.assertEqual(r, [1] * n1)
106 self.assertEqual(r, [1] * n1)
108
107
109 # map goes through remotefunction
108 # map goes through remotefunction
110 r = v.map_sync(double, seq)
109 r = v.map_sync(double, seq)
111 self.assertEqual(r, ref)
110 self.assertEqual(r, ref)
112
111
113 # add a couple more engines, and try again
112 # add a couple more engines, and try again
114 self.add_engines(2)
113 self.add_engines(2)
115 n2 = len(self.client.ids)
114 n2 = len(self.client.ids)
116 self.assertNotEqual(n2, n1)
115 self.assertNotEqual(n2, n1)
117
116
118 # apply
117 # apply
119 r = v.apply_sync(lambda : 1)
118 r = v.apply_sync(lambda : 1)
120 self.assertEqual(r, [1] * n2)
119 self.assertEqual(r, [1] * n2)
121
120
122 # map
121 # map
123 r = v.map_sync(double, seq)
122 r = v.map_sync(double, seq)
124 self.assertEqual(r, ref)
123 self.assertEqual(r, ref)
125
124
126 def test_targets(self):
125 def test_targets(self):
127 """test various valid targets arguments"""
126 """test various valid targets arguments"""
128 build = self.client._build_targets
127 build = self.client._build_targets
129 ids = self.client.ids
128 ids = self.client.ids
130 idents,targets = build(None)
129 idents,targets = build(None)
131 self.assertEqual(ids, targets)
130 self.assertEqual(ids, targets)
132
131
133 def test_clear(self):
132 def test_clear(self):
134 """test clear behavior"""
133 """test clear behavior"""
135 self.minimum_engines(2)
134 self.minimum_engines(2)
136 v = self.client[:]
135 v = self.client[:]
137 v.block=True
136 v.block=True
138 v.push(dict(a=5))
137 v.push(dict(a=5))
139 v.pull('a')
138 v.pull('a')
140 id0 = self.client.ids[-1]
139 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
140 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
141 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
143 self.client.clear(block=True)
145 for i in self.client.ids:
144 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
146
148 def test_get_result(self):
147 def test_get_result(self):
149 """test getting results from the Hub."""
148 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
149 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
150 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
151 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
152 # give the monitor time to notice the message
154 time.sleep(.25)
153 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids[0])
154 ahr = self.client.get_result(ar.msg_ids[0])
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids[0])
157 ar2 = self.client.get_result(ar.msg_ids[0])
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
159 c.close()
161
160
162 def test_get_execute_result(self):
161 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
162 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
163 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
164 t = c.ids[-1]
166 cell = '\n'.join([
165 cell = '\n'.join([
167 'import time',
166 'import time',
168 'time.sleep(0.25)',
167 'time.sleep(0.25)',
169 '5'
168 '5'
170 ])
169 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
170 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
171 # give the monitor time to notice the message
173 time.sleep(.25)
172 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids[0])
173 ahr = self.client.get_result(ar.msg_ids[0])
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
174 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
175 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids[0])
176 ar2 = self.client.get_result(ar.msg_ids[0])
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
177 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
178 c.close()
180
179
181 def test_ids_list(self):
180 def test_ids_list(self):
182 """test client.ids"""
181 """test client.ids"""
183 ids = self.client.ids
182 ids = self.client.ids
184 self.assertEqual(ids, self.client._ids)
183 self.assertEqual(ids, self.client._ids)
185 self.assertFalse(ids is self.client._ids)
184 self.assertFalse(ids is self.client._ids)
186 ids.remove(ids[-1])
185 ids.remove(ids[-1])
187 self.assertNotEqual(ids, self.client._ids)
186 self.assertNotEqual(ids, self.client._ids)
188
187
189 def test_queue_status(self):
188 def test_queue_status(self):
190 ids = self.client.ids
189 ids = self.client.ids
191 id0 = ids[0]
190 id0 = ids[0]
192 qs = self.client.queue_status(targets=id0)
191 qs = self.client.queue_status(targets=id0)
193 self.assertTrue(isinstance(qs, dict))
192 self.assertTrue(isinstance(qs, dict))
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
193 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 allqs = self.client.queue_status()
194 allqs = self.client.queue_status()
196 self.assertTrue(isinstance(allqs, dict))
195 self.assertTrue(isinstance(allqs, dict))
197 intkeys = list(allqs.keys())
196 intkeys = list(allqs.keys())
198 intkeys.remove('unassigned')
197 intkeys.remove('unassigned')
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
198 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 unassigned = allqs.pop('unassigned')
199 unassigned = allqs.pop('unassigned')
201 for eid,qs in allqs.items():
200 for eid,qs in allqs.items():
202 self.assertTrue(isinstance(qs, dict))
201 self.assertTrue(isinstance(qs, dict))
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
202 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204
203
205 def test_shutdown(self):
204 def test_shutdown(self):
206 ids = self.client.ids
205 ids = self.client.ids
207 id0 = ids[0]
206 id0 = ids[0]
208 self.client.shutdown(id0, block=True)
207 self.client.shutdown(id0, block=True)
209 while id0 in self.client.ids:
208 while id0 in self.client.ids:
210 time.sleep(0.1)
209 time.sleep(0.1)
211 self.client.spin()
210 self.client.spin()
212
211
213 self.assertRaises(IndexError, lambda : self.client[id0])
212 self.assertRaises(IndexError, lambda : self.client[id0])
214
213
215 def test_result_status(self):
214 def test_result_status(self):
216 pass
215 pass
217 # to be written
216 # to be written
218
217
219 def test_db_query_dt(self):
218 def test_db_query_dt(self):
220 """test db query by date"""
219 """test db query by date"""
221 hist = self.client.hub_history()
220 hist = self.client.hub_history()
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
221 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 tic = middle['submitted']
222 tic = middle['submitted']
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
223 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
224 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEqual(len(before)+len(after),len(hist))
225 self.assertEqual(len(before)+len(after),len(hist))
227 for b in before:
226 for b in before:
228 self.assertTrue(b['submitted'] < tic)
227 self.assertTrue(b['submitted'] < tic)
229 for a in after:
228 for a in after:
230 self.assertTrue(a['submitted'] >= tic)
229 self.assertTrue(a['submitted'] >= tic)
231 same = self.client.db_query({'submitted' : tic})
230 same = self.client.db_query({'submitted' : tic})
232 for s in same:
231 for s in same:
233 self.assertTrue(s['submitted'] == tic)
232 self.assertTrue(s['submitted'] == tic)
234
233
235 def test_db_query_keys(self):
234 def test_db_query_keys(self):
236 """test extracting subset of record keys"""
235 """test extracting subset of record keys"""
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
236 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 for rec in found:
237 for rec in found:
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
238 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240
239
241 def test_db_query_default_keys(self):
240 def test_db_query_default_keys(self):
242 """default db_query excludes buffers"""
241 """default db_query excludes buffers"""
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
242 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 for rec in found:
243 for rec in found:
245 keys = set(rec.keys())
244 keys = set(rec.keys())
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
245 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
246 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248
247
249 def test_db_query_msg_id(self):
248 def test_db_query_msg_id(self):
250 """ensure msg_id is always in db queries"""
249 """ensure msg_id is always in db queries"""
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
250 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 for rec in found:
251 for rec in found:
253 self.assertTrue('msg_id' in rec.keys())
252 self.assertTrue('msg_id' in rec.keys())
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
253 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 for rec in found:
254 for rec in found:
256 self.assertTrue('msg_id' in rec.keys())
255 self.assertTrue('msg_id' in rec.keys())
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
256 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 for rec in found:
257 for rec in found:
259 self.assertTrue('msg_id' in rec.keys())
258 self.assertTrue('msg_id' in rec.keys())
260
259
261 def test_db_query_get_result(self):
260 def test_db_query_get_result(self):
262 """pop in db_query shouldn't pop from result itself"""
261 """pop in db_query shouldn't pop from result itself"""
263 self.client[:].apply_sync(lambda : 1)
262 self.client[:].apply_sync(lambda : 1)
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
263 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 rc2 = clientmod.Client(profile='iptest')
264 rc2 = clientmod.Client(profile='iptest')
266 # If this bug is not fixed, this call will hang:
265 # If this bug is not fixed, this call will hang:
267 ar = rc2.get_result(self.client.history[-1])
266 ar = rc2.get_result(self.client.history[-1])
268 ar.wait(2)
267 ar.wait(2)
269 self.assertTrue(ar.ready())
268 self.assertTrue(ar.ready())
270 ar.get()
269 ar.get()
271 rc2.close()
270 rc2.close()
272
271
273 def test_db_query_in(self):
272 def test_db_query_in(self):
274 """test db query with '$in','$nin' operators"""
273 """test db query with '$in','$nin' operators"""
275 hist = self.client.hub_history()
274 hist = self.client.hub_history()
276 even = hist[::2]
275 even = hist[::2]
277 odd = hist[1::2]
276 odd = hist[1::2]
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
277 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 found = [ r['msg_id'] for r in recs ]
278 found = [ r['msg_id'] for r in recs ]
280 self.assertEqual(set(even), set(found))
279 self.assertEqual(set(even), set(found))
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
280 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 found = [ r['msg_id'] for r in recs ]
281 found = [ r['msg_id'] for r in recs ]
283 self.assertEqual(set(odd), set(found))
282 self.assertEqual(set(odd), set(found))
284
283
285 def test_hub_history(self):
284 def test_hub_history(self):
286 hist = self.client.hub_history()
285 hist = self.client.hub_history()
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
286 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recdict = {}
287 recdict = {}
289 for rec in recs:
288 for rec in recs:
290 recdict[rec['msg_id']] = rec
289 recdict[rec['msg_id']] = rec
291
290
292 latest = datetime(1984,1,1)
291 latest = datetime(1984,1,1)
293 for msg_id in hist:
292 for msg_id in hist:
294 rec = recdict[msg_id]
293 rec = recdict[msg_id]
295 newt = rec['submitted']
294 newt = rec['submitted']
296 self.assertTrue(newt >= latest)
295 self.assertTrue(newt >= latest)
297 latest = newt
296 latest = newt
298 ar = self.client[-1].apply_async(lambda : 1)
297 ar = self.client[-1].apply_async(lambda : 1)
299 ar.get()
298 ar.get()
300 time.sleep(0.25)
299 time.sleep(0.25)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
300 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302
301
303 def _wait_for_idle(self):
302 def _wait_for_idle(self):
304 """wait for an engine to become idle, according to the Hub"""
303 """wait for an engine to become idle, according to the Hub"""
305 rc = self.client
304 rc = self.client
306
305
307 # step 1. wait for all requests to be noticed
306 # step 1. wait for all requests to be noticed
308 # timeout 5s, polling every 100ms
307 # timeout 5s, polling every 100ms
309 msg_ids = set(rc.history)
308 msg_ids = set(rc.history)
310 hub_hist = rc.hub_history()
309 hub_hist = rc.hub_history()
311 for i in range(50):
310 for i in range(50):
312 if msg_ids.difference(hub_hist):
311 if msg_ids.difference(hub_hist):
313 time.sleep(0.1)
312 time.sleep(0.1)
314 hub_hist = rc.hub_history()
313 hub_hist = rc.hub_history()
315 else:
314 else:
316 break
315 break
317
316
318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
317 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319
318
320 # step 2. wait for all requests to be done
319 # step 2. wait for all requests to be done
321 # timeout 5s, polling every 100ms
320 # timeout 5s, polling every 100ms
322 qs = rc.queue_status()
321 qs = rc.queue_status()
323 for i in range(50):
322 for i in range(50):
324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
323 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 time.sleep(0.1)
324 time.sleep(0.1)
326 qs = rc.queue_status()
325 qs = rc.queue_status()
327 else:
326 else:
328 break
327 break
329
328
330 # ensure Hub up to date:
329 # ensure Hub up to date:
331 self.assertEqual(qs['unassigned'], 0)
330 self.assertEqual(qs['unassigned'], 0)
332 for eid in rc.ids:
331 for eid in rc.ids:
333 self.assertEqual(qs[eid]['tasks'], 0)
332 self.assertEqual(qs[eid]['tasks'], 0)
334
333
335
334
336 def test_resubmit(self):
335 def test_resubmit(self):
337 def f():
336 def f():
338 import random
337 import random
339 return random.random()
338 return random.random()
340 v = self.client.load_balanced_view()
339 v = self.client.load_balanced_view()
341 ar = v.apply_async(f)
340 ar = v.apply_async(f)
342 r1 = ar.get(1)
341 r1 = ar.get(1)
343 # give the Hub a chance to notice:
342 # give the Hub a chance to notice:
344 self._wait_for_idle()
343 self._wait_for_idle()
345 ahr = self.client.resubmit(ar.msg_ids)
344 ahr = self.client.resubmit(ar.msg_ids)
346 r2 = ahr.get(1)
345 r2 = ahr.get(1)
347 self.assertFalse(r1 == r2)
346 self.assertFalse(r1 == r2)
348
347
349 def test_resubmit_chain(self):
348 def test_resubmit_chain(self):
350 """resubmit resubmitted tasks"""
349 """resubmit resubmitted tasks"""
351 v = self.client.load_balanced_view()
350 v = self.client.load_balanced_view()
352 ar = v.apply_async(lambda x: x, 'x'*1024)
351 ar = v.apply_async(lambda x: x, 'x'*1024)
353 ar.get()
352 ar.get()
354 self._wait_for_idle()
353 self._wait_for_idle()
355 ars = [ar]
354 ars = [ar]
356
355
357 for i in range(10):
356 for i in range(10):
358 ar = ars[-1]
357 ar = ars[-1]
359 ar2 = self.client.resubmit(ar.msg_ids)
358 ar2 = self.client.resubmit(ar.msg_ids)
360
359
361 [ ar.get() for ar in ars ]
360 [ ar.get() for ar in ars ]
362
361
363 def test_resubmit_header(self):
362 def test_resubmit_header(self):
364 """resubmit shouldn't clobber the whole header"""
363 """resubmit shouldn't clobber the whole header"""
365 def f():
364 def f():
366 import random
365 import random
367 return random.random()
366 return random.random()
368 v = self.client.load_balanced_view()
367 v = self.client.load_balanced_view()
369 v.retries = 1
368 v.retries = 1
370 ar = v.apply_async(f)
369 ar = v.apply_async(f)
371 r1 = ar.get(1)
370 r1 = ar.get(1)
372 # give the Hub a chance to notice:
371 # give the Hub a chance to notice:
373 self._wait_for_idle()
372 self._wait_for_idle()
374 ahr = self.client.resubmit(ar.msg_ids)
373 ahr = self.client.resubmit(ar.msg_ids)
375 ahr.get(1)
374 ahr.get(1)
376 time.sleep(0.5)
375 time.sleep(0.5)
377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
376 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 h1,h2 = [ r['header'] for r in records ]
377 h1,h2 = [ r['header'] for r in records ]
379 for key in set(h1.keys()).union(set(h2.keys())):
378 for key in set(h1.keys()).union(set(h2.keys())):
380 if key in ('msg_id', 'date'):
379 if key in ('msg_id', 'date'):
381 self.assertNotEqual(h1[key], h2[key])
380 self.assertNotEqual(h1[key], h2[key])
382 else:
381 else:
383 self.assertEqual(h1[key], h2[key])
382 self.assertEqual(h1[key], h2[key])
384
383
385 def test_resubmit_aborted(self):
384 def test_resubmit_aborted(self):
386 def f():
385 def f():
387 import random
386 import random
388 return random.random()
387 return random.random()
389 v = self.client.load_balanced_view()
388 v = self.client.load_balanced_view()
390 # restrict to one engine, so we can put a sleep
389 # restrict to one engine, so we can put a sleep
391 # ahead of the task, so it will get aborted
390 # ahead of the task, so it will get aborted
392 eid = self.client.ids[-1]
391 eid = self.client.ids[-1]
393 v.targets = [eid]
392 v.targets = [eid]
394 sleep = v.apply_async(time.sleep, 0.5)
393 sleep = v.apply_async(time.sleep, 0.5)
395 ar = v.apply_async(f)
394 ar = v.apply_async(f)
396 ar.abort()
395 ar.abort()
397 self.assertRaises(error.TaskAborted, ar.get)
396 self.assertRaises(error.TaskAborted, ar.get)
398 # Give the Hub a chance to get up to date:
397 # Give the Hub a chance to get up to date:
399 self._wait_for_idle()
398 self._wait_for_idle()
400 ahr = self.client.resubmit(ar.msg_ids)
399 ahr = self.client.resubmit(ar.msg_ids)
401 r2 = ahr.get(1)
400 r2 = ahr.get(1)
402
401
403 def test_resubmit_inflight(self):
402 def test_resubmit_inflight(self):
404 """resubmit of inflight task"""
403 """resubmit of inflight task"""
405 v = self.client.load_balanced_view()
404 v = self.client.load_balanced_view()
406 ar = v.apply_async(time.sleep,1)
405 ar = v.apply_async(time.sleep,1)
407 # give the message a chance to arrive
406 # give the message a chance to arrive
408 time.sleep(0.2)
407 time.sleep(0.2)
409 ahr = self.client.resubmit(ar.msg_ids)
408 ahr = self.client.resubmit(ar.msg_ids)
410 ar.get(2)
409 ar.get(2)
411 ahr.get(2)
410 ahr.get(2)
412
411
413 def test_resubmit_badkey(self):
412 def test_resubmit_badkey(self):
414 """ensure KeyError on resubmit of nonexistant task"""
413 """ensure KeyError on resubmit of nonexistant task"""
415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
414 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416
415
417 def test_purge_hub_results(self):
416 def test_purge_hub_results(self):
418 # ensure there are some tasks
417 # ensure there are some tasks
419 for i in range(5):
418 for i in range(5):
420 self.client[:].apply_sync(lambda : 1)
419 self.client[:].apply_sync(lambda : 1)
421 # Wait for the Hub to realise the result is done:
420 # Wait for the Hub to realise the result is done:
422 # This prevents a race condition, where we
421 # This prevents a race condition, where we
423 # might purge a result the Hub still thinks is pending.
422 # might purge a result the Hub still thinks is pending.
424 self._wait_for_idle()
423 self._wait_for_idle()
425 rc2 = clientmod.Client(profile='iptest')
424 rc2 = clientmod.Client(profile='iptest')
426 hist = self.client.hub_history()
425 hist = self.client.hub_history()
427 ahr = rc2.get_result([hist[-1]])
426 ahr = rc2.get_result([hist[-1]])
428 ahr.wait(10)
427 ahr.wait(10)
429 self.client.purge_hub_results(hist[-1])
428 self.client.purge_hub_results(hist[-1])
430 newhist = self.client.hub_history()
429 newhist = self.client.hub_history()
431 self.assertEqual(len(newhist)+1,len(hist))
430 self.assertEqual(len(newhist)+1,len(hist))
432 rc2.spin()
431 rc2.spin()
433 rc2.close()
432 rc2.close()
434
433
435 def test_purge_local_results(self):
434 def test_purge_local_results(self):
436 # ensure there are some tasks
435 # ensure there are some tasks
437 res = []
436 res = []
438 for i in range(5):
437 for i in range(5):
439 res.append(self.client[:].apply_async(lambda : 1))
438 res.append(self.client[:].apply_async(lambda : 1))
440 self._wait_for_idle()
439 self._wait_for_idle()
441 self.client.wait(10) # wait for the results to come back
440 self.client.wait(10) # wait for the results to come back
442 before = len(self.client.results)
441 before = len(self.client.results)
443 self.assertEqual(len(self.client.metadata),before)
442 self.assertEqual(len(self.client.metadata),before)
444 self.client.purge_local_results(res[-1])
443 self.client.purge_local_results(res[-1])
445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
444 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
445 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447
446
448 def test_purge_local_results_outstanding(self):
447 def test_purge_local_results_outstanding(self):
449 v = self.client[-1]
448 v = self.client[-1]
450 ar = v.apply_async(lambda : 1)
449 ar = v.apply_async(lambda : 1)
451 msg_id = ar.msg_ids[0]
450 msg_id = ar.msg_ids[0]
452 ar.get()
451 ar.get()
453 self._wait_for_idle()
452 self._wait_for_idle()
454 ar2 = v.apply_async(time.sleep, 1)
453 ar2 = v.apply_async(time.sleep, 1)
455 self.assertIn(msg_id, self.client.results)
454 self.assertIn(msg_id, self.client.results)
456 self.assertIn(msg_id, self.client.metadata)
455 self.assertIn(msg_id, self.client.metadata)
457 self.client.purge_local_results(ar)
456 self.client.purge_local_results(ar)
458 self.assertNotIn(msg_id, self.client.results)
457 self.assertNotIn(msg_id, self.client.results)
459 self.assertNotIn(msg_id, self.client.metadata)
458 self.assertNotIn(msg_id, self.client.metadata)
460 with self.assertRaises(RuntimeError):
459 with self.assertRaises(RuntimeError):
461 self.client.purge_local_results(ar2)
460 self.client.purge_local_results(ar2)
462 ar2.get()
461 ar2.get()
463 self.client.purge_local_results(ar2)
462 self.client.purge_local_results(ar2)
464
463
465 def test_purge_all_local_results_outstanding(self):
464 def test_purge_all_local_results_outstanding(self):
466 v = self.client[-1]
465 v = self.client[-1]
467 ar = v.apply_async(time.sleep, 1)
466 ar = v.apply_async(time.sleep, 1)
468 with self.assertRaises(RuntimeError):
467 with self.assertRaises(RuntimeError):
469 self.client.purge_local_results('all')
468 self.client.purge_local_results('all')
470 ar.get()
469 ar.get()
471 self.client.purge_local_results('all')
470 self.client.purge_local_results('all')
472
471
473 def test_purge_all_hub_results(self):
472 def test_purge_all_hub_results(self):
474 self.client.purge_hub_results('all')
473 self.client.purge_hub_results('all')
475 hist = self.client.hub_history()
474 hist = self.client.hub_history()
476 self.assertEqual(len(hist), 0)
475 self.assertEqual(len(hist), 0)
477
476
478 def test_purge_all_local_results(self):
477 def test_purge_all_local_results(self):
479 self.client.purge_local_results('all')
478 self.client.purge_local_results('all')
480 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
479 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
481 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
480 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
482
481
483 def test_purge_all_results(self):
482 def test_purge_all_results(self):
484 # ensure there are some tasks
483 # ensure there are some tasks
485 for i in range(5):
484 for i in range(5):
486 self.client[:].apply_sync(lambda : 1)
485 self.client[:].apply_sync(lambda : 1)
487 self.client.wait(10)
486 self.client.wait(10)
488 self._wait_for_idle()
487 self._wait_for_idle()
489 self.client.purge_results('all')
488 self.client.purge_results('all')
490 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
489 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
491 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
490 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
492 hist = self.client.hub_history()
491 hist = self.client.hub_history()
493 self.assertEqual(len(hist), 0, msg="hub history not empty")
492 self.assertEqual(len(hist), 0, msg="hub history not empty")
494
493
495 def test_purge_everything(self):
494 def test_purge_everything(self):
496 # ensure there are some tasks
495 # ensure there are some tasks
497 for i in range(5):
496 for i in range(5):
498 self.client[:].apply_sync(lambda : 1)
497 self.client[:].apply_sync(lambda : 1)
499 self.client.wait(10)
498 self.client.wait(10)
500 self._wait_for_idle()
499 self._wait_for_idle()
501 self.client.purge_everything()
500 self.client.purge_everything()
502 # The client results
501 # The client results
503 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
502 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
504 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
503 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
505 # The client "bookkeeping"
504 # The client "bookkeeping"
506 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
505 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
507 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
506 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
508 # the hub results
507 # the hub results
509 hist = self.client.hub_history()
508 hist = self.client.hub_history()
510 self.assertEqual(len(hist), 0, msg="hub history not empty")
509 self.assertEqual(len(hist), 0, msg="hub history not empty")
511
510
512
511
513 def test_spin_thread(self):
512 def test_spin_thread(self):
514 self.client.spin_thread(0.01)
513 self.client.spin_thread(0.01)
515 ar = self.client[-1].apply_async(lambda : 1)
514 ar = self.client[-1].apply_async(lambda : 1)
516 time.sleep(0.1)
515 time.sleep(0.1)
517 self.assertTrue(ar.wall_time < 0.1,
516 self.assertTrue(ar.wall_time < 0.1,
518 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
517 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
519 )
518 )
520
519
521 def test_stop_spin_thread(self):
520 def test_stop_spin_thread(self):
522 self.client.spin_thread(0.01)
521 self.client.spin_thread(0.01)
523 self.client.stop_spin_thread()
522 self.client.stop_spin_thread()
524 ar = self.client[-1].apply_async(lambda : 1)
523 ar = self.client[-1].apply_async(lambda : 1)
525 time.sleep(0.15)
524 time.sleep(0.15)
526 self.assertTrue(ar.wall_time > 0.1,
525 self.assertTrue(ar.wall_time > 0.1,
527 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
526 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
528 )
527 )
529
528
530 def test_activate(self):
529 def test_activate(self):
531 ip = get_ipython()
530 ip = get_ipython()
532 magics = ip.magics_manager.magics
531 magics = ip.magics_manager.magics
533 self.assertTrue('px' in magics['line'])
532 self.assertTrue('px' in magics['line'])
534 self.assertTrue('px' in magics['cell'])
533 self.assertTrue('px' in magics['cell'])
535 v0 = self.client.activate(-1, '0')
534 v0 = self.client.activate(-1, '0')
536 self.assertTrue('px0' in magics['line'])
535 self.assertTrue('px0' in magics['line'])
537 self.assertTrue('px0' in magics['cell'])
536 self.assertTrue('px0' in magics['cell'])
538 self.assertEqual(v0.targets, self.client.ids[-1])
537 self.assertEqual(v0.targets, self.client.ids[-1])
539 v0 = self.client.activate('all', 'all')
538 v0 = self.client.activate('all', 'all')
540 self.assertTrue('pxall' in magics['line'])
539 self.assertTrue('pxall' in magics['line'])
541 self.assertTrue('pxall' in magics['cell'])
540 self.assertTrue('pxall' in magics['cell'])
542 self.assertEqual(v0.targets, 'all')
541 self.assertEqual(v0.targets, 'all')
@@ -1,801 +1,800 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from collections import namedtuple
22 from collections import namedtuple
23 from tempfile import mktemp
23 from tempfile import NamedTemporaryFile
24 from StringIO import StringIO
24 from StringIO import StringIO
25
25
26 import zmq
26 import zmq
27 from nose import SkipTest
27 from nose import SkipTest
28 from nose.plugins.attrib import attr
28 from nose.plugins.attrib import attr
29
29
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.ipunittest import ParametricTestCase
31 from IPython.testing.ipunittest import ParametricTestCase
32 from IPython.utils.io import capture_output
32 from IPython.utils.io import capture_output
33
33
34 from IPython import parallel as pmod
34 from IPython import parallel as pmod
35 from IPython.parallel import error
35 from IPython.parallel import error
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 from IPython.parallel import DirectView
37 from IPython.parallel import DirectView
38 from IPython.parallel.util import interactive
38 from IPython.parallel.util import interactive
39
39
40 from IPython.parallel.tests import add_engines
40 from IPython.parallel.tests import add_engines
41
41
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43
43
44 def setup():
44 def setup():
45 add_engines(3, total=True)
45 add_engines(3, total=True)
46
46
47 point = namedtuple("point", "x y")
47 point = namedtuple("point", "x y")
48
48
49 class TestView(ClusterTestCase, ParametricTestCase):
49 class TestView(ClusterTestCase, ParametricTestCase):
50
50
51 def setUp(self):
51 def setUp(self):
52 # On Win XP, wait for resource cleanup, else parallel test group fails
52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 time.sleep(2)
55 time.sleep(2)
56 super(TestView, self).setUp()
56 super(TestView, self).setUp()
57
57
58 @attr('crash')
58 @attr('crash')
59 def test_z_crash_mux(self):
59 def test_z_crash_mux(self):
60 """test graceful handling of engine death (direct)"""
60 """test graceful handling of engine death (direct)"""
61 # self.add_engines(1)
61 # self.add_engines(1)
62 eid = self.client.ids[-1]
62 eid = self.client.ids[-1]
63 ar = self.client[eid].apply_async(crash)
63 ar = self.client[eid].apply_async(crash)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 eid = ar.engine_id
65 eid = ar.engine_id
66 tic = time.time()
66 tic = time.time()
67 while eid in self.client.ids and time.time()-tic < 5:
67 while eid in self.client.ids and time.time()-tic < 5:
68 time.sleep(.01)
68 time.sleep(.01)
69 self.client.spin()
69 self.client.spin()
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71
71
72 def test_push_pull(self):
72 def test_push_pull(self):
73 """test pushing and pulling"""
73 """test pushing and pulling"""
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 t = self.client.ids[-1]
75 t = self.client.ids[-1]
76 v = self.client[t]
76 v = self.client[t]
77 push = v.push
77 push = v.push
78 pull = v.pull
78 pull = v.pull
79 v.block=True
79 v.block=True
80 nengines = len(self.client)
80 nengines = len(self.client)
81 push({'data':data})
81 push({'data':data})
82 d = pull('data')
82 d = pull('data')
83 self.assertEqual(d, data)
83 self.assertEqual(d, data)
84 self.client[:].push({'data':data})
84 self.client[:].push({'data':data})
85 d = self.client[:].pull('data', block=True)
85 d = self.client[:].pull('data', block=True)
86 self.assertEqual(d, nengines*[data])
86 self.assertEqual(d, nengines*[data])
87 ar = push({'data':data}, block=False)
87 ar = push({'data':data}, block=False)
88 self.assertTrue(isinstance(ar, AsyncResult))
88 self.assertTrue(isinstance(ar, AsyncResult))
89 r = ar.get()
89 r = ar.get()
90 ar = self.client[:].pull('data', block=False)
90 ar = self.client[:].pull('data', block=False)
91 self.assertTrue(isinstance(ar, AsyncResult))
91 self.assertTrue(isinstance(ar, AsyncResult))
92 r = ar.get()
92 r = ar.get()
93 self.assertEqual(r, nengines*[data])
93 self.assertEqual(r, nengines*[data])
94 self.client[:].push(dict(a=10,b=20))
94 self.client[:].push(dict(a=10,b=20))
95 r = self.client[:].pull(('a','b'), block=True)
95 r = self.client[:].pull(('a','b'), block=True)
96 self.assertEqual(r, nengines*[[10,20]])
96 self.assertEqual(r, nengines*[[10,20]])
97
97
98 def test_push_pull_function(self):
98 def test_push_pull_function(self):
99 "test pushing and pulling functions"
99 "test pushing and pulling functions"
100 def testf(x):
100 def testf(x):
101 return 2.0*x
101 return 2.0*x
102
102
103 t = self.client.ids[-1]
103 t = self.client.ids[-1]
104 v = self.client[t]
104 v = self.client[t]
105 v.block=True
105 v.block=True
106 push = v.push
106 push = v.push
107 pull = v.pull
107 pull = v.pull
108 execute = v.execute
108 execute = v.execute
109 push({'testf':testf})
109 push({'testf':testf})
110 r = pull('testf')
110 r = pull('testf')
111 self.assertEqual(r(1.0), testf(1.0))
111 self.assertEqual(r(1.0), testf(1.0))
112 execute('r = testf(10)')
112 execute('r = testf(10)')
113 r = pull('r')
113 r = pull('r')
114 self.assertEqual(r, testf(10))
114 self.assertEqual(r, testf(10))
115 ar = self.client[:].push({'testf':testf}, block=False)
115 ar = self.client[:].push({'testf':testf}, block=False)
116 ar.get()
116 ar.get()
117 ar = self.client[:].pull('testf', block=False)
117 ar = self.client[:].pull('testf', block=False)
118 rlist = ar.get()
118 rlist = ar.get()
119 for r in rlist:
119 for r in rlist:
120 self.assertEqual(r(1.0), testf(1.0))
120 self.assertEqual(r(1.0), testf(1.0))
121 execute("def g(x): return x*x")
121 execute("def g(x): return x*x")
122 r = pull(('testf','g'))
122 r = pull(('testf','g'))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124
124
125 def test_push_function_globals(self):
125 def test_push_function_globals(self):
126 """test that pushed functions have access to globals"""
126 """test that pushed functions have access to globals"""
127 @interactive
127 @interactive
128 def geta():
128 def geta():
129 return a
129 return a
130 # self.add_engines(1)
130 # self.add_engines(1)
131 v = self.client[-1]
131 v = self.client[-1]
132 v.block=True
132 v.block=True
133 v['f'] = geta
133 v['f'] = geta
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 v.execute('a=5')
135 v.execute('a=5')
136 v.execute('b=f()')
136 v.execute('b=f()')
137 self.assertEqual(v['b'], 5)
137 self.assertEqual(v['b'], 5)
138
138
139 def test_push_function_defaults(self):
139 def test_push_function_defaults(self):
140 """test that pushed functions preserve default args"""
140 """test that pushed functions preserve default args"""
141 def echo(a=10):
141 def echo(a=10):
142 return a
142 return a
143 v = self.client[-1]
143 v = self.client[-1]
144 v.block=True
144 v.block=True
145 v['f'] = echo
145 v['f'] = echo
146 v.execute('b=f()')
146 v.execute('b=f()')
147 self.assertEqual(v['b'], 10)
147 self.assertEqual(v['b'], 10)
148
148
149 def test_get_result(self):
149 def test_get_result(self):
150 """test getting results from the Hub."""
150 """test getting results from the Hub."""
151 c = pmod.Client(profile='iptest')
151 c = pmod.Client(profile='iptest')
152 # self.add_engines(1)
152 # self.add_engines(1)
153 t = c.ids[-1]
153 t = c.ids[-1]
154 v = c[t]
154 v = c[t]
155 v2 = self.client[t]
155 v2 = self.client[t]
156 ar = v.apply_async(wait, 1)
156 ar = v.apply_async(wait, 1)
157 # give the monitor time to notice the message
157 # give the monitor time to notice the message
158 time.sleep(.25)
158 time.sleep(.25)
159 ahr = v2.get_result(ar.msg_ids[0])
159 ahr = v2.get_result(ar.msg_ids[0])
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 self.assertEqual(ahr.get(), ar.get())
161 self.assertEqual(ahr.get(), ar.get())
162 ar2 = v2.get_result(ar.msg_ids[0])
162 ar2 = v2.get_result(ar.msg_ids[0])
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 c.spin()
164 c.spin()
165 c.close()
165 c.close()
166
166
167 def test_run_newline(self):
167 def test_run_newline(self):
168 """test that run appends newline to files"""
168 """test that run appends newline to files"""
169 tmpfile = mktemp()
169 with NamedTemporaryFile('w', delete=False) as f:
170 with open(tmpfile, 'w') as f:
171 f.write("""def g():
170 f.write("""def g():
172 return 5
171 return 5
173 """)
172 """)
174 v = self.client[-1]
173 v = self.client[-1]
175 v.run(tmpfile, block=True)
174 v.run(f.name, block=True)
176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177
176
178 def test_apply_tracked(self):
177 def test_apply_tracked(self):
179 """test tracking for apply"""
178 """test tracking for apply"""
180 # self.add_engines(1)
179 # self.add_engines(1)
181 t = self.client.ids[-1]
180 t = self.client.ids[-1]
182 v = self.client[t]
181 v = self.client[t]
183 v.block=False
182 v.block=False
184 def echo(n=1024*1024, **kwargs):
183 def echo(n=1024*1024, **kwargs):
185 with v.temp_flags(**kwargs):
184 with v.temp_flags(**kwargs):
186 return v.apply(lambda x: x, 'x'*n)
185 return v.apply(lambda x: x, 'x'*n)
187 ar = echo(1, track=False)
186 ar = echo(1, track=False)
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
190 ar = echo(track=True)
189 ar = echo(track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertEqual(ar.sent, ar._tracker.done)
191 self.assertEqual(ar.sent, ar._tracker.done)
193 ar._tracker.wait()
192 ar._tracker.wait()
194 self.assertTrue(ar.sent)
193 self.assertTrue(ar.sent)
195
194
196 def test_push_tracked(self):
195 def test_push_tracked(self):
197 t = self.client.ids[-1]
196 t = self.client.ids[-1]
198 ns = dict(x='x'*1024*1024)
197 ns = dict(x='x'*1024*1024)
199 v = self.client[t]
198 v = self.client[t]
200 ar = v.push(ns, block=False, track=False)
199 ar = v.push(ns, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
201 self.assertTrue(ar.sent)
203
202
204 ar = v.push(ns, block=False, track=True)
203 ar = v.push(ns, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 ar._tracker.wait()
205 ar._tracker.wait()
207 self.assertEqual(ar.sent, ar._tracker.done)
206 self.assertEqual(ar.sent, ar._tracker.done)
208 self.assertTrue(ar.sent)
207 self.assertTrue(ar.sent)
209 ar.get()
208 ar.get()
210
209
211 def test_scatter_tracked(self):
210 def test_scatter_tracked(self):
212 t = self.client.ids
211 t = self.client.ids
213 x='x'*1024*1024
212 x='x'*1024*1024
214 ar = self.client[t].scatter('x', x, block=False, track=False)
213 ar = self.client[t].scatter('x', x, block=False, track=False)
215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(ar.sent)
215 self.assertTrue(ar.sent)
217
216
218 ar = self.client[t].scatter('x', x, block=False, track=True)
217 ar = self.client[t].scatter('x', x, block=False, track=True)
219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 self.assertEqual(ar.sent, ar._tracker.done)
219 self.assertEqual(ar.sent, ar._tracker.done)
221 ar._tracker.wait()
220 ar._tracker.wait()
222 self.assertTrue(ar.sent)
221 self.assertTrue(ar.sent)
223 ar.get()
222 ar.get()
224
223
225 def test_remote_reference(self):
224 def test_remote_reference(self):
226 v = self.client[-1]
225 v = self.client[-1]
227 v['a'] = 123
226 v['a'] = 123
228 ra = pmod.Reference('a')
227 ra = pmod.Reference('a')
229 b = v.apply_sync(lambda x: x, ra)
228 b = v.apply_sync(lambda x: x, ra)
230 self.assertEqual(b, 123)
229 self.assertEqual(b, 123)
231
230
232
231
233 def test_scatter_gather(self):
232 def test_scatter_gather(self):
234 view = self.client[:]
233 view = self.client[:]
235 seq1 = range(16)
234 seq1 = range(16)
236 view.scatter('a', seq1)
235 view.scatter('a', seq1)
237 seq2 = view.gather('a', block=True)
236 seq2 = view.gather('a', block=True)
238 self.assertEqual(seq2, seq1)
237 self.assertEqual(seq2, seq1)
239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240
239
241 @skip_without('numpy')
240 @skip_without('numpy')
242 def test_scatter_gather_numpy(self):
241 def test_scatter_gather_numpy(self):
243 import numpy
242 import numpy
244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
243 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 view = self.client[:]
244 view = self.client[:]
246 a = numpy.arange(64)
245 a = numpy.arange(64)
247 view.scatter('a', a, block=True)
246 view.scatter('a', a, block=True)
248 b = view.gather('a', block=True)
247 b = view.gather('a', block=True)
249 assert_array_equal(b, a)
248 assert_array_equal(b, a)
250
249
251 def test_scatter_gather_lazy(self):
250 def test_scatter_gather_lazy(self):
252 """scatter/gather with targets='all'"""
251 """scatter/gather with targets='all'"""
253 view = self.client.direct_view(targets='all')
252 view = self.client.direct_view(targets='all')
254 x = range(64)
253 x = range(64)
255 view.scatter('x', x)
254 view.scatter('x', x)
256 gathered = view.gather('x', block=True)
255 gathered = view.gather('x', block=True)
257 self.assertEqual(gathered, x)
256 self.assertEqual(gathered, x)
258
257
259
258
260 @dec.known_failure_py3
259 @dec.known_failure_py3
261 @skip_without('numpy')
260 @skip_without('numpy')
262 def test_push_numpy_nocopy(self):
261 def test_push_numpy_nocopy(self):
263 import numpy
262 import numpy
264 view = self.client[:]
263 view = self.client[:]
265 a = numpy.arange(64)
264 a = numpy.arange(64)
266 view['A'] = a
265 view['A'] = a
267 @interactive
266 @interactive
268 def check_writeable(x):
267 def check_writeable(x):
269 return x.flags.writeable
268 return x.flags.writeable
270
269
271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273
272
274 view.push(dict(B=a))
273 view.push(dict(B=a))
275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277
276
278 @skip_without('numpy')
277 @skip_without('numpy')
279 def test_apply_numpy(self):
278 def test_apply_numpy(self):
280 """view.apply(f, ndarray)"""
279 """view.apply(f, ndarray)"""
281 import numpy
280 import numpy
282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
281 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283
282
284 A = numpy.random.random((100,100))
283 A = numpy.random.random((100,100))
285 view = self.client[-1]
284 view = self.client[-1]
286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 B = A.astype(dt)
286 B = A.astype(dt)
288 C = view.apply_sync(lambda x:x, B)
287 C = view.apply_sync(lambda x:x, B)
289 assert_array_equal(B,C)
288 assert_array_equal(B,C)
290
289
291 @skip_without('numpy')
290 @skip_without('numpy')
292 def test_push_pull_recarray(self):
291 def test_push_pull_recarray(self):
293 """push/pull recarrays"""
292 """push/pull recarrays"""
294 import numpy
293 import numpy
295 from numpy.testing.utils import assert_array_equal
294 from numpy.testing.utils import assert_array_equal
296
295
297 view = self.client[-1]
296 view = self.client[-1]
298
297
299 R = numpy.array([
298 R = numpy.array([
300 (1, 'hi', 0.),
299 (1, 'hi', 0.),
301 (2**30, 'there', 2.5),
300 (2**30, 'there', 2.5),
302 (-99999, 'world', -12345.6789),
301 (-99999, 'world', -12345.6789),
303 ], [('n', int), ('s', '|S10'), ('f', float)])
302 ], [('n', int), ('s', '|S10'), ('f', float)])
304
303
305 view['RR'] = R
304 view['RR'] = R
306 R2 = view['RR']
305 R2 = view['RR']
307
306
308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
307 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 self.assertEqual(r_dtype, R.dtype)
308 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_shape, R.shape)
309 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(R2.dtype, R.dtype)
310 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.shape, R.shape)
311 self.assertEqual(R2.shape, R.shape)
313 assert_array_equal(R2, R)
312 assert_array_equal(R2, R)
314
313
315 @skip_without('pandas')
314 @skip_without('pandas')
316 def test_push_pull_timeseries(self):
315 def test_push_pull_timeseries(self):
317 """push/pull pandas.TimeSeries"""
316 """push/pull pandas.TimeSeries"""
318 import pandas
317 import pandas
319
318
320 ts = pandas.TimeSeries(range(10))
319 ts = pandas.TimeSeries(range(10))
321
320
322 view = self.client[-1]
321 view = self.client[-1]
323
322
324 view.push(dict(ts=ts), block=True)
323 view.push(dict(ts=ts), block=True)
325 rts = view['ts']
324 rts = view['ts']
326
325
327 self.assertEqual(type(rts), type(ts))
326 self.assertEqual(type(rts), type(ts))
328 self.assertTrue((ts == rts).all())
327 self.assertTrue((ts == rts).all())
329
328
330 def test_map(self):
329 def test_map(self):
331 view = self.client[:]
330 view = self.client[:]
332 def f(x):
331 def f(x):
333 return x**2
332 return x**2
334 data = range(16)
333 data = range(16)
335 r = view.map_sync(f, data)
334 r = view.map_sync(f, data)
336 self.assertEqual(r, map(f, data))
335 self.assertEqual(r, map(f, data))
337
336
338 def test_map_iterable(self):
337 def test_map_iterable(self):
339 """test map on iterables (direct)"""
338 """test map on iterables (direct)"""
340 view = self.client[:]
339 view = self.client[:]
341 # 101 is prime, so it won't be evenly distributed
340 # 101 is prime, so it won't be evenly distributed
342 arr = range(101)
341 arr = range(101)
343 # ensure it will be an iterator, even in Python 3
342 # ensure it will be an iterator, even in Python 3
344 it = iter(arr)
343 it = iter(arr)
345 r = view.map_sync(lambda x: x, it)
344 r = view.map_sync(lambda x: x, it)
346 self.assertEqual(r, list(arr))
345 self.assertEqual(r, list(arr))
347
346
348 @skip_without('numpy')
347 @skip_without('numpy')
349 def test_map_numpy(self):
348 def test_map_numpy(self):
350 """test map on numpy arrays (direct)"""
349 """test map on numpy arrays (direct)"""
351 import numpy
350 import numpy
352 from numpy.testing.utils import assert_array_equal
351 from numpy.testing.utils import assert_array_equal
353
352
354 view = self.client[:]
353 view = self.client[:]
355 # 101 is prime, so it won't be evenly distributed
354 # 101 is prime, so it won't be evenly distributed
356 arr = numpy.arange(101)
355 arr = numpy.arange(101)
357 r = view.map_sync(lambda x: x, arr)
356 r = view.map_sync(lambda x: x, arr)
358 assert_array_equal(r, arr)
357 assert_array_equal(r, arr)
359
358
360 def test_scatter_gather_nonblocking(self):
359 def test_scatter_gather_nonblocking(self):
361 data = range(16)
360 data = range(16)
362 view = self.client[:]
361 view = self.client[:]
363 view.scatter('a', data, block=False)
362 view.scatter('a', data, block=False)
364 ar = view.gather('a', block=False)
363 ar = view.gather('a', block=False)
365 self.assertEqual(ar.get(), data)
364 self.assertEqual(ar.get(), data)
366
365
367 @skip_without('numpy')
366 @skip_without('numpy')
368 def test_scatter_gather_numpy_nonblocking(self):
367 def test_scatter_gather_numpy_nonblocking(self):
369 import numpy
368 import numpy
370 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
369 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 a = numpy.arange(64)
370 a = numpy.arange(64)
372 view = self.client[:]
371 view = self.client[:]
373 ar = view.scatter('a', a, block=False)
372 ar = view.scatter('a', a, block=False)
374 self.assertTrue(isinstance(ar, AsyncResult))
373 self.assertTrue(isinstance(ar, AsyncResult))
375 amr = view.gather('a', block=False)
374 amr = view.gather('a', block=False)
376 self.assertTrue(isinstance(amr, AsyncMapResult))
375 self.assertTrue(isinstance(amr, AsyncMapResult))
377 assert_array_equal(amr.get(), a)
376 assert_array_equal(amr.get(), a)
378
377
379 def test_execute(self):
378 def test_execute(self):
380 view = self.client[:]
379 view = self.client[:]
381 # self.client.debug=True
380 # self.client.debug=True
382 execute = view.execute
381 execute = view.execute
383 ar = execute('c=30', block=False)
382 ar = execute('c=30', block=False)
384 self.assertTrue(isinstance(ar, AsyncResult))
383 self.assertTrue(isinstance(ar, AsyncResult))
385 ar = execute('d=[0,1,2]', block=False)
384 ar = execute('d=[0,1,2]', block=False)
386 self.client.wait(ar, 1)
385 self.client.wait(ar, 1)
387 self.assertEqual(len(ar.get()), len(self.client))
386 self.assertEqual(len(ar.get()), len(self.client))
388 for c in view['c']:
387 for c in view['c']:
389 self.assertEqual(c, 30)
388 self.assertEqual(c, 30)
390
389
391 def test_abort(self):
390 def test_abort(self):
392 view = self.client[-1]
391 view = self.client[-1]
393 ar = view.execute('import time; time.sleep(1)', block=False)
392 ar = view.execute('import time; time.sleep(1)', block=False)
394 ar2 = view.apply_async(lambda : 2)
393 ar2 = view.apply_async(lambda : 2)
395 ar3 = view.apply_async(lambda : 3)
394 ar3 = view.apply_async(lambda : 3)
396 view.abort(ar2)
395 view.abort(ar2)
397 view.abort(ar3.msg_ids)
396 view.abort(ar3.msg_ids)
398 self.assertRaises(error.TaskAborted, ar2.get)
397 self.assertRaises(error.TaskAborted, ar2.get)
399 self.assertRaises(error.TaskAborted, ar3.get)
398 self.assertRaises(error.TaskAborted, ar3.get)
400
399
401 def test_abort_all(self):
400 def test_abort_all(self):
402 """view.abort() aborts all outstanding tasks"""
401 """view.abort() aborts all outstanding tasks"""
403 view = self.client[-1]
402 view = self.client[-1]
404 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
403 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 view.abort()
404 view.abort()
406 view.wait(timeout=5)
405 view.wait(timeout=5)
407 for ar in ars[5:]:
406 for ar in ars[5:]:
408 self.assertRaises(error.TaskAborted, ar.get)
407 self.assertRaises(error.TaskAborted, ar.get)
409
408
410 def test_temp_flags(self):
409 def test_temp_flags(self):
411 view = self.client[-1]
410 view = self.client[-1]
412 view.block=True
411 view.block=True
413 with view.temp_flags(block=False):
412 with view.temp_flags(block=False):
414 self.assertFalse(view.block)
413 self.assertFalse(view.block)
415 self.assertTrue(view.block)
414 self.assertTrue(view.block)
416
415
417 @dec.known_failure_py3
416 @dec.known_failure_py3
418 def test_importer(self):
417 def test_importer(self):
419 view = self.client[-1]
418 view = self.client[-1]
420 view.clear(block=True)
419 view.clear(block=True)
421 with view.importer:
420 with view.importer:
422 import re
421 import re
423
422
424 @interactive
423 @interactive
425 def findall(pat, s):
424 def findall(pat, s):
426 # this globals() step isn't necessary in real code
425 # this globals() step isn't necessary in real code
427 # only to prevent a closure in the test
426 # only to prevent a closure in the test
428 re = globals()['re']
427 re = globals()['re']
429 return re.findall(pat, s)
428 return re.findall(pat, s)
430
429
431 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
430 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432
431
433 def test_unicode_execute(self):
432 def test_unicode_execute(self):
434 """test executing unicode strings"""
433 """test executing unicode strings"""
435 v = self.client[-1]
434 v = self.client[-1]
436 v.block=True
435 v.block=True
437 if sys.version_info[0] >= 3:
436 if sys.version_info[0] >= 3:
438 code="a='é'"
437 code="a='é'"
439 else:
438 else:
440 code=u"a=u'é'"
439 code=u"a=u'é'"
441 v.execute(code)
440 v.execute(code)
442 self.assertEqual(v['a'], u'é')
441 self.assertEqual(v['a'], u'é')
443
442
444 def test_unicode_apply_result(self):
443 def test_unicode_apply_result(self):
445 """test unicode apply results"""
444 """test unicode apply results"""
446 v = self.client[-1]
445 v = self.client[-1]
447 r = v.apply_sync(lambda : u'é')
446 r = v.apply_sync(lambda : u'é')
448 self.assertEqual(r, u'é')
447 self.assertEqual(r, u'é')
449
448
450 def test_unicode_apply_arg(self):
449 def test_unicode_apply_arg(self):
451 """test passing unicode arguments to apply"""
450 """test passing unicode arguments to apply"""
452 v = self.client[-1]
451 v = self.client[-1]
453
452
454 @interactive
453 @interactive
455 def check_unicode(a, check):
454 def check_unicode(a, check):
456 assert isinstance(a, unicode), "%r is not unicode"%a
455 assert isinstance(a, unicode), "%r is not unicode"%a
457 assert isinstance(check, bytes), "%r is not bytes"%check
456 assert isinstance(check, bytes), "%r is not bytes"%check
458 assert a.encode('utf8') == check, "%s != %s"%(a,check)
457 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459
458
460 for s in [ u'é', u'ßø®∫',u'asdf' ]:
459 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 try:
460 try:
462 v.apply_sync(check_unicode, s, s.encode('utf8'))
461 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 except error.RemoteError as e:
462 except error.RemoteError as e:
464 if e.ename == 'AssertionError':
463 if e.ename == 'AssertionError':
465 self.fail(e.evalue)
464 self.fail(e.evalue)
466 else:
465 else:
467 raise e
466 raise e
468
467
469 def test_map_reference(self):
468 def test_map_reference(self):
470 """view.map(<Reference>, *seqs) should work"""
469 """view.map(<Reference>, *seqs) should work"""
471 v = self.client[:]
470 v = self.client[:]
472 v.scatter('n', self.client.ids, flatten=True)
471 v.scatter('n', self.client.ids, flatten=True)
473 v.execute("f = lambda x,y: x*y")
472 v.execute("f = lambda x,y: x*y")
474 rf = pmod.Reference('f')
473 rf = pmod.Reference('f')
475 nlist = list(range(10))
474 nlist = list(range(10))
476 mlist = nlist[::-1]
475 mlist = nlist[::-1]
477 expected = [ m*n for m,n in zip(mlist, nlist) ]
476 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 result = v.map_sync(rf, mlist, nlist)
477 result = v.map_sync(rf, mlist, nlist)
479 self.assertEqual(result, expected)
478 self.assertEqual(result, expected)
480
479
481 def test_apply_reference(self):
480 def test_apply_reference(self):
482 """view.apply(<Reference>, *args) should work"""
481 """view.apply(<Reference>, *args) should work"""
483 v = self.client[:]
482 v = self.client[:]
484 v.scatter('n', self.client.ids, flatten=True)
483 v.scatter('n', self.client.ids, flatten=True)
485 v.execute("f = lambda x: n*x")
484 v.execute("f = lambda x: n*x")
486 rf = pmod.Reference('f')
485 rf = pmod.Reference('f')
487 result = v.apply_sync(rf, 5)
486 result = v.apply_sync(rf, 5)
488 expected = [ 5*id for id in self.client.ids ]
487 expected = [ 5*id for id in self.client.ids ]
489 self.assertEqual(result, expected)
488 self.assertEqual(result, expected)
490
489
491 def test_eval_reference(self):
490 def test_eval_reference(self):
492 v = self.client[self.client.ids[0]]
491 v = self.client[self.client.ids[0]]
493 v['g'] = range(5)
492 v['g'] = range(5)
494 rg = pmod.Reference('g[0]')
493 rg = pmod.Reference('g[0]')
495 echo = lambda x:x
494 echo = lambda x:x
496 self.assertEqual(v.apply_sync(echo, rg), 0)
495 self.assertEqual(v.apply_sync(echo, rg), 0)
497
496
498 def test_reference_nameerror(self):
497 def test_reference_nameerror(self):
499 v = self.client[self.client.ids[0]]
498 v = self.client[self.client.ids[0]]
500 r = pmod.Reference('elvis_has_left')
499 r = pmod.Reference('elvis_has_left')
501 echo = lambda x:x
500 echo = lambda x:x
502 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
501 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503
502
504 def test_single_engine_map(self):
503 def test_single_engine_map(self):
505 e0 = self.client[self.client.ids[0]]
504 e0 = self.client[self.client.ids[0]]
506 r = range(5)
505 r = range(5)
507 check = [ -1*i for i in r ]
506 check = [ -1*i for i in r ]
508 result = e0.map_sync(lambda x: -1*x, r)
507 result = e0.map_sync(lambda x: -1*x, r)
509 self.assertEqual(result, check)
508 self.assertEqual(result, check)
510
509
511 def test_len(self):
510 def test_len(self):
512 """len(view) makes sense"""
511 """len(view) makes sense"""
513 e0 = self.client[self.client.ids[0]]
512 e0 = self.client[self.client.ids[0]]
514 yield self.assertEqual(len(e0), 1)
513 yield self.assertEqual(len(e0), 1)
515 v = self.client[:]
514 v = self.client[:]
516 yield self.assertEqual(len(v), len(self.client.ids))
515 yield self.assertEqual(len(v), len(self.client.ids))
517 v = self.client.direct_view('all')
516 v = self.client.direct_view('all')
518 yield self.assertEqual(len(v), len(self.client.ids))
517 yield self.assertEqual(len(v), len(self.client.ids))
519 v = self.client[:2]
518 v = self.client[:2]
520 yield self.assertEqual(len(v), 2)
519 yield self.assertEqual(len(v), 2)
521 v = self.client[:1]
520 v = self.client[:1]
522 yield self.assertEqual(len(v), 1)
521 yield self.assertEqual(len(v), 1)
523 v = self.client.load_balanced_view()
522 v = self.client.load_balanced_view()
524 yield self.assertEqual(len(v), len(self.client.ids))
523 yield self.assertEqual(len(v), len(self.client.ids))
525 # parametric tests seem to require manual closing?
524 # parametric tests seem to require manual closing?
526 self.client.close()
525 self.client.close()
527
526
528
527
529 # begin execute tests
528 # begin execute tests
530
529
531 def test_execute_reply(self):
530 def test_execute_reply(self):
532 e0 = self.client[self.client.ids[0]]
531 e0 = self.client[self.client.ids[0]]
533 e0.block = True
532 e0.block = True
534 ar = e0.execute("5", silent=False)
533 ar = e0.execute("5", silent=False)
535 er = ar.get()
534 er = ar.get()
536 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
535 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 self.assertEqual(er.pyout['data']['text/plain'], '5')
536 self.assertEqual(er.pyout['data']['text/plain'], '5')
538
537
539 def test_execute_reply_stdout(self):
538 def test_execute_reply_stdout(self):
540 e0 = self.client[self.client.ids[0]]
539 e0 = self.client[self.client.ids[0]]
541 e0.block = True
540 e0.block = True
542 ar = e0.execute("print (5)", silent=False)
541 ar = e0.execute("print (5)", silent=False)
543 er = ar.get()
542 er = ar.get()
544 self.assertEqual(er.stdout.strip(), '5')
543 self.assertEqual(er.stdout.strip(), '5')
545
544
546 def test_execute_pyout(self):
545 def test_execute_pyout(self):
547 """execute triggers pyout with silent=False"""
546 """execute triggers pyout with silent=False"""
548 view = self.client[:]
547 view = self.client[:]
549 ar = view.execute("5", silent=False, block=True)
548 ar = view.execute("5", silent=False, block=True)
550
549
551 expected = [{'text/plain' : '5'}] * len(view)
550 expected = [{'text/plain' : '5'}] * len(view)
552 mimes = [ out['data'] for out in ar.pyout ]
551 mimes = [ out['data'] for out in ar.pyout ]
553 self.assertEqual(mimes, expected)
552 self.assertEqual(mimes, expected)
554
553
555 def test_execute_silent(self):
554 def test_execute_silent(self):
556 """execute does not trigger pyout with silent=True"""
555 """execute does not trigger pyout with silent=True"""
557 view = self.client[:]
556 view = self.client[:]
558 ar = view.execute("5", block=True)
557 ar = view.execute("5", block=True)
559 expected = [None] * len(view)
558 expected = [None] * len(view)
560 self.assertEqual(ar.pyout, expected)
559 self.assertEqual(ar.pyout, expected)
561
560
562 def test_execute_magic(self):
561 def test_execute_magic(self):
563 """execute accepts IPython commands"""
562 """execute accepts IPython commands"""
564 view = self.client[:]
563 view = self.client[:]
565 view.execute("a = 5")
564 view.execute("a = 5")
566 ar = view.execute("%whos", block=True)
565 ar = view.execute("%whos", block=True)
567 # this will raise, if that failed
566 # this will raise, if that failed
568 ar.get(5)
567 ar.get(5)
569 for stdout in ar.stdout:
568 for stdout in ar.stdout:
570 lines = stdout.splitlines()
569 lines = stdout.splitlines()
571 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
570 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
572 found = False
571 found = False
573 for line in lines[2:]:
572 for line in lines[2:]:
574 split = line.split()
573 split = line.split()
575 if split == ['a', 'int', '5']:
574 if split == ['a', 'int', '5']:
576 found = True
575 found = True
577 break
576 break
578 self.assertTrue(found, "whos output wrong: %s" % stdout)
577 self.assertTrue(found, "whos output wrong: %s" % stdout)
579
578
580 def test_execute_displaypub(self):
579 def test_execute_displaypub(self):
581 """execute tracks display_pub output"""
580 """execute tracks display_pub output"""
582 view = self.client[:]
581 view = self.client[:]
583 view.execute("from IPython.core.display import *")
582 view.execute("from IPython.core.display import *")
584 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
583 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
585
584
586 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
585 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
587 for outputs in ar.outputs:
586 for outputs in ar.outputs:
588 mimes = [ out['data'] for out in outputs ]
587 mimes = [ out['data'] for out in outputs ]
589 self.assertEqual(mimes, expected)
588 self.assertEqual(mimes, expected)
590
589
591 def test_apply_displaypub(self):
590 def test_apply_displaypub(self):
592 """apply tracks display_pub output"""
591 """apply tracks display_pub output"""
593 view = self.client[:]
592 view = self.client[:]
594 view.execute("from IPython.core.display import *")
593 view.execute("from IPython.core.display import *")
595
594
596 @interactive
595 @interactive
597 def publish():
596 def publish():
598 [ display(i) for i in range(5) ]
597 [ display(i) for i in range(5) ]
599
598
600 ar = view.apply_async(publish)
599 ar = view.apply_async(publish)
601 ar.get(5)
600 ar.get(5)
602 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
601 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
603 for outputs in ar.outputs:
602 for outputs in ar.outputs:
604 mimes = [ out['data'] for out in outputs ]
603 mimes = [ out['data'] for out in outputs ]
605 self.assertEqual(mimes, expected)
604 self.assertEqual(mimes, expected)
606
605
607 def test_execute_raises(self):
606 def test_execute_raises(self):
608 """exceptions in execute requests raise appropriately"""
607 """exceptions in execute requests raise appropriately"""
609 view = self.client[-1]
608 view = self.client[-1]
610 ar = view.execute("1/0")
609 ar = view.execute("1/0")
611 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
610 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
612
611
613 def test_remoteerror_render_exception(self):
612 def test_remoteerror_render_exception(self):
614 """RemoteErrors get nice tracebacks"""
613 """RemoteErrors get nice tracebacks"""
615 view = self.client[-1]
614 view = self.client[-1]
616 ar = view.execute("1/0")
615 ar = view.execute("1/0")
617 ip = get_ipython()
616 ip = get_ipython()
618 ip.user_ns['ar'] = ar
617 ip.user_ns['ar'] = ar
619 with capture_output() as io:
618 with capture_output() as io:
620 ip.run_cell("ar.get(2)")
619 ip.run_cell("ar.get(2)")
621
620
622 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
621 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
623
622
624 def test_compositeerror_render_exception(self):
623 def test_compositeerror_render_exception(self):
625 """CompositeErrors get nice tracebacks"""
624 """CompositeErrors get nice tracebacks"""
626 view = self.client[:]
625 view = self.client[:]
627 ar = view.execute("1/0")
626 ar = view.execute("1/0")
628 ip = get_ipython()
627 ip = get_ipython()
629 ip.user_ns['ar'] = ar
628 ip.user_ns['ar'] = ar
630
629
631 with capture_output() as io:
630 with capture_output() as io:
632 ip.run_cell("ar.get(2)")
631 ip.run_cell("ar.get(2)")
633
632
634 count = min(error.CompositeError.tb_limit, len(view))
633 count = min(error.CompositeError.tb_limit, len(view))
635
634
636 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
635 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
637 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
636 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
638 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
637 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
639
638
640 def test_compositeerror_truncate(self):
639 def test_compositeerror_truncate(self):
641 """Truncate CompositeErrors with many exceptions"""
640 """Truncate CompositeErrors with many exceptions"""
642 view = self.client[:]
641 view = self.client[:]
643 msg_ids = []
642 msg_ids = []
644 for i in range(10):
643 for i in range(10):
645 ar = view.execute("1/0")
644 ar = view.execute("1/0")
646 msg_ids.extend(ar.msg_ids)
645 msg_ids.extend(ar.msg_ids)
647
646
648 ar = self.client.get_result(msg_ids)
647 ar = self.client.get_result(msg_ids)
649 try:
648 try:
650 ar.get()
649 ar.get()
651 except error.CompositeError as _e:
650 except error.CompositeError as _e:
652 e = _e
651 e = _e
653 else:
652 else:
654 self.fail("Should have raised CompositeError")
653 self.fail("Should have raised CompositeError")
655
654
656 lines = e.render_traceback()
655 lines = e.render_traceback()
657 with capture_output() as io:
656 with capture_output() as io:
658 e.print_traceback()
657 e.print_traceback()
659
658
660 self.assertTrue("more exceptions" in lines[-1])
659 self.assertTrue("more exceptions" in lines[-1])
661 count = e.tb_limit
660 count = e.tb_limit
662
661
663 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
662 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
664 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
663 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
665 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
664 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
666
665
667 @dec.skipif_not_matplotlib
666 @dec.skipif_not_matplotlib
668 def test_magic_pylab(self):
667 def test_magic_pylab(self):
669 """%pylab works on engines"""
668 """%pylab works on engines"""
670 view = self.client[-1]
669 view = self.client[-1]
671 ar = view.execute("%pylab inline")
670 ar = view.execute("%pylab inline")
672 # at least check if this raised:
671 # at least check if this raised:
673 reply = ar.get(5)
672 reply = ar.get(5)
674 # include imports, in case user config
673 # include imports, in case user config
675 ar = view.execute("plot(rand(100))", silent=False)
674 ar = view.execute("plot(rand(100))", silent=False)
676 reply = ar.get(5)
675 reply = ar.get(5)
677 self.assertEqual(len(reply.outputs), 1)
676 self.assertEqual(len(reply.outputs), 1)
678 output = reply.outputs[0]
677 output = reply.outputs[0]
679 self.assertTrue("data" in output)
678 self.assertTrue("data" in output)
680 data = output['data']
679 data = output['data']
681 self.assertTrue("image/png" in data)
680 self.assertTrue("image/png" in data)
682
681
683 def test_func_default_func(self):
682 def test_func_default_func(self):
684 """interactively defined function as apply func default"""
683 """interactively defined function as apply func default"""
685 def foo():
684 def foo():
686 return 'foo'
685 return 'foo'
687
686
688 def bar(f=foo):
687 def bar(f=foo):
689 return f()
688 return f()
690
689
691 view = self.client[-1]
690 view = self.client[-1]
692 ar = view.apply_async(bar)
691 ar = view.apply_async(bar)
693 r = ar.get(10)
692 r = ar.get(10)
694 self.assertEqual(r, 'foo')
693 self.assertEqual(r, 'foo')
695 def test_data_pub_single(self):
694 def test_data_pub_single(self):
696 view = self.client[-1]
695 view = self.client[-1]
697 ar = view.execute('\n'.join([
696 ar = view.execute('\n'.join([
698 'from IPython.kernel.zmq.datapub import publish_data',
697 'from IPython.kernel.zmq.datapub import publish_data',
699 'for i in range(5):',
698 'for i in range(5):',
700 ' publish_data(dict(i=i))'
699 ' publish_data(dict(i=i))'
701 ]), block=False)
700 ]), block=False)
702 self.assertTrue(isinstance(ar.data, dict))
701 self.assertTrue(isinstance(ar.data, dict))
703 ar.get(5)
702 ar.get(5)
704 self.assertEqual(ar.data, dict(i=4))
703 self.assertEqual(ar.data, dict(i=4))
705
704
706 def test_data_pub(self):
705 def test_data_pub(self):
707 view = self.client[:]
706 view = self.client[:]
708 ar = view.execute('\n'.join([
707 ar = view.execute('\n'.join([
709 'from IPython.kernel.zmq.datapub import publish_data',
708 'from IPython.kernel.zmq.datapub import publish_data',
710 'for i in range(5):',
709 'for i in range(5):',
711 ' publish_data(dict(i=i))'
710 ' publish_data(dict(i=i))'
712 ]), block=False)
711 ]), block=False)
713 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
712 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
714 ar.get(5)
713 ar.get(5)
715 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
714 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
716
715
717 def test_can_list_arg(self):
716 def test_can_list_arg(self):
718 """args in lists are canned"""
717 """args in lists are canned"""
719 view = self.client[-1]
718 view = self.client[-1]
720 view['a'] = 128
719 view['a'] = 128
721 rA = pmod.Reference('a')
720 rA = pmod.Reference('a')
722 ar = view.apply_async(lambda x: x, [rA])
721 ar = view.apply_async(lambda x: x, [rA])
723 r = ar.get(5)
722 r = ar.get(5)
724 self.assertEqual(r, [128])
723 self.assertEqual(r, [128])
725
724
726 def test_can_dict_arg(self):
725 def test_can_dict_arg(self):
727 """args in dicts are canned"""
726 """args in dicts are canned"""
728 view = self.client[-1]
727 view = self.client[-1]
729 view['a'] = 128
728 view['a'] = 128
730 rA = pmod.Reference('a')
729 rA = pmod.Reference('a')
731 ar = view.apply_async(lambda x: x, dict(foo=rA))
730 ar = view.apply_async(lambda x: x, dict(foo=rA))
732 r = ar.get(5)
731 r = ar.get(5)
733 self.assertEqual(r, dict(foo=128))
732 self.assertEqual(r, dict(foo=128))
734
733
735 def test_can_list_kwarg(self):
734 def test_can_list_kwarg(self):
736 """kwargs in lists are canned"""
735 """kwargs in lists are canned"""
737 view = self.client[-1]
736 view = self.client[-1]
738 view['a'] = 128
737 view['a'] = 128
739 rA = pmod.Reference('a')
738 rA = pmod.Reference('a')
740 ar = view.apply_async(lambda x=5: x, x=[rA])
739 ar = view.apply_async(lambda x=5: x, x=[rA])
741 r = ar.get(5)
740 r = ar.get(5)
742 self.assertEqual(r, [128])
741 self.assertEqual(r, [128])
743
742
744 def test_can_dict_kwarg(self):
743 def test_can_dict_kwarg(self):
745 """kwargs in dicts are canned"""
744 """kwargs in dicts are canned"""
746 view = self.client[-1]
745 view = self.client[-1]
747 view['a'] = 128
746 view['a'] = 128
748 rA = pmod.Reference('a')
747 rA = pmod.Reference('a')
749 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
748 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
750 r = ar.get(5)
749 r = ar.get(5)
751 self.assertEqual(r, dict(foo=128))
750 self.assertEqual(r, dict(foo=128))
752
751
753 def test_map_ref(self):
752 def test_map_ref(self):
754 """view.map works with references"""
753 """view.map works with references"""
755 view = self.client[:]
754 view = self.client[:]
756 ranks = sorted(self.client.ids)
755 ranks = sorted(self.client.ids)
757 view.scatter('rank', ranks, flatten=True)
756 view.scatter('rank', ranks, flatten=True)
758 rrank = pmod.Reference('rank')
757 rrank = pmod.Reference('rank')
759
758
760 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
759 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
761 drank = amr.get(5)
760 drank = amr.get(5)
762 self.assertEqual(drank, [ r*2 for r in ranks ])
761 self.assertEqual(drank, [ r*2 for r in ranks ])
763
762
764 def test_nested_getitem_setitem(self):
763 def test_nested_getitem_setitem(self):
765 """get and set with view['a.b']"""
764 """get and set with view['a.b']"""
766 view = self.client[-1]
765 view = self.client[-1]
767 view.execute('\n'.join([
766 view.execute('\n'.join([
768 'class A(object): pass',
767 'class A(object): pass',
769 'a = A()',
768 'a = A()',
770 'a.b = 128',
769 'a.b = 128',
771 ]), block=True)
770 ]), block=True)
772 ra = pmod.Reference('a')
771 ra = pmod.Reference('a')
773
772
774 r = view.apply_sync(lambda x: x.b, ra)
773 r = view.apply_sync(lambda x: x.b, ra)
775 self.assertEqual(r, 128)
774 self.assertEqual(r, 128)
776 self.assertEqual(view['a.b'], 128)
775 self.assertEqual(view['a.b'], 128)
777
776
778 view['a.b'] = 0
777 view['a.b'] = 0
779
778
780 r = view.apply_sync(lambda x: x.b, ra)
779 r = view.apply_sync(lambda x: x.b, ra)
781 self.assertEqual(r, 0)
780 self.assertEqual(r, 0)
782 self.assertEqual(view['a.b'], 0)
781 self.assertEqual(view['a.b'], 0)
783
782
784 def test_return_namedtuple(self):
783 def test_return_namedtuple(self):
785 def namedtuplify(x, y):
784 def namedtuplify(x, y):
786 from IPython.parallel.tests.test_view import point
785 from IPython.parallel.tests.test_view import point
787 return point(x, y)
786 return point(x, y)
788
787
789 view = self.client[-1]
788 view = self.client[-1]
790 p = view.apply_sync(namedtuplify, 1, 2)
789 p = view.apply_sync(namedtuplify, 1, 2)
791 self.assertEqual(p.x, 1)
790 self.assertEqual(p.x, 1)
792 self.assertEqual(p.y, 2)
791 self.assertEqual(p.y, 2)
793
792
794 def test_apply_namedtuple(self):
793 def test_apply_namedtuple(self):
795 def echoxy(p):
794 def echoxy(p):
796 return p.y, p.x
795 return p.y, p.x
797
796
798 view = self.client[-1]
797 view = self.client[-1]
799 tup = view.apply_sync(echoxy, point(1, 2))
798 tup = view.apply_sync(echoxy, point(1, 2))
800 self.assertEqual(tup, (2,1))
799 self.assertEqual(tup, (2,1))
801
800
@@ -1,442 +1,444 b''
1 """Generic testing tools.
1 """Generic testing tools.
2
2
3 Authors
3 Authors
4 -------
4 -------
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 """
6 """
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2009 The IPython Development Team
11 # Copyright (C) 2009 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 import os
21 import os
22 import re
22 import re
23 import sys
23 import sys
24 import tempfile
24 import tempfile
25
25
26 from contextlib import contextmanager
26 from contextlib import contextmanager
27 from io import StringIO
27 from io import StringIO
28 from subprocess import Popen, PIPE
28 from subprocess import Popen, PIPE
29
29
30 try:
30 try:
31 # These tools are used by parts of the runtime, so we make the nose
31 # These tools are used by parts of the runtime, so we make the nose
32 # dependency optional at this point. Nose is a hard dependency to run the
32 # dependency optional at this point. Nose is a hard dependency to run the
33 # test suite, but NOT to use ipython itself.
33 # test suite, but NOT to use ipython itself.
34 import nose.tools as nt
34 import nose.tools as nt
35 has_nose = True
35 has_nose = True
36 except ImportError:
36 except ImportError:
37 has_nose = False
37 has_nose = False
38
38
39 from IPython.config.loader import Config
39 from IPython.config.loader import Config
40 from IPython.utils.process import get_output_error_code
40 from IPython.utils.process import get_output_error_code
41 from IPython.utils.text import list_strings
41 from IPython.utils.text import list_strings
42 from IPython.utils.io import temp_pyfile, Tee
42 from IPython.utils.io import temp_pyfile, Tee
43 from IPython.utils import py3compat
43 from IPython.utils import py3compat
44 from IPython.utils.encoding import DEFAULT_ENCODING
44 from IPython.utils.encoding import DEFAULT_ENCODING
45
45
46 from . import decorators as dec
46 from . import decorators as dec
47 from . import skipdoctest
47 from . import skipdoctest
48
48
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50 # Functions and classes
50 # Functions and classes
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52
52
53 # The docstring for full_path doctests differently on win32 (different path
53 # The docstring for full_path doctests differently on win32 (different path
54 # separator) so just skip the doctest there. The example remains informative.
54 # separator) so just skip the doctest there. The example remains informative.
55 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
55 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
56
56
57 @doctest_deco
57 @doctest_deco
58 def full_path(startPath,files):
58 def full_path(startPath,files):
59 """Make full paths for all the listed files, based on startPath.
59 """Make full paths for all the listed files, based on startPath.
60
60
61 Only the base part of startPath is kept, since this routine is typically
61 Only the base part of startPath is kept, since this routine is typically
62 used with a script's __file__ variable as startPath. The base of startPath
62 used with a script's __file__ variable as startPath. The base of startPath
63 is then prepended to all the listed files, forming the output list.
63 is then prepended to all the listed files, forming the output list.
64
64
65 Parameters
65 Parameters
66 ----------
66 ----------
67 startPath : string
67 startPath : string
68 Initial path to use as the base for the results. This path is split
68 Initial path to use as the base for the results. This path is split
69 using os.path.split() and only its first component is kept.
69 using os.path.split() and only its first component is kept.
70
70
71 files : string or list
71 files : string or list
72 One or more files.
72 One or more files.
73
73
74 Examples
74 Examples
75 --------
75 --------
76
76
77 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
77 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
78 ['/foo/a.txt', '/foo/b.txt']
78 ['/foo/a.txt', '/foo/b.txt']
79
79
80 >>> full_path('/foo',['a.txt','b.txt'])
80 >>> full_path('/foo',['a.txt','b.txt'])
81 ['/a.txt', '/b.txt']
81 ['/a.txt', '/b.txt']
82
82
83 If a single file is given, the output is still a list:
83 If a single file is given, the output is still a list:
84 >>> full_path('/foo','a.txt')
84 >>> full_path('/foo','a.txt')
85 ['/a.txt']
85 ['/a.txt']
86 """
86 """
87
87
88 files = list_strings(files)
88 files = list_strings(files)
89 base = os.path.split(startPath)[0]
89 base = os.path.split(startPath)[0]
90 return [ os.path.join(base,f) for f in files ]
90 return [ os.path.join(base,f) for f in files ]
91
91
92
92
93 def parse_test_output(txt):
93 def parse_test_output(txt):
94 """Parse the output of a test run and return errors, failures.
94 """Parse the output of a test run and return errors, failures.
95
95
96 Parameters
96 Parameters
97 ----------
97 ----------
98 txt : str
98 txt : str
99 Text output of a test run, assumed to contain a line of one of the
99 Text output of a test run, assumed to contain a line of one of the
100 following forms::
100 following forms::
101
101
102 'FAILED (errors=1)'
102 'FAILED (errors=1)'
103 'FAILED (failures=1)'
103 'FAILED (failures=1)'
104 'FAILED (errors=1, failures=1)'
104 'FAILED (errors=1, failures=1)'
105
105
106 Returns
106 Returns
107 -------
107 -------
108 nerr, nfail: number of errors and failures.
108 nerr, nfail: number of errors and failures.
109 """
109 """
110
110
111 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
111 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
112 if err_m:
112 if err_m:
113 nerr = int(err_m.group(1))
113 nerr = int(err_m.group(1))
114 nfail = 0
114 nfail = 0
115 return nerr, nfail
115 return nerr, nfail
116
116
117 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
117 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
118 if fail_m:
118 if fail_m:
119 nerr = 0
119 nerr = 0
120 nfail = int(fail_m.group(1))
120 nfail = int(fail_m.group(1))
121 return nerr, nfail
121 return nerr, nfail
122
122
123 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
123 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
124 re.MULTILINE)
124 re.MULTILINE)
125 if both_m:
125 if both_m:
126 nerr = int(both_m.group(1))
126 nerr = int(both_m.group(1))
127 nfail = int(both_m.group(2))
127 nfail = int(both_m.group(2))
128 return nerr, nfail
128 return nerr, nfail
129
129
130 # If the input didn't match any of these forms, assume no error/failures
130 # If the input didn't match any of these forms, assume no error/failures
131 return 0, 0
131 return 0, 0
132
132
133
133
134 # So nose doesn't think this is a test
134 # So nose doesn't think this is a test
135 parse_test_output.__test__ = False
135 parse_test_output.__test__ = False
136
136
137
137
138 def default_argv():
138 def default_argv():
139 """Return a valid default argv for creating testing instances of ipython"""
139 """Return a valid default argv for creating testing instances of ipython"""
140
140
141 return ['--quick', # so no config file is loaded
141 return ['--quick', # so no config file is loaded
142 # Other defaults to minimize side effects on stdout
142 # Other defaults to minimize side effects on stdout
143 '--colors=NoColor', '--no-term-title','--no-banner',
143 '--colors=NoColor', '--no-term-title','--no-banner',
144 '--autocall=0']
144 '--autocall=0']
145
145
146
146
147 def default_config():
147 def default_config():
148 """Return a config object with good defaults for testing."""
148 """Return a config object with good defaults for testing."""
149 config = Config()
149 config = Config()
150 config.TerminalInteractiveShell.colors = 'NoColor'
150 config.TerminalInteractiveShell.colors = 'NoColor'
151 config.TerminalTerminalInteractiveShell.term_title = False,
151 config.TerminalTerminalInteractiveShell.term_title = False,
152 config.TerminalInteractiveShell.autocall = 0
152 config.TerminalInteractiveShell.autocall = 0
153 config.HistoryManager.hist_file = tempfile.mktemp(u'test_hist.sqlite')
153 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
154 config.HistoryManager.hist_file = f.name
155 f.close()
154 config.HistoryManager.db_cache_size = 10000
156 config.HistoryManager.db_cache_size = 10000
155 return config
157 return config
156
158
157
159
158 def get_ipython_cmd(as_string=False):
160 def get_ipython_cmd(as_string=False):
159 """
161 """
160 Return appropriate IPython command line name. By default, this will return
162 Return appropriate IPython command line name. By default, this will return
161 a list that can be used with subprocess.Popen, for example, but passing
163 a list that can be used with subprocess.Popen, for example, but passing
162 `as_string=True` allows for returning the IPython command as a string.
164 `as_string=True` allows for returning the IPython command as a string.
163
165
164 Parameters
166 Parameters
165 ----------
167 ----------
166 as_string: bool
168 as_string: bool
167 Flag to allow to return the command as a string.
169 Flag to allow to return the command as a string.
168 """
170 """
169 # FIXME: remove workaround for 2.6 support
171 # FIXME: remove workaround for 2.6 support
170 if sys.version_info[:2] > (2,6):
172 if sys.version_info[:2] > (2,6):
171 ipython_cmd = [sys.executable, "-m", "IPython"]
173 ipython_cmd = [sys.executable, "-m", "IPython"]
172 else:
174 else:
173 ipython_cmd = ["ipython"]
175 ipython_cmd = ["ipython"]
174
176
175 if as_string:
177 if as_string:
176 ipython_cmd = " ".join(ipython_cmd)
178 ipython_cmd = " ".join(ipython_cmd)
177
179
178 return ipython_cmd
180 return ipython_cmd
179
181
180 def ipexec(fname, options=None):
182 def ipexec(fname, options=None):
181 """Utility to call 'ipython filename'.
183 """Utility to call 'ipython filename'.
182
184
183 Starts IPython with a minimal and safe configuration to make startup as fast
185 Starts IPython with a minimal and safe configuration to make startup as fast
184 as possible.
186 as possible.
185
187
186 Note that this starts IPython in a subprocess!
188 Note that this starts IPython in a subprocess!
187
189
188 Parameters
190 Parameters
189 ----------
191 ----------
190 fname : str
192 fname : str
191 Name of file to be executed (should have .py or .ipy extension).
193 Name of file to be executed (should have .py or .ipy extension).
192
194
193 options : optional, list
195 options : optional, list
194 Extra command-line flags to be passed to IPython.
196 Extra command-line flags to be passed to IPython.
195
197
196 Returns
198 Returns
197 -------
199 -------
198 (stdout, stderr) of ipython subprocess.
200 (stdout, stderr) of ipython subprocess.
199 """
201 """
200 if options is None: options = []
202 if options is None: options = []
201
203
202 # For these subprocess calls, eliminate all prompt printing so we only see
204 # For these subprocess calls, eliminate all prompt printing so we only see
203 # output from script execution
205 # output from script execution
204 prompt_opts = [ '--PromptManager.in_template=""',
206 prompt_opts = [ '--PromptManager.in_template=""',
205 '--PromptManager.in2_template=""',
207 '--PromptManager.in2_template=""',
206 '--PromptManager.out_template=""'
208 '--PromptManager.out_template=""'
207 ]
209 ]
208 cmdargs = default_argv() + prompt_opts + options
210 cmdargs = default_argv() + prompt_opts + options
209
211
210 test_dir = os.path.dirname(__file__)
212 test_dir = os.path.dirname(__file__)
211
213
212 ipython_cmd = get_ipython_cmd()
214 ipython_cmd = get_ipython_cmd()
213 # Absolute path for filename
215 # Absolute path for filename
214 full_fname = os.path.join(test_dir, fname)
216 full_fname = os.path.join(test_dir, fname)
215 full_cmd = ipython_cmd + cmdargs + [full_fname]
217 full_cmd = ipython_cmd + cmdargs + [full_fname]
216 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE)
218 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE)
217 out, err = p.communicate()
219 out, err = p.communicate()
218 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
220 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
219 # `import readline` causes 'ESC[?1034h' to be output sometimes,
221 # `import readline` causes 'ESC[?1034h' to be output sometimes,
220 # so strip that out before doing comparisons
222 # so strip that out before doing comparisons
221 if out:
223 if out:
222 out = re.sub(r'\x1b\[[^h]+h', '', out)
224 out = re.sub(r'\x1b\[[^h]+h', '', out)
223 return out, err
225 return out, err
224
226
225
227
226 def ipexec_validate(fname, expected_out, expected_err='',
228 def ipexec_validate(fname, expected_out, expected_err='',
227 options=None):
229 options=None):
228 """Utility to call 'ipython filename' and validate output/error.
230 """Utility to call 'ipython filename' and validate output/error.
229
231
230 This function raises an AssertionError if the validation fails.
232 This function raises an AssertionError if the validation fails.
231
233
232 Note that this starts IPython in a subprocess!
234 Note that this starts IPython in a subprocess!
233
235
234 Parameters
236 Parameters
235 ----------
237 ----------
236 fname : str
238 fname : str
237 Name of the file to be executed (should have .py or .ipy extension).
239 Name of the file to be executed (should have .py or .ipy extension).
238
240
239 expected_out : str
241 expected_out : str
240 Expected stdout of the process.
242 Expected stdout of the process.
241
243
242 expected_err : optional, str
244 expected_err : optional, str
243 Expected stderr of the process.
245 Expected stderr of the process.
244
246
245 options : optional, list
247 options : optional, list
246 Extra command-line flags to be passed to IPython.
248 Extra command-line flags to be passed to IPython.
247
249
248 Returns
250 Returns
249 -------
251 -------
250 None
252 None
251 """
253 """
252
254
253 import nose.tools as nt
255 import nose.tools as nt
254
256
255 out, err = ipexec(fname, options)
257 out, err = ipexec(fname, options)
256 #print 'OUT', out # dbg
258 #print 'OUT', out # dbg
257 #print 'ERR', err # dbg
259 #print 'ERR', err # dbg
258 # If there are any errors, we must check those befor stdout, as they may be
260 # If there are any errors, we must check those befor stdout, as they may be
259 # more informative than simply having an empty stdout.
261 # more informative than simply having an empty stdout.
260 if err:
262 if err:
261 if expected_err:
263 if expected_err:
262 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
264 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
263 else:
265 else:
264 raise ValueError('Running file %r produced error: %r' %
266 raise ValueError('Running file %r produced error: %r' %
265 (fname, err))
267 (fname, err))
266 # If no errors or output on stderr was expected, match stdout
268 # If no errors or output on stderr was expected, match stdout
267 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
269 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
268
270
269
271
270 class TempFileMixin(object):
272 class TempFileMixin(object):
271 """Utility class to create temporary Python/IPython files.
273 """Utility class to create temporary Python/IPython files.
272
274
273 Meant as a mixin class for test cases."""
275 Meant as a mixin class for test cases."""
274
276
275 def mktmp(self, src, ext='.py'):
277 def mktmp(self, src, ext='.py'):
276 """Make a valid python temp file."""
278 """Make a valid python temp file."""
277 fname, f = temp_pyfile(src, ext)
279 fname, f = temp_pyfile(src, ext)
278 self.tmpfile = f
280 self.tmpfile = f
279 self.fname = fname
281 self.fname = fname
280
282
281 def tearDown(self):
283 def tearDown(self):
282 if hasattr(self, 'tmpfile'):
284 if hasattr(self, 'tmpfile'):
283 # If the tmpfile wasn't made because of skipped tests, like in
285 # If the tmpfile wasn't made because of skipped tests, like in
284 # win32, there's nothing to cleanup.
286 # win32, there's nothing to cleanup.
285 self.tmpfile.close()
287 self.tmpfile.close()
286 try:
288 try:
287 os.unlink(self.fname)
289 os.unlink(self.fname)
288 except:
290 except:
289 # On Windows, even though we close the file, we still can't
291 # On Windows, even though we close the file, we still can't
290 # delete it. I have no clue why
292 # delete it. I have no clue why
291 pass
293 pass
292
294
293 pair_fail_msg = ("Testing {0}\n\n"
295 pair_fail_msg = ("Testing {0}\n\n"
294 "In:\n"
296 "In:\n"
295 " {1!r}\n"
297 " {1!r}\n"
296 "Expected:\n"
298 "Expected:\n"
297 " {2!r}\n"
299 " {2!r}\n"
298 "Got:\n"
300 "Got:\n"
299 " {3!r}\n")
301 " {3!r}\n")
300 def check_pairs(func, pairs):
302 def check_pairs(func, pairs):
301 """Utility function for the common case of checking a function with a
303 """Utility function for the common case of checking a function with a
302 sequence of input/output pairs.
304 sequence of input/output pairs.
303
305
304 Parameters
306 Parameters
305 ----------
307 ----------
306 func : callable
308 func : callable
307 The function to be tested. Should accept a single argument.
309 The function to be tested. Should accept a single argument.
308 pairs : iterable
310 pairs : iterable
309 A list of (input, expected_output) tuples.
311 A list of (input, expected_output) tuples.
310
312
311 Returns
313 Returns
312 -------
314 -------
313 None. Raises an AssertionError if any output does not match the expected
315 None. Raises an AssertionError if any output does not match the expected
314 value.
316 value.
315 """
317 """
316 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
318 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
317 for inp, expected in pairs:
319 for inp, expected in pairs:
318 out = func(inp)
320 out = func(inp)
319 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
321 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
320
322
321
323
322 if py3compat.PY3:
324 if py3compat.PY3:
323 MyStringIO = StringIO
325 MyStringIO = StringIO
324 else:
326 else:
325 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
327 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
326 # so we need a class that can handle both.
328 # so we need a class that can handle both.
327 class MyStringIO(StringIO):
329 class MyStringIO(StringIO):
328 def write(self, s):
330 def write(self, s):
329 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
331 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
330 super(MyStringIO, self).write(s)
332 super(MyStringIO, self).write(s)
331
333
332 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
333 -------
335 -------
334 {2!s}
336 {2!s}
335 -------
337 -------
336 """
338 """
337
339
338 class AssertPrints(object):
340 class AssertPrints(object):
339 """Context manager for testing that code prints certain text.
341 """Context manager for testing that code prints certain text.
340
342
341 Examples
343 Examples
342 --------
344 --------
343 >>> with AssertPrints("abc", suppress=False):
345 >>> with AssertPrints("abc", suppress=False):
344 ... print "abcd"
346 ... print "abcd"
345 ... print "def"
347 ... print "def"
346 ...
348 ...
347 abcd
349 abcd
348 def
350 def
349 """
351 """
350 def __init__(self, s, channel='stdout', suppress=True):
352 def __init__(self, s, channel='stdout', suppress=True):
351 self.s = s
353 self.s = s
352 if isinstance(self.s, py3compat.string_types):
354 if isinstance(self.s, py3compat.string_types):
353 self.s = [self.s]
355 self.s = [self.s]
354 self.channel = channel
356 self.channel = channel
355 self.suppress = suppress
357 self.suppress = suppress
356
358
357 def __enter__(self):
359 def __enter__(self):
358 self.orig_stream = getattr(sys, self.channel)
360 self.orig_stream = getattr(sys, self.channel)
359 self.buffer = MyStringIO()
361 self.buffer = MyStringIO()
360 self.tee = Tee(self.buffer, channel=self.channel)
362 self.tee = Tee(self.buffer, channel=self.channel)
361 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
362
364
363 def __exit__(self, etype, value, traceback):
365 def __exit__(self, etype, value, traceback):
364 self.tee.flush()
366 self.tee.flush()
365 setattr(sys, self.channel, self.orig_stream)
367 setattr(sys, self.channel, self.orig_stream)
366 printed = self.buffer.getvalue()
368 printed = self.buffer.getvalue()
367 for s in self.s:
369 for s in self.s:
368 assert s in printed, notprinted_msg.format(s, self.channel, printed)
370 assert s in printed, notprinted_msg.format(s, self.channel, printed)
369 return False
371 return False
370
372
371 printed_msg = """Found {0!r} in printed output (on {1}):
373 printed_msg = """Found {0!r} in printed output (on {1}):
372 -------
374 -------
373 {2!s}
375 {2!s}
374 -------
376 -------
375 """
377 """
376
378
377 class AssertNotPrints(AssertPrints):
379 class AssertNotPrints(AssertPrints):
378 """Context manager for checking that certain output *isn't* produced.
380 """Context manager for checking that certain output *isn't* produced.
379
381
380 Counterpart of AssertPrints"""
382 Counterpart of AssertPrints"""
381 def __exit__(self, etype, value, traceback):
383 def __exit__(self, etype, value, traceback):
382 self.tee.flush()
384 self.tee.flush()
383 setattr(sys, self.channel, self.orig_stream)
385 setattr(sys, self.channel, self.orig_stream)
384 printed = self.buffer.getvalue()
386 printed = self.buffer.getvalue()
385 for s in self.s:
387 for s in self.s:
386 assert s not in printed, printed_msg.format(s, self.channel, printed)
388 assert s not in printed, printed_msg.format(s, self.channel, printed)
387 return False
389 return False
388
390
389 @contextmanager
391 @contextmanager
390 def mute_warn():
392 def mute_warn():
391 from IPython.utils import warn
393 from IPython.utils import warn
392 save_warn = warn.warn
394 save_warn = warn.warn
393 warn.warn = lambda *a, **kw: None
395 warn.warn = lambda *a, **kw: None
394 try:
396 try:
395 yield
397 yield
396 finally:
398 finally:
397 warn.warn = save_warn
399 warn.warn = save_warn
398
400
399 @contextmanager
401 @contextmanager
400 def make_tempfile(name):
402 def make_tempfile(name):
401 """ Create an empty, named, temporary file for the duration of the context.
403 """ Create an empty, named, temporary file for the duration of the context.
402 """
404 """
403 f = open(name, 'w')
405 f = open(name, 'w')
404 f.close()
406 f.close()
405 try:
407 try:
406 yield
408 yield
407 finally:
409 finally:
408 os.unlink(name)
410 os.unlink(name)
409
411
410
412
411 @contextmanager
413 @contextmanager
412 def monkeypatch(obj, name, attr):
414 def monkeypatch(obj, name, attr):
413 """
415 """
414 Context manager to replace attribute named `name` in `obj` with `attr`.
416 Context manager to replace attribute named `name` in `obj` with `attr`.
415 """
417 """
416 orig = getattr(obj, name)
418 orig = getattr(obj, name)
417 setattr(obj, name, attr)
419 setattr(obj, name, attr)
418 yield
420 yield
419 setattr(obj, name, orig)
421 setattr(obj, name, orig)
420
422
421
423
422 def help_output_test(subcommand=''):
424 def help_output_test(subcommand=''):
423 """test that `ipython [subcommand] -h` works"""
425 """test that `ipython [subcommand] -h` works"""
424 cmd = ' '.join(get_ipython_cmd() + [subcommand, '-h'])
426 cmd = ' '.join(get_ipython_cmd() + [subcommand, '-h'])
425 out, err, rc = get_output_error_code(cmd)
427 out, err, rc = get_output_error_code(cmd)
426 nt.assert_equal(rc, 0, err)
428 nt.assert_equal(rc, 0, err)
427 nt.assert_not_in("Traceback", err)
429 nt.assert_not_in("Traceback", err)
428 nt.assert_in("Options", out)
430 nt.assert_in("Options", out)
429 nt.assert_in("--help-all", out)
431 nt.assert_in("--help-all", out)
430 return out, err
432 return out, err
431
433
432
434
433 def help_all_output_test(subcommand=''):
435 def help_all_output_test(subcommand=''):
434 """test that `ipython [subcommand] --help-all` works"""
436 """test that `ipython [subcommand] --help-all` works"""
435 cmd = ' '.join(get_ipython_cmd() + [subcommand, '--help-all'])
437 cmd = ' '.join(get_ipython_cmd() + [subcommand, '--help-all'])
436 out, err, rc = get_output_error_code(cmd)
438 out, err, rc = get_output_error_code(cmd)
437 nt.assert_equal(rc, 0, err)
439 nt.assert_equal(rc, 0, err)
438 nt.assert_not_in("Traceback", err)
440 nt.assert_not_in("Traceback", err)
439 nt.assert_in("Options", out)
441 nt.assert_in("Options", out)
440 nt.assert_in("Class parameters", out)
442 nt.assert_in("Class parameters", out)
441 return out, err
443 return out, err
442
444
General Comments 0
You need to be logged in to leave comments. Login now