##// END OF EJS Templates
Backport PR #5166: remove mktemp usage...
MinRK -
Show More
@@ -1,348 +1,351 b''
1 1 # encoding: utf-8
2 2 """
3 3 Paging capabilities for IPython.core
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Fernando Perez
9 9
10 10 Notes
11 11 -----
12 12
13 13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
14 14 rid of that dependency, we could move it there.
15 15 -----
16 16 """
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Copyright (C) 2008-2011 The IPython Development Team
20 20 #
21 21 # Distributed under the terms of the BSD License. The full license is in
22 22 # the file COPYING, distributed as part of this software.
23 23 #-----------------------------------------------------------------------------
24 24
25 25 #-----------------------------------------------------------------------------
26 26 # Imports
27 27 #-----------------------------------------------------------------------------
28 28 from __future__ import print_function
29 29
30 30 import os
31 31 import re
32 32 import sys
33 33 import tempfile
34 34
35 35 from io import UnsupportedOperation
36 36
37 37 from IPython import get_ipython
38 38 from IPython.core.error import TryNext
39 39 from IPython.utils.data import chop
40 40 from IPython.utils import io
41 41 from IPython.utils.process import system
42 42 from IPython.utils.terminal import get_terminal_size
43 43 from IPython.utils import py3compat
44 44
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Classes and functions
48 48 #-----------------------------------------------------------------------------
49 49
50 50 esc_re = re.compile(r"(\x1b[^m]+m)")
51 51
52 52 def page_dumb(strng, start=0, screen_lines=25):
53 53 """Very dumb 'pager' in Python, for when nothing else works.
54 54
55 55 Only moves forward, same interface as page(), except for pager_cmd and
56 56 mode."""
57 57
58 58 out_ln = strng.splitlines()[start:]
59 59 screens = chop(out_ln,screen_lines-1)
60 60 if len(screens) == 1:
61 61 print(os.linesep.join(screens[0]), file=io.stdout)
62 62 else:
63 63 last_escape = ""
64 64 for scr in screens[0:-1]:
65 65 hunk = os.linesep.join(scr)
66 66 print(last_escape + hunk, file=io.stdout)
67 67 if not page_more():
68 68 return
69 69 esc_list = esc_re.findall(hunk)
70 70 if len(esc_list) > 0:
71 71 last_escape = esc_list[-1]
72 72 print(last_escape + os.linesep.join(screens[-1]), file=io.stdout)
73 73
74 74 def _detect_screen_size(screen_lines_def):
75 75 """Attempt to work out the number of lines on the screen.
76 76
77 77 This is called by page(). It can raise an error (e.g. when run in the
78 78 test suite), so it's separated out so it can easily be called in a try block.
79 79 """
80 80 TERM = os.environ.get('TERM',None)
81 81 if not((TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5'):
82 82 # curses causes problems on many terminals other than xterm, and
83 83 # some termios calls lock up on Sun OS5.
84 84 return screen_lines_def
85 85
86 86 try:
87 87 import termios
88 88 import curses
89 89 except ImportError:
90 90 return screen_lines_def
91 91
92 92 # There is a bug in curses, where *sometimes* it fails to properly
93 93 # initialize, and then after the endwin() call is made, the
94 94 # terminal is left in an unusable state. Rather than trying to
95 95 # check everytime for this (by requesting and comparing termios
96 96 # flags each time), we just save the initial terminal state and
97 97 # unconditionally reset it every time. It's cheaper than making
98 98 # the checks.
99 99 term_flags = termios.tcgetattr(sys.stdout)
100 100
101 101 # Curses modifies the stdout buffer size by default, which messes
102 102 # up Python's normal stdout buffering. This would manifest itself
103 103 # to IPython users as delayed printing on stdout after having used
104 104 # the pager.
105 105 #
106 106 # We can prevent this by manually setting the NCURSES_NO_SETBUF
107 107 # environment variable. For more details, see:
108 108 # http://bugs.python.org/issue10144
109 109 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
110 110 os.environ['NCURSES_NO_SETBUF'] = ''
111 111
112 112 # Proceed with curses initialization
113 113 try:
114 114 scr = curses.initscr()
115 115 except AttributeError:
116 116 # Curses on Solaris may not be complete, so we can't use it there
117 117 return screen_lines_def
118 118
119 119 screen_lines_real,screen_cols = scr.getmaxyx()
120 120 curses.endwin()
121 121
122 122 # Restore environment
123 123 if NCURSES_NO_SETBUF is None:
124 124 del os.environ['NCURSES_NO_SETBUF']
125 125 else:
126 126 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
127 127
128 128 # Restore terminal state in case endwin() didn't.
129 129 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
130 130 # Now we have what we needed: the screen size in rows/columns
131 131 return screen_lines_real
132 132 #print '***Screen size:',screen_lines_real,'lines x',\
133 133 #screen_cols,'columns.' # dbg
134 134
135 135 def page(strng, start=0, screen_lines=0, pager_cmd=None):
136 136 """Print a string, piping through a pager after a certain length.
137 137
138 138 The screen_lines parameter specifies the number of *usable* lines of your
139 139 terminal screen (total lines minus lines you need to reserve to show other
140 140 information).
141 141
142 142 If you set screen_lines to a number <=0, page() will try to auto-determine
143 143 your screen size and will only use up to (screen_size+screen_lines) for
144 144 printing, paging after that. That is, if you want auto-detection but need
145 145 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
146 146 auto-detection without any lines reserved simply use screen_lines = 0.
147 147
148 148 If a string won't fit in the allowed lines, it is sent through the
149 149 specified pager command. If none given, look for PAGER in the environment,
150 150 and ultimately default to less.
151 151
152 152 If no system pager works, the string is sent through a 'dumb pager'
153 153 written in python, very simplistic.
154 154 """
155 155
156 156 # Some routines may auto-compute start offsets incorrectly and pass a
157 157 # negative value. Offset to 0 for robustness.
158 158 start = max(0, start)
159 159
160 160 # first, try the hook
161 161 ip = get_ipython()
162 162 if ip:
163 163 try:
164 164 ip.hooks.show_in_pager(strng)
165 165 return
166 166 except TryNext:
167 167 pass
168 168
169 169 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
170 170 TERM = os.environ.get('TERM','dumb')
171 171 if TERM in ['dumb','emacs'] and os.name != 'nt':
172 172 print(strng)
173 173 return
174 174 # chop off the topmost part of the string we don't want to see
175 175 str_lines = strng.splitlines()[start:]
176 176 str_toprint = os.linesep.join(str_lines)
177 177 num_newlines = len(str_lines)
178 178 len_str = len(str_toprint)
179 179
180 180 # Dumb heuristics to guesstimate number of on-screen lines the string
181 181 # takes. Very basic, but good enough for docstrings in reasonable
182 182 # terminals. If someone later feels like refining it, it's not hard.
183 183 numlines = max(num_newlines,int(len_str/80)+1)
184 184
185 185 screen_lines_def = get_terminal_size()[1]
186 186
187 187 # auto-determine screen size
188 188 if screen_lines <= 0:
189 189 try:
190 190 screen_lines += _detect_screen_size(screen_lines_def)
191 191 except (TypeError, UnsupportedOperation):
192 192 print(str_toprint, file=io.stdout)
193 193 return
194 194
195 195 #print 'numlines',numlines,'screenlines',screen_lines # dbg
196 196 if numlines <= screen_lines :
197 197 #print '*** normal print' # dbg
198 198 print(str_toprint, file=io.stdout)
199 199 else:
200 200 # Try to open pager and default to internal one if that fails.
201 201 # All failure modes are tagged as 'retval=1', to match the return
202 202 # value of a failed system command. If any intermediate attempt
203 203 # sets retval to 1, at the end we resort to our own page_dumb() pager.
204 204 pager_cmd = get_pager_cmd(pager_cmd)
205 205 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
206 206 if os.name == 'nt':
207 207 if pager_cmd.startswith('type'):
208 208 # The default WinXP 'type' command is failing on complex strings.
209 209 retval = 1
210 210 else:
211 tmpname = tempfile.mktemp('.txt')
212 tmpfile = open(tmpname,'wt')
213 tmpfile.write(strng)
214 tmpfile.close()
215 cmd = "%s < %s" % (pager_cmd,tmpname)
216 if os.system(cmd):
217 retval = 1
218 else:
219 retval = None
220 os.remove(tmpname)
211 fd, tmpname = tempfile.mkstemp('.txt')
212 try:
213 os.close(fd)
214 with open(tmpname, 'wt') as tmpfile:
215 tmpfile.write(strng)
216 cmd = "%s < %s" % (pager_cmd, tmpname)
217 # tmpfile needs to be closed for windows
218 if os.system(cmd):
219 retval = 1
220 else:
221 retval = None
222 finally:
223 os.remove(tmpname)
221 224 else:
222 225 try:
223 226 retval = None
224 227 # if I use popen4, things hang. No idea why.
225 228 #pager,shell_out = os.popen4(pager_cmd)
226 229 pager = os.popen(pager_cmd, 'w')
227 230 try:
228 231 pager_encoding = pager.encoding or sys.stdout.encoding
229 232 pager.write(py3compat.cast_bytes_py2(
230 233 strng, encoding=pager_encoding))
231 234 finally:
232 235 retval = pager.close()
233 236 except IOError as msg: # broken pipe when user quits
234 237 if msg.args == (32, 'Broken pipe'):
235 238 retval = None
236 239 else:
237 240 retval = 1
238 241 except OSError:
239 242 # Other strange problems, sometimes seen in Win2k/cygwin
240 243 retval = 1
241 244 if retval is not None:
242 245 page_dumb(strng,screen_lines=screen_lines)
243 246
244 247
245 248 def page_file(fname, start=0, pager_cmd=None):
246 249 """Page a file, using an optional pager command and starting line.
247 250 """
248 251
249 252 pager_cmd = get_pager_cmd(pager_cmd)
250 253 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
251 254
252 255 try:
253 256 if os.environ['TERM'] in ['emacs','dumb']:
254 257 raise EnvironmentError
255 258 system(pager_cmd + ' ' + fname)
256 259 except:
257 260 try:
258 261 if start > 0:
259 262 start -= 1
260 263 page(open(fname).read(),start)
261 264 except:
262 265 print('Unable to show file',repr(fname))
263 266
264 267
265 268 def get_pager_cmd(pager_cmd=None):
266 269 """Return a pager command.
267 270
268 271 Makes some attempts at finding an OS-correct one.
269 272 """
270 273 if os.name == 'posix':
271 274 default_pager_cmd = 'less -r' # -r for color control sequences
272 275 elif os.name in ['nt','dos']:
273 276 default_pager_cmd = 'type'
274 277
275 278 if pager_cmd is None:
276 279 try:
277 280 pager_cmd = os.environ['PAGER']
278 281 except:
279 282 pager_cmd = default_pager_cmd
280 283 return pager_cmd
281 284
282 285
283 286 def get_pager_start(pager, start):
284 287 """Return the string for paging files with an offset.
285 288
286 289 This is the '+N' argument which less and more (under Unix) accept.
287 290 """
288 291
289 292 if pager in ['less','more']:
290 293 if start:
291 294 start_string = '+' + str(start)
292 295 else:
293 296 start_string = ''
294 297 else:
295 298 start_string = ''
296 299 return start_string
297 300
298 301
299 302 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
300 303 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
301 304 import msvcrt
302 305 def page_more():
303 306 """ Smart pausing between pages
304 307
305 308 @return: True if need print more lines, False if quit
306 309 """
307 310 io.stdout.write('---Return to continue, q to quit--- ')
308 311 ans = msvcrt.getch()
309 312 if ans in ("q", "Q"):
310 313 result = False
311 314 else:
312 315 result = True
313 316 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
314 317 return result
315 318 else:
316 319 def page_more():
317 320 ans = raw_input('---Return to continue, q to quit--- ')
318 321 if ans.lower().startswith('q'):
319 322 return False
320 323 else:
321 324 return True
322 325
323 326
324 327 def snip_print(str,width = 75,print_full = 0,header = ''):
325 328 """Print a string snipping the midsection to fit in width.
326 329
327 330 print_full: mode control:
328 331 - 0: only snip long strings
329 332 - 1: send to page() directly.
330 333 - 2: snip long strings and ask for full length viewing with page()
331 334 Return 1 if snipping was necessary, 0 otherwise."""
332 335
333 336 if print_full == 1:
334 337 page(header+str)
335 338 return 0
336 339
337 340 print(header, end=' ')
338 341 if len(str) < width:
339 342 print(str)
340 343 snip = 0
341 344 else:
342 345 whalf = int((width -5)/2)
343 346 print(str[:whalf] + ' <...> ' + str[-whalf:])
344 347 snip = 1
345 348 if snip and print_full == 2:
346 349 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
347 350 page(str)
348 351 return snip
@@ -1,556 +1,557 b''
1 1 """Utilities for connecting to kernels
2 2
3 3 Authors:
4 4
5 5 * Min Ragan-Kelley
6 6
7 7 """
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2013 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19
20 20 from __future__ import absolute_import
21 21
22 22 import glob
23 23 import json
24 24 import os
25 25 import socket
26 26 import sys
27 27 from getpass import getpass
28 28 from subprocess import Popen, PIPE
29 29 import tempfile
30 30
31 31 import zmq
32 32
33 33 # external imports
34 34 from IPython.external.ssh import tunnel
35 35
36 36 # IPython imports
37 37 from IPython.config import Configurable
38 38 from IPython.core.profiledir import ProfileDir
39 39 from IPython.utils.localinterfaces import LOCALHOST
40 40 from IPython.utils.path import filefind, get_ipython_dir
41 41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str, cast_bytes_py2
42 42 from IPython.utils.traitlets import (
43 43 Bool, Integer, Unicode, CaselessStrEnum,
44 44 )
45 45
46 46
47 47 #-----------------------------------------------------------------------------
48 48 # Working with Connection Files
49 49 #-----------------------------------------------------------------------------
50 50
51 51 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
52 52 control_port=0, ip=LOCALHOST, key=b'', transport='tcp',
53 53 signature_scheme='hmac-sha256',
54 54 ):
55 55 """Generates a JSON config file, including the selection of random ports.
56 56
57 57 Parameters
58 58 ----------
59 59
60 60 fname : unicode
61 61 The path to the file to write
62 62
63 63 shell_port : int, optional
64 64 The port to use for ROUTER (shell) channel.
65 65
66 66 iopub_port : int, optional
67 67 The port to use for the SUB channel.
68 68
69 69 stdin_port : int, optional
70 70 The port to use for the ROUTER (raw input) channel.
71 71
72 72 control_port : int, optional
73 73 The port to use for the ROUTER (control) channel.
74 74
75 75 hb_port : int, optional
76 76 The port to use for the heartbeat REP channel.
77 77
78 78 ip : str, optional
79 79 The ip address the kernel will bind to.
80 80
81 81 key : str, optional
82 82 The Session key used for message authentication.
83 83
84 84 signature_scheme : str, optional
85 85 The scheme used for message authentication.
86 86 This has the form 'digest-hash', where 'digest'
87 87 is the scheme used for digests, and 'hash' is the name of the hash function
88 88 used by the digest scheme.
89 89 Currently, 'hmac' is the only supported digest scheme,
90 90 and 'sha256' is the default hash function.
91 91
92 92 """
93 93 # default to temporary connector file
94 94 if not fname:
95 fname = tempfile.mktemp('.json')
95 fd, fname = tempfile.mkstemp('.json')
96 os.close(fd)
96 97
97 98 # Find open ports as necessary.
98 99
99 100 ports = []
100 101 ports_needed = int(shell_port <= 0) + \
101 102 int(iopub_port <= 0) + \
102 103 int(stdin_port <= 0) + \
103 104 int(control_port <= 0) + \
104 105 int(hb_port <= 0)
105 106 if transport == 'tcp':
106 107 for i in range(ports_needed):
107 108 sock = socket.socket()
108 109 sock.bind(('', 0))
109 110 ports.append(sock)
110 111 for i, sock in enumerate(ports):
111 112 port = sock.getsockname()[1]
112 113 sock.close()
113 114 ports[i] = port
114 115 else:
115 116 N = 1
116 117 for i in range(ports_needed):
117 118 while os.path.exists("%s-%s" % (ip, str(N))):
118 119 N += 1
119 120 ports.append(N)
120 121 N += 1
121 122 if shell_port <= 0:
122 123 shell_port = ports.pop(0)
123 124 if iopub_port <= 0:
124 125 iopub_port = ports.pop(0)
125 126 if stdin_port <= 0:
126 127 stdin_port = ports.pop(0)
127 128 if control_port <= 0:
128 129 control_port = ports.pop(0)
129 130 if hb_port <= 0:
130 131 hb_port = ports.pop(0)
131 132
132 133 cfg = dict( shell_port=shell_port,
133 134 iopub_port=iopub_port,
134 135 stdin_port=stdin_port,
135 136 control_port=control_port,
136 137 hb_port=hb_port,
137 138 )
138 139 cfg['ip'] = ip
139 140 cfg['key'] = bytes_to_str(key)
140 141 cfg['transport'] = transport
141 142 cfg['signature_scheme'] = signature_scheme
142 143
143 144 with open(fname, 'w') as f:
144 145 f.write(json.dumps(cfg, indent=2))
145 146
146 147 return fname, cfg
147 148
148 149
149 150 def get_connection_file(app=None):
150 151 """Return the path to the connection file of an app
151 152
152 153 Parameters
153 154 ----------
154 155 app : IPKernelApp instance [optional]
155 156 If unspecified, the currently running app will be used
156 157 """
157 158 if app is None:
158 159 from IPython.kernel.zmq.kernelapp import IPKernelApp
159 160 if not IPKernelApp.initialized():
160 161 raise RuntimeError("app not specified, and not in a running Kernel")
161 162
162 163 app = IPKernelApp.instance()
163 164 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
164 165
165 166
166 167 def find_connection_file(filename, profile=None):
167 168 """find a connection file, and return its absolute path.
168 169
169 170 The current working directory and the profile's security
170 171 directory will be searched for the file if it is not given by
171 172 absolute path.
172 173
173 174 If profile is unspecified, then the current running application's
174 175 profile will be used, or 'default', if not run from IPython.
175 176
176 177 If the argument does not match an existing file, it will be interpreted as a
177 178 fileglob, and the matching file in the profile's security dir with
178 179 the latest access time will be used.
179 180
180 181 Parameters
181 182 ----------
182 183 filename : str
183 184 The connection file or fileglob to search for.
184 185 profile : str [optional]
185 186 The name of the profile to use when searching for the connection file,
186 187 if different from the current IPython session or 'default'.
187 188
188 189 Returns
189 190 -------
190 191 str : The absolute path of the connection file.
191 192 """
192 193 from IPython.core.application import BaseIPythonApplication as IPApp
193 194 try:
194 195 # quick check for absolute path, before going through logic
195 196 return filefind(filename)
196 197 except IOError:
197 198 pass
198 199
199 200 if profile is None:
200 201 # profile unspecified, check if running from an IPython app
201 202 if IPApp.initialized():
202 203 app = IPApp.instance()
203 204 profile_dir = app.profile_dir
204 205 else:
205 206 # not running in IPython, use default profile
206 207 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
207 208 else:
208 209 # find profiledir by profile name:
209 210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
210 211 security_dir = profile_dir.security_dir
211 212
212 213 try:
213 214 # first, try explicit name
214 215 return filefind(filename, ['.', security_dir])
215 216 except IOError:
216 217 pass
217 218
218 219 # not found by full name
219 220
220 221 if '*' in filename:
221 222 # given as a glob already
222 223 pat = filename
223 224 else:
224 225 # accept any substring match
225 226 pat = '*%s*' % filename
226 227 matches = glob.glob( os.path.join(security_dir, pat) )
227 228 if not matches:
228 229 raise IOError("Could not find %r in %r" % (filename, security_dir))
229 230 elif len(matches) == 1:
230 231 return matches[0]
231 232 else:
232 233 # get most recent match, by access time:
233 234 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
234 235
235 236
236 237 def get_connection_info(connection_file=None, unpack=False, profile=None):
237 238 """Return the connection information for the current Kernel.
238 239
239 240 Parameters
240 241 ----------
241 242 connection_file : str [optional]
242 243 The connection file to be used. Can be given by absolute path, or
243 244 IPython will search in the security directory of a given profile.
244 245 If run from IPython,
245 246
246 247 If unspecified, the connection file for the currently running
247 248 IPython Kernel will be used, which is only allowed from inside a kernel.
248 249 unpack : bool [default: False]
249 250 if True, return the unpacked dict, otherwise just the string contents
250 251 of the file.
251 252 profile : str [optional]
252 253 The name of the profile to use when searching for the connection file,
253 254 if different from the current IPython session or 'default'.
254 255
255 256
256 257 Returns
257 258 -------
258 259 The connection dictionary of the current kernel, as string or dict,
259 260 depending on `unpack`.
260 261 """
261 262 if connection_file is None:
262 263 # get connection file from current kernel
263 264 cf = get_connection_file()
264 265 else:
265 266 # connection file specified, allow shortnames:
266 267 cf = find_connection_file(connection_file, profile=profile)
267 268
268 269 with open(cf) as f:
269 270 info = f.read()
270 271
271 272 if unpack:
272 273 info = json.loads(info)
273 274 # ensure key is bytes:
274 275 info['key'] = str_to_bytes(info.get('key', ''))
275 276 return info
276 277
277 278
278 279 def connect_qtconsole(connection_file=None, argv=None, profile=None):
279 280 """Connect a qtconsole to the current kernel.
280 281
281 282 This is useful for connecting a second qtconsole to a kernel, or to a
282 283 local notebook.
283 284
284 285 Parameters
285 286 ----------
286 287 connection_file : str [optional]
287 288 The connection file to be used. Can be given by absolute path, or
288 289 IPython will search in the security directory of a given profile.
289 290 If run from IPython,
290 291
291 292 If unspecified, the connection file for the currently running
292 293 IPython Kernel will be used, which is only allowed from inside a kernel.
293 294 argv : list [optional]
294 295 Any extra args to be passed to the console.
295 296 profile : str [optional]
296 297 The name of the profile to use when searching for the connection file,
297 298 if different from the current IPython session or 'default'.
298 299
299 300
300 301 Returns
301 302 -------
302 303 subprocess.Popen instance running the qtconsole frontend
303 304 """
304 305 argv = [] if argv is None else argv
305 306
306 307 if connection_file is None:
307 308 # get connection file from current kernel
308 309 cf = get_connection_file()
309 310 else:
310 311 cf = find_connection_file(connection_file, profile=profile)
311 312
312 313 cmd = ';'.join([
313 314 "from IPython.qt.console import qtconsoleapp",
314 315 "qtconsoleapp.main()"
315 316 ])
316 317
317 318 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
318 319 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
319 320 )
320 321
321 322
322 323 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
323 324 """tunnel connections to a kernel via ssh
324 325
325 326 This will open four SSH tunnels from localhost on this machine to the
326 327 ports associated with the kernel. They can be either direct
327 328 localhost-localhost tunnels, or if an intermediate server is necessary,
328 329 the kernel must be listening on a public IP.
329 330
330 331 Parameters
331 332 ----------
332 333 connection_info : dict or str (path)
333 334 Either a connection dict, or the path to a JSON connection file
334 335 sshserver : str
335 336 The ssh sever to use to tunnel to the kernel. Can be a full
336 337 `user@server:port` string. ssh config aliases are respected.
337 338 sshkey : str [optional]
338 339 Path to file containing ssh key to use for authentication.
339 340 Only necessary if your ssh config does not already associate
340 341 a keyfile with the host.
341 342
342 343 Returns
343 344 -------
344 345
345 346 (shell, iopub, stdin, hb) : ints
346 347 The four ports on localhost that have been forwarded to the kernel.
347 348 """
348 349 if isinstance(connection_info, basestring):
349 350 # it's a path, unpack it
350 351 with open(connection_info) as f:
351 352 connection_info = json.loads(f.read())
352 353
353 354 cf = connection_info
354 355
355 356 lports = tunnel.select_random_ports(4)
356 357 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
357 358
358 359 remote_ip = cf['ip']
359 360
360 361 if tunnel.try_passwordless_ssh(sshserver, sshkey):
361 362 password=False
362 363 else:
363 364 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
364 365
365 366 for lp,rp in zip(lports, rports):
366 367 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
367 368
368 369 return tuple(lports)
369 370
370 371
371 372 #-----------------------------------------------------------------------------
372 373 # Mixin for classes that work with connection files
373 374 #-----------------------------------------------------------------------------
374 375
375 376 channel_socket_types = {
376 377 'hb' : zmq.REQ,
377 378 'shell' : zmq.DEALER,
378 379 'iopub' : zmq.SUB,
379 380 'stdin' : zmq.DEALER,
380 381 'control': zmq.DEALER,
381 382 }
382 383
383 384 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
384 385
385 386 class ConnectionFileMixin(Configurable):
386 387 """Mixin for configurable classes that work with connection files"""
387 388
388 389 # The addresses for the communication channels
389 390 connection_file = Unicode('')
390 391 _connection_file_written = Bool(False)
391 392
392 393 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
393 394
394 395 ip = Unicode(LOCALHOST, config=True,
395 396 help="""Set the kernel\'s IP address [default localhost].
396 397 If the IP address is something other than localhost, then
397 398 Consoles on other machines will be able to connect
398 399 to the Kernel, so be careful!"""
399 400 )
400 401
401 402 def _ip_default(self):
402 403 if self.transport == 'ipc':
403 404 if self.connection_file:
404 405 return os.path.splitext(self.connection_file)[0] + '-ipc'
405 406 else:
406 407 return 'kernel-ipc'
407 408 else:
408 409 return LOCALHOST
409 410
410 411 def _ip_changed(self, name, old, new):
411 412 if new == '*':
412 413 self.ip = '0.0.0.0'
413 414
414 415 # protected traits
415 416
416 417 shell_port = Integer(0)
417 418 iopub_port = Integer(0)
418 419 stdin_port = Integer(0)
419 420 control_port = Integer(0)
420 421 hb_port = Integer(0)
421 422
422 423 @property
423 424 def ports(self):
424 425 return [ getattr(self, name) for name in port_names ]
425 426
426 427 #--------------------------------------------------------------------------
427 428 # Connection and ipc file management
428 429 #--------------------------------------------------------------------------
429 430
430 431 def get_connection_info(self):
431 432 """return the connection info as a dict"""
432 433 return dict(
433 434 transport=self.transport,
434 435 ip=self.ip,
435 436 shell_port=self.shell_port,
436 437 iopub_port=self.iopub_port,
437 438 stdin_port=self.stdin_port,
438 439 hb_port=self.hb_port,
439 440 control_port=self.control_port,
440 441 signature_scheme=self.session.signature_scheme,
441 442 key=self.session.key,
442 443 )
443 444
444 445 def cleanup_connection_file(self):
445 446 """Cleanup connection file *if we wrote it*
446 447
447 448 Will not raise if the connection file was already removed somehow.
448 449 """
449 450 if self._connection_file_written:
450 451 # cleanup connection files on full shutdown of kernel we started
451 452 self._connection_file_written = False
452 453 try:
453 454 os.remove(self.connection_file)
454 455 except (IOError, OSError, AttributeError):
455 456 pass
456 457
457 458 def cleanup_ipc_files(self):
458 459 """Cleanup ipc files if we wrote them."""
459 460 if self.transport != 'ipc':
460 461 return
461 462 for port in self.ports:
462 463 ipcfile = "%s-%i" % (self.ip, port)
463 464 try:
464 465 os.remove(ipcfile)
465 466 except (IOError, OSError):
466 467 pass
467 468
468 469 def write_connection_file(self):
469 470 """Write connection info to JSON dict in self.connection_file."""
470 471 if self._connection_file_written:
471 472 return
472 473
473 474 self.connection_file, cfg = write_connection_file(self.connection_file,
474 475 transport=self.transport, ip=self.ip, key=self.session.key,
475 476 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
476 477 shell_port=self.shell_port, hb_port=self.hb_port,
477 478 control_port=self.control_port,
478 479 signature_scheme=self.session.signature_scheme,
479 480 )
480 481 # write_connection_file also sets default ports:
481 482 for name in port_names:
482 483 setattr(self, name, cfg[name])
483 484
484 485 self._connection_file_written = True
485 486
486 487 def load_connection_file(self):
487 488 """Load connection info from JSON dict in self.connection_file."""
488 489 with open(self.connection_file) as f:
489 490 cfg = json.loads(f.read())
490 491
491 492 self.transport = cfg.get('transport', 'tcp')
492 493 self.ip = cfg['ip']
493 494 for name in port_names:
494 495 setattr(self, name, cfg[name])
495 496 if 'key' in cfg:
496 497 self.session.key = str_to_bytes(cfg['key'])
497 498 if cfg.get('signature_scheme'):
498 499 self.session.signature_scheme = cfg['signature_scheme']
499 500
500 501 #--------------------------------------------------------------------------
501 502 # Creating connected sockets
502 503 #--------------------------------------------------------------------------
503 504
504 505 def _make_url(self, channel):
505 506 """Make a ZeroMQ URL for a given channel."""
506 507 transport = self.transport
507 508 ip = self.ip
508 509 port = getattr(self, '%s_port' % channel)
509 510
510 511 if transport == 'tcp':
511 512 return "tcp://%s:%i" % (ip, port)
512 513 else:
513 514 return "%s://%s-%s" % (transport, ip, port)
514 515
515 516 def _create_connected_socket(self, channel, identity=None):
516 517 """Create a zmq Socket and connect it to the kernel."""
517 518 url = self._make_url(channel)
518 519 socket_type = channel_socket_types[channel]
519 520 self.log.info("Connecting to: %s" % url)
520 521 sock = self.context.socket(socket_type)
521 522 if identity:
522 523 sock.identity = identity
523 524 sock.connect(url)
524 525 return sock
525 526
526 527 def connect_iopub(self, identity=None):
527 528 """return zmq Socket connected to the IOPub channel"""
528 529 sock = self._create_connected_socket('iopub', identity=identity)
529 530 sock.setsockopt(zmq.SUBSCRIBE, b'')
530 531 return sock
531 532
532 533 def connect_shell(self, identity=None):
533 534 """return zmq Socket connected to the Shell channel"""
534 535 return self._create_connected_socket('shell', identity=identity)
535 536
536 537 def connect_stdin(self, identity=None):
537 538 """return zmq Socket connected to the StdIn channel"""
538 539 return self._create_connected_socket('stdin', identity=identity)
539 540
540 541 def connect_hb(self, identity=None):
541 542 """return zmq Socket connected to the Heartbeat channel"""
542 543 return self._create_connected_socket('hb', identity=identity)
543 544
544 545 def connect_control(self, identity=None):
545 546 """return zmq Socket connected to the Heartbeat channel"""
546 547 return self._create_connected_socket('control', identity=identity)
547 548
548 549
549 550 __all__ = [
550 551 'write_connection_file',
551 552 'get_connection_file',
552 553 'find_connection_file',
553 554 'get_connection_info',
554 555 'connect_qtconsole',
555 556 'tunnel_to_kernel',
556 557 ]
@@ -1,542 +1,541 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 from tempfile import mktemp
24 23
25 24 import zmq
26 25
27 26 from IPython import parallel
28 27 from IPython.parallel.client import client as clientmod
29 28 from IPython.parallel import error
30 29 from IPython.parallel import AsyncResult, AsyncHubResult
31 30 from IPython.parallel import LoadBalancedView, DirectView
32 31
33 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
34 33
35 34 def setup():
36 35 add_engines(4, total=True)
37 36
38 37 class TestClient(ClusterTestCase):
39 38
40 39 def test_ids(self):
41 40 n = len(self.client.ids)
42 41 self.add_engines(2)
43 42 self.assertEqual(len(self.client.ids), n+2)
44 43
45 44 def test_view_indexing(self):
46 45 """test index access for views"""
47 46 self.minimum_engines(4)
48 47 targets = self.client._build_targets('all')[-1]
49 48 v = self.client[:]
50 49 self.assertEqual(v.targets, targets)
51 50 t = self.client.ids[2]
52 51 v = self.client[t]
53 52 self.assertTrue(isinstance(v, DirectView))
54 53 self.assertEqual(v.targets, t)
55 54 t = self.client.ids[2:4]
56 55 v = self.client[t]
57 56 self.assertTrue(isinstance(v, DirectView))
58 57 self.assertEqual(v.targets, t)
59 58 v = self.client[::2]
60 59 self.assertTrue(isinstance(v, DirectView))
61 60 self.assertEqual(v.targets, targets[::2])
62 61 v = self.client[1::3]
63 62 self.assertTrue(isinstance(v, DirectView))
64 63 self.assertEqual(v.targets, targets[1::3])
65 64 v = self.client[:-3]
66 65 self.assertTrue(isinstance(v, DirectView))
67 66 self.assertEqual(v.targets, targets[:-3])
68 67 v = self.client[-1]
69 68 self.assertTrue(isinstance(v, DirectView))
70 69 self.assertEqual(v.targets, targets[-1])
71 70 self.assertRaises(TypeError, lambda : self.client[None])
72 71
73 72 def test_lbview_targets(self):
74 73 """test load_balanced_view targets"""
75 74 v = self.client.load_balanced_view()
76 75 self.assertEqual(v.targets, None)
77 76 v = self.client.load_balanced_view(-1)
78 77 self.assertEqual(v.targets, [self.client.ids[-1]])
79 78 v = self.client.load_balanced_view('all')
80 79 self.assertEqual(v.targets, None)
81 80
82 81 def test_dview_targets(self):
83 82 """test direct_view targets"""
84 83 v = self.client.direct_view()
85 84 self.assertEqual(v.targets, 'all')
86 85 v = self.client.direct_view('all')
87 86 self.assertEqual(v.targets, 'all')
88 87 v = self.client.direct_view(-1)
89 88 self.assertEqual(v.targets, self.client.ids[-1])
90 89
91 90 def test_lazy_all_targets(self):
92 91 """test lazy evaluation of rc.direct_view('all')"""
93 92 v = self.client.direct_view()
94 93 self.assertEqual(v.targets, 'all')
95 94
96 95 def double(x):
97 96 return x*2
98 97 seq = range(100)
99 98 ref = [ double(x) for x in seq ]
100 99
101 100 # add some engines, which should be used
102 101 self.add_engines(1)
103 102 n1 = len(self.client.ids)
104 103
105 104 # simple apply
106 105 r = v.apply_sync(lambda : 1)
107 106 self.assertEqual(r, [1] * n1)
108 107
109 108 # map goes through remotefunction
110 109 r = v.map_sync(double, seq)
111 110 self.assertEqual(r, ref)
112 111
113 112 # add a couple more engines, and try again
114 113 self.add_engines(2)
115 114 n2 = len(self.client.ids)
116 115 self.assertNotEqual(n2, n1)
117 116
118 117 # apply
119 118 r = v.apply_sync(lambda : 1)
120 119 self.assertEqual(r, [1] * n2)
121 120
122 121 # map
123 122 r = v.map_sync(double, seq)
124 123 self.assertEqual(r, ref)
125 124
126 125 def test_targets(self):
127 126 """test various valid targets arguments"""
128 127 build = self.client._build_targets
129 128 ids = self.client.ids
130 129 idents,targets = build(None)
131 130 self.assertEqual(ids, targets)
132 131
133 132 def test_clear(self):
134 133 """test clear behavior"""
135 134 self.minimum_engines(2)
136 135 v = self.client[:]
137 136 v.block=True
138 137 v.push(dict(a=5))
139 138 v.pull('a')
140 139 id0 = self.client.ids[-1]
141 140 self.client.clear(targets=id0, block=True)
142 141 a = self.client[:-1].get('a')
143 142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 143 self.client.clear(block=True)
145 144 for i in self.client.ids:
146 145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 146
148 147 def test_get_result(self):
149 148 """test getting results from the Hub."""
150 149 c = clientmod.Client(profile='iptest')
151 150 t = c.ids[-1]
152 151 ar = c[t].apply_async(wait, 1)
153 152 # give the monitor time to notice the message
154 153 time.sleep(.25)
155 154 ahr = self.client.get_result(ar.msg_ids[0])
156 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 156 self.assertEqual(ahr.get(), ar.get())
158 157 ar2 = self.client.get_result(ar.msg_ids[0])
159 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 159 c.close()
161 160
162 161 def test_get_execute_result(self):
163 162 """test getting execute results from the Hub."""
164 163 c = clientmod.Client(profile='iptest')
165 164 t = c.ids[-1]
166 165 cell = '\n'.join([
167 166 'import time',
168 167 'time.sleep(0.25)',
169 168 '5'
170 169 ])
171 170 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 171 # give the monitor time to notice the message
173 172 time.sleep(.25)
174 173 ahr = self.client.get_result(ar.msg_ids[0])
175 174 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 175 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 176 ar2 = self.client.get_result(ar.msg_ids[0])
178 177 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 178 c.close()
180 179
181 180 def test_ids_list(self):
182 181 """test client.ids"""
183 182 ids = self.client.ids
184 183 self.assertEqual(ids, self.client._ids)
185 184 self.assertFalse(ids is self.client._ids)
186 185 ids.remove(ids[-1])
187 186 self.assertNotEqual(ids, self.client._ids)
188 187
189 188 def test_queue_status(self):
190 189 ids = self.client.ids
191 190 id0 = ids[0]
192 191 qs = self.client.queue_status(targets=id0)
193 192 self.assertTrue(isinstance(qs, dict))
194 193 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 194 allqs = self.client.queue_status()
196 195 self.assertTrue(isinstance(allqs, dict))
197 196 intkeys = list(allqs.keys())
198 197 intkeys.remove('unassigned')
199 198 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 199 unassigned = allqs.pop('unassigned')
201 200 for eid,qs in allqs.items():
202 201 self.assertTrue(isinstance(qs, dict))
203 202 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204 203
205 204 def test_shutdown(self):
206 205 ids = self.client.ids
207 206 id0 = ids[0]
208 207 self.client.shutdown(id0, block=True)
209 208 while id0 in self.client.ids:
210 209 time.sleep(0.1)
211 210 self.client.spin()
212 211
213 212 self.assertRaises(IndexError, lambda : self.client[id0])
214 213
215 214 def test_result_status(self):
216 215 pass
217 216 # to be written
218 217
219 218 def test_db_query_dt(self):
220 219 """test db query by date"""
221 220 hist = self.client.hub_history()
222 221 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 222 tic = middle['submitted']
224 223 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 224 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 225 self.assertEqual(len(before)+len(after),len(hist))
227 226 for b in before:
228 227 self.assertTrue(b['submitted'] < tic)
229 228 for a in after:
230 229 self.assertTrue(a['submitted'] >= tic)
231 230 same = self.client.db_query({'submitted' : tic})
232 231 for s in same:
233 232 self.assertTrue(s['submitted'] == tic)
234 233
235 234 def test_db_query_keys(self):
236 235 """test extracting subset of record keys"""
237 236 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 237 for rec in found:
239 238 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240 239
241 240 def test_db_query_default_keys(self):
242 241 """default db_query excludes buffers"""
243 242 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 243 for rec in found:
245 244 keys = set(rec.keys())
246 245 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 246 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248 247
249 248 def test_db_query_msg_id(self):
250 249 """ensure msg_id is always in db queries"""
251 250 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 251 for rec in found:
253 252 self.assertTrue('msg_id' in rec.keys())
254 253 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 254 for rec in found:
256 255 self.assertTrue('msg_id' in rec.keys())
257 256 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 257 for rec in found:
259 258 self.assertTrue('msg_id' in rec.keys())
260 259
261 260 def test_db_query_get_result(self):
262 261 """pop in db_query shouldn't pop from result itself"""
263 262 self.client[:].apply_sync(lambda : 1)
264 263 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 264 rc2 = clientmod.Client(profile='iptest')
266 265 # If this bug is not fixed, this call will hang:
267 266 ar = rc2.get_result(self.client.history[-1])
268 267 ar.wait(2)
269 268 self.assertTrue(ar.ready())
270 269 ar.get()
271 270 rc2.close()
272 271
273 272 def test_db_query_in(self):
274 273 """test db query with '$in','$nin' operators"""
275 274 hist = self.client.hub_history()
276 275 even = hist[::2]
277 276 odd = hist[1::2]
278 277 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 278 found = [ r['msg_id'] for r in recs ]
280 279 self.assertEqual(set(even), set(found))
281 280 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 281 found = [ r['msg_id'] for r in recs ]
283 282 self.assertEqual(set(odd), set(found))
284 283
285 284 def test_hub_history(self):
286 285 hist = self.client.hub_history()
287 286 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 287 recdict = {}
289 288 for rec in recs:
290 289 recdict[rec['msg_id']] = rec
291 290
292 291 latest = datetime(1984,1,1)
293 292 for msg_id in hist:
294 293 rec = recdict[msg_id]
295 294 newt = rec['submitted']
296 295 self.assertTrue(newt >= latest)
297 296 latest = newt
298 297 ar = self.client[-1].apply_async(lambda : 1)
299 298 ar.get()
300 299 time.sleep(0.25)
301 300 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302 301
303 302 def _wait_for_idle(self):
304 303 """wait for an engine to become idle, according to the Hub"""
305 304 rc = self.client
306 305
307 306 # step 1. wait for all requests to be noticed
308 307 # timeout 5s, polling every 100ms
309 308 msg_ids = set(rc.history)
310 309 hub_hist = rc.hub_history()
311 310 for i in range(50):
312 311 if msg_ids.difference(hub_hist):
313 312 time.sleep(0.1)
314 313 hub_hist = rc.hub_history()
315 314 else:
316 315 break
317 316
318 317 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319 318
320 319 # step 2. wait for all requests to be done
321 320 # timeout 5s, polling every 100ms
322 321 qs = rc.queue_status()
323 322 for i in range(50):
324 323 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 324 time.sleep(0.1)
326 325 qs = rc.queue_status()
327 326 else:
328 327 break
329 328
330 329 # ensure Hub up to date:
331 330 self.assertEqual(qs['unassigned'], 0)
332 331 for eid in rc.ids:
333 332 self.assertEqual(qs[eid]['tasks'], 0)
334 333
335 334
336 335 def test_resubmit(self):
337 336 def f():
338 337 import random
339 338 return random.random()
340 339 v = self.client.load_balanced_view()
341 340 ar = v.apply_async(f)
342 341 r1 = ar.get(1)
343 342 # give the Hub a chance to notice:
344 343 self._wait_for_idle()
345 344 ahr = self.client.resubmit(ar.msg_ids)
346 345 r2 = ahr.get(1)
347 346 self.assertFalse(r1 == r2)
348 347
349 348 def test_resubmit_chain(self):
350 349 """resubmit resubmitted tasks"""
351 350 v = self.client.load_balanced_view()
352 351 ar = v.apply_async(lambda x: x, 'x'*1024)
353 352 ar.get()
354 353 self._wait_for_idle()
355 354 ars = [ar]
356 355
357 356 for i in range(10):
358 357 ar = ars[-1]
359 358 ar2 = self.client.resubmit(ar.msg_ids)
360 359
361 360 [ ar.get() for ar in ars ]
362 361
363 362 def test_resubmit_header(self):
364 363 """resubmit shouldn't clobber the whole header"""
365 364 def f():
366 365 import random
367 366 return random.random()
368 367 v = self.client.load_balanced_view()
369 368 v.retries = 1
370 369 ar = v.apply_async(f)
371 370 r1 = ar.get(1)
372 371 # give the Hub a chance to notice:
373 372 self._wait_for_idle()
374 373 ahr = self.client.resubmit(ar.msg_ids)
375 374 ahr.get(1)
376 375 time.sleep(0.5)
377 376 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 377 h1,h2 = [ r['header'] for r in records ]
379 378 for key in set(h1.keys()).union(set(h2.keys())):
380 379 if key in ('msg_id', 'date'):
381 380 self.assertNotEqual(h1[key], h2[key])
382 381 else:
383 382 self.assertEqual(h1[key], h2[key])
384 383
385 384 def test_resubmit_aborted(self):
386 385 def f():
387 386 import random
388 387 return random.random()
389 388 v = self.client.load_balanced_view()
390 389 # restrict to one engine, so we can put a sleep
391 390 # ahead of the task, so it will get aborted
392 391 eid = self.client.ids[-1]
393 392 v.targets = [eid]
394 393 sleep = v.apply_async(time.sleep, 0.5)
395 394 ar = v.apply_async(f)
396 395 ar.abort()
397 396 self.assertRaises(error.TaskAborted, ar.get)
398 397 # Give the Hub a chance to get up to date:
399 398 self._wait_for_idle()
400 399 ahr = self.client.resubmit(ar.msg_ids)
401 400 r2 = ahr.get(1)
402 401
403 402 def test_resubmit_inflight(self):
404 403 """resubmit of inflight task"""
405 404 v = self.client.load_balanced_view()
406 405 ar = v.apply_async(time.sleep,1)
407 406 # give the message a chance to arrive
408 407 time.sleep(0.2)
409 408 ahr = self.client.resubmit(ar.msg_ids)
410 409 ar.get(2)
411 410 ahr.get(2)
412 411
413 412 def test_resubmit_badkey(self):
414 413 """ensure KeyError on resubmit of nonexistant task"""
415 414 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416 415
417 416 def test_purge_hub_results(self):
418 417 # ensure there are some tasks
419 418 for i in range(5):
420 419 self.client[:].apply_sync(lambda : 1)
421 420 # Wait for the Hub to realise the result is done:
422 421 # This prevents a race condition, where we
423 422 # might purge a result the Hub still thinks is pending.
424 423 self._wait_for_idle()
425 424 rc2 = clientmod.Client(profile='iptest')
426 425 hist = self.client.hub_history()
427 426 ahr = rc2.get_result([hist[-1]])
428 427 ahr.wait(10)
429 428 self.client.purge_hub_results(hist[-1])
430 429 newhist = self.client.hub_history()
431 430 self.assertEqual(len(newhist)+1,len(hist))
432 431 rc2.spin()
433 432 rc2.close()
434 433
435 434 def test_purge_local_results(self):
436 435 # ensure there are some tasks
437 436 res = []
438 437 for i in range(5):
439 438 res.append(self.client[:].apply_async(lambda : 1))
440 439 self._wait_for_idle()
441 440 self.client.wait(10) # wait for the results to come back
442 441 before = len(self.client.results)
443 442 self.assertEqual(len(self.client.metadata),before)
444 443 self.client.purge_local_results(res[-1])
445 444 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 445 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447 446
448 447 def test_purge_local_results_outstanding(self):
449 448 v = self.client[-1]
450 449 ar = v.apply_async(lambda : 1)
451 450 msg_id = ar.msg_ids[0]
452 451 ar.get()
453 452 self._wait_for_idle()
454 453 ar2 = v.apply_async(time.sleep, 1)
455 454 self.assertIn(msg_id, self.client.results)
456 455 self.assertIn(msg_id, self.client.metadata)
457 456 self.client.purge_local_results(ar)
458 457 self.assertNotIn(msg_id, self.client.results)
459 458 self.assertNotIn(msg_id, self.client.metadata)
460 459 with self.assertRaises(RuntimeError):
461 460 self.client.purge_local_results(ar2)
462 461 ar2.get()
463 462 self.client.purge_local_results(ar2)
464 463
465 464 def test_purge_all_local_results_outstanding(self):
466 465 v = self.client[-1]
467 466 ar = v.apply_async(time.sleep, 1)
468 467 with self.assertRaises(RuntimeError):
469 468 self.client.purge_local_results('all')
470 469 ar.get()
471 470 self.client.purge_local_results('all')
472 471
473 472 def test_purge_all_hub_results(self):
474 473 self.client.purge_hub_results('all')
475 474 hist = self.client.hub_history()
476 475 self.assertEqual(len(hist), 0)
477 476
478 477 def test_purge_all_local_results(self):
479 478 self.client.purge_local_results('all')
480 479 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
481 480 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
482 481
483 482 def test_purge_all_results(self):
484 483 # ensure there are some tasks
485 484 for i in range(5):
486 485 self.client[:].apply_sync(lambda : 1)
487 486 self.client.wait(10)
488 487 self._wait_for_idle()
489 488 self.client.purge_results('all')
490 489 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
491 490 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
492 491 hist = self.client.hub_history()
493 492 self.assertEqual(len(hist), 0, msg="hub history not empty")
494 493
495 494 def test_purge_everything(self):
496 495 # ensure there are some tasks
497 496 for i in range(5):
498 497 self.client[:].apply_sync(lambda : 1)
499 498 self.client.wait(10)
500 499 self._wait_for_idle()
501 500 self.client.purge_everything()
502 501 # The client results
503 502 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
504 503 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
505 504 # The client "bookkeeping"
506 505 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
507 506 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
508 507 # the hub results
509 508 hist = self.client.hub_history()
510 509 self.assertEqual(len(hist), 0, msg="hub history not empty")
511 510
512 511
513 512 def test_spin_thread(self):
514 513 self.client.spin_thread(0.01)
515 514 ar = self.client[-1].apply_async(lambda : 1)
516 515 time.sleep(0.1)
517 516 self.assertTrue(ar.wall_time < 0.1,
518 517 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
519 518 )
520 519
521 520 def test_stop_spin_thread(self):
522 521 self.client.spin_thread(0.01)
523 522 self.client.stop_spin_thread()
524 523 ar = self.client[-1].apply_async(lambda : 1)
525 524 time.sleep(0.15)
526 525 self.assertTrue(ar.wall_time > 0.1,
527 526 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
528 527 )
529 528
530 529 def test_activate(self):
531 530 ip = get_ipython()
532 531 magics = ip.magics_manager.magics
533 532 self.assertTrue('px' in magics['line'])
534 533 self.assertTrue('px' in magics['cell'])
535 534 v0 = self.client.activate(-1, '0')
536 535 self.assertTrue('px0' in magics['line'])
537 536 self.assertTrue('px0' in magics['cell'])
538 537 self.assertEqual(v0.targets, self.client.ids[-1])
539 538 v0 = self.client.activate('all', 'all')
540 539 self.assertTrue('pxall' in magics['line'])
541 540 self.assertTrue('pxall' in magics['cell'])
542 541 self.assertEqual(v0.targets, 'all')
@@ -1,801 +1,800 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from collections import namedtuple
23 from tempfile import mktemp
23 from tempfile import NamedTemporaryFile
24 24 from StringIO import StringIO
25 25
26 26 import zmq
27 27 from nose import SkipTest
28 28 from nose.plugins.attrib import attr
29 29
30 30 from IPython.testing import decorators as dec
31 31 from IPython.testing.ipunittest import ParametricTestCase
32 32 from IPython.utils.io import capture_output
33 33
34 34 from IPython import parallel as pmod
35 35 from IPython.parallel import error
36 36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 37 from IPython.parallel import DirectView
38 38 from IPython.parallel.util import interactive
39 39
40 40 from IPython.parallel.tests import add_engines
41 41
42 42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43 43
44 44 def setup():
45 45 add_engines(3, total=True)
46 46
47 47 point = namedtuple("point", "x y")
48 48
49 49 class TestView(ClusterTestCase, ParametricTestCase):
50 50
51 51 def setUp(self):
52 52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 55 time.sleep(2)
56 56 super(TestView, self).setUp()
57 57
58 58 @attr('crash')
59 59 def test_z_crash_mux(self):
60 60 """test graceful handling of engine death (direct)"""
61 61 # self.add_engines(1)
62 62 eid = self.client.ids[-1]
63 63 ar = self.client[eid].apply_async(crash)
64 64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 65 eid = ar.engine_id
66 66 tic = time.time()
67 67 while eid in self.client.ids and time.time()-tic < 5:
68 68 time.sleep(.01)
69 69 self.client.spin()
70 70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71 71
72 72 def test_push_pull(self):
73 73 """test pushing and pulling"""
74 74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 75 t = self.client.ids[-1]
76 76 v = self.client[t]
77 77 push = v.push
78 78 pull = v.pull
79 79 v.block=True
80 80 nengines = len(self.client)
81 81 push({'data':data})
82 82 d = pull('data')
83 83 self.assertEqual(d, data)
84 84 self.client[:].push({'data':data})
85 85 d = self.client[:].pull('data', block=True)
86 86 self.assertEqual(d, nengines*[data])
87 87 ar = push({'data':data}, block=False)
88 88 self.assertTrue(isinstance(ar, AsyncResult))
89 89 r = ar.get()
90 90 ar = self.client[:].pull('data', block=False)
91 91 self.assertTrue(isinstance(ar, AsyncResult))
92 92 r = ar.get()
93 93 self.assertEqual(r, nengines*[data])
94 94 self.client[:].push(dict(a=10,b=20))
95 95 r = self.client[:].pull(('a','b'), block=True)
96 96 self.assertEqual(r, nengines*[[10,20]])
97 97
98 98 def test_push_pull_function(self):
99 99 "test pushing and pulling functions"
100 100 def testf(x):
101 101 return 2.0*x
102 102
103 103 t = self.client.ids[-1]
104 104 v = self.client[t]
105 105 v.block=True
106 106 push = v.push
107 107 pull = v.pull
108 108 execute = v.execute
109 109 push({'testf':testf})
110 110 r = pull('testf')
111 111 self.assertEqual(r(1.0), testf(1.0))
112 112 execute('r = testf(10)')
113 113 r = pull('r')
114 114 self.assertEqual(r, testf(10))
115 115 ar = self.client[:].push({'testf':testf}, block=False)
116 116 ar.get()
117 117 ar = self.client[:].pull('testf', block=False)
118 118 rlist = ar.get()
119 119 for r in rlist:
120 120 self.assertEqual(r(1.0), testf(1.0))
121 121 execute("def g(x): return x*x")
122 122 r = pull(('testf','g'))
123 123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124 124
125 125 def test_push_function_globals(self):
126 126 """test that pushed functions have access to globals"""
127 127 @interactive
128 128 def geta():
129 129 return a
130 130 # self.add_engines(1)
131 131 v = self.client[-1]
132 132 v.block=True
133 133 v['f'] = geta
134 134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 135 v.execute('a=5')
136 136 v.execute('b=f()')
137 137 self.assertEqual(v['b'], 5)
138 138
139 139 def test_push_function_defaults(self):
140 140 """test that pushed functions preserve default args"""
141 141 def echo(a=10):
142 142 return a
143 143 v = self.client[-1]
144 144 v.block=True
145 145 v['f'] = echo
146 146 v.execute('b=f()')
147 147 self.assertEqual(v['b'], 10)
148 148
149 149 def test_get_result(self):
150 150 """test getting results from the Hub."""
151 151 c = pmod.Client(profile='iptest')
152 152 # self.add_engines(1)
153 153 t = c.ids[-1]
154 154 v = c[t]
155 155 v2 = self.client[t]
156 156 ar = v.apply_async(wait, 1)
157 157 # give the monitor time to notice the message
158 158 time.sleep(.25)
159 159 ahr = v2.get_result(ar.msg_ids[0])
160 160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 161 self.assertEqual(ahr.get(), ar.get())
162 162 ar2 = v2.get_result(ar.msg_ids[0])
163 163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 164 c.spin()
165 165 c.close()
166 166
167 167 def test_run_newline(self):
168 168 """test that run appends newline to files"""
169 tmpfile = mktemp()
170 with open(tmpfile, 'w') as f:
169 with NamedTemporaryFile('w', delete=False) as f:
171 170 f.write("""def g():
172 171 return 5
173 172 """)
174 173 v = self.client[-1]
175 v.run(tmpfile, block=True)
174 v.run(f.name, block=True)
176 175 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177 176
178 177 def test_apply_tracked(self):
179 178 """test tracking for apply"""
180 179 # self.add_engines(1)
181 180 t = self.client.ids[-1]
182 181 v = self.client[t]
183 182 v.block=False
184 183 def echo(n=1024*1024, **kwargs):
185 184 with v.temp_flags(**kwargs):
186 185 return v.apply(lambda x: x, 'x'*n)
187 186 ar = echo(1, track=False)
188 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 188 self.assertTrue(ar.sent)
190 189 ar = echo(track=True)
191 190 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 191 self.assertEqual(ar.sent, ar._tracker.done)
193 192 ar._tracker.wait()
194 193 self.assertTrue(ar.sent)
195 194
196 195 def test_push_tracked(self):
197 196 t = self.client.ids[-1]
198 197 ns = dict(x='x'*1024*1024)
199 198 v = self.client[t]
200 199 ar = v.push(ns, block=False, track=False)
201 200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 201 self.assertTrue(ar.sent)
203 202
204 203 ar = v.push(ns, block=False, track=True)
205 204 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 205 ar._tracker.wait()
207 206 self.assertEqual(ar.sent, ar._tracker.done)
208 207 self.assertTrue(ar.sent)
209 208 ar.get()
210 209
211 210 def test_scatter_tracked(self):
212 211 t = self.client.ids
213 212 x='x'*1024*1024
214 213 ar = self.client[t].scatter('x', x, block=False, track=False)
215 214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 215 self.assertTrue(ar.sent)
217 216
218 217 ar = self.client[t].scatter('x', x, block=False, track=True)
219 218 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 219 self.assertEqual(ar.sent, ar._tracker.done)
221 220 ar._tracker.wait()
222 221 self.assertTrue(ar.sent)
223 222 ar.get()
224 223
225 224 def test_remote_reference(self):
226 225 v = self.client[-1]
227 226 v['a'] = 123
228 227 ra = pmod.Reference('a')
229 228 b = v.apply_sync(lambda x: x, ra)
230 229 self.assertEqual(b, 123)
231 230
232 231
233 232 def test_scatter_gather(self):
234 233 view = self.client[:]
235 234 seq1 = range(16)
236 235 view.scatter('a', seq1)
237 236 seq2 = view.gather('a', block=True)
238 237 self.assertEqual(seq2, seq1)
239 238 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240 239
241 240 @skip_without('numpy')
242 241 def test_scatter_gather_numpy(self):
243 242 import numpy
244 243 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 244 view = self.client[:]
246 245 a = numpy.arange(64)
247 246 view.scatter('a', a, block=True)
248 247 b = view.gather('a', block=True)
249 248 assert_array_equal(b, a)
250 249
251 250 def test_scatter_gather_lazy(self):
252 251 """scatter/gather with targets='all'"""
253 252 view = self.client.direct_view(targets='all')
254 253 x = range(64)
255 254 view.scatter('x', x)
256 255 gathered = view.gather('x', block=True)
257 256 self.assertEqual(gathered, x)
258 257
259 258
260 259 @dec.known_failure_py3
261 260 @skip_without('numpy')
262 261 def test_push_numpy_nocopy(self):
263 262 import numpy
264 263 view = self.client[:]
265 264 a = numpy.arange(64)
266 265 view['A'] = a
267 266 @interactive
268 267 def check_writeable(x):
269 268 return x.flags.writeable
270 269
271 270 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 272
274 273 view.push(dict(B=a))
275 274 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 275 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277 276
278 277 @skip_without('numpy')
279 278 def test_apply_numpy(self):
280 279 """view.apply(f, ndarray)"""
281 280 import numpy
282 281 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283 282
284 283 A = numpy.random.random((100,100))
285 284 view = self.client[-1]
286 285 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 286 B = A.astype(dt)
288 287 C = view.apply_sync(lambda x:x, B)
289 288 assert_array_equal(B,C)
290 289
291 290 @skip_without('numpy')
292 291 def test_push_pull_recarray(self):
293 292 """push/pull recarrays"""
294 293 import numpy
295 294 from numpy.testing.utils import assert_array_equal
296 295
297 296 view = self.client[-1]
298 297
299 298 R = numpy.array([
300 299 (1, 'hi', 0.),
301 300 (2**30, 'there', 2.5),
302 301 (-99999, 'world', -12345.6789),
303 302 ], [('n', int), ('s', '|S10'), ('f', float)])
304 303
305 304 view['RR'] = R
306 305 R2 = view['RR']
307 306
308 307 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 308 self.assertEqual(r_dtype, R.dtype)
310 309 self.assertEqual(r_shape, R.shape)
311 310 self.assertEqual(R2.dtype, R.dtype)
312 311 self.assertEqual(R2.shape, R.shape)
313 312 assert_array_equal(R2, R)
314 313
315 314 @skip_without('pandas')
316 315 def test_push_pull_timeseries(self):
317 316 """push/pull pandas.TimeSeries"""
318 317 import pandas
319 318
320 319 ts = pandas.TimeSeries(range(10))
321 320
322 321 view = self.client[-1]
323 322
324 323 view.push(dict(ts=ts), block=True)
325 324 rts = view['ts']
326 325
327 326 self.assertEqual(type(rts), type(ts))
328 327 self.assertTrue((ts == rts).all())
329 328
330 329 def test_map(self):
331 330 view = self.client[:]
332 331 def f(x):
333 332 return x**2
334 333 data = range(16)
335 334 r = view.map_sync(f, data)
336 335 self.assertEqual(r, map(f, data))
337 336
338 337 def test_map_iterable(self):
339 338 """test map on iterables (direct)"""
340 339 view = self.client[:]
341 340 # 101 is prime, so it won't be evenly distributed
342 341 arr = range(101)
343 342 # ensure it will be an iterator, even in Python 3
344 343 it = iter(arr)
345 344 r = view.map_sync(lambda x: x, it)
346 345 self.assertEqual(r, list(arr))
347 346
348 347 @skip_without('numpy')
349 348 def test_map_numpy(self):
350 349 """test map on numpy arrays (direct)"""
351 350 import numpy
352 351 from numpy.testing.utils import assert_array_equal
353 352
354 353 view = self.client[:]
355 354 # 101 is prime, so it won't be evenly distributed
356 355 arr = numpy.arange(101)
357 356 r = view.map_sync(lambda x: x, arr)
358 357 assert_array_equal(r, arr)
359 358
360 359 def test_scatter_gather_nonblocking(self):
361 360 data = range(16)
362 361 view = self.client[:]
363 362 view.scatter('a', data, block=False)
364 363 ar = view.gather('a', block=False)
365 364 self.assertEqual(ar.get(), data)
366 365
367 366 @skip_without('numpy')
368 367 def test_scatter_gather_numpy_nonblocking(self):
369 368 import numpy
370 369 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 370 a = numpy.arange(64)
372 371 view = self.client[:]
373 372 ar = view.scatter('a', a, block=False)
374 373 self.assertTrue(isinstance(ar, AsyncResult))
375 374 amr = view.gather('a', block=False)
376 375 self.assertTrue(isinstance(amr, AsyncMapResult))
377 376 assert_array_equal(amr.get(), a)
378 377
379 378 def test_execute(self):
380 379 view = self.client[:]
381 380 # self.client.debug=True
382 381 execute = view.execute
383 382 ar = execute('c=30', block=False)
384 383 self.assertTrue(isinstance(ar, AsyncResult))
385 384 ar = execute('d=[0,1,2]', block=False)
386 385 self.client.wait(ar, 1)
387 386 self.assertEqual(len(ar.get()), len(self.client))
388 387 for c in view['c']:
389 388 self.assertEqual(c, 30)
390 389
391 390 def test_abort(self):
392 391 view = self.client[-1]
393 392 ar = view.execute('import time; time.sleep(1)', block=False)
394 393 ar2 = view.apply_async(lambda : 2)
395 394 ar3 = view.apply_async(lambda : 3)
396 395 view.abort(ar2)
397 396 view.abort(ar3.msg_ids)
398 397 self.assertRaises(error.TaskAborted, ar2.get)
399 398 self.assertRaises(error.TaskAborted, ar3.get)
400 399
401 400 def test_abort_all(self):
402 401 """view.abort() aborts all outstanding tasks"""
403 402 view = self.client[-1]
404 403 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 404 view.abort()
406 405 view.wait(timeout=5)
407 406 for ar in ars[5:]:
408 407 self.assertRaises(error.TaskAborted, ar.get)
409 408
410 409 def test_temp_flags(self):
411 410 view = self.client[-1]
412 411 view.block=True
413 412 with view.temp_flags(block=False):
414 413 self.assertFalse(view.block)
415 414 self.assertTrue(view.block)
416 415
417 416 @dec.known_failure_py3
418 417 def test_importer(self):
419 418 view = self.client[-1]
420 419 view.clear(block=True)
421 420 with view.importer:
422 421 import re
423 422
424 423 @interactive
425 424 def findall(pat, s):
426 425 # this globals() step isn't necessary in real code
427 426 # only to prevent a closure in the test
428 427 re = globals()['re']
429 428 return re.findall(pat, s)
430 429
431 430 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432 431
433 432 def test_unicode_execute(self):
434 433 """test executing unicode strings"""
435 434 v = self.client[-1]
436 435 v.block=True
437 436 if sys.version_info[0] >= 3:
438 437 code="a='é'"
439 438 else:
440 439 code=u"a=u'é'"
441 440 v.execute(code)
442 441 self.assertEqual(v['a'], u'é')
443 442
444 443 def test_unicode_apply_result(self):
445 444 """test unicode apply results"""
446 445 v = self.client[-1]
447 446 r = v.apply_sync(lambda : u'é')
448 447 self.assertEqual(r, u'é')
449 448
450 449 def test_unicode_apply_arg(self):
451 450 """test passing unicode arguments to apply"""
452 451 v = self.client[-1]
453 452
454 453 @interactive
455 454 def check_unicode(a, check):
456 455 assert isinstance(a, unicode), "%r is not unicode"%a
457 456 assert isinstance(check, bytes), "%r is not bytes"%check
458 457 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459 458
460 459 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 460 try:
462 461 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 462 except error.RemoteError as e:
464 463 if e.ename == 'AssertionError':
465 464 self.fail(e.evalue)
466 465 else:
467 466 raise e
468 467
469 468 def test_map_reference(self):
470 469 """view.map(<Reference>, *seqs) should work"""
471 470 v = self.client[:]
472 471 v.scatter('n', self.client.ids, flatten=True)
473 472 v.execute("f = lambda x,y: x*y")
474 473 rf = pmod.Reference('f')
475 474 nlist = list(range(10))
476 475 mlist = nlist[::-1]
477 476 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 477 result = v.map_sync(rf, mlist, nlist)
479 478 self.assertEqual(result, expected)
480 479
481 480 def test_apply_reference(self):
482 481 """view.apply(<Reference>, *args) should work"""
483 482 v = self.client[:]
484 483 v.scatter('n', self.client.ids, flatten=True)
485 484 v.execute("f = lambda x: n*x")
486 485 rf = pmod.Reference('f')
487 486 result = v.apply_sync(rf, 5)
488 487 expected = [ 5*id for id in self.client.ids ]
489 488 self.assertEqual(result, expected)
490 489
491 490 def test_eval_reference(self):
492 491 v = self.client[self.client.ids[0]]
493 492 v['g'] = range(5)
494 493 rg = pmod.Reference('g[0]')
495 494 echo = lambda x:x
496 495 self.assertEqual(v.apply_sync(echo, rg), 0)
497 496
498 497 def test_reference_nameerror(self):
499 498 v = self.client[self.client.ids[0]]
500 499 r = pmod.Reference('elvis_has_left')
501 500 echo = lambda x:x
502 501 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503 502
504 503 def test_single_engine_map(self):
505 504 e0 = self.client[self.client.ids[0]]
506 505 r = range(5)
507 506 check = [ -1*i for i in r ]
508 507 result = e0.map_sync(lambda x: -1*x, r)
509 508 self.assertEqual(result, check)
510 509
511 510 def test_len(self):
512 511 """len(view) makes sense"""
513 512 e0 = self.client[self.client.ids[0]]
514 513 yield self.assertEqual(len(e0), 1)
515 514 v = self.client[:]
516 515 yield self.assertEqual(len(v), len(self.client.ids))
517 516 v = self.client.direct_view('all')
518 517 yield self.assertEqual(len(v), len(self.client.ids))
519 518 v = self.client[:2]
520 519 yield self.assertEqual(len(v), 2)
521 520 v = self.client[:1]
522 521 yield self.assertEqual(len(v), 1)
523 522 v = self.client.load_balanced_view()
524 523 yield self.assertEqual(len(v), len(self.client.ids))
525 524 # parametric tests seem to require manual closing?
526 525 self.client.close()
527 526
528 527
529 528 # begin execute tests
530 529
531 530 def test_execute_reply(self):
532 531 e0 = self.client[self.client.ids[0]]
533 532 e0.block = True
534 533 ar = e0.execute("5", silent=False)
535 534 er = ar.get()
536 535 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 536 self.assertEqual(er.pyout['data']['text/plain'], '5')
538 537
539 538 def test_execute_reply_stdout(self):
540 539 e0 = self.client[self.client.ids[0]]
541 540 e0.block = True
542 541 ar = e0.execute("print (5)", silent=False)
543 542 er = ar.get()
544 543 self.assertEqual(er.stdout.strip(), '5')
545 544
546 545 def test_execute_pyout(self):
547 546 """execute triggers pyout with silent=False"""
548 547 view = self.client[:]
549 548 ar = view.execute("5", silent=False, block=True)
550 549
551 550 expected = [{'text/plain' : '5'}] * len(view)
552 551 mimes = [ out['data'] for out in ar.pyout ]
553 552 self.assertEqual(mimes, expected)
554 553
555 554 def test_execute_silent(self):
556 555 """execute does not trigger pyout with silent=True"""
557 556 view = self.client[:]
558 557 ar = view.execute("5", block=True)
559 558 expected = [None] * len(view)
560 559 self.assertEqual(ar.pyout, expected)
561 560
562 561 def test_execute_magic(self):
563 562 """execute accepts IPython commands"""
564 563 view = self.client[:]
565 564 view.execute("a = 5")
566 565 ar = view.execute("%whos", block=True)
567 566 # this will raise, if that failed
568 567 ar.get(5)
569 568 for stdout in ar.stdout:
570 569 lines = stdout.splitlines()
571 570 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
572 571 found = False
573 572 for line in lines[2:]:
574 573 split = line.split()
575 574 if split == ['a', 'int', '5']:
576 575 found = True
577 576 break
578 577 self.assertTrue(found, "whos output wrong: %s" % stdout)
579 578
580 579 def test_execute_displaypub(self):
581 580 """execute tracks display_pub output"""
582 581 view = self.client[:]
583 582 view.execute("from IPython.core.display import *")
584 583 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
585 584
586 585 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
587 586 for outputs in ar.outputs:
588 587 mimes = [ out['data'] for out in outputs ]
589 588 self.assertEqual(mimes, expected)
590 589
591 590 def test_apply_displaypub(self):
592 591 """apply tracks display_pub output"""
593 592 view = self.client[:]
594 593 view.execute("from IPython.core.display import *")
595 594
596 595 @interactive
597 596 def publish():
598 597 [ display(i) for i in range(5) ]
599 598
600 599 ar = view.apply_async(publish)
601 600 ar.get(5)
602 601 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
603 602 for outputs in ar.outputs:
604 603 mimes = [ out['data'] for out in outputs ]
605 604 self.assertEqual(mimes, expected)
606 605
607 606 def test_execute_raises(self):
608 607 """exceptions in execute requests raise appropriately"""
609 608 view = self.client[-1]
610 609 ar = view.execute("1/0")
611 610 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
612 611
613 612 def test_remoteerror_render_exception(self):
614 613 """RemoteErrors get nice tracebacks"""
615 614 view = self.client[-1]
616 615 ar = view.execute("1/0")
617 616 ip = get_ipython()
618 617 ip.user_ns['ar'] = ar
619 618 with capture_output() as io:
620 619 ip.run_cell("ar.get(2)")
621 620
622 621 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
623 622
624 623 def test_compositeerror_render_exception(self):
625 624 """CompositeErrors get nice tracebacks"""
626 625 view = self.client[:]
627 626 ar = view.execute("1/0")
628 627 ip = get_ipython()
629 628 ip.user_ns['ar'] = ar
630 629
631 630 with capture_output() as io:
632 631 ip.run_cell("ar.get(2)")
633 632
634 633 count = min(error.CompositeError.tb_limit, len(view))
635 634
636 635 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
637 636 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
638 637 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
639 638
640 639 def test_compositeerror_truncate(self):
641 640 """Truncate CompositeErrors with many exceptions"""
642 641 view = self.client[:]
643 642 msg_ids = []
644 643 for i in range(10):
645 644 ar = view.execute("1/0")
646 645 msg_ids.extend(ar.msg_ids)
647 646
648 647 ar = self.client.get_result(msg_ids)
649 648 try:
650 649 ar.get()
651 650 except error.CompositeError as _e:
652 651 e = _e
653 652 else:
654 653 self.fail("Should have raised CompositeError")
655 654
656 655 lines = e.render_traceback()
657 656 with capture_output() as io:
658 657 e.print_traceback()
659 658
660 659 self.assertTrue("more exceptions" in lines[-1])
661 660 count = e.tb_limit
662 661
663 662 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
664 663 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
665 664 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
666 665
667 666 @dec.skipif_not_matplotlib
668 667 def test_magic_pylab(self):
669 668 """%pylab works on engines"""
670 669 view = self.client[-1]
671 670 ar = view.execute("%pylab inline")
672 671 # at least check if this raised:
673 672 reply = ar.get(5)
674 673 # include imports, in case user config
675 674 ar = view.execute("plot(rand(100))", silent=False)
676 675 reply = ar.get(5)
677 676 self.assertEqual(len(reply.outputs), 1)
678 677 output = reply.outputs[0]
679 678 self.assertTrue("data" in output)
680 679 data = output['data']
681 680 self.assertTrue("image/png" in data)
682 681
683 682 def test_func_default_func(self):
684 683 """interactively defined function as apply func default"""
685 684 def foo():
686 685 return 'foo'
687 686
688 687 def bar(f=foo):
689 688 return f()
690 689
691 690 view = self.client[-1]
692 691 ar = view.apply_async(bar)
693 692 r = ar.get(10)
694 693 self.assertEqual(r, 'foo')
695 694 def test_data_pub_single(self):
696 695 view = self.client[-1]
697 696 ar = view.execute('\n'.join([
698 697 'from IPython.kernel.zmq.datapub import publish_data',
699 698 'for i in range(5):',
700 699 ' publish_data(dict(i=i))'
701 700 ]), block=False)
702 701 self.assertTrue(isinstance(ar.data, dict))
703 702 ar.get(5)
704 703 self.assertEqual(ar.data, dict(i=4))
705 704
706 705 def test_data_pub(self):
707 706 view = self.client[:]
708 707 ar = view.execute('\n'.join([
709 708 'from IPython.kernel.zmq.datapub import publish_data',
710 709 'for i in range(5):',
711 710 ' publish_data(dict(i=i))'
712 711 ]), block=False)
713 712 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
714 713 ar.get(5)
715 714 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
716 715
717 716 def test_can_list_arg(self):
718 717 """args in lists are canned"""
719 718 view = self.client[-1]
720 719 view['a'] = 128
721 720 rA = pmod.Reference('a')
722 721 ar = view.apply_async(lambda x: x, [rA])
723 722 r = ar.get(5)
724 723 self.assertEqual(r, [128])
725 724
726 725 def test_can_dict_arg(self):
727 726 """args in dicts are canned"""
728 727 view = self.client[-1]
729 728 view['a'] = 128
730 729 rA = pmod.Reference('a')
731 730 ar = view.apply_async(lambda x: x, dict(foo=rA))
732 731 r = ar.get(5)
733 732 self.assertEqual(r, dict(foo=128))
734 733
735 734 def test_can_list_kwarg(self):
736 735 """kwargs in lists are canned"""
737 736 view = self.client[-1]
738 737 view['a'] = 128
739 738 rA = pmod.Reference('a')
740 739 ar = view.apply_async(lambda x=5: x, x=[rA])
741 740 r = ar.get(5)
742 741 self.assertEqual(r, [128])
743 742
744 743 def test_can_dict_kwarg(self):
745 744 """kwargs in dicts are canned"""
746 745 view = self.client[-1]
747 746 view['a'] = 128
748 747 rA = pmod.Reference('a')
749 748 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
750 749 r = ar.get(5)
751 750 self.assertEqual(r, dict(foo=128))
752 751
753 752 def test_map_ref(self):
754 753 """view.map works with references"""
755 754 view = self.client[:]
756 755 ranks = sorted(self.client.ids)
757 756 view.scatter('rank', ranks, flatten=True)
758 757 rrank = pmod.Reference('rank')
759 758
760 759 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
761 760 drank = amr.get(5)
762 761 self.assertEqual(drank, [ r*2 for r in ranks ])
763 762
764 763 def test_nested_getitem_setitem(self):
765 764 """get and set with view['a.b']"""
766 765 view = self.client[-1]
767 766 view.execute('\n'.join([
768 767 'class A(object): pass',
769 768 'a = A()',
770 769 'a.b = 128',
771 770 ]), block=True)
772 771 ra = pmod.Reference('a')
773 772
774 773 r = view.apply_sync(lambda x: x.b, ra)
775 774 self.assertEqual(r, 128)
776 775 self.assertEqual(view['a.b'], 128)
777 776
778 777 view['a.b'] = 0
779 778
780 779 r = view.apply_sync(lambda x: x.b, ra)
781 780 self.assertEqual(r, 0)
782 781 self.assertEqual(view['a.b'], 0)
783 782
784 783 def test_return_namedtuple(self):
785 784 def namedtuplify(x, y):
786 785 from IPython.parallel.tests.test_view import point
787 786 return point(x, y)
788 787
789 788 view = self.client[-1]
790 789 p = view.apply_sync(namedtuplify, 1, 2)
791 790 self.assertEqual(p.x, 1)
792 791 self.assertEqual(p.y, 2)
793 792
794 793 def test_apply_namedtuple(self):
795 794 def echoxy(p):
796 795 return p.y, p.x
797 796
798 797 view = self.client[-1]
799 798 tup = view.apply_sync(echoxy, point(1, 2))
800 799 self.assertEqual(tup, (2,1))
801 800
@@ -1,442 +1,444 b''
1 1 """Generic testing tools.
2 2
3 3 Authors
4 4 -------
5 5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 6 """
7 7
8 8 from __future__ import absolute_import
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2009 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20
21 21 import os
22 22 import re
23 23 import sys
24 24 import tempfile
25 25
26 26 from contextlib import contextmanager
27 27 from io import StringIO
28 28 from subprocess import Popen, PIPE
29 29
30 30 try:
31 31 # These tools are used by parts of the runtime, so we make the nose
32 32 # dependency optional at this point. Nose is a hard dependency to run the
33 33 # test suite, but NOT to use ipython itself.
34 34 import nose.tools as nt
35 35 has_nose = True
36 36 except ImportError:
37 37 has_nose = False
38 38
39 39 from IPython.config.loader import Config
40 40 from IPython.utils.process import get_output_error_code
41 41 from IPython.utils.text import list_strings
42 42 from IPython.utils.io import temp_pyfile, Tee
43 43 from IPython.utils import py3compat
44 44 from IPython.utils.encoding import DEFAULT_ENCODING
45 45
46 46 from . import decorators as dec
47 47 from . import skipdoctest
48 48
49 49 #-----------------------------------------------------------------------------
50 50 # Functions and classes
51 51 #-----------------------------------------------------------------------------
52 52
53 53 # The docstring for full_path doctests differently on win32 (different path
54 54 # separator) so just skip the doctest there. The example remains informative.
55 55 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
56 56
57 57 @doctest_deco
58 58 def full_path(startPath,files):
59 59 """Make full paths for all the listed files, based on startPath.
60 60
61 61 Only the base part of startPath is kept, since this routine is typically
62 62 used with a script's __file__ variable as startPath. The base of startPath
63 63 is then prepended to all the listed files, forming the output list.
64 64
65 65 Parameters
66 66 ----------
67 67 startPath : string
68 68 Initial path to use as the base for the results. This path is split
69 69 using os.path.split() and only its first component is kept.
70 70
71 71 files : string or list
72 72 One or more files.
73 73
74 74 Examples
75 75 --------
76 76
77 77 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
78 78 ['/foo/a.txt', '/foo/b.txt']
79 79
80 80 >>> full_path('/foo',['a.txt','b.txt'])
81 81 ['/a.txt', '/b.txt']
82 82
83 83 If a single file is given, the output is still a list:
84 84 >>> full_path('/foo','a.txt')
85 85 ['/a.txt']
86 86 """
87 87
88 88 files = list_strings(files)
89 89 base = os.path.split(startPath)[0]
90 90 return [ os.path.join(base,f) for f in files ]
91 91
92 92
93 93 def parse_test_output(txt):
94 94 """Parse the output of a test run and return errors, failures.
95 95
96 96 Parameters
97 97 ----------
98 98 txt : str
99 99 Text output of a test run, assumed to contain a line of one of the
100 100 following forms::
101 101
102 102 'FAILED (errors=1)'
103 103 'FAILED (failures=1)'
104 104 'FAILED (errors=1, failures=1)'
105 105
106 106 Returns
107 107 -------
108 108 nerr, nfail: number of errors and failures.
109 109 """
110 110
111 111 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
112 112 if err_m:
113 113 nerr = int(err_m.group(1))
114 114 nfail = 0
115 115 return nerr, nfail
116 116
117 117 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
118 118 if fail_m:
119 119 nerr = 0
120 120 nfail = int(fail_m.group(1))
121 121 return nerr, nfail
122 122
123 123 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
124 124 re.MULTILINE)
125 125 if both_m:
126 126 nerr = int(both_m.group(1))
127 127 nfail = int(both_m.group(2))
128 128 return nerr, nfail
129 129
130 130 # If the input didn't match any of these forms, assume no error/failures
131 131 return 0, 0
132 132
133 133
134 134 # So nose doesn't think this is a test
135 135 parse_test_output.__test__ = False
136 136
137 137
138 138 def default_argv():
139 139 """Return a valid default argv for creating testing instances of ipython"""
140 140
141 141 return ['--quick', # so no config file is loaded
142 142 # Other defaults to minimize side effects on stdout
143 143 '--colors=NoColor', '--no-term-title','--no-banner',
144 144 '--autocall=0']
145 145
146 146
147 147 def default_config():
148 148 """Return a config object with good defaults for testing."""
149 149 config = Config()
150 150 config.TerminalInteractiveShell.colors = 'NoColor'
151 151 config.TerminalTerminalInteractiveShell.term_title = False,
152 152 config.TerminalInteractiveShell.autocall = 0
153 config.HistoryManager.hist_file = tempfile.mktemp(u'test_hist.sqlite')
153 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
154 config.HistoryManager.hist_file = f.name
155 f.close()
154 156 config.HistoryManager.db_cache_size = 10000
155 157 return config
156 158
157 159
158 160 def get_ipython_cmd(as_string=False):
159 161 """
160 162 Return appropriate IPython command line name. By default, this will return
161 163 a list that can be used with subprocess.Popen, for example, but passing
162 164 `as_string=True` allows for returning the IPython command as a string.
163 165
164 166 Parameters
165 167 ----------
166 168 as_string: bool
167 169 Flag to allow to return the command as a string.
168 170 """
169 171 # FIXME: remove workaround for 2.6 support
170 172 if sys.version_info[:2] > (2,6):
171 173 ipython_cmd = [sys.executable, "-m", "IPython"]
172 174 else:
173 175 ipython_cmd = ["ipython"]
174 176
175 177 if as_string:
176 178 ipython_cmd = " ".join(ipython_cmd)
177 179
178 180 return ipython_cmd
179 181
180 182 def ipexec(fname, options=None):
181 183 """Utility to call 'ipython filename'.
182 184
183 185 Starts IPython with a minimal and safe configuration to make startup as fast
184 186 as possible.
185 187
186 188 Note that this starts IPython in a subprocess!
187 189
188 190 Parameters
189 191 ----------
190 192 fname : str
191 193 Name of file to be executed (should have .py or .ipy extension).
192 194
193 195 options : optional, list
194 196 Extra command-line flags to be passed to IPython.
195 197
196 198 Returns
197 199 -------
198 200 (stdout, stderr) of ipython subprocess.
199 201 """
200 202 if options is None: options = []
201 203
202 204 # For these subprocess calls, eliminate all prompt printing so we only see
203 205 # output from script execution
204 206 prompt_opts = [ '--PromptManager.in_template=""',
205 207 '--PromptManager.in2_template=""',
206 208 '--PromptManager.out_template=""'
207 209 ]
208 210 cmdargs = default_argv() + prompt_opts + options
209 211
210 212 test_dir = os.path.dirname(__file__)
211 213
212 214 ipython_cmd = get_ipython_cmd()
213 215 # Absolute path for filename
214 216 full_fname = os.path.join(test_dir, fname)
215 217 full_cmd = ipython_cmd + cmdargs + [full_fname]
216 218 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE)
217 219 out, err = p.communicate()
218 220 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
219 221 # `import readline` causes 'ESC[?1034h' to be output sometimes,
220 222 # so strip that out before doing comparisons
221 223 if out:
222 224 out = re.sub(r'\x1b\[[^h]+h', '', out)
223 225 return out, err
224 226
225 227
226 228 def ipexec_validate(fname, expected_out, expected_err='',
227 229 options=None):
228 230 """Utility to call 'ipython filename' and validate output/error.
229 231
230 232 This function raises an AssertionError if the validation fails.
231 233
232 234 Note that this starts IPython in a subprocess!
233 235
234 236 Parameters
235 237 ----------
236 238 fname : str
237 239 Name of the file to be executed (should have .py or .ipy extension).
238 240
239 241 expected_out : str
240 242 Expected stdout of the process.
241 243
242 244 expected_err : optional, str
243 245 Expected stderr of the process.
244 246
245 247 options : optional, list
246 248 Extra command-line flags to be passed to IPython.
247 249
248 250 Returns
249 251 -------
250 252 None
251 253 """
252 254
253 255 import nose.tools as nt
254 256
255 257 out, err = ipexec(fname, options)
256 258 #print 'OUT', out # dbg
257 259 #print 'ERR', err # dbg
258 260 # If there are any errors, we must check those befor stdout, as they may be
259 261 # more informative than simply having an empty stdout.
260 262 if err:
261 263 if expected_err:
262 264 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
263 265 else:
264 266 raise ValueError('Running file %r produced error: %r' %
265 267 (fname, err))
266 268 # If no errors or output on stderr was expected, match stdout
267 269 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
268 270
269 271
270 272 class TempFileMixin(object):
271 273 """Utility class to create temporary Python/IPython files.
272 274
273 275 Meant as a mixin class for test cases."""
274 276
275 277 def mktmp(self, src, ext='.py'):
276 278 """Make a valid python temp file."""
277 279 fname, f = temp_pyfile(src, ext)
278 280 self.tmpfile = f
279 281 self.fname = fname
280 282
281 283 def tearDown(self):
282 284 if hasattr(self, 'tmpfile'):
283 285 # If the tmpfile wasn't made because of skipped tests, like in
284 286 # win32, there's nothing to cleanup.
285 287 self.tmpfile.close()
286 288 try:
287 289 os.unlink(self.fname)
288 290 except:
289 291 # On Windows, even though we close the file, we still can't
290 292 # delete it. I have no clue why
291 293 pass
292 294
293 295 pair_fail_msg = ("Testing {0}\n\n"
294 296 "In:\n"
295 297 " {1!r}\n"
296 298 "Expected:\n"
297 299 " {2!r}\n"
298 300 "Got:\n"
299 301 " {3!r}\n")
300 302 def check_pairs(func, pairs):
301 303 """Utility function for the common case of checking a function with a
302 304 sequence of input/output pairs.
303 305
304 306 Parameters
305 307 ----------
306 308 func : callable
307 309 The function to be tested. Should accept a single argument.
308 310 pairs : iterable
309 311 A list of (input, expected_output) tuples.
310 312
311 313 Returns
312 314 -------
313 315 None. Raises an AssertionError if any output does not match the expected
314 316 value.
315 317 """
316 318 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
317 319 for inp, expected in pairs:
318 320 out = func(inp)
319 321 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
320 322
321 323
322 324 if py3compat.PY3:
323 325 MyStringIO = StringIO
324 326 else:
325 327 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
326 328 # so we need a class that can handle both.
327 329 class MyStringIO(StringIO):
328 330 def write(self, s):
329 331 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
330 332 super(MyStringIO, self).write(s)
331 333
332 334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
333 335 -------
334 336 {2!s}
335 337 -------
336 338 """
337 339
338 340 class AssertPrints(object):
339 341 """Context manager for testing that code prints certain text.
340 342
341 343 Examples
342 344 --------
343 345 >>> with AssertPrints("abc", suppress=False):
344 346 ... print "abcd"
345 347 ... print "def"
346 348 ...
347 349 abcd
348 350 def
349 351 """
350 352 def __init__(self, s, channel='stdout', suppress=True):
351 353 self.s = s
352 354 if isinstance(self.s, py3compat.string_types):
353 355 self.s = [self.s]
354 356 self.channel = channel
355 357 self.suppress = suppress
356 358
357 359 def __enter__(self):
358 360 self.orig_stream = getattr(sys, self.channel)
359 361 self.buffer = MyStringIO()
360 362 self.tee = Tee(self.buffer, channel=self.channel)
361 363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
362 364
363 365 def __exit__(self, etype, value, traceback):
364 366 self.tee.flush()
365 367 setattr(sys, self.channel, self.orig_stream)
366 368 printed = self.buffer.getvalue()
367 369 for s in self.s:
368 370 assert s in printed, notprinted_msg.format(s, self.channel, printed)
369 371 return False
370 372
371 373 printed_msg = """Found {0!r} in printed output (on {1}):
372 374 -------
373 375 {2!s}
374 376 -------
375 377 """
376 378
377 379 class AssertNotPrints(AssertPrints):
378 380 """Context manager for checking that certain output *isn't* produced.
379 381
380 382 Counterpart of AssertPrints"""
381 383 def __exit__(self, etype, value, traceback):
382 384 self.tee.flush()
383 385 setattr(sys, self.channel, self.orig_stream)
384 386 printed = self.buffer.getvalue()
385 387 for s in self.s:
386 388 assert s not in printed, printed_msg.format(s, self.channel, printed)
387 389 return False
388 390
389 391 @contextmanager
390 392 def mute_warn():
391 393 from IPython.utils import warn
392 394 save_warn = warn.warn
393 395 warn.warn = lambda *a, **kw: None
394 396 try:
395 397 yield
396 398 finally:
397 399 warn.warn = save_warn
398 400
399 401 @contextmanager
400 402 def make_tempfile(name):
401 403 """ Create an empty, named, temporary file for the duration of the context.
402 404 """
403 405 f = open(name, 'w')
404 406 f.close()
405 407 try:
406 408 yield
407 409 finally:
408 410 os.unlink(name)
409 411
410 412
411 413 @contextmanager
412 414 def monkeypatch(obj, name, attr):
413 415 """
414 416 Context manager to replace attribute named `name` in `obj` with `attr`.
415 417 """
416 418 orig = getattr(obj, name)
417 419 setattr(obj, name, attr)
418 420 yield
419 421 setattr(obj, name, orig)
420 422
421 423
422 424 def help_output_test(subcommand=''):
423 425 """test that `ipython [subcommand] -h` works"""
424 426 cmd = ' '.join(get_ipython_cmd() + [subcommand, '-h'])
425 427 out, err, rc = get_output_error_code(cmd)
426 428 nt.assert_equal(rc, 0, err)
427 429 nt.assert_not_in("Traceback", err)
428 430 nt.assert_in("Options", out)
429 431 nt.assert_in("--help-all", out)
430 432 return out, err
431 433
432 434
433 435 def help_all_output_test(subcommand=''):
434 436 """test that `ipython [subcommand] --help-all` works"""
435 437 cmd = ' '.join(get_ipython_cmd() + [subcommand, '--help-all'])
436 438 out, err, rc = get_output_error_code(cmd)
437 439 nt.assert_equal(rc, 0, err)
438 440 nt.assert_not_in("Traceback", err)
439 441 nt.assert_in("Options", out)
440 442 nt.assert_in("Class parameters", out)
441 443 return out, err
442 444
General Comments 0
You need to be logged in to leave comments. Login now