##// END OF EJS Templates
remove mktemp usage...
Julian Taylor -
Show More
@@ -1,350 +1,353 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Paging capabilities for IPython.core
3 Paging capabilities for IPython.core
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez
8 * Fernando Perez
9
9
10 Notes
10 Notes
11 -----
11 -----
12
12
13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
14 rid of that dependency, we could move it there.
14 rid of that dependency, we could move it there.
15 -----
15 -----
16 """
16 """
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Copyright (C) 2008-2011 The IPython Development Team
19 # Copyright (C) 2008-2011 The IPython Development Team
20 #
20 #
21 # Distributed under the terms of the BSD License. The full license is in
21 # Distributed under the terms of the BSD License. The full license is in
22 # the file COPYING, distributed as part of this software.
22 # the file COPYING, distributed as part of this software.
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24
24
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Imports
26 # Imports
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 from __future__ import print_function
28 from __future__ import print_function
29
29
30 import os
30 import os
31 import re
31 import re
32 import sys
32 import sys
33 import tempfile
33 import tempfile
34
34
35 from io import UnsupportedOperation
35 from io import UnsupportedOperation
36
36
37 from IPython import get_ipython
37 from IPython import get_ipython
38 from IPython.core.error import TryNext
38 from IPython.core.error import TryNext
39 from IPython.utils.data import chop
39 from IPython.utils.data import chop
40 from IPython.utils import io
40 from IPython.utils import io
41 from IPython.utils.process import system
41 from IPython.utils.process import system
42 from IPython.utils.terminal import get_terminal_size
42 from IPython.utils.terminal import get_terminal_size
43 from IPython.utils import py3compat
43 from IPython.utils import py3compat
44
44
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Classes and functions
47 # Classes and functions
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50 esc_re = re.compile(r"(\x1b[^m]+m)")
50 esc_re = re.compile(r"(\x1b[^m]+m)")
51
51
52 def page_dumb(strng, start=0, screen_lines=25):
52 def page_dumb(strng, start=0, screen_lines=25):
53 """Very dumb 'pager' in Python, for when nothing else works.
53 """Very dumb 'pager' in Python, for when nothing else works.
54
54
55 Only moves forward, same interface as page(), except for pager_cmd and
55 Only moves forward, same interface as page(), except for pager_cmd and
56 mode."""
56 mode."""
57
57
58 out_ln = strng.splitlines()[start:]
58 out_ln = strng.splitlines()[start:]
59 screens = chop(out_ln,screen_lines-1)
59 screens = chop(out_ln,screen_lines-1)
60 if len(screens) == 1:
60 if len(screens) == 1:
61 print(os.linesep.join(screens[0]), file=io.stdout)
61 print(os.linesep.join(screens[0]), file=io.stdout)
62 else:
62 else:
63 last_escape = ""
63 last_escape = ""
64 for scr in screens[0:-1]:
64 for scr in screens[0:-1]:
65 hunk = os.linesep.join(scr)
65 hunk = os.linesep.join(scr)
66 print(last_escape + hunk, file=io.stdout)
66 print(last_escape + hunk, file=io.stdout)
67 if not page_more():
67 if not page_more():
68 return
68 return
69 esc_list = esc_re.findall(hunk)
69 esc_list = esc_re.findall(hunk)
70 if len(esc_list) > 0:
70 if len(esc_list) > 0:
71 last_escape = esc_list[-1]
71 last_escape = esc_list[-1]
72 print(last_escape + os.linesep.join(screens[-1]), file=io.stdout)
72 print(last_escape + os.linesep.join(screens[-1]), file=io.stdout)
73
73
74 def _detect_screen_size(screen_lines_def):
74 def _detect_screen_size(screen_lines_def):
75 """Attempt to work out the number of lines on the screen.
75 """Attempt to work out the number of lines on the screen.
76
76
77 This is called by page(). It can raise an error (e.g. when run in the
77 This is called by page(). It can raise an error (e.g. when run in the
78 test suite), so it's separated out so it can easily be called in a try block.
78 test suite), so it's separated out so it can easily be called in a try block.
79 """
79 """
80 TERM = os.environ.get('TERM',None)
80 TERM = os.environ.get('TERM',None)
81 if not((TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5'):
81 if not((TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5'):
82 # curses causes problems on many terminals other than xterm, and
82 # curses causes problems on many terminals other than xterm, and
83 # some termios calls lock up on Sun OS5.
83 # some termios calls lock up on Sun OS5.
84 return screen_lines_def
84 return screen_lines_def
85
85
86 try:
86 try:
87 import termios
87 import termios
88 import curses
88 import curses
89 except ImportError:
89 except ImportError:
90 return screen_lines_def
90 return screen_lines_def
91
91
92 # There is a bug in curses, where *sometimes* it fails to properly
92 # There is a bug in curses, where *sometimes* it fails to properly
93 # initialize, and then after the endwin() call is made, the
93 # initialize, and then after the endwin() call is made, the
94 # terminal is left in an unusable state. Rather than trying to
94 # terminal is left in an unusable state. Rather than trying to
95 # check everytime for this (by requesting and comparing termios
95 # check everytime for this (by requesting and comparing termios
96 # flags each time), we just save the initial terminal state and
96 # flags each time), we just save the initial terminal state and
97 # unconditionally reset it every time. It's cheaper than making
97 # unconditionally reset it every time. It's cheaper than making
98 # the checks.
98 # the checks.
99 term_flags = termios.tcgetattr(sys.stdout)
99 term_flags = termios.tcgetattr(sys.stdout)
100
100
101 # Curses modifies the stdout buffer size by default, which messes
101 # Curses modifies the stdout buffer size by default, which messes
102 # up Python's normal stdout buffering. This would manifest itself
102 # up Python's normal stdout buffering. This would manifest itself
103 # to IPython users as delayed printing on stdout after having used
103 # to IPython users as delayed printing on stdout after having used
104 # the pager.
104 # the pager.
105 #
105 #
106 # We can prevent this by manually setting the NCURSES_NO_SETBUF
106 # We can prevent this by manually setting the NCURSES_NO_SETBUF
107 # environment variable. For more details, see:
107 # environment variable. For more details, see:
108 # http://bugs.python.org/issue10144
108 # http://bugs.python.org/issue10144
109 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
109 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
110 os.environ['NCURSES_NO_SETBUF'] = ''
110 os.environ['NCURSES_NO_SETBUF'] = ''
111
111
112 # Proceed with curses initialization
112 # Proceed with curses initialization
113 try:
113 try:
114 scr = curses.initscr()
114 scr = curses.initscr()
115 except AttributeError:
115 except AttributeError:
116 # Curses on Solaris may not be complete, so we can't use it there
116 # Curses on Solaris may not be complete, so we can't use it there
117 return screen_lines_def
117 return screen_lines_def
118
118
119 screen_lines_real,screen_cols = scr.getmaxyx()
119 screen_lines_real,screen_cols = scr.getmaxyx()
120 curses.endwin()
120 curses.endwin()
121
121
122 # Restore environment
122 # Restore environment
123 if NCURSES_NO_SETBUF is None:
123 if NCURSES_NO_SETBUF is None:
124 del os.environ['NCURSES_NO_SETBUF']
124 del os.environ['NCURSES_NO_SETBUF']
125 else:
125 else:
126 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
126 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
127
127
128 # Restore terminal state in case endwin() didn't.
128 # Restore terminal state in case endwin() didn't.
129 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
129 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
130 # Now we have what we needed: the screen size in rows/columns
130 # Now we have what we needed: the screen size in rows/columns
131 return screen_lines_real
131 return screen_lines_real
132 #print '***Screen size:',screen_lines_real,'lines x',\
132 #print '***Screen size:',screen_lines_real,'lines x',\
133 #screen_cols,'columns.' # dbg
133 #screen_cols,'columns.' # dbg
134
134
135 def page(strng, start=0, screen_lines=0, pager_cmd=None):
135 def page(strng, start=0, screen_lines=0, pager_cmd=None):
136 """Print a string, piping through a pager after a certain length.
136 """Print a string, piping through a pager after a certain length.
137
137
138 The screen_lines parameter specifies the number of *usable* lines of your
138 The screen_lines parameter specifies the number of *usable* lines of your
139 terminal screen (total lines minus lines you need to reserve to show other
139 terminal screen (total lines minus lines you need to reserve to show other
140 information).
140 information).
141
141
142 If you set screen_lines to a number <=0, page() will try to auto-determine
142 If you set screen_lines to a number <=0, page() will try to auto-determine
143 your screen size and will only use up to (screen_size+screen_lines) for
143 your screen size and will only use up to (screen_size+screen_lines) for
144 printing, paging after that. That is, if you want auto-detection but need
144 printing, paging after that. That is, if you want auto-detection but need
145 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
145 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
146 auto-detection without any lines reserved simply use screen_lines = 0.
146 auto-detection without any lines reserved simply use screen_lines = 0.
147
147
148 If a string won't fit in the allowed lines, it is sent through the
148 If a string won't fit in the allowed lines, it is sent through the
149 specified pager command. If none given, look for PAGER in the environment,
149 specified pager command. If none given, look for PAGER in the environment,
150 and ultimately default to less.
150 and ultimately default to less.
151
151
152 If no system pager works, the string is sent through a 'dumb pager'
152 If no system pager works, the string is sent through a 'dumb pager'
153 written in python, very simplistic.
153 written in python, very simplistic.
154 """
154 """
155
155
156 # Some routines may auto-compute start offsets incorrectly and pass a
156 # Some routines may auto-compute start offsets incorrectly and pass a
157 # negative value. Offset to 0 for robustness.
157 # negative value. Offset to 0 for robustness.
158 start = max(0, start)
158 start = max(0, start)
159
159
160 # first, try the hook
160 # first, try the hook
161 ip = get_ipython()
161 ip = get_ipython()
162 if ip:
162 if ip:
163 try:
163 try:
164 ip.hooks.show_in_pager(strng)
164 ip.hooks.show_in_pager(strng)
165 return
165 return
166 except TryNext:
166 except TryNext:
167 pass
167 pass
168
168
169 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
169 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
170 TERM = os.environ.get('TERM','dumb')
170 TERM = os.environ.get('TERM','dumb')
171 if TERM in ['dumb','emacs'] and os.name != 'nt':
171 if TERM in ['dumb','emacs'] and os.name != 'nt':
172 print(strng)
172 print(strng)
173 return
173 return
174 # chop off the topmost part of the string we don't want to see
174 # chop off the topmost part of the string we don't want to see
175 str_lines = strng.splitlines()[start:]
175 str_lines = strng.splitlines()[start:]
176 str_toprint = os.linesep.join(str_lines)
176 str_toprint = os.linesep.join(str_lines)
177 num_newlines = len(str_lines)
177 num_newlines = len(str_lines)
178 len_str = len(str_toprint)
178 len_str = len(str_toprint)
179
179
180 # Dumb heuristics to guesstimate number of on-screen lines the string
180 # Dumb heuristics to guesstimate number of on-screen lines the string
181 # takes. Very basic, but good enough for docstrings in reasonable
181 # takes. Very basic, but good enough for docstrings in reasonable
182 # terminals. If someone later feels like refining it, it's not hard.
182 # terminals. If someone later feels like refining it, it's not hard.
183 numlines = max(num_newlines,int(len_str/80)+1)
183 numlines = max(num_newlines,int(len_str/80)+1)
184
184
185 screen_lines_def = get_terminal_size()[1]
185 screen_lines_def = get_terminal_size()[1]
186
186
187 # auto-determine screen size
187 # auto-determine screen size
188 if screen_lines <= 0:
188 if screen_lines <= 0:
189 try:
189 try:
190 screen_lines += _detect_screen_size(screen_lines_def)
190 screen_lines += _detect_screen_size(screen_lines_def)
191 except (TypeError, UnsupportedOperation):
191 except (TypeError, UnsupportedOperation):
192 print(str_toprint, file=io.stdout)
192 print(str_toprint, file=io.stdout)
193 return
193 return
194
194
195 #print 'numlines',numlines,'screenlines',screen_lines # dbg
195 #print 'numlines',numlines,'screenlines',screen_lines # dbg
196 if numlines <= screen_lines :
196 if numlines <= screen_lines :
197 #print '*** normal print' # dbg
197 #print '*** normal print' # dbg
198 print(str_toprint, file=io.stdout)
198 print(str_toprint, file=io.stdout)
199 else:
199 else:
200 # Try to open pager and default to internal one if that fails.
200 # Try to open pager and default to internal one if that fails.
201 # All failure modes are tagged as 'retval=1', to match the return
201 # All failure modes are tagged as 'retval=1', to match the return
202 # value of a failed system command. If any intermediate attempt
202 # value of a failed system command. If any intermediate attempt
203 # sets retval to 1, at the end we resort to our own page_dumb() pager.
203 # sets retval to 1, at the end we resort to our own page_dumb() pager.
204 pager_cmd = get_pager_cmd(pager_cmd)
204 pager_cmd = get_pager_cmd(pager_cmd)
205 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
205 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
206 if os.name == 'nt':
206 if os.name == 'nt':
207 if pager_cmd.startswith('type'):
207 if pager_cmd.startswith('type'):
208 # The default WinXP 'type' command is failing on complex strings.
208 # The default WinXP 'type' command is failing on complex strings.
209 retval = 1
209 retval = 1
210 else:
210 else:
211 tmpname = tempfile.mktemp('.txt')
211 fd, tmpname = tempfile.mkstemp('.txt')
212 tmpfile = open(tmpname,'wt')
212 try:
213 tmpfile.write(strng)
213 os.close(fd)
214 tmpfile.close()
214 with open(tmpname, 'wt') as tmpfile:
215 cmd = "%s < %s" % (pager_cmd,tmpname)
215 tmpfile.write(strng)
216 if os.system(cmd):
216 cmd = "%s < %s" % (pager_cmd, tmpname)
217 retval = 1
217 # tmpfile needs to be closed for windows
218 else:
218 if os.system(cmd):
219 retval = None
219 retval = 1
220 os.remove(tmpname)
220 else:
221 retval = None
222 finally:
223 os.remove(tmpname)
221 else:
224 else:
222 try:
225 try:
223 retval = None
226 retval = None
224 # if I use popen4, things hang. No idea why.
227 # if I use popen4, things hang. No idea why.
225 #pager,shell_out = os.popen4(pager_cmd)
228 #pager,shell_out = os.popen4(pager_cmd)
226 pager = os.popen(pager_cmd, 'w')
229 pager = os.popen(pager_cmd, 'w')
227 try:
230 try:
228 pager_encoding = pager.encoding or sys.stdout.encoding
231 pager_encoding = pager.encoding or sys.stdout.encoding
229 pager.write(py3compat.cast_bytes_py2(
232 pager.write(py3compat.cast_bytes_py2(
230 strng, encoding=pager_encoding))
233 strng, encoding=pager_encoding))
231 finally:
234 finally:
232 retval = pager.close()
235 retval = pager.close()
233 except IOError as msg: # broken pipe when user quits
236 except IOError as msg: # broken pipe when user quits
234 if msg.args == (32, 'Broken pipe'):
237 if msg.args == (32, 'Broken pipe'):
235 retval = None
238 retval = None
236 else:
239 else:
237 retval = 1
240 retval = 1
238 except OSError:
241 except OSError:
239 # Other strange problems, sometimes seen in Win2k/cygwin
242 # Other strange problems, sometimes seen in Win2k/cygwin
240 retval = 1
243 retval = 1
241 if retval is not None:
244 if retval is not None:
242 page_dumb(strng,screen_lines=screen_lines)
245 page_dumb(strng,screen_lines=screen_lines)
243
246
244
247
245 def page_file(fname, start=0, pager_cmd=None):
248 def page_file(fname, start=0, pager_cmd=None):
246 """Page a file, using an optional pager command and starting line.
249 """Page a file, using an optional pager command and starting line.
247 """
250 """
248
251
249 pager_cmd = get_pager_cmd(pager_cmd)
252 pager_cmd = get_pager_cmd(pager_cmd)
250 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
253 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
251
254
252 try:
255 try:
253 if os.environ['TERM'] in ['emacs','dumb']:
256 if os.environ['TERM'] in ['emacs','dumb']:
254 raise EnvironmentError
257 raise EnvironmentError
255 system(pager_cmd + ' ' + fname)
258 system(pager_cmd + ' ' + fname)
256 except:
259 except:
257 try:
260 try:
258 if start > 0:
261 if start > 0:
259 start -= 1
262 start -= 1
260 page(open(fname).read(),start)
263 page(open(fname).read(),start)
261 except:
264 except:
262 print('Unable to show file',repr(fname))
265 print('Unable to show file',repr(fname))
263
266
264
267
265 def get_pager_cmd(pager_cmd=None):
268 def get_pager_cmd(pager_cmd=None):
266 """Return a pager command.
269 """Return a pager command.
267
270
268 Makes some attempts at finding an OS-correct one.
271 Makes some attempts at finding an OS-correct one.
269 """
272 """
270 if os.name == 'posix':
273 if os.name == 'posix':
271 default_pager_cmd = 'less -r' # -r for color control sequences
274 default_pager_cmd = 'less -r' # -r for color control sequences
272 elif os.name in ['nt','dos']:
275 elif os.name in ['nt','dos']:
273 default_pager_cmd = 'type'
276 default_pager_cmd = 'type'
274
277
275 if pager_cmd is None:
278 if pager_cmd is None:
276 try:
279 try:
277 pager_cmd = os.environ['PAGER']
280 pager_cmd = os.environ['PAGER']
278 except:
281 except:
279 pager_cmd = default_pager_cmd
282 pager_cmd = default_pager_cmd
280 return pager_cmd
283 return pager_cmd
281
284
282
285
283 def get_pager_start(pager, start):
286 def get_pager_start(pager, start):
284 """Return the string for paging files with an offset.
287 """Return the string for paging files with an offset.
285
288
286 This is the '+N' argument which less and more (under Unix) accept.
289 This is the '+N' argument which less and more (under Unix) accept.
287 """
290 """
288
291
289 if pager in ['less','more']:
292 if pager in ['less','more']:
290 if start:
293 if start:
291 start_string = '+' + str(start)
294 start_string = '+' + str(start)
292 else:
295 else:
293 start_string = ''
296 start_string = ''
294 else:
297 else:
295 start_string = ''
298 start_string = ''
296 return start_string
299 return start_string
297
300
298
301
299 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
302 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
300 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
303 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
301 import msvcrt
304 import msvcrt
302 def page_more():
305 def page_more():
303 """ Smart pausing between pages
306 """ Smart pausing between pages
304
307
305 @return: True if need print more lines, False if quit
308 @return: True if need print more lines, False if quit
306 """
309 """
307 io.stdout.write('---Return to continue, q to quit--- ')
310 io.stdout.write('---Return to continue, q to quit--- ')
308 ans = msvcrt.getwch()
311 ans = msvcrt.getwch()
309 if ans in ("q", "Q"):
312 if ans in ("q", "Q"):
310 result = False
313 result = False
311 else:
314 else:
312 result = True
315 result = True
313 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
316 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
314 return result
317 return result
315 else:
318 else:
316 def page_more():
319 def page_more():
317 ans = py3compat.input('---Return to continue, q to quit--- ')
320 ans = py3compat.input('---Return to continue, q to quit--- ')
318 if ans.lower().startswith('q'):
321 if ans.lower().startswith('q'):
319 return False
322 return False
320 else:
323 else:
321 return True
324 return True
322
325
323
326
324 def snip_print(str,width = 75,print_full = 0,header = ''):
327 def snip_print(str,width = 75,print_full = 0,header = ''):
325 """Print a string snipping the midsection to fit in width.
328 """Print a string snipping the midsection to fit in width.
326
329
327 print_full: mode control:
330 print_full: mode control:
328
331
329 - 0: only snip long strings
332 - 0: only snip long strings
330 - 1: send to page() directly.
333 - 1: send to page() directly.
331 - 2: snip long strings and ask for full length viewing with page()
334 - 2: snip long strings and ask for full length viewing with page()
332
335
333 Return 1 if snipping was necessary, 0 otherwise."""
336 Return 1 if snipping was necessary, 0 otherwise."""
334
337
335 if print_full == 1:
338 if print_full == 1:
336 page(header+str)
339 page(header+str)
337 return 0
340 return 0
338
341
339 print(header, end=' ')
342 print(header, end=' ')
340 if len(str) < width:
343 if len(str) < width:
341 print(str)
344 print(str)
342 snip = 0
345 snip = 0
343 else:
346 else:
344 whalf = int((width -5)/2)
347 whalf = int((width -5)/2)
345 print(str[:whalf] + ' <...> ' + str[-whalf:])
348 print(str[:whalf] + ' <...> ' + str[-whalf:])
346 snip = 1
349 snip = 1
347 if snip and print_full == 2:
350 if snip and print_full == 2:
348 if py3compat.input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
351 if py3compat.input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
349 page(str)
352 page(str)
350 return snip
353 return snip
@@ -1,559 +1,560 b''
1 """Utilities for connecting to kernels
1 """Utilities for connecting to kernels
2
2
3 Authors:
3 Authors:
4
4
5 * Min Ragan-Kelley
5 * Min Ragan-Kelley
6
6
7 """
7 """
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2013 The IPython Development Team
10 # Copyright (C) 2013 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 from __future__ import absolute_import
20 from __future__ import absolute_import
21
21
22 import glob
22 import glob
23 import json
23 import json
24 import os
24 import os
25 import socket
25 import socket
26 import sys
26 import sys
27 from getpass import getpass
27 from getpass import getpass
28 from subprocess import Popen, PIPE
28 from subprocess import Popen, PIPE
29 import tempfile
29 import tempfile
30
30
31 import zmq
31 import zmq
32
32
33 # external imports
33 # external imports
34 from IPython.external.ssh import tunnel
34 from IPython.external.ssh import tunnel
35
35
36 # IPython imports
36 # IPython imports
37 from IPython.config import Configurable
37 from IPython.config import Configurable
38 from IPython.core.profiledir import ProfileDir
38 from IPython.core.profiledir import ProfileDir
39 from IPython.utils.localinterfaces import localhost
39 from IPython.utils.localinterfaces import localhost
40 from IPython.utils.path import filefind, get_ipython_dir
40 from IPython.utils.path import filefind, get_ipython_dir
41 from IPython.utils.py3compat import (str_to_bytes, bytes_to_str, cast_bytes_py2,
41 from IPython.utils.py3compat import (str_to_bytes, bytes_to_str, cast_bytes_py2,
42 string_types)
42 string_types)
43 from IPython.utils.traitlets import (
43 from IPython.utils.traitlets import (
44 Bool, Integer, Unicode, CaselessStrEnum,
44 Bool, Integer, Unicode, CaselessStrEnum,
45 )
45 )
46
46
47
47
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49 # Working with Connection Files
49 # Working with Connection Files
50 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
51
51
52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
53 control_port=0, ip='', key=b'', transport='tcp',
53 control_port=0, ip='', key=b'', transport='tcp',
54 signature_scheme='hmac-sha256',
54 signature_scheme='hmac-sha256',
55 ):
55 ):
56 """Generates a JSON config file, including the selection of random ports.
56 """Generates a JSON config file, including the selection of random ports.
57
57
58 Parameters
58 Parameters
59 ----------
59 ----------
60
60
61 fname : unicode
61 fname : unicode
62 The path to the file to write
62 The path to the file to write
63
63
64 shell_port : int, optional
64 shell_port : int, optional
65 The port to use for ROUTER (shell) channel.
65 The port to use for ROUTER (shell) channel.
66
66
67 iopub_port : int, optional
67 iopub_port : int, optional
68 The port to use for the SUB channel.
68 The port to use for the SUB channel.
69
69
70 stdin_port : int, optional
70 stdin_port : int, optional
71 The port to use for the ROUTER (raw input) channel.
71 The port to use for the ROUTER (raw input) channel.
72
72
73 control_port : int, optional
73 control_port : int, optional
74 The port to use for the ROUTER (control) channel.
74 The port to use for the ROUTER (control) channel.
75
75
76 hb_port : int, optional
76 hb_port : int, optional
77 The port to use for the heartbeat REP channel.
77 The port to use for the heartbeat REP channel.
78
78
79 ip : str, optional
79 ip : str, optional
80 The ip address the kernel will bind to.
80 The ip address the kernel will bind to.
81
81
82 key : str, optional
82 key : str, optional
83 The Session key used for message authentication.
83 The Session key used for message authentication.
84
84
85 signature_scheme : str, optional
85 signature_scheme : str, optional
86 The scheme used for message authentication.
86 The scheme used for message authentication.
87 This has the form 'digest-hash', where 'digest'
87 This has the form 'digest-hash', where 'digest'
88 is the scheme used for digests, and 'hash' is the name of the hash function
88 is the scheme used for digests, and 'hash' is the name of the hash function
89 used by the digest scheme.
89 used by the digest scheme.
90 Currently, 'hmac' is the only supported digest scheme,
90 Currently, 'hmac' is the only supported digest scheme,
91 and 'sha256' is the default hash function.
91 and 'sha256' is the default hash function.
92
92
93 """
93 """
94 if not ip:
94 if not ip:
95 ip = localhost()
95 ip = localhost()
96 # default to temporary connector file
96 # default to temporary connector file
97 if not fname:
97 if not fname:
98 fname = tempfile.mktemp('.json')
98 fd, fname = tempfile.mkstemp('.json')
99 os.close(fd)
99
100
100 # Find open ports as necessary.
101 # Find open ports as necessary.
101
102
102 ports = []
103 ports = []
103 ports_needed = int(shell_port <= 0) + \
104 ports_needed = int(shell_port <= 0) + \
104 int(iopub_port <= 0) + \
105 int(iopub_port <= 0) + \
105 int(stdin_port <= 0) + \
106 int(stdin_port <= 0) + \
106 int(control_port <= 0) + \
107 int(control_port <= 0) + \
107 int(hb_port <= 0)
108 int(hb_port <= 0)
108 if transport == 'tcp':
109 if transport == 'tcp':
109 for i in range(ports_needed):
110 for i in range(ports_needed):
110 sock = socket.socket()
111 sock = socket.socket()
111 sock.bind(('', 0))
112 sock.bind(('', 0))
112 ports.append(sock)
113 ports.append(sock)
113 for i, sock in enumerate(ports):
114 for i, sock in enumerate(ports):
114 port = sock.getsockname()[1]
115 port = sock.getsockname()[1]
115 sock.close()
116 sock.close()
116 ports[i] = port
117 ports[i] = port
117 else:
118 else:
118 N = 1
119 N = 1
119 for i in range(ports_needed):
120 for i in range(ports_needed):
120 while os.path.exists("%s-%s" % (ip, str(N))):
121 while os.path.exists("%s-%s" % (ip, str(N))):
121 N += 1
122 N += 1
122 ports.append(N)
123 ports.append(N)
123 N += 1
124 N += 1
124 if shell_port <= 0:
125 if shell_port <= 0:
125 shell_port = ports.pop(0)
126 shell_port = ports.pop(0)
126 if iopub_port <= 0:
127 if iopub_port <= 0:
127 iopub_port = ports.pop(0)
128 iopub_port = ports.pop(0)
128 if stdin_port <= 0:
129 if stdin_port <= 0:
129 stdin_port = ports.pop(0)
130 stdin_port = ports.pop(0)
130 if control_port <= 0:
131 if control_port <= 0:
131 control_port = ports.pop(0)
132 control_port = ports.pop(0)
132 if hb_port <= 0:
133 if hb_port <= 0:
133 hb_port = ports.pop(0)
134 hb_port = ports.pop(0)
134
135
135 cfg = dict( shell_port=shell_port,
136 cfg = dict( shell_port=shell_port,
136 iopub_port=iopub_port,
137 iopub_port=iopub_port,
137 stdin_port=stdin_port,
138 stdin_port=stdin_port,
138 control_port=control_port,
139 control_port=control_port,
139 hb_port=hb_port,
140 hb_port=hb_port,
140 )
141 )
141 cfg['ip'] = ip
142 cfg['ip'] = ip
142 cfg['key'] = bytes_to_str(key)
143 cfg['key'] = bytes_to_str(key)
143 cfg['transport'] = transport
144 cfg['transport'] = transport
144 cfg['signature_scheme'] = signature_scheme
145 cfg['signature_scheme'] = signature_scheme
145
146
146 with open(fname, 'w') as f:
147 with open(fname, 'w') as f:
147 f.write(json.dumps(cfg, indent=2))
148 f.write(json.dumps(cfg, indent=2))
148
149
149 return fname, cfg
150 return fname, cfg
150
151
151
152
152 def get_connection_file(app=None):
153 def get_connection_file(app=None):
153 """Return the path to the connection file of an app
154 """Return the path to the connection file of an app
154
155
155 Parameters
156 Parameters
156 ----------
157 ----------
157 app : IPKernelApp instance [optional]
158 app : IPKernelApp instance [optional]
158 If unspecified, the currently running app will be used
159 If unspecified, the currently running app will be used
159 """
160 """
160 if app is None:
161 if app is None:
161 from IPython.kernel.zmq.kernelapp import IPKernelApp
162 from IPython.kernel.zmq.kernelapp import IPKernelApp
162 if not IPKernelApp.initialized():
163 if not IPKernelApp.initialized():
163 raise RuntimeError("app not specified, and not in a running Kernel")
164 raise RuntimeError("app not specified, and not in a running Kernel")
164
165
165 app = IPKernelApp.instance()
166 app = IPKernelApp.instance()
166 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
167 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
167
168
168
169
169 def find_connection_file(filename, profile=None):
170 def find_connection_file(filename, profile=None):
170 """find a connection file, and return its absolute path.
171 """find a connection file, and return its absolute path.
171
172
172 The current working directory and the profile's security
173 The current working directory and the profile's security
173 directory will be searched for the file if it is not given by
174 directory will be searched for the file if it is not given by
174 absolute path.
175 absolute path.
175
176
176 If profile is unspecified, then the current running application's
177 If profile is unspecified, then the current running application's
177 profile will be used, or 'default', if not run from IPython.
178 profile will be used, or 'default', if not run from IPython.
178
179
179 If the argument does not match an existing file, it will be interpreted as a
180 If the argument does not match an existing file, it will be interpreted as a
180 fileglob, and the matching file in the profile's security dir with
181 fileglob, and the matching file in the profile's security dir with
181 the latest access time will be used.
182 the latest access time will be used.
182
183
183 Parameters
184 Parameters
184 ----------
185 ----------
185 filename : str
186 filename : str
186 The connection file or fileglob to search for.
187 The connection file or fileglob to search for.
187 profile : str [optional]
188 profile : str [optional]
188 The name of the profile to use when searching for the connection file,
189 The name of the profile to use when searching for the connection file,
189 if different from the current IPython session or 'default'.
190 if different from the current IPython session or 'default'.
190
191
191 Returns
192 Returns
192 -------
193 -------
193 str : The absolute path of the connection file.
194 str : The absolute path of the connection file.
194 """
195 """
195 from IPython.core.application import BaseIPythonApplication as IPApp
196 from IPython.core.application import BaseIPythonApplication as IPApp
196 try:
197 try:
197 # quick check for absolute path, before going through logic
198 # quick check for absolute path, before going through logic
198 return filefind(filename)
199 return filefind(filename)
199 except IOError:
200 except IOError:
200 pass
201 pass
201
202
202 if profile is None:
203 if profile is None:
203 # profile unspecified, check if running from an IPython app
204 # profile unspecified, check if running from an IPython app
204 if IPApp.initialized():
205 if IPApp.initialized():
205 app = IPApp.instance()
206 app = IPApp.instance()
206 profile_dir = app.profile_dir
207 profile_dir = app.profile_dir
207 else:
208 else:
208 # not running in IPython, use default profile
209 # not running in IPython, use default profile
209 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
210 else:
211 else:
211 # find profiledir by profile name:
212 # find profiledir by profile name:
212 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
213 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
213 security_dir = profile_dir.security_dir
214 security_dir = profile_dir.security_dir
214
215
215 try:
216 try:
216 # first, try explicit name
217 # first, try explicit name
217 return filefind(filename, ['.', security_dir])
218 return filefind(filename, ['.', security_dir])
218 except IOError:
219 except IOError:
219 pass
220 pass
220
221
221 # not found by full name
222 # not found by full name
222
223
223 if '*' in filename:
224 if '*' in filename:
224 # given as a glob already
225 # given as a glob already
225 pat = filename
226 pat = filename
226 else:
227 else:
227 # accept any substring match
228 # accept any substring match
228 pat = '*%s*' % filename
229 pat = '*%s*' % filename
229 matches = glob.glob( os.path.join(security_dir, pat) )
230 matches = glob.glob( os.path.join(security_dir, pat) )
230 if not matches:
231 if not matches:
231 raise IOError("Could not find %r in %r" % (filename, security_dir))
232 raise IOError("Could not find %r in %r" % (filename, security_dir))
232 elif len(matches) == 1:
233 elif len(matches) == 1:
233 return matches[0]
234 return matches[0]
234 else:
235 else:
235 # get most recent match, by access time:
236 # get most recent match, by access time:
236 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
237 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
237
238
238
239
239 def get_connection_info(connection_file=None, unpack=False, profile=None):
240 def get_connection_info(connection_file=None, unpack=False, profile=None):
240 """Return the connection information for the current Kernel.
241 """Return the connection information for the current Kernel.
241
242
242 Parameters
243 Parameters
243 ----------
244 ----------
244 connection_file : str [optional]
245 connection_file : str [optional]
245 The connection file to be used. Can be given by absolute path, or
246 The connection file to be used. Can be given by absolute path, or
246 IPython will search in the security directory of a given profile.
247 IPython will search in the security directory of a given profile.
247 If run from IPython,
248 If run from IPython,
248
249
249 If unspecified, the connection file for the currently running
250 If unspecified, the connection file for the currently running
250 IPython Kernel will be used, which is only allowed from inside a kernel.
251 IPython Kernel will be used, which is only allowed from inside a kernel.
251 unpack : bool [default: False]
252 unpack : bool [default: False]
252 if True, return the unpacked dict, otherwise just the string contents
253 if True, return the unpacked dict, otherwise just the string contents
253 of the file.
254 of the file.
254 profile : str [optional]
255 profile : str [optional]
255 The name of the profile to use when searching for the connection file,
256 The name of the profile to use when searching for the connection file,
256 if different from the current IPython session or 'default'.
257 if different from the current IPython session or 'default'.
257
258
258
259
259 Returns
260 Returns
260 -------
261 -------
261 The connection dictionary of the current kernel, as string or dict,
262 The connection dictionary of the current kernel, as string or dict,
262 depending on `unpack`.
263 depending on `unpack`.
263 """
264 """
264 if connection_file is None:
265 if connection_file is None:
265 # get connection file from current kernel
266 # get connection file from current kernel
266 cf = get_connection_file()
267 cf = get_connection_file()
267 else:
268 else:
268 # connection file specified, allow shortnames:
269 # connection file specified, allow shortnames:
269 cf = find_connection_file(connection_file, profile=profile)
270 cf = find_connection_file(connection_file, profile=profile)
270
271
271 with open(cf) as f:
272 with open(cf) as f:
272 info = f.read()
273 info = f.read()
273
274
274 if unpack:
275 if unpack:
275 info = json.loads(info)
276 info = json.loads(info)
276 # ensure key is bytes:
277 # ensure key is bytes:
277 info['key'] = str_to_bytes(info.get('key', ''))
278 info['key'] = str_to_bytes(info.get('key', ''))
278 return info
279 return info
279
280
280
281
281 def connect_qtconsole(connection_file=None, argv=None, profile=None):
282 def connect_qtconsole(connection_file=None, argv=None, profile=None):
282 """Connect a qtconsole to the current kernel.
283 """Connect a qtconsole to the current kernel.
283
284
284 This is useful for connecting a second qtconsole to a kernel, or to a
285 This is useful for connecting a second qtconsole to a kernel, or to a
285 local notebook.
286 local notebook.
286
287
287 Parameters
288 Parameters
288 ----------
289 ----------
289 connection_file : str [optional]
290 connection_file : str [optional]
290 The connection file to be used. Can be given by absolute path, or
291 The connection file to be used. Can be given by absolute path, or
291 IPython will search in the security directory of a given profile.
292 IPython will search in the security directory of a given profile.
292 If run from IPython,
293 If run from IPython,
293
294
294 If unspecified, the connection file for the currently running
295 If unspecified, the connection file for the currently running
295 IPython Kernel will be used, which is only allowed from inside a kernel.
296 IPython Kernel will be used, which is only allowed from inside a kernel.
296 argv : list [optional]
297 argv : list [optional]
297 Any extra args to be passed to the console.
298 Any extra args to be passed to the console.
298 profile : str [optional]
299 profile : str [optional]
299 The name of the profile to use when searching for the connection file,
300 The name of the profile to use when searching for the connection file,
300 if different from the current IPython session or 'default'.
301 if different from the current IPython session or 'default'.
301
302
302
303
303 Returns
304 Returns
304 -------
305 -------
305 subprocess.Popen instance running the qtconsole frontend
306 subprocess.Popen instance running the qtconsole frontend
306 """
307 """
307 argv = [] if argv is None else argv
308 argv = [] if argv is None else argv
308
309
309 if connection_file is None:
310 if connection_file is None:
310 # get connection file from current kernel
311 # get connection file from current kernel
311 cf = get_connection_file()
312 cf = get_connection_file()
312 else:
313 else:
313 cf = find_connection_file(connection_file, profile=profile)
314 cf = find_connection_file(connection_file, profile=profile)
314
315
315 cmd = ';'.join([
316 cmd = ';'.join([
316 "from IPython.qt.console import qtconsoleapp",
317 "from IPython.qt.console import qtconsoleapp",
317 "qtconsoleapp.main()"
318 "qtconsoleapp.main()"
318 ])
319 ])
319
320
320 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
321 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
321 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
322 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
322 )
323 )
323
324
324
325
325 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
326 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
326 """tunnel connections to a kernel via ssh
327 """tunnel connections to a kernel via ssh
327
328
328 This will open four SSH tunnels from localhost on this machine to the
329 This will open four SSH tunnels from localhost on this machine to the
329 ports associated with the kernel. They can be either direct
330 ports associated with the kernel. They can be either direct
330 localhost-localhost tunnels, or if an intermediate server is necessary,
331 localhost-localhost tunnels, or if an intermediate server is necessary,
331 the kernel must be listening on a public IP.
332 the kernel must be listening on a public IP.
332
333
333 Parameters
334 Parameters
334 ----------
335 ----------
335 connection_info : dict or str (path)
336 connection_info : dict or str (path)
336 Either a connection dict, or the path to a JSON connection file
337 Either a connection dict, or the path to a JSON connection file
337 sshserver : str
338 sshserver : str
338 The ssh sever to use to tunnel to the kernel. Can be a full
339 The ssh sever to use to tunnel to the kernel. Can be a full
339 `user@server:port` string. ssh config aliases are respected.
340 `user@server:port` string. ssh config aliases are respected.
340 sshkey : str [optional]
341 sshkey : str [optional]
341 Path to file containing ssh key to use for authentication.
342 Path to file containing ssh key to use for authentication.
342 Only necessary if your ssh config does not already associate
343 Only necessary if your ssh config does not already associate
343 a keyfile with the host.
344 a keyfile with the host.
344
345
345 Returns
346 Returns
346 -------
347 -------
347
348
348 (shell, iopub, stdin, hb) : ints
349 (shell, iopub, stdin, hb) : ints
349 The four ports on localhost that have been forwarded to the kernel.
350 The four ports on localhost that have been forwarded to the kernel.
350 """
351 """
351 if isinstance(connection_info, string_types):
352 if isinstance(connection_info, string_types):
352 # it's a path, unpack it
353 # it's a path, unpack it
353 with open(connection_info) as f:
354 with open(connection_info) as f:
354 connection_info = json.loads(f.read())
355 connection_info = json.loads(f.read())
355
356
356 cf = connection_info
357 cf = connection_info
357
358
358 lports = tunnel.select_random_ports(4)
359 lports = tunnel.select_random_ports(4)
359 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
360 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
360
361
361 remote_ip = cf['ip']
362 remote_ip = cf['ip']
362
363
363 if tunnel.try_passwordless_ssh(sshserver, sshkey):
364 if tunnel.try_passwordless_ssh(sshserver, sshkey):
364 password=False
365 password=False
365 else:
366 else:
366 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
367 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
367
368
368 for lp,rp in zip(lports, rports):
369 for lp,rp in zip(lports, rports):
369 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
370 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
370
371
371 return tuple(lports)
372 return tuple(lports)
372
373
373
374
374 #-----------------------------------------------------------------------------
375 #-----------------------------------------------------------------------------
375 # Mixin for classes that work with connection files
376 # Mixin for classes that work with connection files
376 #-----------------------------------------------------------------------------
377 #-----------------------------------------------------------------------------
377
378
378 channel_socket_types = {
379 channel_socket_types = {
379 'hb' : zmq.REQ,
380 'hb' : zmq.REQ,
380 'shell' : zmq.DEALER,
381 'shell' : zmq.DEALER,
381 'iopub' : zmq.SUB,
382 'iopub' : zmq.SUB,
382 'stdin' : zmq.DEALER,
383 'stdin' : zmq.DEALER,
383 'control': zmq.DEALER,
384 'control': zmq.DEALER,
384 }
385 }
385
386
386 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
387 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
387
388
388 class ConnectionFileMixin(Configurable):
389 class ConnectionFileMixin(Configurable):
389 """Mixin for configurable classes that work with connection files"""
390 """Mixin for configurable classes that work with connection files"""
390
391
391 # The addresses for the communication channels
392 # The addresses for the communication channels
392 connection_file = Unicode('')
393 connection_file = Unicode('')
393 _connection_file_written = Bool(False)
394 _connection_file_written = Bool(False)
394
395
395 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
396 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
396
397
397 ip = Unicode(config=True,
398 ip = Unicode(config=True,
398 help="""Set the kernel\'s IP address [default localhost].
399 help="""Set the kernel\'s IP address [default localhost].
399 If the IP address is something other than localhost, then
400 If the IP address is something other than localhost, then
400 Consoles on other machines will be able to connect
401 Consoles on other machines will be able to connect
401 to the Kernel, so be careful!"""
402 to the Kernel, so be careful!"""
402 )
403 )
403
404
404 def _ip_default(self):
405 def _ip_default(self):
405 if self.transport == 'ipc':
406 if self.transport == 'ipc':
406 if self.connection_file:
407 if self.connection_file:
407 return os.path.splitext(self.connection_file)[0] + '-ipc'
408 return os.path.splitext(self.connection_file)[0] + '-ipc'
408 else:
409 else:
409 return 'kernel-ipc'
410 return 'kernel-ipc'
410 else:
411 else:
411 return localhost()
412 return localhost()
412
413
413 def _ip_changed(self, name, old, new):
414 def _ip_changed(self, name, old, new):
414 if new == '*':
415 if new == '*':
415 self.ip = '0.0.0.0'
416 self.ip = '0.0.0.0'
416
417
417 # protected traits
418 # protected traits
418
419
419 shell_port = Integer(0)
420 shell_port = Integer(0)
420 iopub_port = Integer(0)
421 iopub_port = Integer(0)
421 stdin_port = Integer(0)
422 stdin_port = Integer(0)
422 control_port = Integer(0)
423 control_port = Integer(0)
423 hb_port = Integer(0)
424 hb_port = Integer(0)
424
425
425 @property
426 @property
426 def ports(self):
427 def ports(self):
427 return [ getattr(self, name) for name in port_names ]
428 return [ getattr(self, name) for name in port_names ]
428
429
429 #--------------------------------------------------------------------------
430 #--------------------------------------------------------------------------
430 # Connection and ipc file management
431 # Connection and ipc file management
431 #--------------------------------------------------------------------------
432 #--------------------------------------------------------------------------
432
433
433 def get_connection_info(self):
434 def get_connection_info(self):
434 """return the connection info as a dict"""
435 """return the connection info as a dict"""
435 return dict(
436 return dict(
436 transport=self.transport,
437 transport=self.transport,
437 ip=self.ip,
438 ip=self.ip,
438 shell_port=self.shell_port,
439 shell_port=self.shell_port,
439 iopub_port=self.iopub_port,
440 iopub_port=self.iopub_port,
440 stdin_port=self.stdin_port,
441 stdin_port=self.stdin_port,
441 hb_port=self.hb_port,
442 hb_port=self.hb_port,
442 control_port=self.control_port,
443 control_port=self.control_port,
443 signature_scheme=self.session.signature_scheme,
444 signature_scheme=self.session.signature_scheme,
444 key=self.session.key,
445 key=self.session.key,
445 )
446 )
446
447
447 def cleanup_connection_file(self):
448 def cleanup_connection_file(self):
448 """Cleanup connection file *if we wrote it*
449 """Cleanup connection file *if we wrote it*
449
450
450 Will not raise if the connection file was already removed somehow.
451 Will not raise if the connection file was already removed somehow.
451 """
452 """
452 if self._connection_file_written:
453 if self._connection_file_written:
453 # cleanup connection files on full shutdown of kernel we started
454 # cleanup connection files on full shutdown of kernel we started
454 self._connection_file_written = False
455 self._connection_file_written = False
455 try:
456 try:
456 os.remove(self.connection_file)
457 os.remove(self.connection_file)
457 except (IOError, OSError, AttributeError):
458 except (IOError, OSError, AttributeError):
458 pass
459 pass
459
460
460 def cleanup_ipc_files(self):
461 def cleanup_ipc_files(self):
461 """Cleanup ipc files if we wrote them."""
462 """Cleanup ipc files if we wrote them."""
462 if self.transport != 'ipc':
463 if self.transport != 'ipc':
463 return
464 return
464 for port in self.ports:
465 for port in self.ports:
465 ipcfile = "%s-%i" % (self.ip, port)
466 ipcfile = "%s-%i" % (self.ip, port)
466 try:
467 try:
467 os.remove(ipcfile)
468 os.remove(ipcfile)
468 except (IOError, OSError):
469 except (IOError, OSError):
469 pass
470 pass
470
471
471 def write_connection_file(self):
472 def write_connection_file(self):
472 """Write connection info to JSON dict in self.connection_file."""
473 """Write connection info to JSON dict in self.connection_file."""
473 if self._connection_file_written:
474 if self._connection_file_written:
474 return
475 return
475
476
476 self.connection_file, cfg = write_connection_file(self.connection_file,
477 self.connection_file, cfg = write_connection_file(self.connection_file,
477 transport=self.transport, ip=self.ip, key=self.session.key,
478 transport=self.transport, ip=self.ip, key=self.session.key,
478 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
479 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
479 shell_port=self.shell_port, hb_port=self.hb_port,
480 shell_port=self.shell_port, hb_port=self.hb_port,
480 control_port=self.control_port,
481 control_port=self.control_port,
481 signature_scheme=self.session.signature_scheme,
482 signature_scheme=self.session.signature_scheme,
482 )
483 )
483 # write_connection_file also sets default ports:
484 # write_connection_file also sets default ports:
484 for name in port_names:
485 for name in port_names:
485 setattr(self, name, cfg[name])
486 setattr(self, name, cfg[name])
486
487
487 self._connection_file_written = True
488 self._connection_file_written = True
488
489
489 def load_connection_file(self):
490 def load_connection_file(self):
490 """Load connection info from JSON dict in self.connection_file."""
491 """Load connection info from JSON dict in self.connection_file."""
491 with open(self.connection_file) as f:
492 with open(self.connection_file) as f:
492 cfg = json.loads(f.read())
493 cfg = json.loads(f.read())
493
494
494 self.transport = cfg.get('transport', 'tcp')
495 self.transport = cfg.get('transport', 'tcp')
495 self.ip = cfg['ip']
496 self.ip = cfg['ip']
496 for name in port_names:
497 for name in port_names:
497 setattr(self, name, cfg[name])
498 setattr(self, name, cfg[name])
498 if 'key' in cfg:
499 if 'key' in cfg:
499 self.session.key = str_to_bytes(cfg['key'])
500 self.session.key = str_to_bytes(cfg['key'])
500 if cfg.get('signature_scheme'):
501 if cfg.get('signature_scheme'):
501 self.session.signature_scheme = cfg['signature_scheme']
502 self.session.signature_scheme = cfg['signature_scheme']
502
503
503 #--------------------------------------------------------------------------
504 #--------------------------------------------------------------------------
504 # Creating connected sockets
505 # Creating connected sockets
505 #--------------------------------------------------------------------------
506 #--------------------------------------------------------------------------
506
507
507 def _make_url(self, channel):
508 def _make_url(self, channel):
508 """Make a ZeroMQ URL for a given channel."""
509 """Make a ZeroMQ URL for a given channel."""
509 transport = self.transport
510 transport = self.transport
510 ip = self.ip
511 ip = self.ip
511 port = getattr(self, '%s_port' % channel)
512 port = getattr(self, '%s_port' % channel)
512
513
513 if transport == 'tcp':
514 if transport == 'tcp':
514 return "tcp://%s:%i" % (ip, port)
515 return "tcp://%s:%i" % (ip, port)
515 else:
516 else:
516 return "%s://%s-%s" % (transport, ip, port)
517 return "%s://%s-%s" % (transport, ip, port)
517
518
518 def _create_connected_socket(self, channel, identity=None):
519 def _create_connected_socket(self, channel, identity=None):
519 """Create a zmq Socket and connect it to the kernel."""
520 """Create a zmq Socket and connect it to the kernel."""
520 url = self._make_url(channel)
521 url = self._make_url(channel)
521 socket_type = channel_socket_types[channel]
522 socket_type = channel_socket_types[channel]
522 self.log.debug("Connecting to: %s" % url)
523 self.log.debug("Connecting to: %s" % url)
523 sock = self.context.socket(socket_type)
524 sock = self.context.socket(socket_type)
524 if identity:
525 if identity:
525 sock.identity = identity
526 sock.identity = identity
526 sock.connect(url)
527 sock.connect(url)
527 return sock
528 return sock
528
529
529 def connect_iopub(self, identity=None):
530 def connect_iopub(self, identity=None):
530 """return zmq Socket connected to the IOPub channel"""
531 """return zmq Socket connected to the IOPub channel"""
531 sock = self._create_connected_socket('iopub', identity=identity)
532 sock = self._create_connected_socket('iopub', identity=identity)
532 sock.setsockopt(zmq.SUBSCRIBE, b'')
533 sock.setsockopt(zmq.SUBSCRIBE, b'')
533 return sock
534 return sock
534
535
535 def connect_shell(self, identity=None):
536 def connect_shell(self, identity=None):
536 """return zmq Socket connected to the Shell channel"""
537 """return zmq Socket connected to the Shell channel"""
537 return self._create_connected_socket('shell', identity=identity)
538 return self._create_connected_socket('shell', identity=identity)
538
539
539 def connect_stdin(self, identity=None):
540 def connect_stdin(self, identity=None):
540 """return zmq Socket connected to the StdIn channel"""
541 """return zmq Socket connected to the StdIn channel"""
541 return self._create_connected_socket('stdin', identity=identity)
542 return self._create_connected_socket('stdin', identity=identity)
542
543
543 def connect_hb(self, identity=None):
544 def connect_hb(self, identity=None):
544 """return zmq Socket connected to the Heartbeat channel"""
545 """return zmq Socket connected to the Heartbeat channel"""
545 return self._create_connected_socket('hb', identity=identity)
546 return self._create_connected_socket('hb', identity=identity)
546
547
547 def connect_control(self, identity=None):
548 def connect_control(self, identity=None):
548 """return zmq Socket connected to the Heartbeat channel"""
549 """return zmq Socket connected to the Heartbeat channel"""
549 return self._create_connected_socket('control', identity=identity)
550 return self._create_connected_socket('control', identity=identity)
550
551
551
552
552 __all__ = [
553 __all__ = [
553 'write_connection_file',
554 'write_connection_file',
554 'get_connection_file',
555 'get_connection_file',
555 'find_connection_file',
556 'find_connection_file',
556 'get_connection_info',
557 'get_connection_info',
557 'connect_qtconsole',
558 'connect_qtconsole',
558 'tunnel_to_kernel',
559 'tunnel_to_kernel',
559 ]
560 ]
@@ -1,547 +1,546 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
24
23
25 import zmq
24 import zmq
26
25
27 from IPython import parallel
26 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
28 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
32
31
33 from .clienttest import ClusterTestCase, segfault, wait, add_engines
32 from .clienttest import ClusterTestCase, segfault, wait, add_engines
34
33
35 def setup():
34 def setup():
36 add_engines(4, total=True)
35 add_engines(4, total=True)
37
36
38 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
39
38
40 def test_ids(self):
39 def test_ids(self):
41 n = len(self.client.ids)
40 n = len(self.client.ids)
42 self.add_engines(2)
41 self.add_engines(2)
43 self.assertEqual(len(self.client.ids), n+2)
42 self.assertEqual(len(self.client.ids), n+2)
44
43
45 def test_view_indexing(self):
44 def test_view_indexing(self):
46 """test index access for views"""
45 """test index access for views"""
47 self.minimum_engines(4)
46 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
48 v = self.client[:]
50 self.assertEqual(v.targets, targets)
49 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
50 t = self.client.ids[2]
52 v = self.client[t]
51 v = self.client[t]
53 self.assertTrue(isinstance(v, DirectView))
52 self.assertTrue(isinstance(v, DirectView))
54 self.assertEqual(v.targets, t)
53 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
56 v = self.client[t]
55 v = self.client[t]
57 self.assertTrue(isinstance(v, DirectView))
56 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, t)
57 self.assertEqual(v.targets, t)
59 v = self.client[::2]
58 v = self.client[::2]
60 self.assertTrue(isinstance(v, DirectView))
59 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[::2])
60 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
61 v = self.client[1::3]
63 self.assertTrue(isinstance(v, DirectView))
62 self.assertTrue(isinstance(v, DirectView))
64 self.assertEqual(v.targets, targets[1::3])
63 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
64 v = self.client[:-3]
66 self.assertTrue(isinstance(v, DirectView))
65 self.assertTrue(isinstance(v, DirectView))
67 self.assertEqual(v.targets, targets[:-3])
66 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
67 v = self.client[-1]
69 self.assertTrue(isinstance(v, DirectView))
68 self.assertTrue(isinstance(v, DirectView))
70 self.assertEqual(v.targets, targets[-1])
69 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
72
71
73 def test_lbview_targets(self):
72 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
76 self.assertEqual(v.targets, None)
75 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
78 self.assertEqual(v.targets, [self.client.ids[-1]])
77 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
80 self.assertEqual(v.targets, None)
79 self.assertEqual(v.targets, None)
81
80
82 def test_dview_targets(self):
81 def test_dview_targets(self):
83 """test direct_view targets"""
82 """test direct_view targets"""
84 v = self.client.direct_view()
83 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
84 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
85 v = self.client.direct_view('all')
87 self.assertEqual(v.targets, 'all')
86 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
87 v = self.client.direct_view(-1)
89 self.assertEqual(v.targets, self.client.ids[-1])
88 self.assertEqual(v.targets, self.client.ids[-1])
90
89
91 def test_lazy_all_targets(self):
90 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
91 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
92 v = self.client.direct_view()
94 self.assertEqual(v.targets, 'all')
93 self.assertEqual(v.targets, 'all')
95
94
96 def double(x):
95 def double(x):
97 return x*2
96 return x*2
98 seq = list(range(100))
97 seq = list(range(100))
99 ref = [ double(x) for x in seq ]
98 ref = [ double(x) for x in seq ]
100
99
101 # add some engines, which should be used
100 # add some engines, which should be used
102 self.add_engines(1)
101 self.add_engines(1)
103 n1 = len(self.client.ids)
102 n1 = len(self.client.ids)
104
103
105 # simple apply
104 # simple apply
106 r = v.apply_sync(lambda : 1)
105 r = v.apply_sync(lambda : 1)
107 self.assertEqual(r, [1] * n1)
106 self.assertEqual(r, [1] * n1)
108
107
109 # map goes through remotefunction
108 # map goes through remotefunction
110 r = v.map_sync(double, seq)
109 r = v.map_sync(double, seq)
111 self.assertEqual(r, ref)
110 self.assertEqual(r, ref)
112
111
113 # add a couple more engines, and try again
112 # add a couple more engines, and try again
114 self.add_engines(2)
113 self.add_engines(2)
115 n2 = len(self.client.ids)
114 n2 = len(self.client.ids)
116 self.assertNotEqual(n2, n1)
115 self.assertNotEqual(n2, n1)
117
116
118 # apply
117 # apply
119 r = v.apply_sync(lambda : 1)
118 r = v.apply_sync(lambda : 1)
120 self.assertEqual(r, [1] * n2)
119 self.assertEqual(r, [1] * n2)
121
120
122 # map
121 # map
123 r = v.map_sync(double, seq)
122 r = v.map_sync(double, seq)
124 self.assertEqual(r, ref)
123 self.assertEqual(r, ref)
125
124
126 def test_targets(self):
125 def test_targets(self):
127 """test various valid targets arguments"""
126 """test various valid targets arguments"""
128 build = self.client._build_targets
127 build = self.client._build_targets
129 ids = self.client.ids
128 ids = self.client.ids
130 idents,targets = build(None)
129 idents,targets = build(None)
131 self.assertEqual(ids, targets)
130 self.assertEqual(ids, targets)
132
131
133 def test_clear(self):
132 def test_clear(self):
134 """test clear behavior"""
133 """test clear behavior"""
135 self.minimum_engines(2)
134 self.minimum_engines(2)
136 v = self.client[:]
135 v = self.client[:]
137 v.block=True
136 v.block=True
138 v.push(dict(a=5))
137 v.push(dict(a=5))
139 v.pull('a')
138 v.pull('a')
140 id0 = self.client.ids[-1]
139 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
140 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
141 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
143 self.client.clear(block=True)
145 for i in self.client.ids:
144 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
146
148 def test_get_result(self):
147 def test_get_result(self):
149 """test getting results from the Hub."""
148 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
149 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
150 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
151 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
152 # give the monitor time to notice the message
154 time.sleep(.25)
153 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids[0])
154 ahr = self.client.get_result(ar.msg_ids[0])
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids[0])
157 ar2 = self.client.get_result(ar.msg_ids[0])
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
159 c.close()
161
160
162 def test_get_execute_result(self):
161 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
162 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
163 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
164 t = c.ids[-1]
166 cell = '\n'.join([
165 cell = '\n'.join([
167 'import time',
166 'import time',
168 'time.sleep(0.25)',
167 'time.sleep(0.25)',
169 '5'
168 '5'
170 ])
169 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
170 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
171 # give the monitor time to notice the message
173 time.sleep(.25)
172 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids[0])
173 ahr = self.client.get_result(ar.msg_ids[0])
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
174 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
175 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids[0])
176 ar2 = self.client.get_result(ar.msg_ids[0])
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
177 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
178 c.close()
180
179
181 def test_ids_list(self):
180 def test_ids_list(self):
182 """test client.ids"""
181 """test client.ids"""
183 ids = self.client.ids
182 ids = self.client.ids
184 self.assertEqual(ids, self.client._ids)
183 self.assertEqual(ids, self.client._ids)
185 self.assertFalse(ids is self.client._ids)
184 self.assertFalse(ids is self.client._ids)
186 ids.remove(ids[-1])
185 ids.remove(ids[-1])
187 self.assertNotEqual(ids, self.client._ids)
186 self.assertNotEqual(ids, self.client._ids)
188
187
189 def test_queue_status(self):
188 def test_queue_status(self):
190 ids = self.client.ids
189 ids = self.client.ids
191 id0 = ids[0]
190 id0 = ids[0]
192 qs = self.client.queue_status(targets=id0)
191 qs = self.client.queue_status(targets=id0)
193 self.assertTrue(isinstance(qs, dict))
192 self.assertTrue(isinstance(qs, dict))
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
193 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 allqs = self.client.queue_status()
194 allqs = self.client.queue_status()
196 self.assertTrue(isinstance(allqs, dict))
195 self.assertTrue(isinstance(allqs, dict))
197 intkeys = list(allqs.keys())
196 intkeys = list(allqs.keys())
198 intkeys.remove('unassigned')
197 intkeys.remove('unassigned')
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
198 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 unassigned = allqs.pop('unassigned')
199 unassigned = allqs.pop('unassigned')
201 for eid,qs in allqs.items():
200 for eid,qs in allqs.items():
202 self.assertTrue(isinstance(qs, dict))
201 self.assertTrue(isinstance(qs, dict))
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
202 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204
203
205 def test_shutdown(self):
204 def test_shutdown(self):
206 ids = self.client.ids
205 ids = self.client.ids
207 id0 = ids[0]
206 id0 = ids[0]
208 self.client.shutdown(id0, block=True)
207 self.client.shutdown(id0, block=True)
209 while id0 in self.client.ids:
208 while id0 in self.client.ids:
210 time.sleep(0.1)
209 time.sleep(0.1)
211 self.client.spin()
210 self.client.spin()
212
211
213 self.assertRaises(IndexError, lambda : self.client[id0])
212 self.assertRaises(IndexError, lambda : self.client[id0])
214
213
215 def test_result_status(self):
214 def test_result_status(self):
216 pass
215 pass
217 # to be written
216 # to be written
218
217
219 def test_db_query_dt(self):
218 def test_db_query_dt(self):
220 """test db query by date"""
219 """test db query by date"""
221 hist = self.client.hub_history()
220 hist = self.client.hub_history()
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
221 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 tic = middle['submitted']
222 tic = middle['submitted']
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
223 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
224 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEqual(len(before)+len(after),len(hist))
225 self.assertEqual(len(before)+len(after),len(hist))
227 for b in before:
226 for b in before:
228 self.assertTrue(b['submitted'] < tic)
227 self.assertTrue(b['submitted'] < tic)
229 for a in after:
228 for a in after:
230 self.assertTrue(a['submitted'] >= tic)
229 self.assertTrue(a['submitted'] >= tic)
231 same = self.client.db_query({'submitted' : tic})
230 same = self.client.db_query({'submitted' : tic})
232 for s in same:
231 for s in same:
233 self.assertTrue(s['submitted'] == tic)
232 self.assertTrue(s['submitted'] == tic)
234
233
235 def test_db_query_keys(self):
234 def test_db_query_keys(self):
236 """test extracting subset of record keys"""
235 """test extracting subset of record keys"""
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
236 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 for rec in found:
237 for rec in found:
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
238 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240
239
241 def test_db_query_default_keys(self):
240 def test_db_query_default_keys(self):
242 """default db_query excludes buffers"""
241 """default db_query excludes buffers"""
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
242 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 for rec in found:
243 for rec in found:
245 keys = set(rec.keys())
244 keys = set(rec.keys())
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
245 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
246 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248
247
249 def test_db_query_msg_id(self):
248 def test_db_query_msg_id(self):
250 """ensure msg_id is always in db queries"""
249 """ensure msg_id is always in db queries"""
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
250 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 for rec in found:
251 for rec in found:
253 self.assertTrue('msg_id' in rec.keys())
252 self.assertTrue('msg_id' in rec.keys())
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
253 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 for rec in found:
254 for rec in found:
256 self.assertTrue('msg_id' in rec.keys())
255 self.assertTrue('msg_id' in rec.keys())
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
256 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 for rec in found:
257 for rec in found:
259 self.assertTrue('msg_id' in rec.keys())
258 self.assertTrue('msg_id' in rec.keys())
260
259
261 def test_db_query_get_result(self):
260 def test_db_query_get_result(self):
262 """pop in db_query shouldn't pop from result itself"""
261 """pop in db_query shouldn't pop from result itself"""
263 self.client[:].apply_sync(lambda : 1)
262 self.client[:].apply_sync(lambda : 1)
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
263 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 rc2 = clientmod.Client(profile='iptest')
264 rc2 = clientmod.Client(profile='iptest')
266 # If this bug is not fixed, this call will hang:
265 # If this bug is not fixed, this call will hang:
267 ar = rc2.get_result(self.client.history[-1])
266 ar = rc2.get_result(self.client.history[-1])
268 ar.wait(2)
267 ar.wait(2)
269 self.assertTrue(ar.ready())
268 self.assertTrue(ar.ready())
270 ar.get()
269 ar.get()
271 rc2.close()
270 rc2.close()
272
271
273 def test_db_query_in(self):
272 def test_db_query_in(self):
274 """test db query with '$in','$nin' operators"""
273 """test db query with '$in','$nin' operators"""
275 hist = self.client.hub_history()
274 hist = self.client.hub_history()
276 even = hist[::2]
275 even = hist[::2]
277 odd = hist[1::2]
276 odd = hist[1::2]
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
277 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 found = [ r['msg_id'] for r in recs ]
278 found = [ r['msg_id'] for r in recs ]
280 self.assertEqual(set(even), set(found))
279 self.assertEqual(set(even), set(found))
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
280 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 found = [ r['msg_id'] for r in recs ]
281 found = [ r['msg_id'] for r in recs ]
283 self.assertEqual(set(odd), set(found))
282 self.assertEqual(set(odd), set(found))
284
283
285 def test_hub_history(self):
284 def test_hub_history(self):
286 hist = self.client.hub_history()
285 hist = self.client.hub_history()
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
286 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recdict = {}
287 recdict = {}
289 for rec in recs:
288 for rec in recs:
290 recdict[rec['msg_id']] = rec
289 recdict[rec['msg_id']] = rec
291
290
292 latest = datetime(1984,1,1)
291 latest = datetime(1984,1,1)
293 for msg_id in hist:
292 for msg_id in hist:
294 rec = recdict[msg_id]
293 rec = recdict[msg_id]
295 newt = rec['submitted']
294 newt = rec['submitted']
296 self.assertTrue(newt >= latest)
295 self.assertTrue(newt >= latest)
297 latest = newt
296 latest = newt
298 ar = self.client[-1].apply_async(lambda : 1)
297 ar = self.client[-1].apply_async(lambda : 1)
299 ar.get()
298 ar.get()
300 time.sleep(0.25)
299 time.sleep(0.25)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
300 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302
301
303 def _wait_for_idle(self):
302 def _wait_for_idle(self):
304 """wait for the cluster to become idle, according to the everyone."""
303 """wait for the cluster to become idle, according to the everyone."""
305 rc = self.client
304 rc = self.client
306
305
307 # step 0. wait for local results
306 # step 0. wait for local results
308 # this should be sufficient 99% of the time.
307 # this should be sufficient 99% of the time.
309 rc.wait(timeout=5)
308 rc.wait(timeout=5)
310
309
311 # step 1. wait for all requests to be noticed
310 # step 1. wait for all requests to be noticed
312 # timeout 5s, polling every 100ms
311 # timeout 5s, polling every 100ms
313 msg_ids = set(rc.history)
312 msg_ids = set(rc.history)
314 hub_hist = rc.hub_history()
313 hub_hist = rc.hub_history()
315 for i in range(50):
314 for i in range(50):
316 if msg_ids.difference(hub_hist):
315 if msg_ids.difference(hub_hist):
317 time.sleep(0.1)
316 time.sleep(0.1)
318 hub_hist = rc.hub_history()
317 hub_hist = rc.hub_history()
319 else:
318 else:
320 break
319 break
321
320
322 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
321 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
323
322
324 # step 2. wait for all requests to be done
323 # step 2. wait for all requests to be done
325 # timeout 5s, polling every 100ms
324 # timeout 5s, polling every 100ms
326 qs = rc.queue_status()
325 qs = rc.queue_status()
327 for i in range(50):
326 for i in range(50):
328 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
327 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
329 time.sleep(0.1)
328 time.sleep(0.1)
330 qs = rc.queue_status()
329 qs = rc.queue_status()
331 else:
330 else:
332 break
331 break
333
332
334 # ensure Hub up to date:
333 # ensure Hub up to date:
335 self.assertEqual(qs['unassigned'], 0)
334 self.assertEqual(qs['unassigned'], 0)
336 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
335 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
337 self.assertEqual(qs[eid]['tasks'], 0)
336 self.assertEqual(qs[eid]['tasks'], 0)
338 self.assertEqual(qs[eid]['queue'], 0)
337 self.assertEqual(qs[eid]['queue'], 0)
339
338
340
339
341 def test_resubmit(self):
340 def test_resubmit(self):
342 def f():
341 def f():
343 import random
342 import random
344 return random.random()
343 return random.random()
345 v = self.client.load_balanced_view()
344 v = self.client.load_balanced_view()
346 ar = v.apply_async(f)
345 ar = v.apply_async(f)
347 r1 = ar.get(1)
346 r1 = ar.get(1)
348 # give the Hub a chance to notice:
347 # give the Hub a chance to notice:
349 self._wait_for_idle()
348 self._wait_for_idle()
350 ahr = self.client.resubmit(ar.msg_ids)
349 ahr = self.client.resubmit(ar.msg_ids)
351 r2 = ahr.get(1)
350 r2 = ahr.get(1)
352 self.assertFalse(r1 == r2)
351 self.assertFalse(r1 == r2)
353
352
354 def test_resubmit_chain(self):
353 def test_resubmit_chain(self):
355 """resubmit resubmitted tasks"""
354 """resubmit resubmitted tasks"""
356 v = self.client.load_balanced_view()
355 v = self.client.load_balanced_view()
357 ar = v.apply_async(lambda x: x, 'x'*1024)
356 ar = v.apply_async(lambda x: x, 'x'*1024)
358 ar.get()
357 ar.get()
359 self._wait_for_idle()
358 self._wait_for_idle()
360 ars = [ar]
359 ars = [ar]
361
360
362 for i in range(10):
361 for i in range(10):
363 ar = ars[-1]
362 ar = ars[-1]
364 ar2 = self.client.resubmit(ar.msg_ids)
363 ar2 = self.client.resubmit(ar.msg_ids)
365
364
366 [ ar.get() for ar in ars ]
365 [ ar.get() for ar in ars ]
367
366
368 def test_resubmit_header(self):
367 def test_resubmit_header(self):
369 """resubmit shouldn't clobber the whole header"""
368 """resubmit shouldn't clobber the whole header"""
370 def f():
369 def f():
371 import random
370 import random
372 return random.random()
371 return random.random()
373 v = self.client.load_balanced_view()
372 v = self.client.load_balanced_view()
374 v.retries = 1
373 v.retries = 1
375 ar = v.apply_async(f)
374 ar = v.apply_async(f)
376 r1 = ar.get(1)
375 r1 = ar.get(1)
377 # give the Hub a chance to notice:
376 # give the Hub a chance to notice:
378 self._wait_for_idle()
377 self._wait_for_idle()
379 ahr = self.client.resubmit(ar.msg_ids)
378 ahr = self.client.resubmit(ar.msg_ids)
380 ahr.get(1)
379 ahr.get(1)
381 time.sleep(0.5)
380 time.sleep(0.5)
382 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
381 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
383 h1,h2 = [ r['header'] for r in records ]
382 h1,h2 = [ r['header'] for r in records ]
384 for key in set(h1.keys()).union(set(h2.keys())):
383 for key in set(h1.keys()).union(set(h2.keys())):
385 if key in ('msg_id', 'date'):
384 if key in ('msg_id', 'date'):
386 self.assertNotEqual(h1[key], h2[key])
385 self.assertNotEqual(h1[key], h2[key])
387 else:
386 else:
388 self.assertEqual(h1[key], h2[key])
387 self.assertEqual(h1[key], h2[key])
389
388
390 def test_resubmit_aborted(self):
389 def test_resubmit_aborted(self):
391 def f():
390 def f():
392 import random
391 import random
393 return random.random()
392 return random.random()
394 v = self.client.load_balanced_view()
393 v = self.client.load_balanced_view()
395 # restrict to one engine, so we can put a sleep
394 # restrict to one engine, so we can put a sleep
396 # ahead of the task, so it will get aborted
395 # ahead of the task, so it will get aborted
397 eid = self.client.ids[-1]
396 eid = self.client.ids[-1]
398 v.targets = [eid]
397 v.targets = [eid]
399 sleep = v.apply_async(time.sleep, 0.5)
398 sleep = v.apply_async(time.sleep, 0.5)
400 ar = v.apply_async(f)
399 ar = v.apply_async(f)
401 ar.abort()
400 ar.abort()
402 self.assertRaises(error.TaskAborted, ar.get)
401 self.assertRaises(error.TaskAborted, ar.get)
403 # Give the Hub a chance to get up to date:
402 # Give the Hub a chance to get up to date:
404 self._wait_for_idle()
403 self._wait_for_idle()
405 ahr = self.client.resubmit(ar.msg_ids)
404 ahr = self.client.resubmit(ar.msg_ids)
406 r2 = ahr.get(1)
405 r2 = ahr.get(1)
407
406
408 def test_resubmit_inflight(self):
407 def test_resubmit_inflight(self):
409 """resubmit of inflight task"""
408 """resubmit of inflight task"""
410 v = self.client.load_balanced_view()
409 v = self.client.load_balanced_view()
411 ar = v.apply_async(time.sleep,1)
410 ar = v.apply_async(time.sleep,1)
412 # give the message a chance to arrive
411 # give the message a chance to arrive
413 time.sleep(0.2)
412 time.sleep(0.2)
414 ahr = self.client.resubmit(ar.msg_ids)
413 ahr = self.client.resubmit(ar.msg_ids)
415 ar.get(2)
414 ar.get(2)
416 ahr.get(2)
415 ahr.get(2)
417
416
418 def test_resubmit_badkey(self):
417 def test_resubmit_badkey(self):
419 """ensure KeyError on resubmit of nonexistant task"""
418 """ensure KeyError on resubmit of nonexistant task"""
420 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
419 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
421
420
422 def test_purge_hub_results(self):
421 def test_purge_hub_results(self):
423 # ensure there are some tasks
422 # ensure there are some tasks
424 for i in range(5):
423 for i in range(5):
425 self.client[:].apply_sync(lambda : 1)
424 self.client[:].apply_sync(lambda : 1)
426 # Wait for the Hub to realise the result is done:
425 # Wait for the Hub to realise the result is done:
427 # This prevents a race condition, where we
426 # This prevents a race condition, where we
428 # might purge a result the Hub still thinks is pending.
427 # might purge a result the Hub still thinks is pending.
429 self._wait_for_idle()
428 self._wait_for_idle()
430 rc2 = clientmod.Client(profile='iptest')
429 rc2 = clientmod.Client(profile='iptest')
431 hist = self.client.hub_history()
430 hist = self.client.hub_history()
432 ahr = rc2.get_result([hist[-1]])
431 ahr = rc2.get_result([hist[-1]])
433 ahr.wait(10)
432 ahr.wait(10)
434 self.client.purge_hub_results(hist[-1])
433 self.client.purge_hub_results(hist[-1])
435 newhist = self.client.hub_history()
434 newhist = self.client.hub_history()
436 self.assertEqual(len(newhist)+1,len(hist))
435 self.assertEqual(len(newhist)+1,len(hist))
437 rc2.spin()
436 rc2.spin()
438 rc2.close()
437 rc2.close()
439
438
440 def test_purge_local_results(self):
439 def test_purge_local_results(self):
441 # ensure there are some tasks
440 # ensure there are some tasks
442 res = []
441 res = []
443 for i in range(5):
442 for i in range(5):
444 res.append(self.client[:].apply_async(lambda : 1))
443 res.append(self.client[:].apply_async(lambda : 1))
445 self._wait_for_idle()
444 self._wait_for_idle()
446 self.client.wait(10) # wait for the results to come back
445 self.client.wait(10) # wait for the results to come back
447 before = len(self.client.results)
446 before = len(self.client.results)
448 self.assertEqual(len(self.client.metadata),before)
447 self.assertEqual(len(self.client.metadata),before)
449 self.client.purge_local_results(res[-1])
448 self.client.purge_local_results(res[-1])
450 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
449 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
451 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
450 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
452
451
453 def test_purge_local_results_outstanding(self):
452 def test_purge_local_results_outstanding(self):
454 v = self.client[-1]
453 v = self.client[-1]
455 ar = v.apply_async(lambda : 1)
454 ar = v.apply_async(lambda : 1)
456 msg_id = ar.msg_ids[0]
455 msg_id = ar.msg_ids[0]
457 ar.get()
456 ar.get()
458 self._wait_for_idle()
457 self._wait_for_idle()
459 ar2 = v.apply_async(time.sleep, 1)
458 ar2 = v.apply_async(time.sleep, 1)
460 self.assertIn(msg_id, self.client.results)
459 self.assertIn(msg_id, self.client.results)
461 self.assertIn(msg_id, self.client.metadata)
460 self.assertIn(msg_id, self.client.metadata)
462 self.client.purge_local_results(ar)
461 self.client.purge_local_results(ar)
463 self.assertNotIn(msg_id, self.client.results)
462 self.assertNotIn(msg_id, self.client.results)
464 self.assertNotIn(msg_id, self.client.metadata)
463 self.assertNotIn(msg_id, self.client.metadata)
465 with self.assertRaises(RuntimeError):
464 with self.assertRaises(RuntimeError):
466 self.client.purge_local_results(ar2)
465 self.client.purge_local_results(ar2)
467 ar2.get()
466 ar2.get()
468 self.client.purge_local_results(ar2)
467 self.client.purge_local_results(ar2)
469
468
470 def test_purge_all_local_results_outstanding(self):
469 def test_purge_all_local_results_outstanding(self):
471 v = self.client[-1]
470 v = self.client[-1]
472 ar = v.apply_async(time.sleep, 1)
471 ar = v.apply_async(time.sleep, 1)
473 with self.assertRaises(RuntimeError):
472 with self.assertRaises(RuntimeError):
474 self.client.purge_local_results('all')
473 self.client.purge_local_results('all')
475 ar.get()
474 ar.get()
476 self.client.purge_local_results('all')
475 self.client.purge_local_results('all')
477
476
478 def test_purge_all_hub_results(self):
477 def test_purge_all_hub_results(self):
479 self.client.purge_hub_results('all')
478 self.client.purge_hub_results('all')
480 hist = self.client.hub_history()
479 hist = self.client.hub_history()
481 self.assertEqual(len(hist), 0)
480 self.assertEqual(len(hist), 0)
482
481
483 def test_purge_all_local_results(self):
482 def test_purge_all_local_results(self):
484 self.client.purge_local_results('all')
483 self.client.purge_local_results('all')
485 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
484 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
486 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
485 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
487
486
488 def test_purge_all_results(self):
487 def test_purge_all_results(self):
489 # ensure there are some tasks
488 # ensure there are some tasks
490 for i in range(5):
489 for i in range(5):
491 self.client[:].apply_sync(lambda : 1)
490 self.client[:].apply_sync(lambda : 1)
492 self.client.wait(10)
491 self.client.wait(10)
493 self._wait_for_idle()
492 self._wait_for_idle()
494 self.client.purge_results('all')
493 self.client.purge_results('all')
495 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
494 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
496 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
495 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
497 hist = self.client.hub_history()
496 hist = self.client.hub_history()
498 self.assertEqual(len(hist), 0, msg="hub history not empty")
497 self.assertEqual(len(hist), 0, msg="hub history not empty")
499
498
500 def test_purge_everything(self):
499 def test_purge_everything(self):
501 # ensure there are some tasks
500 # ensure there are some tasks
502 for i in range(5):
501 for i in range(5):
503 self.client[:].apply_sync(lambda : 1)
502 self.client[:].apply_sync(lambda : 1)
504 self.client.wait(10)
503 self.client.wait(10)
505 self._wait_for_idle()
504 self._wait_for_idle()
506 self.client.purge_everything()
505 self.client.purge_everything()
507 # The client results
506 # The client results
508 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
507 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
509 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
508 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
510 # The client "bookkeeping"
509 # The client "bookkeeping"
511 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
510 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
512 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
511 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
513 # the hub results
512 # the hub results
514 hist = self.client.hub_history()
513 hist = self.client.hub_history()
515 self.assertEqual(len(hist), 0, msg="hub history not empty")
514 self.assertEqual(len(hist), 0, msg="hub history not empty")
516
515
517
516
518 def test_spin_thread(self):
517 def test_spin_thread(self):
519 self.client.spin_thread(0.01)
518 self.client.spin_thread(0.01)
520 ar = self.client[-1].apply_async(lambda : 1)
519 ar = self.client[-1].apply_async(lambda : 1)
521 time.sleep(0.1)
520 time.sleep(0.1)
522 self.assertTrue(ar.wall_time < 0.1,
521 self.assertTrue(ar.wall_time < 0.1,
523 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
522 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
524 )
523 )
525
524
526 def test_stop_spin_thread(self):
525 def test_stop_spin_thread(self):
527 self.client.spin_thread(0.01)
526 self.client.spin_thread(0.01)
528 self.client.stop_spin_thread()
527 self.client.stop_spin_thread()
529 ar = self.client[-1].apply_async(lambda : 1)
528 ar = self.client[-1].apply_async(lambda : 1)
530 time.sleep(0.15)
529 time.sleep(0.15)
531 self.assertTrue(ar.wall_time > 0.1,
530 self.assertTrue(ar.wall_time > 0.1,
532 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
531 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
533 )
532 )
534
533
535 def test_activate(self):
534 def test_activate(self):
536 ip = get_ipython()
535 ip = get_ipython()
537 magics = ip.magics_manager.magics
536 magics = ip.magics_manager.magics
538 self.assertTrue('px' in magics['line'])
537 self.assertTrue('px' in magics['line'])
539 self.assertTrue('px' in magics['cell'])
538 self.assertTrue('px' in magics['cell'])
540 v0 = self.client.activate(-1, '0')
539 v0 = self.client.activate(-1, '0')
541 self.assertTrue('px0' in magics['line'])
540 self.assertTrue('px0' in magics['line'])
542 self.assertTrue('px0' in magics['cell'])
541 self.assertTrue('px0' in magics['cell'])
543 self.assertEqual(v0.targets, self.client.ids[-1])
542 self.assertEqual(v0.targets, self.client.ids[-1])
544 v0 = self.client.activate('all', 'all')
543 v0 = self.client.activate('all', 'all')
545 self.assertTrue('pxall' in magics['line'])
544 self.assertTrue('pxall' in magics['line'])
546 self.assertTrue('pxall' in magics['cell'])
545 self.assertTrue('pxall' in magics['cell'])
547 self.assertEqual(v0.targets, 'all')
546 self.assertEqual(v0.targets, 'all')
@@ -1,850 +1,849 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import base64
19 import base64
20 import sys
20 import sys
21 import platform
21 import platform
22 import time
22 import time
23 from collections import namedtuple
23 from collections import namedtuple
24 from tempfile import mktemp
24 from tempfile import NamedTemporaryFile
25
25
26 import zmq
26 import zmq
27 from nose.plugins.attrib import attr
27 from nose.plugins.attrib import attr
28
28
29 from IPython.testing import decorators as dec
29 from IPython.testing import decorators as dec
30 from IPython.utils.io import capture_output
30 from IPython.utils.io import capture_output
31 from IPython.utils.py3compat import unicode_type
31 from IPython.utils.py3compat import unicode_type
32
32
33 from IPython import parallel as pmod
33 from IPython import parallel as pmod
34 from IPython.parallel import error
34 from IPython.parallel import error
35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel.util import interactive
36 from IPython.parallel.util import interactive
37
37
38 from IPython.parallel.tests import add_engines
38 from IPython.parallel.tests import add_engines
39
39
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41
41
42 def setup():
42 def setup():
43 add_engines(3, total=True)
43 add_engines(3, total=True)
44
44
45 point = namedtuple("point", "x y")
45 point = namedtuple("point", "x y")
46
46
47 class TestView(ClusterTestCase):
47 class TestView(ClusterTestCase):
48
48
49 def setUp(self):
49 def setUp(self):
50 # On Win XP, wait for resource cleanup, else parallel test group fails
50 # On Win XP, wait for resource cleanup, else parallel test group fails
51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 time.sleep(2)
53 time.sleep(2)
54 super(TestView, self).setUp()
54 super(TestView, self).setUp()
55
55
56 @attr('crash')
56 @attr('crash')
57 def test_z_crash_mux(self):
57 def test_z_crash_mux(self):
58 """test graceful handling of engine death (direct)"""
58 """test graceful handling of engine death (direct)"""
59 # self.add_engines(1)
59 # self.add_engines(1)
60 eid = self.client.ids[-1]
60 eid = self.client.ids[-1]
61 ar = self.client[eid].apply_async(crash)
61 ar = self.client[eid].apply_async(crash)
62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 eid = ar.engine_id
63 eid = ar.engine_id
64 tic = time.time()
64 tic = time.time()
65 while eid in self.client.ids and time.time()-tic < 5:
65 while eid in self.client.ids and time.time()-tic < 5:
66 time.sleep(.01)
66 time.sleep(.01)
67 self.client.spin()
67 self.client.spin()
68 self.assertFalse(eid in self.client.ids, "Engine should have died")
68 self.assertFalse(eid in self.client.ids, "Engine should have died")
69
69
70 def test_push_pull(self):
70 def test_push_pull(self):
71 """test pushing and pulling"""
71 """test pushing and pulling"""
72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 t = self.client.ids[-1]
73 t = self.client.ids[-1]
74 v = self.client[t]
74 v = self.client[t]
75 push = v.push
75 push = v.push
76 pull = v.pull
76 pull = v.pull
77 v.block=True
77 v.block=True
78 nengines = len(self.client)
78 nengines = len(self.client)
79 push({'data':data})
79 push({'data':data})
80 d = pull('data')
80 d = pull('data')
81 self.assertEqual(d, data)
81 self.assertEqual(d, data)
82 self.client[:].push({'data':data})
82 self.client[:].push({'data':data})
83 d = self.client[:].pull('data', block=True)
83 d = self.client[:].pull('data', block=True)
84 self.assertEqual(d, nengines*[data])
84 self.assertEqual(d, nengines*[data])
85 ar = push({'data':data}, block=False)
85 ar = push({'data':data}, block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 ar = self.client[:].pull('data', block=False)
88 ar = self.client[:].pull('data', block=False)
89 self.assertTrue(isinstance(ar, AsyncResult))
89 self.assertTrue(isinstance(ar, AsyncResult))
90 r = ar.get()
90 r = ar.get()
91 self.assertEqual(r, nengines*[data])
91 self.assertEqual(r, nengines*[data])
92 self.client[:].push(dict(a=10,b=20))
92 self.client[:].push(dict(a=10,b=20))
93 r = self.client[:].pull(('a','b'), block=True)
93 r = self.client[:].pull(('a','b'), block=True)
94 self.assertEqual(r, nengines*[[10,20]])
94 self.assertEqual(r, nengines*[[10,20]])
95
95
96 def test_push_pull_function(self):
96 def test_push_pull_function(self):
97 "test pushing and pulling functions"
97 "test pushing and pulling functions"
98 def testf(x):
98 def testf(x):
99 return 2.0*x
99 return 2.0*x
100
100
101 t = self.client.ids[-1]
101 t = self.client.ids[-1]
102 v = self.client[t]
102 v = self.client[t]
103 v.block=True
103 v.block=True
104 push = v.push
104 push = v.push
105 pull = v.pull
105 pull = v.pull
106 execute = v.execute
106 execute = v.execute
107 push({'testf':testf})
107 push({'testf':testf})
108 r = pull('testf')
108 r = pull('testf')
109 self.assertEqual(r(1.0), testf(1.0))
109 self.assertEqual(r(1.0), testf(1.0))
110 execute('r = testf(10)')
110 execute('r = testf(10)')
111 r = pull('r')
111 r = pull('r')
112 self.assertEqual(r, testf(10))
112 self.assertEqual(r, testf(10))
113 ar = self.client[:].push({'testf':testf}, block=False)
113 ar = self.client[:].push({'testf':testf}, block=False)
114 ar.get()
114 ar.get()
115 ar = self.client[:].pull('testf', block=False)
115 ar = self.client[:].pull('testf', block=False)
116 rlist = ar.get()
116 rlist = ar.get()
117 for r in rlist:
117 for r in rlist:
118 self.assertEqual(r(1.0), testf(1.0))
118 self.assertEqual(r(1.0), testf(1.0))
119 execute("def g(x): return x*x")
119 execute("def g(x): return x*x")
120 r = pull(('testf','g'))
120 r = pull(('testf','g'))
121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122
122
123 def test_push_function_globals(self):
123 def test_push_function_globals(self):
124 """test that pushed functions have access to globals"""
124 """test that pushed functions have access to globals"""
125 @interactive
125 @interactive
126 def geta():
126 def geta():
127 return a
127 return a
128 # self.add_engines(1)
128 # self.add_engines(1)
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = geta
131 v['f'] = geta
132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 v.execute('a=5')
133 v.execute('a=5')
134 v.execute('b=f()')
134 v.execute('b=f()')
135 self.assertEqual(v['b'], 5)
135 self.assertEqual(v['b'], 5)
136
136
137 def test_push_function_defaults(self):
137 def test_push_function_defaults(self):
138 """test that pushed functions preserve default args"""
138 """test that pushed functions preserve default args"""
139 def echo(a=10):
139 def echo(a=10):
140 return a
140 return a
141 v = self.client[-1]
141 v = self.client[-1]
142 v.block=True
142 v.block=True
143 v['f'] = echo
143 v['f'] = echo
144 v.execute('b=f()')
144 v.execute('b=f()')
145 self.assertEqual(v['b'], 10)
145 self.assertEqual(v['b'], 10)
146
146
147 def test_get_result(self):
147 def test_get_result(self):
148 """test getting results from the Hub."""
148 """test getting results from the Hub."""
149 c = pmod.Client(profile='iptest')
149 c = pmod.Client(profile='iptest')
150 # self.add_engines(1)
150 # self.add_engines(1)
151 t = c.ids[-1]
151 t = c.ids[-1]
152 v = c[t]
152 v = c[t]
153 v2 = self.client[t]
153 v2 = self.client[t]
154 ar = v.apply_async(wait, 1)
154 ar = v.apply_async(wait, 1)
155 # give the monitor time to notice the message
155 # give the monitor time to notice the message
156 time.sleep(.25)
156 time.sleep(.25)
157 ahr = v2.get_result(ar.msg_ids[0])
157 ahr = v2.get_result(ar.msg_ids[0])
158 self.assertTrue(isinstance(ahr, AsyncHubResult))
158 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 self.assertEqual(ahr.get(), ar.get())
159 self.assertEqual(ahr.get(), ar.get())
160 ar2 = v2.get_result(ar.msg_ids[0])
160 ar2 = v2.get_result(ar.msg_ids[0])
161 self.assertFalse(isinstance(ar2, AsyncHubResult))
161 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 c.spin()
162 c.spin()
163 c.close()
163 c.close()
164
164
165 def test_run_newline(self):
165 def test_run_newline(self):
166 """test that run appends newline to files"""
166 """test that run appends newline to files"""
167 tmpfile = mktemp()
167 with NamedTemporaryFile('w', delete=False) as f:
168 with open(tmpfile, 'w') as f:
169 f.write("""def g():
168 f.write("""def g():
170 return 5
169 return 5
171 """)
170 """)
172 v = self.client[-1]
171 v = self.client[-1]
173 v.run(tmpfile, block=True)
172 v.run(f.name, block=True)
174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
173 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175
174
176 def test_apply_tracked(self):
175 def test_apply_tracked(self):
177 """test tracking for apply"""
176 """test tracking for apply"""
178 # self.add_engines(1)
177 # self.add_engines(1)
179 t = self.client.ids[-1]
178 t = self.client.ids[-1]
180 v = self.client[t]
179 v = self.client[t]
181 v.block=False
180 v.block=False
182 def echo(n=1024*1024, **kwargs):
181 def echo(n=1024*1024, **kwargs):
183 with v.temp_flags(**kwargs):
182 with v.temp_flags(**kwargs):
184 return v.apply(lambda x: x, 'x'*n)
183 return v.apply(lambda x: x, 'x'*n)
185 ar = echo(1, track=False)
184 ar = echo(1, track=False)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
185 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(ar.sent)
186 self.assertTrue(ar.sent)
188 ar = echo(track=True)
187 ar = echo(track=True)
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 self.assertEqual(ar.sent, ar._tracker.done)
189 self.assertEqual(ar.sent, ar._tracker.done)
191 ar._tracker.wait()
190 ar._tracker.wait()
192 self.assertTrue(ar.sent)
191 self.assertTrue(ar.sent)
193
192
194 def test_push_tracked(self):
193 def test_push_tracked(self):
195 t = self.client.ids[-1]
194 t = self.client.ids[-1]
196 ns = dict(x='x'*1024*1024)
195 ns = dict(x='x'*1024*1024)
197 v = self.client[t]
196 v = self.client[t]
198 ar = v.push(ns, block=False, track=False)
197 ar = v.push(ns, block=False, track=False)
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
198 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(ar.sent)
199 self.assertTrue(ar.sent)
201
200
202 ar = v.push(ns, block=False, track=True)
201 ar = v.push(ns, block=False, track=True)
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 ar._tracker.wait()
203 ar._tracker.wait()
205 self.assertEqual(ar.sent, ar._tracker.done)
204 self.assertEqual(ar.sent, ar._tracker.done)
206 self.assertTrue(ar.sent)
205 self.assertTrue(ar.sent)
207 ar.get()
206 ar.get()
208
207
209 def test_scatter_tracked(self):
208 def test_scatter_tracked(self):
210 t = self.client.ids
209 t = self.client.ids
211 x='x'*1024*1024
210 x='x'*1024*1024
212 ar = self.client[t].scatter('x', x, block=False, track=False)
211 ar = self.client[t].scatter('x', x, block=False, track=False)
213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
212 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(ar.sent)
213 self.assertTrue(ar.sent)
215
214
216 ar = self.client[t].scatter('x', x, block=False, track=True)
215 ar = self.client[t].scatter('x', x, block=False, track=True)
217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 self.assertEqual(ar.sent, ar._tracker.done)
217 self.assertEqual(ar.sent, ar._tracker.done)
219 ar._tracker.wait()
218 ar._tracker.wait()
220 self.assertTrue(ar.sent)
219 self.assertTrue(ar.sent)
221 ar.get()
220 ar.get()
222
221
223 def test_remote_reference(self):
222 def test_remote_reference(self):
224 v = self.client[-1]
223 v = self.client[-1]
225 v['a'] = 123
224 v['a'] = 123
226 ra = pmod.Reference('a')
225 ra = pmod.Reference('a')
227 b = v.apply_sync(lambda x: x, ra)
226 b = v.apply_sync(lambda x: x, ra)
228 self.assertEqual(b, 123)
227 self.assertEqual(b, 123)
229
228
230
229
231 def test_scatter_gather(self):
230 def test_scatter_gather(self):
232 view = self.client[:]
231 view = self.client[:]
233 seq1 = list(range(16))
232 seq1 = list(range(16))
234 view.scatter('a', seq1)
233 view.scatter('a', seq1)
235 seq2 = view.gather('a', block=True)
234 seq2 = view.gather('a', block=True)
236 self.assertEqual(seq2, seq1)
235 self.assertEqual(seq2, seq1)
237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
236 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238
237
239 @skip_without('numpy')
238 @skip_without('numpy')
240 def test_scatter_gather_numpy(self):
239 def test_scatter_gather_numpy(self):
241 import numpy
240 import numpy
242 from numpy.testing.utils import assert_array_equal
241 from numpy.testing.utils import assert_array_equal
243 view = self.client[:]
242 view = self.client[:]
244 a = numpy.arange(64)
243 a = numpy.arange(64)
245 view.scatter('a', a, block=True)
244 view.scatter('a', a, block=True)
246 b = view.gather('a', block=True)
245 b = view.gather('a', block=True)
247 assert_array_equal(b, a)
246 assert_array_equal(b, a)
248
247
249 def test_scatter_gather_lazy(self):
248 def test_scatter_gather_lazy(self):
250 """scatter/gather with targets='all'"""
249 """scatter/gather with targets='all'"""
251 view = self.client.direct_view(targets='all')
250 view = self.client.direct_view(targets='all')
252 x = list(range(64))
251 x = list(range(64))
253 view.scatter('x', x)
252 view.scatter('x', x)
254 gathered = view.gather('x', block=True)
253 gathered = view.gather('x', block=True)
255 self.assertEqual(gathered, x)
254 self.assertEqual(gathered, x)
256
255
257
256
258 @dec.known_failure_py3
257 @dec.known_failure_py3
259 @skip_without('numpy')
258 @skip_without('numpy')
260 def test_push_numpy_nocopy(self):
259 def test_push_numpy_nocopy(self):
261 import numpy
260 import numpy
262 view = self.client[:]
261 view = self.client[:]
263 a = numpy.arange(64)
262 a = numpy.arange(64)
264 view['A'] = a
263 view['A'] = a
265 @interactive
264 @interactive
266 def check_writeable(x):
265 def check_writeable(x):
267 return x.flags.writeable
266 return x.flags.writeable
268
267
269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
268 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
269 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271
270
272 view.push(dict(B=a))
271 view.push(dict(B=a))
273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
272 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275
274
276 @skip_without('numpy')
275 @skip_without('numpy')
277 def test_apply_numpy(self):
276 def test_apply_numpy(self):
278 """view.apply(f, ndarray)"""
277 """view.apply(f, ndarray)"""
279 import numpy
278 import numpy
280 from numpy.testing.utils import assert_array_equal
279 from numpy.testing.utils import assert_array_equal
281
280
282 A = numpy.random.random((100,100))
281 A = numpy.random.random((100,100))
283 view = self.client[-1]
282 view = self.client[-1]
284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
283 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 B = A.astype(dt)
284 B = A.astype(dt)
286 C = view.apply_sync(lambda x:x, B)
285 C = view.apply_sync(lambda x:x, B)
287 assert_array_equal(B,C)
286 assert_array_equal(B,C)
288
287
289 @skip_without('numpy')
288 @skip_without('numpy')
290 def test_apply_numpy_object_dtype(self):
289 def test_apply_numpy_object_dtype(self):
291 """view.apply(f, ndarray) with dtype=object"""
290 """view.apply(f, ndarray) with dtype=object"""
292 import numpy
291 import numpy
293 from numpy.testing.utils import assert_array_equal
292 from numpy.testing.utils import assert_array_equal
294 view = self.client[-1]
293 view = self.client[-1]
295
294
296 A = numpy.array([dict(a=5)])
295 A = numpy.array([dict(a=5)])
297 B = view.apply_sync(lambda x:x, A)
296 B = view.apply_sync(lambda x:x, A)
298 assert_array_equal(A,B)
297 assert_array_equal(A,B)
299
298
300 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
299 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
301 B = view.apply_sync(lambda x:x, A)
300 B = view.apply_sync(lambda x:x, A)
302 assert_array_equal(A,B)
301 assert_array_equal(A,B)
303
302
304 @skip_without('numpy')
303 @skip_without('numpy')
305 def test_push_pull_recarray(self):
304 def test_push_pull_recarray(self):
306 """push/pull recarrays"""
305 """push/pull recarrays"""
307 import numpy
306 import numpy
308 from numpy.testing.utils import assert_array_equal
307 from numpy.testing.utils import assert_array_equal
309
308
310 view = self.client[-1]
309 view = self.client[-1]
311
310
312 R = numpy.array([
311 R = numpy.array([
313 (1, 'hi', 0.),
312 (1, 'hi', 0.),
314 (2**30, 'there', 2.5),
313 (2**30, 'there', 2.5),
315 (-99999, 'world', -12345.6789),
314 (-99999, 'world', -12345.6789),
316 ], [('n', int), ('s', '|S10'), ('f', float)])
315 ], [('n', int), ('s', '|S10'), ('f', float)])
317
316
318 view['RR'] = R
317 view['RR'] = R
319 R2 = view['RR']
318 R2 = view['RR']
320
319
321 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
320 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
322 self.assertEqual(r_dtype, R.dtype)
321 self.assertEqual(r_dtype, R.dtype)
323 self.assertEqual(r_shape, R.shape)
322 self.assertEqual(r_shape, R.shape)
324 self.assertEqual(R2.dtype, R.dtype)
323 self.assertEqual(R2.dtype, R.dtype)
325 self.assertEqual(R2.shape, R.shape)
324 self.assertEqual(R2.shape, R.shape)
326 assert_array_equal(R2, R)
325 assert_array_equal(R2, R)
327
326
328 @skip_without('pandas')
327 @skip_without('pandas')
329 def test_push_pull_timeseries(self):
328 def test_push_pull_timeseries(self):
330 """push/pull pandas.TimeSeries"""
329 """push/pull pandas.TimeSeries"""
331 import pandas
330 import pandas
332
331
333 ts = pandas.TimeSeries(list(range(10)))
332 ts = pandas.TimeSeries(list(range(10)))
334
333
335 view = self.client[-1]
334 view = self.client[-1]
336
335
337 view.push(dict(ts=ts), block=True)
336 view.push(dict(ts=ts), block=True)
338 rts = view['ts']
337 rts = view['ts']
339
338
340 self.assertEqual(type(rts), type(ts))
339 self.assertEqual(type(rts), type(ts))
341 self.assertTrue((ts == rts).all())
340 self.assertTrue((ts == rts).all())
342
341
343 def test_map(self):
342 def test_map(self):
344 view = self.client[:]
343 view = self.client[:]
345 def f(x):
344 def f(x):
346 return x**2
345 return x**2
347 data = list(range(16))
346 data = list(range(16))
348 r = view.map_sync(f, data)
347 r = view.map_sync(f, data)
349 self.assertEqual(r, list(map(f, data)))
348 self.assertEqual(r, list(map(f, data)))
350
349
351 def test_map_iterable(self):
350 def test_map_iterable(self):
352 """test map on iterables (direct)"""
351 """test map on iterables (direct)"""
353 view = self.client[:]
352 view = self.client[:]
354 # 101 is prime, so it won't be evenly distributed
353 # 101 is prime, so it won't be evenly distributed
355 arr = range(101)
354 arr = range(101)
356 # ensure it will be an iterator, even in Python 3
355 # ensure it will be an iterator, even in Python 3
357 it = iter(arr)
356 it = iter(arr)
358 r = view.map_sync(lambda x: x, it)
357 r = view.map_sync(lambda x: x, it)
359 self.assertEqual(r, list(arr))
358 self.assertEqual(r, list(arr))
360
359
361 @skip_without('numpy')
360 @skip_without('numpy')
362 def test_map_numpy(self):
361 def test_map_numpy(self):
363 """test map on numpy arrays (direct)"""
362 """test map on numpy arrays (direct)"""
364 import numpy
363 import numpy
365 from numpy.testing.utils import assert_array_equal
364 from numpy.testing.utils import assert_array_equal
366
365
367 view = self.client[:]
366 view = self.client[:]
368 # 101 is prime, so it won't be evenly distributed
367 # 101 is prime, so it won't be evenly distributed
369 arr = numpy.arange(101)
368 arr = numpy.arange(101)
370 r = view.map_sync(lambda x: x, arr)
369 r = view.map_sync(lambda x: x, arr)
371 assert_array_equal(r, arr)
370 assert_array_equal(r, arr)
372
371
373 def test_scatter_gather_nonblocking(self):
372 def test_scatter_gather_nonblocking(self):
374 data = list(range(16))
373 data = list(range(16))
375 view = self.client[:]
374 view = self.client[:]
376 view.scatter('a', data, block=False)
375 view.scatter('a', data, block=False)
377 ar = view.gather('a', block=False)
376 ar = view.gather('a', block=False)
378 self.assertEqual(ar.get(), data)
377 self.assertEqual(ar.get(), data)
379
378
380 @skip_without('numpy')
379 @skip_without('numpy')
381 def test_scatter_gather_numpy_nonblocking(self):
380 def test_scatter_gather_numpy_nonblocking(self):
382 import numpy
381 import numpy
383 from numpy.testing.utils import assert_array_equal
382 from numpy.testing.utils import assert_array_equal
384 a = numpy.arange(64)
383 a = numpy.arange(64)
385 view = self.client[:]
384 view = self.client[:]
386 ar = view.scatter('a', a, block=False)
385 ar = view.scatter('a', a, block=False)
387 self.assertTrue(isinstance(ar, AsyncResult))
386 self.assertTrue(isinstance(ar, AsyncResult))
388 amr = view.gather('a', block=False)
387 amr = view.gather('a', block=False)
389 self.assertTrue(isinstance(amr, AsyncMapResult))
388 self.assertTrue(isinstance(amr, AsyncMapResult))
390 assert_array_equal(amr.get(), a)
389 assert_array_equal(amr.get(), a)
391
390
392 def test_execute(self):
391 def test_execute(self):
393 view = self.client[:]
392 view = self.client[:]
394 # self.client.debug=True
393 # self.client.debug=True
395 execute = view.execute
394 execute = view.execute
396 ar = execute('c=30', block=False)
395 ar = execute('c=30', block=False)
397 self.assertTrue(isinstance(ar, AsyncResult))
396 self.assertTrue(isinstance(ar, AsyncResult))
398 ar = execute('d=[0,1,2]', block=False)
397 ar = execute('d=[0,1,2]', block=False)
399 self.client.wait(ar, 1)
398 self.client.wait(ar, 1)
400 self.assertEqual(len(ar.get()), len(self.client))
399 self.assertEqual(len(ar.get()), len(self.client))
401 for c in view['c']:
400 for c in view['c']:
402 self.assertEqual(c, 30)
401 self.assertEqual(c, 30)
403
402
404 def test_abort(self):
403 def test_abort(self):
405 view = self.client[-1]
404 view = self.client[-1]
406 ar = view.execute('import time; time.sleep(1)', block=False)
405 ar = view.execute('import time; time.sleep(1)', block=False)
407 ar2 = view.apply_async(lambda : 2)
406 ar2 = view.apply_async(lambda : 2)
408 ar3 = view.apply_async(lambda : 3)
407 ar3 = view.apply_async(lambda : 3)
409 view.abort(ar2)
408 view.abort(ar2)
410 view.abort(ar3.msg_ids)
409 view.abort(ar3.msg_ids)
411 self.assertRaises(error.TaskAborted, ar2.get)
410 self.assertRaises(error.TaskAborted, ar2.get)
412 self.assertRaises(error.TaskAborted, ar3.get)
411 self.assertRaises(error.TaskAborted, ar3.get)
413
412
414 def test_abort_all(self):
413 def test_abort_all(self):
415 """view.abort() aborts all outstanding tasks"""
414 """view.abort() aborts all outstanding tasks"""
416 view = self.client[-1]
415 view = self.client[-1]
417 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
416 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
418 view.abort()
417 view.abort()
419 view.wait(timeout=5)
418 view.wait(timeout=5)
420 for ar in ars[5:]:
419 for ar in ars[5:]:
421 self.assertRaises(error.TaskAborted, ar.get)
420 self.assertRaises(error.TaskAborted, ar.get)
422
421
423 def test_temp_flags(self):
422 def test_temp_flags(self):
424 view = self.client[-1]
423 view = self.client[-1]
425 view.block=True
424 view.block=True
426 with view.temp_flags(block=False):
425 with view.temp_flags(block=False):
427 self.assertFalse(view.block)
426 self.assertFalse(view.block)
428 self.assertTrue(view.block)
427 self.assertTrue(view.block)
429
428
430 @dec.known_failure_py3
429 @dec.known_failure_py3
431 def test_importer(self):
430 def test_importer(self):
432 view = self.client[-1]
431 view = self.client[-1]
433 view.clear(block=True)
432 view.clear(block=True)
434 with view.importer:
433 with view.importer:
435 import re
434 import re
436
435
437 @interactive
436 @interactive
438 def findall(pat, s):
437 def findall(pat, s):
439 # this globals() step isn't necessary in real code
438 # this globals() step isn't necessary in real code
440 # only to prevent a closure in the test
439 # only to prevent a closure in the test
441 re = globals()['re']
440 re = globals()['re']
442 return re.findall(pat, s)
441 return re.findall(pat, s)
443
442
444 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
443 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
445
444
446 def test_unicode_execute(self):
445 def test_unicode_execute(self):
447 """test executing unicode strings"""
446 """test executing unicode strings"""
448 v = self.client[-1]
447 v = self.client[-1]
449 v.block=True
448 v.block=True
450 if sys.version_info[0] >= 3:
449 if sys.version_info[0] >= 3:
451 code="a='é'"
450 code="a='é'"
452 else:
451 else:
453 code=u"a=u'é'"
452 code=u"a=u'é'"
454 v.execute(code)
453 v.execute(code)
455 self.assertEqual(v['a'], u'é')
454 self.assertEqual(v['a'], u'é')
456
455
457 def test_unicode_apply_result(self):
456 def test_unicode_apply_result(self):
458 """test unicode apply results"""
457 """test unicode apply results"""
459 v = self.client[-1]
458 v = self.client[-1]
460 r = v.apply_sync(lambda : u'é')
459 r = v.apply_sync(lambda : u'é')
461 self.assertEqual(r, u'é')
460 self.assertEqual(r, u'é')
462
461
463 def test_unicode_apply_arg(self):
462 def test_unicode_apply_arg(self):
464 """test passing unicode arguments to apply"""
463 """test passing unicode arguments to apply"""
465 v = self.client[-1]
464 v = self.client[-1]
466
465
467 @interactive
466 @interactive
468 def check_unicode(a, check):
467 def check_unicode(a, check):
469 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
468 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
470 assert isinstance(check, bytes), "%r is not bytes"%check
469 assert isinstance(check, bytes), "%r is not bytes"%check
471 assert a.encode('utf8') == check, "%s != %s"%(a,check)
470 assert a.encode('utf8') == check, "%s != %s"%(a,check)
472
471
473 for s in [ u'é', u'ßø®∫',u'asdf' ]:
472 for s in [ u'é', u'ßø®∫',u'asdf' ]:
474 try:
473 try:
475 v.apply_sync(check_unicode, s, s.encode('utf8'))
474 v.apply_sync(check_unicode, s, s.encode('utf8'))
476 except error.RemoteError as e:
475 except error.RemoteError as e:
477 if e.ename == 'AssertionError':
476 if e.ename == 'AssertionError':
478 self.fail(e.evalue)
477 self.fail(e.evalue)
479 else:
478 else:
480 raise e
479 raise e
481
480
482 def test_map_reference(self):
481 def test_map_reference(self):
483 """view.map(<Reference>, *seqs) should work"""
482 """view.map(<Reference>, *seqs) should work"""
484 v = self.client[:]
483 v = self.client[:]
485 v.scatter('n', self.client.ids, flatten=True)
484 v.scatter('n', self.client.ids, flatten=True)
486 v.execute("f = lambda x,y: x*y")
485 v.execute("f = lambda x,y: x*y")
487 rf = pmod.Reference('f')
486 rf = pmod.Reference('f')
488 nlist = list(range(10))
487 nlist = list(range(10))
489 mlist = nlist[::-1]
488 mlist = nlist[::-1]
490 expected = [ m*n for m,n in zip(mlist, nlist) ]
489 expected = [ m*n for m,n in zip(mlist, nlist) ]
491 result = v.map_sync(rf, mlist, nlist)
490 result = v.map_sync(rf, mlist, nlist)
492 self.assertEqual(result, expected)
491 self.assertEqual(result, expected)
493
492
494 def test_apply_reference(self):
493 def test_apply_reference(self):
495 """view.apply(<Reference>, *args) should work"""
494 """view.apply(<Reference>, *args) should work"""
496 v = self.client[:]
495 v = self.client[:]
497 v.scatter('n', self.client.ids, flatten=True)
496 v.scatter('n', self.client.ids, flatten=True)
498 v.execute("f = lambda x: n*x")
497 v.execute("f = lambda x: n*x")
499 rf = pmod.Reference('f')
498 rf = pmod.Reference('f')
500 result = v.apply_sync(rf, 5)
499 result = v.apply_sync(rf, 5)
501 expected = [ 5*id for id in self.client.ids ]
500 expected = [ 5*id for id in self.client.ids ]
502 self.assertEqual(result, expected)
501 self.assertEqual(result, expected)
503
502
504 def test_eval_reference(self):
503 def test_eval_reference(self):
505 v = self.client[self.client.ids[0]]
504 v = self.client[self.client.ids[0]]
506 v['g'] = list(range(5))
505 v['g'] = list(range(5))
507 rg = pmod.Reference('g[0]')
506 rg = pmod.Reference('g[0]')
508 echo = lambda x:x
507 echo = lambda x:x
509 self.assertEqual(v.apply_sync(echo, rg), 0)
508 self.assertEqual(v.apply_sync(echo, rg), 0)
510
509
511 def test_reference_nameerror(self):
510 def test_reference_nameerror(self):
512 v = self.client[self.client.ids[0]]
511 v = self.client[self.client.ids[0]]
513 r = pmod.Reference('elvis_has_left')
512 r = pmod.Reference('elvis_has_left')
514 echo = lambda x:x
513 echo = lambda x:x
515 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
514 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
516
515
517 def test_single_engine_map(self):
516 def test_single_engine_map(self):
518 e0 = self.client[self.client.ids[0]]
517 e0 = self.client[self.client.ids[0]]
519 r = list(range(5))
518 r = list(range(5))
520 check = [ -1*i for i in r ]
519 check = [ -1*i for i in r ]
521 result = e0.map_sync(lambda x: -1*x, r)
520 result = e0.map_sync(lambda x: -1*x, r)
522 self.assertEqual(result, check)
521 self.assertEqual(result, check)
523
522
524 def test_len(self):
523 def test_len(self):
525 """len(view) makes sense"""
524 """len(view) makes sense"""
526 e0 = self.client[self.client.ids[0]]
525 e0 = self.client[self.client.ids[0]]
527 self.assertEqual(len(e0), 1)
526 self.assertEqual(len(e0), 1)
528 v = self.client[:]
527 v = self.client[:]
529 self.assertEqual(len(v), len(self.client.ids))
528 self.assertEqual(len(v), len(self.client.ids))
530 v = self.client.direct_view('all')
529 v = self.client.direct_view('all')
531 self.assertEqual(len(v), len(self.client.ids))
530 self.assertEqual(len(v), len(self.client.ids))
532 v = self.client[:2]
531 v = self.client[:2]
533 self.assertEqual(len(v), 2)
532 self.assertEqual(len(v), 2)
534 v = self.client[:1]
533 v = self.client[:1]
535 self.assertEqual(len(v), 1)
534 self.assertEqual(len(v), 1)
536 v = self.client.load_balanced_view()
535 v = self.client.load_balanced_view()
537 self.assertEqual(len(v), len(self.client.ids))
536 self.assertEqual(len(v), len(self.client.ids))
538
537
539
538
540 # begin execute tests
539 # begin execute tests
541
540
542 def test_execute_reply(self):
541 def test_execute_reply(self):
543 e0 = self.client[self.client.ids[0]]
542 e0 = self.client[self.client.ids[0]]
544 e0.block = True
543 e0.block = True
545 ar = e0.execute("5", silent=False)
544 ar = e0.execute("5", silent=False)
546 er = ar.get()
545 er = ar.get()
547 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
546 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
548 self.assertEqual(er.pyout['data']['text/plain'], '5')
547 self.assertEqual(er.pyout['data']['text/plain'], '5')
549
548
550 def test_execute_reply_rich(self):
549 def test_execute_reply_rich(self):
551 e0 = self.client[self.client.ids[0]]
550 e0 = self.client[self.client.ids[0]]
552 e0.block = True
551 e0.block = True
553 e0.execute("from IPython.display import Image, HTML")
552 e0.execute("from IPython.display import Image, HTML")
554 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
553 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
555 er = ar.get()
554 er = ar.get()
556 b64data = base64.encodestring(b'garbage').decode('ascii')
555 b64data = base64.encodestring(b'garbage').decode('ascii')
557 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
556 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
558 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
557 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
559 er = ar.get()
558 er = ar.get()
560 self.assertEqual(er._repr_html_(), "<b>bold</b>")
559 self.assertEqual(er._repr_html_(), "<b>bold</b>")
561
560
562 def test_execute_reply_stdout(self):
561 def test_execute_reply_stdout(self):
563 e0 = self.client[self.client.ids[0]]
562 e0 = self.client[self.client.ids[0]]
564 e0.block = True
563 e0.block = True
565 ar = e0.execute("print (5)", silent=False)
564 ar = e0.execute("print (5)", silent=False)
566 er = ar.get()
565 er = ar.get()
567 self.assertEqual(er.stdout.strip(), '5')
566 self.assertEqual(er.stdout.strip(), '5')
568
567
569 def test_execute_pyout(self):
568 def test_execute_pyout(self):
570 """execute triggers pyout with silent=False"""
569 """execute triggers pyout with silent=False"""
571 view = self.client[:]
570 view = self.client[:]
572 ar = view.execute("5", silent=False, block=True)
571 ar = view.execute("5", silent=False, block=True)
573
572
574 expected = [{'text/plain' : '5'}] * len(view)
573 expected = [{'text/plain' : '5'}] * len(view)
575 mimes = [ out['data'] for out in ar.pyout ]
574 mimes = [ out['data'] for out in ar.pyout ]
576 self.assertEqual(mimes, expected)
575 self.assertEqual(mimes, expected)
577
576
578 def test_execute_silent(self):
577 def test_execute_silent(self):
579 """execute does not trigger pyout with silent=True"""
578 """execute does not trigger pyout with silent=True"""
580 view = self.client[:]
579 view = self.client[:]
581 ar = view.execute("5", block=True)
580 ar = view.execute("5", block=True)
582 expected = [None] * len(view)
581 expected = [None] * len(view)
583 self.assertEqual(ar.pyout, expected)
582 self.assertEqual(ar.pyout, expected)
584
583
585 def test_execute_magic(self):
584 def test_execute_magic(self):
586 """execute accepts IPython commands"""
585 """execute accepts IPython commands"""
587 view = self.client[:]
586 view = self.client[:]
588 view.execute("a = 5")
587 view.execute("a = 5")
589 ar = view.execute("%whos", block=True)
588 ar = view.execute("%whos", block=True)
590 # this will raise, if that failed
589 # this will raise, if that failed
591 ar.get(5)
590 ar.get(5)
592 for stdout in ar.stdout:
591 for stdout in ar.stdout:
593 lines = stdout.splitlines()
592 lines = stdout.splitlines()
594 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
593 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
595 found = False
594 found = False
596 for line in lines[2:]:
595 for line in lines[2:]:
597 split = line.split()
596 split = line.split()
598 if split == ['a', 'int', '5']:
597 if split == ['a', 'int', '5']:
599 found = True
598 found = True
600 break
599 break
601 self.assertTrue(found, "whos output wrong: %s" % stdout)
600 self.assertTrue(found, "whos output wrong: %s" % stdout)
602
601
603 def test_execute_displaypub(self):
602 def test_execute_displaypub(self):
604 """execute tracks display_pub output"""
603 """execute tracks display_pub output"""
605 view = self.client[:]
604 view = self.client[:]
606 view.execute("from IPython.core.display import *")
605 view.execute("from IPython.core.display import *")
607 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
606 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
608
607
609 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
608 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
610 for outputs in ar.outputs:
609 for outputs in ar.outputs:
611 mimes = [ out['data'] for out in outputs ]
610 mimes = [ out['data'] for out in outputs ]
612 self.assertEqual(mimes, expected)
611 self.assertEqual(mimes, expected)
613
612
614 def test_apply_displaypub(self):
613 def test_apply_displaypub(self):
615 """apply tracks display_pub output"""
614 """apply tracks display_pub output"""
616 view = self.client[:]
615 view = self.client[:]
617 view.execute("from IPython.core.display import *")
616 view.execute("from IPython.core.display import *")
618
617
619 @interactive
618 @interactive
620 def publish():
619 def publish():
621 [ display(i) for i in range(5) ]
620 [ display(i) for i in range(5) ]
622
621
623 ar = view.apply_async(publish)
622 ar = view.apply_async(publish)
624 ar.get(5)
623 ar.get(5)
625 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
624 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
626 for outputs in ar.outputs:
625 for outputs in ar.outputs:
627 mimes = [ out['data'] for out in outputs ]
626 mimes = [ out['data'] for out in outputs ]
628 self.assertEqual(mimes, expected)
627 self.assertEqual(mimes, expected)
629
628
630 def test_execute_raises(self):
629 def test_execute_raises(self):
631 """exceptions in execute requests raise appropriately"""
630 """exceptions in execute requests raise appropriately"""
632 view = self.client[-1]
631 view = self.client[-1]
633 ar = view.execute("1/0")
632 ar = view.execute("1/0")
634 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
633 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
635
634
636 def test_remoteerror_render_exception(self):
635 def test_remoteerror_render_exception(self):
637 """RemoteErrors get nice tracebacks"""
636 """RemoteErrors get nice tracebacks"""
638 view = self.client[-1]
637 view = self.client[-1]
639 ar = view.execute("1/0")
638 ar = view.execute("1/0")
640 ip = get_ipython()
639 ip = get_ipython()
641 ip.user_ns['ar'] = ar
640 ip.user_ns['ar'] = ar
642 with capture_output() as io:
641 with capture_output() as io:
643 ip.run_cell("ar.get(2)")
642 ip.run_cell("ar.get(2)")
644
643
645 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
644 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
646
645
647 def test_compositeerror_render_exception(self):
646 def test_compositeerror_render_exception(self):
648 """CompositeErrors get nice tracebacks"""
647 """CompositeErrors get nice tracebacks"""
649 view = self.client[:]
648 view = self.client[:]
650 ar = view.execute("1/0")
649 ar = view.execute("1/0")
651 ip = get_ipython()
650 ip = get_ipython()
652 ip.user_ns['ar'] = ar
651 ip.user_ns['ar'] = ar
653
652
654 with capture_output() as io:
653 with capture_output() as io:
655 ip.run_cell("ar.get(2)")
654 ip.run_cell("ar.get(2)")
656
655
657 count = min(error.CompositeError.tb_limit, len(view))
656 count = min(error.CompositeError.tb_limit, len(view))
658
657
659 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
658 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
660 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
659 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
661 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
660 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
662
661
663 def test_compositeerror_truncate(self):
662 def test_compositeerror_truncate(self):
664 """Truncate CompositeErrors with many exceptions"""
663 """Truncate CompositeErrors with many exceptions"""
665 view = self.client[:]
664 view = self.client[:]
666 msg_ids = []
665 msg_ids = []
667 for i in range(10):
666 for i in range(10):
668 ar = view.execute("1/0")
667 ar = view.execute("1/0")
669 msg_ids.extend(ar.msg_ids)
668 msg_ids.extend(ar.msg_ids)
670
669
671 ar = self.client.get_result(msg_ids)
670 ar = self.client.get_result(msg_ids)
672 try:
671 try:
673 ar.get()
672 ar.get()
674 except error.CompositeError as _e:
673 except error.CompositeError as _e:
675 e = _e
674 e = _e
676 else:
675 else:
677 self.fail("Should have raised CompositeError")
676 self.fail("Should have raised CompositeError")
678
677
679 lines = e.render_traceback()
678 lines = e.render_traceback()
680 with capture_output() as io:
679 with capture_output() as io:
681 e.print_traceback()
680 e.print_traceback()
682
681
683 self.assertTrue("more exceptions" in lines[-1])
682 self.assertTrue("more exceptions" in lines[-1])
684 count = e.tb_limit
683 count = e.tb_limit
685
684
686 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
685 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
687 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
686 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
688 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
687 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
689
688
690 @dec.skipif_not_matplotlib
689 @dec.skipif_not_matplotlib
691 def test_magic_pylab(self):
690 def test_magic_pylab(self):
692 """%pylab works on engines"""
691 """%pylab works on engines"""
693 view = self.client[-1]
692 view = self.client[-1]
694 ar = view.execute("%pylab inline")
693 ar = view.execute("%pylab inline")
695 # at least check if this raised:
694 # at least check if this raised:
696 reply = ar.get(5)
695 reply = ar.get(5)
697 # include imports, in case user config
696 # include imports, in case user config
698 ar = view.execute("plot(rand(100))", silent=False)
697 ar = view.execute("plot(rand(100))", silent=False)
699 reply = ar.get(5)
698 reply = ar.get(5)
700 self.assertEqual(len(reply.outputs), 1)
699 self.assertEqual(len(reply.outputs), 1)
701 output = reply.outputs[0]
700 output = reply.outputs[0]
702 self.assertTrue("data" in output)
701 self.assertTrue("data" in output)
703 data = output['data']
702 data = output['data']
704 self.assertTrue("image/png" in data)
703 self.assertTrue("image/png" in data)
705
704
706 def test_func_default_func(self):
705 def test_func_default_func(self):
707 """interactively defined function as apply func default"""
706 """interactively defined function as apply func default"""
708 def foo():
707 def foo():
709 return 'foo'
708 return 'foo'
710
709
711 def bar(f=foo):
710 def bar(f=foo):
712 return f()
711 return f()
713
712
714 view = self.client[-1]
713 view = self.client[-1]
715 ar = view.apply_async(bar)
714 ar = view.apply_async(bar)
716 r = ar.get(10)
715 r = ar.get(10)
717 self.assertEqual(r, 'foo')
716 self.assertEqual(r, 'foo')
718 def test_data_pub_single(self):
717 def test_data_pub_single(self):
719 view = self.client[-1]
718 view = self.client[-1]
720 ar = view.execute('\n'.join([
719 ar = view.execute('\n'.join([
721 'from IPython.kernel.zmq.datapub import publish_data',
720 'from IPython.kernel.zmq.datapub import publish_data',
722 'for i in range(5):',
721 'for i in range(5):',
723 ' publish_data(dict(i=i))'
722 ' publish_data(dict(i=i))'
724 ]), block=False)
723 ]), block=False)
725 self.assertTrue(isinstance(ar.data, dict))
724 self.assertTrue(isinstance(ar.data, dict))
726 ar.get(5)
725 ar.get(5)
727 self.assertEqual(ar.data, dict(i=4))
726 self.assertEqual(ar.data, dict(i=4))
728
727
729 def test_data_pub(self):
728 def test_data_pub(self):
730 view = self.client[:]
729 view = self.client[:]
731 ar = view.execute('\n'.join([
730 ar = view.execute('\n'.join([
732 'from IPython.kernel.zmq.datapub import publish_data',
731 'from IPython.kernel.zmq.datapub import publish_data',
733 'for i in range(5):',
732 'for i in range(5):',
734 ' publish_data(dict(i=i))'
733 ' publish_data(dict(i=i))'
735 ]), block=False)
734 ]), block=False)
736 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
735 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
737 ar.get(5)
736 ar.get(5)
738 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
737 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
739
738
740 def test_can_list_arg(self):
739 def test_can_list_arg(self):
741 """args in lists are canned"""
740 """args in lists are canned"""
742 view = self.client[-1]
741 view = self.client[-1]
743 view['a'] = 128
742 view['a'] = 128
744 rA = pmod.Reference('a')
743 rA = pmod.Reference('a')
745 ar = view.apply_async(lambda x: x, [rA])
744 ar = view.apply_async(lambda x: x, [rA])
746 r = ar.get(5)
745 r = ar.get(5)
747 self.assertEqual(r, [128])
746 self.assertEqual(r, [128])
748
747
749 def test_can_dict_arg(self):
748 def test_can_dict_arg(self):
750 """args in dicts are canned"""
749 """args in dicts are canned"""
751 view = self.client[-1]
750 view = self.client[-1]
752 view['a'] = 128
751 view['a'] = 128
753 rA = pmod.Reference('a')
752 rA = pmod.Reference('a')
754 ar = view.apply_async(lambda x: x, dict(foo=rA))
753 ar = view.apply_async(lambda x: x, dict(foo=rA))
755 r = ar.get(5)
754 r = ar.get(5)
756 self.assertEqual(r, dict(foo=128))
755 self.assertEqual(r, dict(foo=128))
757
756
758 def test_can_list_kwarg(self):
757 def test_can_list_kwarg(self):
759 """kwargs in lists are canned"""
758 """kwargs in lists are canned"""
760 view = self.client[-1]
759 view = self.client[-1]
761 view['a'] = 128
760 view['a'] = 128
762 rA = pmod.Reference('a')
761 rA = pmod.Reference('a')
763 ar = view.apply_async(lambda x=5: x, x=[rA])
762 ar = view.apply_async(lambda x=5: x, x=[rA])
764 r = ar.get(5)
763 r = ar.get(5)
765 self.assertEqual(r, [128])
764 self.assertEqual(r, [128])
766
765
767 def test_can_dict_kwarg(self):
766 def test_can_dict_kwarg(self):
768 """kwargs in dicts are canned"""
767 """kwargs in dicts are canned"""
769 view = self.client[-1]
768 view = self.client[-1]
770 view['a'] = 128
769 view['a'] = 128
771 rA = pmod.Reference('a')
770 rA = pmod.Reference('a')
772 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
771 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
773 r = ar.get(5)
772 r = ar.get(5)
774 self.assertEqual(r, dict(foo=128))
773 self.assertEqual(r, dict(foo=128))
775
774
776 def test_map_ref(self):
775 def test_map_ref(self):
777 """view.map works with references"""
776 """view.map works with references"""
778 view = self.client[:]
777 view = self.client[:]
779 ranks = sorted(self.client.ids)
778 ranks = sorted(self.client.ids)
780 view.scatter('rank', ranks, flatten=True)
779 view.scatter('rank', ranks, flatten=True)
781 rrank = pmod.Reference('rank')
780 rrank = pmod.Reference('rank')
782
781
783 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
782 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
784 drank = amr.get(5)
783 drank = amr.get(5)
785 self.assertEqual(drank, [ r*2 for r in ranks ])
784 self.assertEqual(drank, [ r*2 for r in ranks ])
786
785
787 def test_nested_getitem_setitem(self):
786 def test_nested_getitem_setitem(self):
788 """get and set with view['a.b']"""
787 """get and set with view['a.b']"""
789 view = self.client[-1]
788 view = self.client[-1]
790 view.execute('\n'.join([
789 view.execute('\n'.join([
791 'class A(object): pass',
790 'class A(object): pass',
792 'a = A()',
791 'a = A()',
793 'a.b = 128',
792 'a.b = 128',
794 ]), block=True)
793 ]), block=True)
795 ra = pmod.Reference('a')
794 ra = pmod.Reference('a')
796
795
797 r = view.apply_sync(lambda x: x.b, ra)
796 r = view.apply_sync(lambda x: x.b, ra)
798 self.assertEqual(r, 128)
797 self.assertEqual(r, 128)
799 self.assertEqual(view['a.b'], 128)
798 self.assertEqual(view['a.b'], 128)
800
799
801 view['a.b'] = 0
800 view['a.b'] = 0
802
801
803 r = view.apply_sync(lambda x: x.b, ra)
802 r = view.apply_sync(lambda x: x.b, ra)
804 self.assertEqual(r, 0)
803 self.assertEqual(r, 0)
805 self.assertEqual(view['a.b'], 0)
804 self.assertEqual(view['a.b'], 0)
806
805
807 def test_return_namedtuple(self):
806 def test_return_namedtuple(self):
808 def namedtuplify(x, y):
807 def namedtuplify(x, y):
809 from IPython.parallel.tests.test_view import point
808 from IPython.parallel.tests.test_view import point
810 return point(x, y)
809 return point(x, y)
811
810
812 view = self.client[-1]
811 view = self.client[-1]
813 p = view.apply_sync(namedtuplify, 1, 2)
812 p = view.apply_sync(namedtuplify, 1, 2)
814 self.assertEqual(p.x, 1)
813 self.assertEqual(p.x, 1)
815 self.assertEqual(p.y, 2)
814 self.assertEqual(p.y, 2)
816
815
817 def test_apply_namedtuple(self):
816 def test_apply_namedtuple(self):
818 def echoxy(p):
817 def echoxy(p):
819 return p.y, p.x
818 return p.y, p.x
820
819
821 view = self.client[-1]
820 view = self.client[-1]
822 tup = view.apply_sync(echoxy, point(1, 2))
821 tup = view.apply_sync(echoxy, point(1, 2))
823 self.assertEqual(tup, (2,1))
822 self.assertEqual(tup, (2,1))
824
823
825 def test_sync_imports(self):
824 def test_sync_imports(self):
826 view = self.client[-1]
825 view = self.client[-1]
827 with capture_output() as io:
826 with capture_output() as io:
828 with view.sync_imports():
827 with view.sync_imports():
829 import IPython
828 import IPython
830 self.assertIn("IPython", io.stdout)
829 self.assertIn("IPython", io.stdout)
831
830
832 @interactive
831 @interactive
833 def find_ipython():
832 def find_ipython():
834 return 'IPython' in globals()
833 return 'IPython' in globals()
835
834
836 assert view.apply_sync(find_ipython)
835 assert view.apply_sync(find_ipython)
837
836
838 def test_sync_imports_quiet(self):
837 def test_sync_imports_quiet(self):
839 view = self.client[-1]
838 view = self.client[-1]
840 with capture_output() as io:
839 with capture_output() as io:
841 with view.sync_imports(quiet=True):
840 with view.sync_imports(quiet=True):
842 import IPython
841 import IPython
843 self.assertEqual(io.stdout, '')
842 self.assertEqual(io.stdout, '')
844
843
845 @interactive
844 @interactive
846 def find_ipython():
845 def find_ipython():
847 return 'IPython' in globals()
846 return 'IPython' in globals()
848
847
849 assert view.apply_sync(find_ipython)
848 assert view.apply_sync(find_ipython)
850
849
@@ -1,454 +1,456 b''
1 """Generic testing tools.
1 """Generic testing tools.
2
2
3 Authors
3 Authors
4 -------
4 -------
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 """
6 """
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2009 The IPython Development Team
11 # Copyright (C) 2009 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 import os
21 import os
22 import re
22 import re
23 import sys
23 import sys
24 import tempfile
24 import tempfile
25
25
26 from contextlib import contextmanager
26 from contextlib import contextmanager
27 from io import StringIO
27 from io import StringIO
28 from subprocess import Popen, PIPE
28 from subprocess import Popen, PIPE
29
29
30 try:
30 try:
31 # These tools are used by parts of the runtime, so we make the nose
31 # These tools are used by parts of the runtime, so we make the nose
32 # dependency optional at this point. Nose is a hard dependency to run the
32 # dependency optional at this point. Nose is a hard dependency to run the
33 # test suite, but NOT to use ipython itself.
33 # test suite, but NOT to use ipython itself.
34 import nose.tools as nt
34 import nose.tools as nt
35 has_nose = True
35 has_nose = True
36 except ImportError:
36 except ImportError:
37 has_nose = False
37 has_nose = False
38
38
39 from IPython.config.loader import Config
39 from IPython.config.loader import Config
40 from IPython.utils.process import get_output_error_code
40 from IPython.utils.process import get_output_error_code
41 from IPython.utils.text import list_strings
41 from IPython.utils.text import list_strings
42 from IPython.utils.io import temp_pyfile, Tee
42 from IPython.utils.io import temp_pyfile, Tee
43 from IPython.utils import py3compat
43 from IPython.utils import py3compat
44 from IPython.utils.encoding import DEFAULT_ENCODING
44 from IPython.utils.encoding import DEFAULT_ENCODING
45
45
46 from . import decorators as dec
46 from . import decorators as dec
47 from . import skipdoctest
47 from . import skipdoctest
48
48
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50 # Functions and classes
50 # Functions and classes
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52
52
53 # The docstring for full_path doctests differently on win32 (different path
53 # The docstring for full_path doctests differently on win32 (different path
54 # separator) so just skip the doctest there. The example remains informative.
54 # separator) so just skip the doctest there. The example remains informative.
55 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
55 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
56
56
57 @doctest_deco
57 @doctest_deco
58 def full_path(startPath,files):
58 def full_path(startPath,files):
59 """Make full paths for all the listed files, based on startPath.
59 """Make full paths for all the listed files, based on startPath.
60
60
61 Only the base part of startPath is kept, since this routine is typically
61 Only the base part of startPath is kept, since this routine is typically
62 used with a script's ``__file__`` variable as startPath. The base of startPath
62 used with a script's ``__file__`` variable as startPath. The base of startPath
63 is then prepended to all the listed files, forming the output list.
63 is then prepended to all the listed files, forming the output list.
64
64
65 Parameters
65 Parameters
66 ----------
66 ----------
67 startPath : string
67 startPath : string
68 Initial path to use as the base for the results. This path is split
68 Initial path to use as the base for the results. This path is split
69 using os.path.split() and only its first component is kept.
69 using os.path.split() and only its first component is kept.
70
70
71 files : string or list
71 files : string or list
72 One or more files.
72 One or more files.
73
73
74 Examples
74 Examples
75 --------
75 --------
76
76
77 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
77 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
78 ['/foo/a.txt', '/foo/b.txt']
78 ['/foo/a.txt', '/foo/b.txt']
79
79
80 >>> full_path('/foo',['a.txt','b.txt'])
80 >>> full_path('/foo',['a.txt','b.txt'])
81 ['/a.txt', '/b.txt']
81 ['/a.txt', '/b.txt']
82
82
83 If a single file is given, the output is still a list::
83 If a single file is given, the output is still a list::
84
84
85 >>> full_path('/foo','a.txt')
85 >>> full_path('/foo','a.txt')
86 ['/a.txt']
86 ['/a.txt']
87 """
87 """
88
88
89 files = list_strings(files)
89 files = list_strings(files)
90 base = os.path.split(startPath)[0]
90 base = os.path.split(startPath)[0]
91 return [ os.path.join(base,f) for f in files ]
91 return [ os.path.join(base,f) for f in files ]
92
92
93
93
94 def parse_test_output(txt):
94 def parse_test_output(txt):
95 """Parse the output of a test run and return errors, failures.
95 """Parse the output of a test run and return errors, failures.
96
96
97 Parameters
97 Parameters
98 ----------
98 ----------
99 txt : str
99 txt : str
100 Text output of a test run, assumed to contain a line of one of the
100 Text output of a test run, assumed to contain a line of one of the
101 following forms::
101 following forms::
102
102
103 'FAILED (errors=1)'
103 'FAILED (errors=1)'
104 'FAILED (failures=1)'
104 'FAILED (failures=1)'
105 'FAILED (errors=1, failures=1)'
105 'FAILED (errors=1, failures=1)'
106
106
107 Returns
107 Returns
108 -------
108 -------
109 nerr, nfail
109 nerr, nfail
110 number of errors and failures.
110 number of errors and failures.
111 """
111 """
112
112
113 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
113 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
114 if err_m:
114 if err_m:
115 nerr = int(err_m.group(1))
115 nerr = int(err_m.group(1))
116 nfail = 0
116 nfail = 0
117 return nerr, nfail
117 return nerr, nfail
118
118
119 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
119 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
120 if fail_m:
120 if fail_m:
121 nerr = 0
121 nerr = 0
122 nfail = int(fail_m.group(1))
122 nfail = int(fail_m.group(1))
123 return nerr, nfail
123 return nerr, nfail
124
124
125 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
125 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
126 re.MULTILINE)
126 re.MULTILINE)
127 if both_m:
127 if both_m:
128 nerr = int(both_m.group(1))
128 nerr = int(both_m.group(1))
129 nfail = int(both_m.group(2))
129 nfail = int(both_m.group(2))
130 return nerr, nfail
130 return nerr, nfail
131
131
132 # If the input didn't match any of these forms, assume no error/failures
132 # If the input didn't match any of these forms, assume no error/failures
133 return 0, 0
133 return 0, 0
134
134
135
135
136 # So nose doesn't think this is a test
136 # So nose doesn't think this is a test
137 parse_test_output.__test__ = False
137 parse_test_output.__test__ = False
138
138
139
139
140 def default_argv():
140 def default_argv():
141 """Return a valid default argv for creating testing instances of ipython"""
141 """Return a valid default argv for creating testing instances of ipython"""
142
142
143 return ['--quick', # so no config file is loaded
143 return ['--quick', # so no config file is loaded
144 # Other defaults to minimize side effects on stdout
144 # Other defaults to minimize side effects on stdout
145 '--colors=NoColor', '--no-term-title','--no-banner',
145 '--colors=NoColor', '--no-term-title','--no-banner',
146 '--autocall=0']
146 '--autocall=0']
147
147
148
148
149 def default_config():
149 def default_config():
150 """Return a config object with good defaults for testing."""
150 """Return a config object with good defaults for testing."""
151 config = Config()
151 config = Config()
152 config.TerminalInteractiveShell.colors = 'NoColor'
152 config.TerminalInteractiveShell.colors = 'NoColor'
153 config.TerminalTerminalInteractiveShell.term_title = False,
153 config.TerminalTerminalInteractiveShell.term_title = False,
154 config.TerminalInteractiveShell.autocall = 0
154 config.TerminalInteractiveShell.autocall = 0
155 config.HistoryManager.hist_file = tempfile.mktemp(u'test_hist.sqlite')
155 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
156 config.HistoryManager.hist_file = f.name
157 f.close()
156 config.HistoryManager.db_cache_size = 10000
158 config.HistoryManager.db_cache_size = 10000
157 return config
159 return config
158
160
159
161
160 def get_ipython_cmd(as_string=False):
162 def get_ipython_cmd(as_string=False):
161 """
163 """
162 Return appropriate IPython command line name. By default, this will return
164 Return appropriate IPython command line name. By default, this will return
163 a list that can be used with subprocess.Popen, for example, but passing
165 a list that can be used with subprocess.Popen, for example, but passing
164 `as_string=True` allows for returning the IPython command as a string.
166 `as_string=True` allows for returning the IPython command as a string.
165
167
166 Parameters
168 Parameters
167 ----------
169 ----------
168 as_string: bool
170 as_string: bool
169 Flag to allow to return the command as a string.
171 Flag to allow to return the command as a string.
170 """
172 """
171 ipython_cmd = [sys.executable, "-m", "IPython"]
173 ipython_cmd = [sys.executable, "-m", "IPython"]
172
174
173 if as_string:
175 if as_string:
174 ipython_cmd = " ".join(ipython_cmd)
176 ipython_cmd = " ".join(ipython_cmd)
175
177
176 return ipython_cmd
178 return ipython_cmd
177
179
178 def ipexec(fname, options=None):
180 def ipexec(fname, options=None):
179 """Utility to call 'ipython filename'.
181 """Utility to call 'ipython filename'.
180
182
181 Starts IPython with a minimal and safe configuration to make startup as fast
183 Starts IPython with a minimal and safe configuration to make startup as fast
182 as possible.
184 as possible.
183
185
184 Note that this starts IPython in a subprocess!
186 Note that this starts IPython in a subprocess!
185
187
186 Parameters
188 Parameters
187 ----------
189 ----------
188 fname : str
190 fname : str
189 Name of file to be executed (should have .py or .ipy extension).
191 Name of file to be executed (should have .py or .ipy extension).
190
192
191 options : optional, list
193 options : optional, list
192 Extra command-line flags to be passed to IPython.
194 Extra command-line flags to be passed to IPython.
193
195
194 Returns
196 Returns
195 -------
197 -------
196 (stdout, stderr) of ipython subprocess.
198 (stdout, stderr) of ipython subprocess.
197 """
199 """
198 if options is None: options = []
200 if options is None: options = []
199
201
200 # For these subprocess calls, eliminate all prompt printing so we only see
202 # For these subprocess calls, eliminate all prompt printing so we only see
201 # output from script execution
203 # output from script execution
202 prompt_opts = [ '--PromptManager.in_template=""',
204 prompt_opts = [ '--PromptManager.in_template=""',
203 '--PromptManager.in2_template=""',
205 '--PromptManager.in2_template=""',
204 '--PromptManager.out_template=""'
206 '--PromptManager.out_template=""'
205 ]
207 ]
206 cmdargs = default_argv() + prompt_opts + options
208 cmdargs = default_argv() + prompt_opts + options
207
209
208 test_dir = os.path.dirname(__file__)
210 test_dir = os.path.dirname(__file__)
209
211
210 ipython_cmd = get_ipython_cmd()
212 ipython_cmd = get_ipython_cmd()
211 # Absolute path for filename
213 # Absolute path for filename
212 full_fname = os.path.join(test_dir, fname)
214 full_fname = os.path.join(test_dir, fname)
213 full_cmd = ipython_cmd + cmdargs + [full_fname]
215 full_cmd = ipython_cmd + cmdargs + [full_fname]
214 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE)
216 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE)
215 out, err = p.communicate()
217 out, err = p.communicate()
216 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
218 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
217 # `import readline` causes 'ESC[?1034h' to be output sometimes,
219 # `import readline` causes 'ESC[?1034h' to be output sometimes,
218 # so strip that out before doing comparisons
220 # so strip that out before doing comparisons
219 if out:
221 if out:
220 out = re.sub(r'\x1b\[[^h]+h', '', out)
222 out = re.sub(r'\x1b\[[^h]+h', '', out)
221 return out, err
223 return out, err
222
224
223
225
224 def ipexec_validate(fname, expected_out, expected_err='',
226 def ipexec_validate(fname, expected_out, expected_err='',
225 options=None):
227 options=None):
226 """Utility to call 'ipython filename' and validate output/error.
228 """Utility to call 'ipython filename' and validate output/error.
227
229
228 This function raises an AssertionError if the validation fails.
230 This function raises an AssertionError if the validation fails.
229
231
230 Note that this starts IPython in a subprocess!
232 Note that this starts IPython in a subprocess!
231
233
232 Parameters
234 Parameters
233 ----------
235 ----------
234 fname : str
236 fname : str
235 Name of the file to be executed (should have .py or .ipy extension).
237 Name of the file to be executed (should have .py or .ipy extension).
236
238
237 expected_out : str
239 expected_out : str
238 Expected stdout of the process.
240 Expected stdout of the process.
239
241
240 expected_err : optional, str
242 expected_err : optional, str
241 Expected stderr of the process.
243 Expected stderr of the process.
242
244
243 options : optional, list
245 options : optional, list
244 Extra command-line flags to be passed to IPython.
246 Extra command-line flags to be passed to IPython.
245
247
246 Returns
248 Returns
247 -------
249 -------
248 None
250 None
249 """
251 """
250
252
251 import nose.tools as nt
253 import nose.tools as nt
252
254
253 out, err = ipexec(fname, options)
255 out, err = ipexec(fname, options)
254 #print 'OUT', out # dbg
256 #print 'OUT', out # dbg
255 #print 'ERR', err # dbg
257 #print 'ERR', err # dbg
256 # If there are any errors, we must check those befor stdout, as they may be
258 # If there are any errors, we must check those befor stdout, as they may be
257 # more informative than simply having an empty stdout.
259 # more informative than simply having an empty stdout.
258 if err:
260 if err:
259 if expected_err:
261 if expected_err:
260 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
262 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
261 else:
263 else:
262 raise ValueError('Running file %r produced error: %r' %
264 raise ValueError('Running file %r produced error: %r' %
263 (fname, err))
265 (fname, err))
264 # If no errors or output on stderr was expected, match stdout
266 # If no errors or output on stderr was expected, match stdout
265 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
267 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
266
268
267
269
268 class TempFileMixin(object):
270 class TempFileMixin(object):
269 """Utility class to create temporary Python/IPython files.
271 """Utility class to create temporary Python/IPython files.
270
272
271 Meant as a mixin class for test cases."""
273 Meant as a mixin class for test cases."""
272
274
273 def mktmp(self, src, ext='.py'):
275 def mktmp(self, src, ext='.py'):
274 """Make a valid python temp file."""
276 """Make a valid python temp file."""
275 fname, f = temp_pyfile(src, ext)
277 fname, f = temp_pyfile(src, ext)
276 self.tmpfile = f
278 self.tmpfile = f
277 self.fname = fname
279 self.fname = fname
278
280
279 def tearDown(self):
281 def tearDown(self):
280 if hasattr(self, 'tmpfile'):
282 if hasattr(self, 'tmpfile'):
281 # If the tmpfile wasn't made because of skipped tests, like in
283 # If the tmpfile wasn't made because of skipped tests, like in
282 # win32, there's nothing to cleanup.
284 # win32, there's nothing to cleanup.
283 self.tmpfile.close()
285 self.tmpfile.close()
284 try:
286 try:
285 os.unlink(self.fname)
287 os.unlink(self.fname)
286 except:
288 except:
287 # On Windows, even though we close the file, we still can't
289 # On Windows, even though we close the file, we still can't
288 # delete it. I have no clue why
290 # delete it. I have no clue why
289 pass
291 pass
290
292
291 pair_fail_msg = ("Testing {0}\n\n"
293 pair_fail_msg = ("Testing {0}\n\n"
292 "In:\n"
294 "In:\n"
293 " {1!r}\n"
295 " {1!r}\n"
294 "Expected:\n"
296 "Expected:\n"
295 " {2!r}\n"
297 " {2!r}\n"
296 "Got:\n"
298 "Got:\n"
297 " {3!r}\n")
299 " {3!r}\n")
298 def check_pairs(func, pairs):
300 def check_pairs(func, pairs):
299 """Utility function for the common case of checking a function with a
301 """Utility function for the common case of checking a function with a
300 sequence of input/output pairs.
302 sequence of input/output pairs.
301
303
302 Parameters
304 Parameters
303 ----------
305 ----------
304 func : callable
306 func : callable
305 The function to be tested. Should accept a single argument.
307 The function to be tested. Should accept a single argument.
306 pairs : iterable
308 pairs : iterable
307 A list of (input, expected_output) tuples.
309 A list of (input, expected_output) tuples.
308
310
309 Returns
311 Returns
310 -------
312 -------
311 None. Raises an AssertionError if any output does not match the expected
313 None. Raises an AssertionError if any output does not match the expected
312 value.
314 value.
313 """
315 """
314 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
316 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
315 for inp, expected in pairs:
317 for inp, expected in pairs:
316 out = func(inp)
318 out = func(inp)
317 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
319 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
318
320
319
321
320 if py3compat.PY3:
322 if py3compat.PY3:
321 MyStringIO = StringIO
323 MyStringIO = StringIO
322 else:
324 else:
323 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
325 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
324 # so we need a class that can handle both.
326 # so we need a class that can handle both.
325 class MyStringIO(StringIO):
327 class MyStringIO(StringIO):
326 def write(self, s):
328 def write(self, s):
327 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
329 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
328 super(MyStringIO, self).write(s)
330 super(MyStringIO, self).write(s)
329
331
330 _re_type = type(re.compile(r''))
332 _re_type = type(re.compile(r''))
331
333
332 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
333 -------
335 -------
334 {2!s}
336 {2!s}
335 -------
337 -------
336 """
338 """
337
339
338 class AssertPrints(object):
340 class AssertPrints(object):
339 """Context manager for testing that code prints certain text.
341 """Context manager for testing that code prints certain text.
340
342
341 Examples
343 Examples
342 --------
344 --------
343 >>> with AssertPrints("abc", suppress=False):
345 >>> with AssertPrints("abc", suppress=False):
344 ... print("abcd")
346 ... print("abcd")
345 ... print("def")
347 ... print("def")
346 ...
348 ...
347 abcd
349 abcd
348 def
350 def
349 """
351 """
350 def __init__(self, s, channel='stdout', suppress=True):
352 def __init__(self, s, channel='stdout', suppress=True):
351 self.s = s
353 self.s = s
352 if isinstance(self.s, (py3compat.string_types, _re_type)):
354 if isinstance(self.s, (py3compat.string_types, _re_type)):
353 self.s = [self.s]
355 self.s = [self.s]
354 self.channel = channel
356 self.channel = channel
355 self.suppress = suppress
357 self.suppress = suppress
356
358
357 def __enter__(self):
359 def __enter__(self):
358 self.orig_stream = getattr(sys, self.channel)
360 self.orig_stream = getattr(sys, self.channel)
359 self.buffer = MyStringIO()
361 self.buffer = MyStringIO()
360 self.tee = Tee(self.buffer, channel=self.channel)
362 self.tee = Tee(self.buffer, channel=self.channel)
361 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
362
364
363 def __exit__(self, etype, value, traceback):
365 def __exit__(self, etype, value, traceback):
364 if value is not None:
366 if value is not None:
365 # If an error was raised, don't check anything else
367 # If an error was raised, don't check anything else
366 return False
368 return False
367 self.tee.flush()
369 self.tee.flush()
368 setattr(sys, self.channel, self.orig_stream)
370 setattr(sys, self.channel, self.orig_stream)
369 printed = self.buffer.getvalue()
371 printed = self.buffer.getvalue()
370 for s in self.s:
372 for s in self.s:
371 if isinstance(s, _re_type):
373 if isinstance(s, _re_type):
372 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
374 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
373 else:
375 else:
374 assert s in printed, notprinted_msg.format(s, self.channel, printed)
376 assert s in printed, notprinted_msg.format(s, self.channel, printed)
375 return False
377 return False
376
378
377 printed_msg = """Found {0!r} in printed output (on {1}):
379 printed_msg = """Found {0!r} in printed output (on {1}):
378 -------
380 -------
379 {2!s}
381 {2!s}
380 -------
382 -------
381 """
383 """
382
384
383 class AssertNotPrints(AssertPrints):
385 class AssertNotPrints(AssertPrints):
384 """Context manager for checking that certain output *isn't* produced.
386 """Context manager for checking that certain output *isn't* produced.
385
387
386 Counterpart of AssertPrints"""
388 Counterpart of AssertPrints"""
387 def __exit__(self, etype, value, traceback):
389 def __exit__(self, etype, value, traceback):
388 if value is not None:
390 if value is not None:
389 # If an error was raised, don't check anything else
391 # If an error was raised, don't check anything else
390 return False
392 return False
391 self.tee.flush()
393 self.tee.flush()
392 setattr(sys, self.channel, self.orig_stream)
394 setattr(sys, self.channel, self.orig_stream)
393 printed = self.buffer.getvalue()
395 printed = self.buffer.getvalue()
394 for s in self.s:
396 for s in self.s:
395 if isinstance(s, _re_type):
397 if isinstance(s, _re_type):
396 assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed)
398 assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed)
397 else:
399 else:
398 assert s not in printed, printed_msg.format(s, self.channel, printed)
400 assert s not in printed, printed_msg.format(s, self.channel, printed)
399 return False
401 return False
400
402
401 @contextmanager
403 @contextmanager
402 def mute_warn():
404 def mute_warn():
403 from IPython.utils import warn
405 from IPython.utils import warn
404 save_warn = warn.warn
406 save_warn = warn.warn
405 warn.warn = lambda *a, **kw: None
407 warn.warn = lambda *a, **kw: None
406 try:
408 try:
407 yield
409 yield
408 finally:
410 finally:
409 warn.warn = save_warn
411 warn.warn = save_warn
410
412
411 @contextmanager
413 @contextmanager
412 def make_tempfile(name):
414 def make_tempfile(name):
413 """ Create an empty, named, temporary file for the duration of the context.
415 """ Create an empty, named, temporary file for the duration of the context.
414 """
416 """
415 f = open(name, 'w')
417 f = open(name, 'w')
416 f.close()
418 f.close()
417 try:
419 try:
418 yield
420 yield
419 finally:
421 finally:
420 os.unlink(name)
422 os.unlink(name)
421
423
422
424
423 @contextmanager
425 @contextmanager
424 def monkeypatch(obj, name, attr):
426 def monkeypatch(obj, name, attr):
425 """
427 """
426 Context manager to replace attribute named `name` in `obj` with `attr`.
428 Context manager to replace attribute named `name` in `obj` with `attr`.
427 """
429 """
428 orig = getattr(obj, name)
430 orig = getattr(obj, name)
429 setattr(obj, name, attr)
431 setattr(obj, name, attr)
430 yield
432 yield
431 setattr(obj, name, orig)
433 setattr(obj, name, orig)
432
434
433
435
434 def help_output_test(subcommand=''):
436 def help_output_test(subcommand=''):
435 """test that `ipython [subcommand] -h` works"""
437 """test that `ipython [subcommand] -h` works"""
436 cmd = get_ipython_cmd() + [subcommand, '-h']
438 cmd = get_ipython_cmd() + [subcommand, '-h']
437 out, err, rc = get_output_error_code(cmd)
439 out, err, rc = get_output_error_code(cmd)
438 nt.assert_equal(rc, 0, err)
440 nt.assert_equal(rc, 0, err)
439 nt.assert_not_in("Traceback", err)
441 nt.assert_not_in("Traceback", err)
440 nt.assert_in("Options", out)
442 nt.assert_in("Options", out)
441 nt.assert_in("--help-all", out)
443 nt.assert_in("--help-all", out)
442 return out, err
444 return out, err
443
445
444
446
445 def help_all_output_test(subcommand=''):
447 def help_all_output_test(subcommand=''):
446 """test that `ipython [subcommand] --help-all` works"""
448 """test that `ipython [subcommand] --help-all` works"""
447 cmd = get_ipython_cmd() + [subcommand, '--help-all']
449 cmd = get_ipython_cmd() + [subcommand, '--help-all']
448 out, err, rc = get_output_error_code(cmd)
450 out, err, rc = get_output_error_code(cmd)
449 nt.assert_equal(rc, 0, err)
451 nt.assert_equal(rc, 0, err)
450 nt.assert_not_in("Traceback", err)
452 nt.assert_not_in("Traceback", err)
451 nt.assert_in("Options", out)
453 nt.assert_in("Options", out)
452 nt.assert_in("Class parameters", out)
454 nt.assert_in("Class parameters", out)
453 return out, err
455 return out, err
454
456
General Comments 0
You need to be logged in to leave comments. Login now