##// END OF EJS Templates
remove mktemp usage...
Julian Taylor -
Show More
@@ -1,350 +1,353 b''
1 1 # encoding: utf-8
2 2 """
3 3 Paging capabilities for IPython.core
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Fernando Perez
9 9
10 10 Notes
11 11 -----
12 12
13 13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
14 14 rid of that dependency, we could move it there.
15 15 -----
16 16 """
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Copyright (C) 2008-2011 The IPython Development Team
20 20 #
21 21 # Distributed under the terms of the BSD License. The full license is in
22 22 # the file COPYING, distributed as part of this software.
23 23 #-----------------------------------------------------------------------------
24 24
25 25 #-----------------------------------------------------------------------------
26 26 # Imports
27 27 #-----------------------------------------------------------------------------
28 28 from __future__ import print_function
29 29
30 30 import os
31 31 import re
32 32 import sys
33 33 import tempfile
34 34
35 35 from io import UnsupportedOperation
36 36
37 37 from IPython import get_ipython
38 38 from IPython.core.error import TryNext
39 39 from IPython.utils.data import chop
40 40 from IPython.utils import io
41 41 from IPython.utils.process import system
42 42 from IPython.utils.terminal import get_terminal_size
43 43 from IPython.utils import py3compat
44 44
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Classes and functions
48 48 #-----------------------------------------------------------------------------
49 49
50 50 esc_re = re.compile(r"(\x1b[^m]+m)")
51 51
52 52 def page_dumb(strng, start=0, screen_lines=25):
53 53 """Very dumb 'pager' in Python, for when nothing else works.
54 54
55 55 Only moves forward, same interface as page(), except for pager_cmd and
56 56 mode."""
57 57
58 58 out_ln = strng.splitlines()[start:]
59 59 screens = chop(out_ln,screen_lines-1)
60 60 if len(screens) == 1:
61 61 print(os.linesep.join(screens[0]), file=io.stdout)
62 62 else:
63 63 last_escape = ""
64 64 for scr in screens[0:-1]:
65 65 hunk = os.linesep.join(scr)
66 66 print(last_escape + hunk, file=io.stdout)
67 67 if not page_more():
68 68 return
69 69 esc_list = esc_re.findall(hunk)
70 70 if len(esc_list) > 0:
71 71 last_escape = esc_list[-1]
72 72 print(last_escape + os.linesep.join(screens[-1]), file=io.stdout)
73 73
74 74 def _detect_screen_size(screen_lines_def):
75 75 """Attempt to work out the number of lines on the screen.
76 76
77 77 This is called by page(). It can raise an error (e.g. when run in the
78 78 test suite), so it's separated out so it can easily be called in a try block.
79 79 """
80 80 TERM = os.environ.get('TERM',None)
81 81 if not((TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5'):
82 82 # curses causes problems on many terminals other than xterm, and
83 83 # some termios calls lock up on Sun OS5.
84 84 return screen_lines_def
85 85
86 86 try:
87 87 import termios
88 88 import curses
89 89 except ImportError:
90 90 return screen_lines_def
91 91
92 92 # There is a bug in curses, where *sometimes* it fails to properly
93 93 # initialize, and then after the endwin() call is made, the
94 94 # terminal is left in an unusable state. Rather than trying to
95 95 # check everytime for this (by requesting and comparing termios
96 96 # flags each time), we just save the initial terminal state and
97 97 # unconditionally reset it every time. It's cheaper than making
98 98 # the checks.
99 99 term_flags = termios.tcgetattr(sys.stdout)
100 100
101 101 # Curses modifies the stdout buffer size by default, which messes
102 102 # up Python's normal stdout buffering. This would manifest itself
103 103 # to IPython users as delayed printing on stdout after having used
104 104 # the pager.
105 105 #
106 106 # We can prevent this by manually setting the NCURSES_NO_SETBUF
107 107 # environment variable. For more details, see:
108 108 # http://bugs.python.org/issue10144
109 109 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
110 110 os.environ['NCURSES_NO_SETBUF'] = ''
111 111
112 112 # Proceed with curses initialization
113 113 try:
114 114 scr = curses.initscr()
115 115 except AttributeError:
116 116 # Curses on Solaris may not be complete, so we can't use it there
117 117 return screen_lines_def
118 118
119 119 screen_lines_real,screen_cols = scr.getmaxyx()
120 120 curses.endwin()
121 121
122 122 # Restore environment
123 123 if NCURSES_NO_SETBUF is None:
124 124 del os.environ['NCURSES_NO_SETBUF']
125 125 else:
126 126 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
127 127
128 128 # Restore terminal state in case endwin() didn't.
129 129 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
130 130 # Now we have what we needed: the screen size in rows/columns
131 131 return screen_lines_real
132 132 #print '***Screen size:',screen_lines_real,'lines x',\
133 133 #screen_cols,'columns.' # dbg
134 134
135 135 def page(strng, start=0, screen_lines=0, pager_cmd=None):
136 136 """Print a string, piping through a pager after a certain length.
137 137
138 138 The screen_lines parameter specifies the number of *usable* lines of your
139 139 terminal screen (total lines minus lines you need to reserve to show other
140 140 information).
141 141
142 142 If you set screen_lines to a number <=0, page() will try to auto-determine
143 143 your screen size and will only use up to (screen_size+screen_lines) for
144 144 printing, paging after that. That is, if you want auto-detection but need
145 145 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
146 146 auto-detection without any lines reserved simply use screen_lines = 0.
147 147
148 148 If a string won't fit in the allowed lines, it is sent through the
149 149 specified pager command. If none given, look for PAGER in the environment,
150 150 and ultimately default to less.
151 151
152 152 If no system pager works, the string is sent through a 'dumb pager'
153 153 written in python, very simplistic.
154 154 """
155 155
156 156 # Some routines may auto-compute start offsets incorrectly and pass a
157 157 # negative value. Offset to 0 for robustness.
158 158 start = max(0, start)
159 159
160 160 # first, try the hook
161 161 ip = get_ipython()
162 162 if ip:
163 163 try:
164 164 ip.hooks.show_in_pager(strng)
165 165 return
166 166 except TryNext:
167 167 pass
168 168
169 169 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
170 170 TERM = os.environ.get('TERM','dumb')
171 171 if TERM in ['dumb','emacs'] and os.name != 'nt':
172 172 print(strng)
173 173 return
174 174 # chop off the topmost part of the string we don't want to see
175 175 str_lines = strng.splitlines()[start:]
176 176 str_toprint = os.linesep.join(str_lines)
177 177 num_newlines = len(str_lines)
178 178 len_str = len(str_toprint)
179 179
180 180 # Dumb heuristics to guesstimate number of on-screen lines the string
181 181 # takes. Very basic, but good enough for docstrings in reasonable
182 182 # terminals. If someone later feels like refining it, it's not hard.
183 183 numlines = max(num_newlines,int(len_str/80)+1)
184 184
185 185 screen_lines_def = get_terminal_size()[1]
186 186
187 187 # auto-determine screen size
188 188 if screen_lines <= 0:
189 189 try:
190 190 screen_lines += _detect_screen_size(screen_lines_def)
191 191 except (TypeError, UnsupportedOperation):
192 192 print(str_toprint, file=io.stdout)
193 193 return
194 194
195 195 #print 'numlines',numlines,'screenlines',screen_lines # dbg
196 196 if numlines <= screen_lines :
197 197 #print '*** normal print' # dbg
198 198 print(str_toprint, file=io.stdout)
199 199 else:
200 200 # Try to open pager and default to internal one if that fails.
201 201 # All failure modes are tagged as 'retval=1', to match the return
202 202 # value of a failed system command. If any intermediate attempt
203 203 # sets retval to 1, at the end we resort to our own page_dumb() pager.
204 204 pager_cmd = get_pager_cmd(pager_cmd)
205 205 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
206 206 if os.name == 'nt':
207 207 if pager_cmd.startswith('type'):
208 208 # The default WinXP 'type' command is failing on complex strings.
209 209 retval = 1
210 210 else:
211 tmpname = tempfile.mktemp('.txt')
212 tmpfile = open(tmpname,'wt')
211 fd, tmpname = tempfile.mkstemp('.txt')
212 try:
213 os.close(fd)
214 with open(tmpname, 'wt') as tmpfile:
213 215 tmpfile.write(strng)
214 tmpfile.close()
215 216 cmd = "%s < %s" % (pager_cmd,tmpname)
217 # tmpfile needs to be closed for windows
216 218 if os.system(cmd):
217 219 retval = 1
218 220 else:
219 221 retval = None
222 finally:
220 223 os.remove(tmpname)
221 224 else:
222 225 try:
223 226 retval = None
224 227 # if I use popen4, things hang. No idea why.
225 228 #pager,shell_out = os.popen4(pager_cmd)
226 229 pager = os.popen(pager_cmd, 'w')
227 230 try:
228 231 pager_encoding = pager.encoding or sys.stdout.encoding
229 232 pager.write(py3compat.cast_bytes_py2(
230 233 strng, encoding=pager_encoding))
231 234 finally:
232 235 retval = pager.close()
233 236 except IOError as msg: # broken pipe when user quits
234 237 if msg.args == (32, 'Broken pipe'):
235 238 retval = None
236 239 else:
237 240 retval = 1
238 241 except OSError:
239 242 # Other strange problems, sometimes seen in Win2k/cygwin
240 243 retval = 1
241 244 if retval is not None:
242 245 page_dumb(strng,screen_lines=screen_lines)
243 246
244 247
245 248 def page_file(fname, start=0, pager_cmd=None):
246 249 """Page a file, using an optional pager command and starting line.
247 250 """
248 251
249 252 pager_cmd = get_pager_cmd(pager_cmd)
250 253 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
251 254
252 255 try:
253 256 if os.environ['TERM'] in ['emacs','dumb']:
254 257 raise EnvironmentError
255 258 system(pager_cmd + ' ' + fname)
256 259 except:
257 260 try:
258 261 if start > 0:
259 262 start -= 1
260 263 page(open(fname).read(),start)
261 264 except:
262 265 print('Unable to show file',repr(fname))
263 266
264 267
265 268 def get_pager_cmd(pager_cmd=None):
266 269 """Return a pager command.
267 270
268 271 Makes some attempts at finding an OS-correct one.
269 272 """
270 273 if os.name == 'posix':
271 274 default_pager_cmd = 'less -r' # -r for color control sequences
272 275 elif os.name in ['nt','dos']:
273 276 default_pager_cmd = 'type'
274 277
275 278 if pager_cmd is None:
276 279 try:
277 280 pager_cmd = os.environ['PAGER']
278 281 except:
279 282 pager_cmd = default_pager_cmd
280 283 return pager_cmd
281 284
282 285
283 286 def get_pager_start(pager, start):
284 287 """Return the string for paging files with an offset.
285 288
286 289 This is the '+N' argument which less and more (under Unix) accept.
287 290 """
288 291
289 292 if pager in ['less','more']:
290 293 if start:
291 294 start_string = '+' + str(start)
292 295 else:
293 296 start_string = ''
294 297 else:
295 298 start_string = ''
296 299 return start_string
297 300
298 301
299 302 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
300 303 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
301 304 import msvcrt
302 305 def page_more():
303 306 """ Smart pausing between pages
304 307
305 308 @return: True if need print more lines, False if quit
306 309 """
307 310 io.stdout.write('---Return to continue, q to quit--- ')
308 311 ans = msvcrt.getwch()
309 312 if ans in ("q", "Q"):
310 313 result = False
311 314 else:
312 315 result = True
313 316 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
314 317 return result
315 318 else:
316 319 def page_more():
317 320 ans = py3compat.input('---Return to continue, q to quit--- ')
318 321 if ans.lower().startswith('q'):
319 322 return False
320 323 else:
321 324 return True
322 325
323 326
324 327 def snip_print(str,width = 75,print_full = 0,header = ''):
325 328 """Print a string snipping the midsection to fit in width.
326 329
327 330 print_full: mode control:
328 331
329 332 - 0: only snip long strings
330 333 - 1: send to page() directly.
331 334 - 2: snip long strings and ask for full length viewing with page()
332 335
333 336 Return 1 if snipping was necessary, 0 otherwise."""
334 337
335 338 if print_full == 1:
336 339 page(header+str)
337 340 return 0
338 341
339 342 print(header, end=' ')
340 343 if len(str) < width:
341 344 print(str)
342 345 snip = 0
343 346 else:
344 347 whalf = int((width -5)/2)
345 348 print(str[:whalf] + ' <...> ' + str[-whalf:])
346 349 snip = 1
347 350 if snip and print_full == 2:
348 351 if py3compat.input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
349 352 page(str)
350 353 return snip
@@ -1,559 +1,560 b''
1 1 """Utilities for connecting to kernels
2 2
3 3 Authors:
4 4
5 5 * Min Ragan-Kelley
6 6
7 7 """
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2013 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19
20 20 from __future__ import absolute_import
21 21
22 22 import glob
23 23 import json
24 24 import os
25 25 import socket
26 26 import sys
27 27 from getpass import getpass
28 28 from subprocess import Popen, PIPE
29 29 import tempfile
30 30
31 31 import zmq
32 32
33 33 # external imports
34 34 from IPython.external.ssh import tunnel
35 35
36 36 # IPython imports
37 37 from IPython.config import Configurable
38 38 from IPython.core.profiledir import ProfileDir
39 39 from IPython.utils.localinterfaces import localhost
40 40 from IPython.utils.path import filefind, get_ipython_dir
41 41 from IPython.utils.py3compat import (str_to_bytes, bytes_to_str, cast_bytes_py2,
42 42 string_types)
43 43 from IPython.utils.traitlets import (
44 44 Bool, Integer, Unicode, CaselessStrEnum,
45 45 )
46 46
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Working with Connection Files
50 50 #-----------------------------------------------------------------------------
51 51
52 52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
53 53 control_port=0, ip='', key=b'', transport='tcp',
54 54 signature_scheme='hmac-sha256',
55 55 ):
56 56 """Generates a JSON config file, including the selection of random ports.
57 57
58 58 Parameters
59 59 ----------
60 60
61 61 fname : unicode
62 62 The path to the file to write
63 63
64 64 shell_port : int, optional
65 65 The port to use for ROUTER (shell) channel.
66 66
67 67 iopub_port : int, optional
68 68 The port to use for the SUB channel.
69 69
70 70 stdin_port : int, optional
71 71 The port to use for the ROUTER (raw input) channel.
72 72
73 73 control_port : int, optional
74 74 The port to use for the ROUTER (control) channel.
75 75
76 76 hb_port : int, optional
77 77 The port to use for the heartbeat REP channel.
78 78
79 79 ip : str, optional
80 80 The ip address the kernel will bind to.
81 81
82 82 key : str, optional
83 83 The Session key used for message authentication.
84 84
85 85 signature_scheme : str, optional
86 86 The scheme used for message authentication.
87 87 This has the form 'digest-hash', where 'digest'
88 88 is the scheme used for digests, and 'hash' is the name of the hash function
89 89 used by the digest scheme.
90 90 Currently, 'hmac' is the only supported digest scheme,
91 91 and 'sha256' is the default hash function.
92 92
93 93 """
94 94 if not ip:
95 95 ip = localhost()
96 96 # default to temporary connector file
97 97 if not fname:
98 fname = tempfile.mktemp('.json')
98 fd, fname = tempfile.mkstemp('.json')
99 os.close(fd)
99 100
100 101 # Find open ports as necessary.
101 102
102 103 ports = []
103 104 ports_needed = int(shell_port <= 0) + \
104 105 int(iopub_port <= 0) + \
105 106 int(stdin_port <= 0) + \
106 107 int(control_port <= 0) + \
107 108 int(hb_port <= 0)
108 109 if transport == 'tcp':
109 110 for i in range(ports_needed):
110 111 sock = socket.socket()
111 112 sock.bind(('', 0))
112 113 ports.append(sock)
113 114 for i, sock in enumerate(ports):
114 115 port = sock.getsockname()[1]
115 116 sock.close()
116 117 ports[i] = port
117 118 else:
118 119 N = 1
119 120 for i in range(ports_needed):
120 121 while os.path.exists("%s-%s" % (ip, str(N))):
121 122 N += 1
122 123 ports.append(N)
123 124 N += 1
124 125 if shell_port <= 0:
125 126 shell_port = ports.pop(0)
126 127 if iopub_port <= 0:
127 128 iopub_port = ports.pop(0)
128 129 if stdin_port <= 0:
129 130 stdin_port = ports.pop(0)
130 131 if control_port <= 0:
131 132 control_port = ports.pop(0)
132 133 if hb_port <= 0:
133 134 hb_port = ports.pop(0)
134 135
135 136 cfg = dict( shell_port=shell_port,
136 137 iopub_port=iopub_port,
137 138 stdin_port=stdin_port,
138 139 control_port=control_port,
139 140 hb_port=hb_port,
140 141 )
141 142 cfg['ip'] = ip
142 143 cfg['key'] = bytes_to_str(key)
143 144 cfg['transport'] = transport
144 145 cfg['signature_scheme'] = signature_scheme
145 146
146 147 with open(fname, 'w') as f:
147 148 f.write(json.dumps(cfg, indent=2))
148 149
149 150 return fname, cfg
150 151
151 152
152 153 def get_connection_file(app=None):
153 154 """Return the path to the connection file of an app
154 155
155 156 Parameters
156 157 ----------
157 158 app : IPKernelApp instance [optional]
158 159 If unspecified, the currently running app will be used
159 160 """
160 161 if app is None:
161 162 from IPython.kernel.zmq.kernelapp import IPKernelApp
162 163 if not IPKernelApp.initialized():
163 164 raise RuntimeError("app not specified, and not in a running Kernel")
164 165
165 166 app = IPKernelApp.instance()
166 167 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
167 168
168 169
169 170 def find_connection_file(filename, profile=None):
170 171 """find a connection file, and return its absolute path.
171 172
172 173 The current working directory and the profile's security
173 174 directory will be searched for the file if it is not given by
174 175 absolute path.
175 176
176 177 If profile is unspecified, then the current running application's
177 178 profile will be used, or 'default', if not run from IPython.
178 179
179 180 If the argument does not match an existing file, it will be interpreted as a
180 181 fileglob, and the matching file in the profile's security dir with
181 182 the latest access time will be used.
182 183
183 184 Parameters
184 185 ----------
185 186 filename : str
186 187 The connection file or fileglob to search for.
187 188 profile : str [optional]
188 189 The name of the profile to use when searching for the connection file,
189 190 if different from the current IPython session or 'default'.
190 191
191 192 Returns
192 193 -------
193 194 str : The absolute path of the connection file.
194 195 """
195 196 from IPython.core.application import BaseIPythonApplication as IPApp
196 197 try:
197 198 # quick check for absolute path, before going through logic
198 199 return filefind(filename)
199 200 except IOError:
200 201 pass
201 202
202 203 if profile is None:
203 204 # profile unspecified, check if running from an IPython app
204 205 if IPApp.initialized():
205 206 app = IPApp.instance()
206 207 profile_dir = app.profile_dir
207 208 else:
208 209 # not running in IPython, use default profile
209 210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
210 211 else:
211 212 # find profiledir by profile name:
212 213 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
213 214 security_dir = profile_dir.security_dir
214 215
215 216 try:
216 217 # first, try explicit name
217 218 return filefind(filename, ['.', security_dir])
218 219 except IOError:
219 220 pass
220 221
221 222 # not found by full name
222 223
223 224 if '*' in filename:
224 225 # given as a glob already
225 226 pat = filename
226 227 else:
227 228 # accept any substring match
228 229 pat = '*%s*' % filename
229 230 matches = glob.glob( os.path.join(security_dir, pat) )
230 231 if not matches:
231 232 raise IOError("Could not find %r in %r" % (filename, security_dir))
232 233 elif len(matches) == 1:
233 234 return matches[0]
234 235 else:
235 236 # get most recent match, by access time:
236 237 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
237 238
238 239
239 240 def get_connection_info(connection_file=None, unpack=False, profile=None):
240 241 """Return the connection information for the current Kernel.
241 242
242 243 Parameters
243 244 ----------
244 245 connection_file : str [optional]
245 246 The connection file to be used. Can be given by absolute path, or
246 247 IPython will search in the security directory of a given profile.
247 248 If run from IPython,
248 249
249 250 If unspecified, the connection file for the currently running
250 251 IPython Kernel will be used, which is only allowed from inside a kernel.
251 252 unpack : bool [default: False]
252 253 if True, return the unpacked dict, otherwise just the string contents
253 254 of the file.
254 255 profile : str [optional]
255 256 The name of the profile to use when searching for the connection file,
256 257 if different from the current IPython session or 'default'.
257 258
258 259
259 260 Returns
260 261 -------
261 262 The connection dictionary of the current kernel, as string or dict,
262 263 depending on `unpack`.
263 264 """
264 265 if connection_file is None:
265 266 # get connection file from current kernel
266 267 cf = get_connection_file()
267 268 else:
268 269 # connection file specified, allow shortnames:
269 270 cf = find_connection_file(connection_file, profile=profile)
270 271
271 272 with open(cf) as f:
272 273 info = f.read()
273 274
274 275 if unpack:
275 276 info = json.loads(info)
276 277 # ensure key is bytes:
277 278 info['key'] = str_to_bytes(info.get('key', ''))
278 279 return info
279 280
280 281
281 282 def connect_qtconsole(connection_file=None, argv=None, profile=None):
282 283 """Connect a qtconsole to the current kernel.
283 284
284 285 This is useful for connecting a second qtconsole to a kernel, or to a
285 286 local notebook.
286 287
287 288 Parameters
288 289 ----------
289 290 connection_file : str [optional]
290 291 The connection file to be used. Can be given by absolute path, or
291 292 IPython will search in the security directory of a given profile.
292 293 If run from IPython,
293 294
294 295 If unspecified, the connection file for the currently running
295 296 IPython Kernel will be used, which is only allowed from inside a kernel.
296 297 argv : list [optional]
297 298 Any extra args to be passed to the console.
298 299 profile : str [optional]
299 300 The name of the profile to use when searching for the connection file,
300 301 if different from the current IPython session or 'default'.
301 302
302 303
303 304 Returns
304 305 -------
305 306 subprocess.Popen instance running the qtconsole frontend
306 307 """
307 308 argv = [] if argv is None else argv
308 309
309 310 if connection_file is None:
310 311 # get connection file from current kernel
311 312 cf = get_connection_file()
312 313 else:
313 314 cf = find_connection_file(connection_file, profile=profile)
314 315
315 316 cmd = ';'.join([
316 317 "from IPython.qt.console import qtconsoleapp",
317 318 "qtconsoleapp.main()"
318 319 ])
319 320
320 321 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
321 322 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
322 323 )
323 324
324 325
325 326 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
326 327 """tunnel connections to a kernel via ssh
327 328
328 329 This will open four SSH tunnels from localhost on this machine to the
329 330 ports associated with the kernel. They can be either direct
330 331 localhost-localhost tunnels, or if an intermediate server is necessary,
331 332 the kernel must be listening on a public IP.
332 333
333 334 Parameters
334 335 ----------
335 336 connection_info : dict or str (path)
336 337 Either a connection dict, or the path to a JSON connection file
337 338 sshserver : str
338 339 The ssh sever to use to tunnel to the kernel. Can be a full
339 340 `user@server:port` string. ssh config aliases are respected.
340 341 sshkey : str [optional]
341 342 Path to file containing ssh key to use for authentication.
342 343 Only necessary if your ssh config does not already associate
343 344 a keyfile with the host.
344 345
345 346 Returns
346 347 -------
347 348
348 349 (shell, iopub, stdin, hb) : ints
349 350 The four ports on localhost that have been forwarded to the kernel.
350 351 """
351 352 if isinstance(connection_info, string_types):
352 353 # it's a path, unpack it
353 354 with open(connection_info) as f:
354 355 connection_info = json.loads(f.read())
355 356
356 357 cf = connection_info
357 358
358 359 lports = tunnel.select_random_ports(4)
359 360 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
360 361
361 362 remote_ip = cf['ip']
362 363
363 364 if tunnel.try_passwordless_ssh(sshserver, sshkey):
364 365 password=False
365 366 else:
366 367 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
367 368
368 369 for lp,rp in zip(lports, rports):
369 370 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
370 371
371 372 return tuple(lports)
372 373
373 374
374 375 #-----------------------------------------------------------------------------
375 376 # Mixin for classes that work with connection files
376 377 #-----------------------------------------------------------------------------
377 378
378 379 channel_socket_types = {
379 380 'hb' : zmq.REQ,
380 381 'shell' : zmq.DEALER,
381 382 'iopub' : zmq.SUB,
382 383 'stdin' : zmq.DEALER,
383 384 'control': zmq.DEALER,
384 385 }
385 386
386 387 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
387 388
388 389 class ConnectionFileMixin(Configurable):
389 390 """Mixin for configurable classes that work with connection files"""
390 391
391 392 # The addresses for the communication channels
392 393 connection_file = Unicode('')
393 394 _connection_file_written = Bool(False)
394 395
395 396 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
396 397
397 398 ip = Unicode(config=True,
398 399 help="""Set the kernel\'s IP address [default localhost].
399 400 If the IP address is something other than localhost, then
400 401 Consoles on other machines will be able to connect
401 402 to the Kernel, so be careful!"""
402 403 )
403 404
404 405 def _ip_default(self):
405 406 if self.transport == 'ipc':
406 407 if self.connection_file:
407 408 return os.path.splitext(self.connection_file)[0] + '-ipc'
408 409 else:
409 410 return 'kernel-ipc'
410 411 else:
411 412 return localhost()
412 413
413 414 def _ip_changed(self, name, old, new):
414 415 if new == '*':
415 416 self.ip = '0.0.0.0'
416 417
417 418 # protected traits
418 419
419 420 shell_port = Integer(0)
420 421 iopub_port = Integer(0)
421 422 stdin_port = Integer(0)
422 423 control_port = Integer(0)
423 424 hb_port = Integer(0)
424 425
425 426 @property
426 427 def ports(self):
427 428 return [ getattr(self, name) for name in port_names ]
428 429
429 430 #--------------------------------------------------------------------------
430 431 # Connection and ipc file management
431 432 #--------------------------------------------------------------------------
432 433
433 434 def get_connection_info(self):
434 435 """return the connection info as a dict"""
435 436 return dict(
436 437 transport=self.transport,
437 438 ip=self.ip,
438 439 shell_port=self.shell_port,
439 440 iopub_port=self.iopub_port,
440 441 stdin_port=self.stdin_port,
441 442 hb_port=self.hb_port,
442 443 control_port=self.control_port,
443 444 signature_scheme=self.session.signature_scheme,
444 445 key=self.session.key,
445 446 )
446 447
447 448 def cleanup_connection_file(self):
448 449 """Cleanup connection file *if we wrote it*
449 450
450 451 Will not raise if the connection file was already removed somehow.
451 452 """
452 453 if self._connection_file_written:
453 454 # cleanup connection files on full shutdown of kernel we started
454 455 self._connection_file_written = False
455 456 try:
456 457 os.remove(self.connection_file)
457 458 except (IOError, OSError, AttributeError):
458 459 pass
459 460
460 461 def cleanup_ipc_files(self):
461 462 """Cleanup ipc files if we wrote them."""
462 463 if self.transport != 'ipc':
463 464 return
464 465 for port in self.ports:
465 466 ipcfile = "%s-%i" % (self.ip, port)
466 467 try:
467 468 os.remove(ipcfile)
468 469 except (IOError, OSError):
469 470 pass
470 471
471 472 def write_connection_file(self):
472 473 """Write connection info to JSON dict in self.connection_file."""
473 474 if self._connection_file_written:
474 475 return
475 476
476 477 self.connection_file, cfg = write_connection_file(self.connection_file,
477 478 transport=self.transport, ip=self.ip, key=self.session.key,
478 479 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
479 480 shell_port=self.shell_port, hb_port=self.hb_port,
480 481 control_port=self.control_port,
481 482 signature_scheme=self.session.signature_scheme,
482 483 )
483 484 # write_connection_file also sets default ports:
484 485 for name in port_names:
485 486 setattr(self, name, cfg[name])
486 487
487 488 self._connection_file_written = True
488 489
489 490 def load_connection_file(self):
490 491 """Load connection info from JSON dict in self.connection_file."""
491 492 with open(self.connection_file) as f:
492 493 cfg = json.loads(f.read())
493 494
494 495 self.transport = cfg.get('transport', 'tcp')
495 496 self.ip = cfg['ip']
496 497 for name in port_names:
497 498 setattr(self, name, cfg[name])
498 499 if 'key' in cfg:
499 500 self.session.key = str_to_bytes(cfg['key'])
500 501 if cfg.get('signature_scheme'):
501 502 self.session.signature_scheme = cfg['signature_scheme']
502 503
503 504 #--------------------------------------------------------------------------
504 505 # Creating connected sockets
505 506 #--------------------------------------------------------------------------
506 507
507 508 def _make_url(self, channel):
508 509 """Make a ZeroMQ URL for a given channel."""
509 510 transport = self.transport
510 511 ip = self.ip
511 512 port = getattr(self, '%s_port' % channel)
512 513
513 514 if transport == 'tcp':
514 515 return "tcp://%s:%i" % (ip, port)
515 516 else:
516 517 return "%s://%s-%s" % (transport, ip, port)
517 518
518 519 def _create_connected_socket(self, channel, identity=None):
519 520 """Create a zmq Socket and connect it to the kernel."""
520 521 url = self._make_url(channel)
521 522 socket_type = channel_socket_types[channel]
522 523 self.log.debug("Connecting to: %s" % url)
523 524 sock = self.context.socket(socket_type)
524 525 if identity:
525 526 sock.identity = identity
526 527 sock.connect(url)
527 528 return sock
528 529
529 530 def connect_iopub(self, identity=None):
530 531 """return zmq Socket connected to the IOPub channel"""
531 532 sock = self._create_connected_socket('iopub', identity=identity)
532 533 sock.setsockopt(zmq.SUBSCRIBE, b'')
533 534 return sock
534 535
535 536 def connect_shell(self, identity=None):
536 537 """return zmq Socket connected to the Shell channel"""
537 538 return self._create_connected_socket('shell', identity=identity)
538 539
539 540 def connect_stdin(self, identity=None):
540 541 """return zmq Socket connected to the StdIn channel"""
541 542 return self._create_connected_socket('stdin', identity=identity)
542 543
543 544 def connect_hb(self, identity=None):
544 545 """return zmq Socket connected to the Heartbeat channel"""
545 546 return self._create_connected_socket('hb', identity=identity)
546 547
547 548 def connect_control(self, identity=None):
548 549 """return zmq Socket connected to the Heartbeat channel"""
549 550 return self._create_connected_socket('control', identity=identity)
550 551
551 552
552 553 __all__ = [
553 554 'write_connection_file',
554 555 'get_connection_file',
555 556 'find_connection_file',
556 557 'get_connection_info',
557 558 'connect_qtconsole',
558 559 'tunnel_to_kernel',
559 560 ]
@@ -1,547 +1,546 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 from tempfile import mktemp
24 23
25 24 import zmq
26 25
27 26 from IPython import parallel
28 27 from IPython.parallel.client import client as clientmod
29 28 from IPython.parallel import error
30 29 from IPython.parallel import AsyncResult, AsyncHubResult
31 30 from IPython.parallel import LoadBalancedView, DirectView
32 31
33 32 from .clienttest import ClusterTestCase, segfault, wait, add_engines
34 33
35 34 def setup():
36 35 add_engines(4, total=True)
37 36
38 37 class TestClient(ClusterTestCase):
39 38
40 39 def test_ids(self):
41 40 n = len(self.client.ids)
42 41 self.add_engines(2)
43 42 self.assertEqual(len(self.client.ids), n+2)
44 43
45 44 def test_view_indexing(self):
46 45 """test index access for views"""
47 46 self.minimum_engines(4)
48 47 targets = self.client._build_targets('all')[-1]
49 48 v = self.client[:]
50 49 self.assertEqual(v.targets, targets)
51 50 t = self.client.ids[2]
52 51 v = self.client[t]
53 52 self.assertTrue(isinstance(v, DirectView))
54 53 self.assertEqual(v.targets, t)
55 54 t = self.client.ids[2:4]
56 55 v = self.client[t]
57 56 self.assertTrue(isinstance(v, DirectView))
58 57 self.assertEqual(v.targets, t)
59 58 v = self.client[::2]
60 59 self.assertTrue(isinstance(v, DirectView))
61 60 self.assertEqual(v.targets, targets[::2])
62 61 v = self.client[1::3]
63 62 self.assertTrue(isinstance(v, DirectView))
64 63 self.assertEqual(v.targets, targets[1::3])
65 64 v = self.client[:-3]
66 65 self.assertTrue(isinstance(v, DirectView))
67 66 self.assertEqual(v.targets, targets[:-3])
68 67 v = self.client[-1]
69 68 self.assertTrue(isinstance(v, DirectView))
70 69 self.assertEqual(v.targets, targets[-1])
71 70 self.assertRaises(TypeError, lambda : self.client[None])
72 71
73 72 def test_lbview_targets(self):
74 73 """test load_balanced_view targets"""
75 74 v = self.client.load_balanced_view()
76 75 self.assertEqual(v.targets, None)
77 76 v = self.client.load_balanced_view(-1)
78 77 self.assertEqual(v.targets, [self.client.ids[-1]])
79 78 v = self.client.load_balanced_view('all')
80 79 self.assertEqual(v.targets, None)
81 80
82 81 def test_dview_targets(self):
83 82 """test direct_view targets"""
84 83 v = self.client.direct_view()
85 84 self.assertEqual(v.targets, 'all')
86 85 v = self.client.direct_view('all')
87 86 self.assertEqual(v.targets, 'all')
88 87 v = self.client.direct_view(-1)
89 88 self.assertEqual(v.targets, self.client.ids[-1])
90 89
91 90 def test_lazy_all_targets(self):
92 91 """test lazy evaluation of rc.direct_view('all')"""
93 92 v = self.client.direct_view()
94 93 self.assertEqual(v.targets, 'all')
95 94
96 95 def double(x):
97 96 return x*2
98 97 seq = list(range(100))
99 98 ref = [ double(x) for x in seq ]
100 99
101 100 # add some engines, which should be used
102 101 self.add_engines(1)
103 102 n1 = len(self.client.ids)
104 103
105 104 # simple apply
106 105 r = v.apply_sync(lambda : 1)
107 106 self.assertEqual(r, [1] * n1)
108 107
109 108 # map goes through remotefunction
110 109 r = v.map_sync(double, seq)
111 110 self.assertEqual(r, ref)
112 111
113 112 # add a couple more engines, and try again
114 113 self.add_engines(2)
115 114 n2 = len(self.client.ids)
116 115 self.assertNotEqual(n2, n1)
117 116
118 117 # apply
119 118 r = v.apply_sync(lambda : 1)
120 119 self.assertEqual(r, [1] * n2)
121 120
122 121 # map
123 122 r = v.map_sync(double, seq)
124 123 self.assertEqual(r, ref)
125 124
126 125 def test_targets(self):
127 126 """test various valid targets arguments"""
128 127 build = self.client._build_targets
129 128 ids = self.client.ids
130 129 idents,targets = build(None)
131 130 self.assertEqual(ids, targets)
132 131
133 132 def test_clear(self):
134 133 """test clear behavior"""
135 134 self.minimum_engines(2)
136 135 v = self.client[:]
137 136 v.block=True
138 137 v.push(dict(a=5))
139 138 v.pull('a')
140 139 id0 = self.client.ids[-1]
141 140 self.client.clear(targets=id0, block=True)
142 141 a = self.client[:-1].get('a')
143 142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 143 self.client.clear(block=True)
145 144 for i in self.client.ids:
146 145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 146
148 147 def test_get_result(self):
149 148 """test getting results from the Hub."""
150 149 c = clientmod.Client(profile='iptest')
151 150 t = c.ids[-1]
152 151 ar = c[t].apply_async(wait, 1)
153 152 # give the monitor time to notice the message
154 153 time.sleep(.25)
155 154 ahr = self.client.get_result(ar.msg_ids[0])
156 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 156 self.assertEqual(ahr.get(), ar.get())
158 157 ar2 = self.client.get_result(ar.msg_ids[0])
159 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 159 c.close()
161 160
162 161 def test_get_execute_result(self):
163 162 """test getting execute results from the Hub."""
164 163 c = clientmod.Client(profile='iptest')
165 164 t = c.ids[-1]
166 165 cell = '\n'.join([
167 166 'import time',
168 167 'time.sleep(0.25)',
169 168 '5'
170 169 ])
171 170 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 171 # give the monitor time to notice the message
173 172 time.sleep(.25)
174 173 ahr = self.client.get_result(ar.msg_ids[0])
175 174 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 175 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 176 ar2 = self.client.get_result(ar.msg_ids[0])
178 177 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 178 c.close()
180 179
181 180 def test_ids_list(self):
182 181 """test client.ids"""
183 182 ids = self.client.ids
184 183 self.assertEqual(ids, self.client._ids)
185 184 self.assertFalse(ids is self.client._ids)
186 185 ids.remove(ids[-1])
187 186 self.assertNotEqual(ids, self.client._ids)
188 187
189 188 def test_queue_status(self):
190 189 ids = self.client.ids
191 190 id0 = ids[0]
192 191 qs = self.client.queue_status(targets=id0)
193 192 self.assertTrue(isinstance(qs, dict))
194 193 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 194 allqs = self.client.queue_status()
196 195 self.assertTrue(isinstance(allqs, dict))
197 196 intkeys = list(allqs.keys())
198 197 intkeys.remove('unassigned')
199 198 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 199 unassigned = allqs.pop('unassigned')
201 200 for eid,qs in allqs.items():
202 201 self.assertTrue(isinstance(qs, dict))
203 202 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204 203
205 204 def test_shutdown(self):
206 205 ids = self.client.ids
207 206 id0 = ids[0]
208 207 self.client.shutdown(id0, block=True)
209 208 while id0 in self.client.ids:
210 209 time.sleep(0.1)
211 210 self.client.spin()
212 211
213 212 self.assertRaises(IndexError, lambda : self.client[id0])
214 213
215 214 def test_result_status(self):
216 215 pass
217 216 # to be written
218 217
219 218 def test_db_query_dt(self):
220 219 """test db query by date"""
221 220 hist = self.client.hub_history()
222 221 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 222 tic = middle['submitted']
224 223 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 224 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 225 self.assertEqual(len(before)+len(after),len(hist))
227 226 for b in before:
228 227 self.assertTrue(b['submitted'] < tic)
229 228 for a in after:
230 229 self.assertTrue(a['submitted'] >= tic)
231 230 same = self.client.db_query({'submitted' : tic})
232 231 for s in same:
233 232 self.assertTrue(s['submitted'] == tic)
234 233
235 234 def test_db_query_keys(self):
236 235 """test extracting subset of record keys"""
237 236 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 237 for rec in found:
239 238 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240 239
241 240 def test_db_query_default_keys(self):
242 241 """default db_query excludes buffers"""
243 242 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 243 for rec in found:
245 244 keys = set(rec.keys())
246 245 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 246 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248 247
249 248 def test_db_query_msg_id(self):
250 249 """ensure msg_id is always in db queries"""
251 250 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 251 for rec in found:
253 252 self.assertTrue('msg_id' in rec.keys())
254 253 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 254 for rec in found:
256 255 self.assertTrue('msg_id' in rec.keys())
257 256 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 257 for rec in found:
259 258 self.assertTrue('msg_id' in rec.keys())
260 259
261 260 def test_db_query_get_result(self):
262 261 """pop in db_query shouldn't pop from result itself"""
263 262 self.client[:].apply_sync(lambda : 1)
264 263 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 264 rc2 = clientmod.Client(profile='iptest')
266 265 # If this bug is not fixed, this call will hang:
267 266 ar = rc2.get_result(self.client.history[-1])
268 267 ar.wait(2)
269 268 self.assertTrue(ar.ready())
270 269 ar.get()
271 270 rc2.close()
272 271
273 272 def test_db_query_in(self):
274 273 """test db query with '$in','$nin' operators"""
275 274 hist = self.client.hub_history()
276 275 even = hist[::2]
277 276 odd = hist[1::2]
278 277 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 278 found = [ r['msg_id'] for r in recs ]
280 279 self.assertEqual(set(even), set(found))
281 280 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 281 found = [ r['msg_id'] for r in recs ]
283 282 self.assertEqual(set(odd), set(found))
284 283
285 284 def test_hub_history(self):
286 285 hist = self.client.hub_history()
287 286 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 287 recdict = {}
289 288 for rec in recs:
290 289 recdict[rec['msg_id']] = rec
291 290
292 291 latest = datetime(1984,1,1)
293 292 for msg_id in hist:
294 293 rec = recdict[msg_id]
295 294 newt = rec['submitted']
296 295 self.assertTrue(newt >= latest)
297 296 latest = newt
298 297 ar = self.client[-1].apply_async(lambda : 1)
299 298 ar.get()
300 299 time.sleep(0.25)
301 300 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302 301
303 302 def _wait_for_idle(self):
304 303 """wait for the cluster to become idle, according to the everyone."""
305 304 rc = self.client
306 305
307 306 # step 0. wait for local results
308 307 # this should be sufficient 99% of the time.
309 308 rc.wait(timeout=5)
310 309
311 310 # step 1. wait for all requests to be noticed
312 311 # timeout 5s, polling every 100ms
313 312 msg_ids = set(rc.history)
314 313 hub_hist = rc.hub_history()
315 314 for i in range(50):
316 315 if msg_ids.difference(hub_hist):
317 316 time.sleep(0.1)
318 317 hub_hist = rc.hub_history()
319 318 else:
320 319 break
321 320
322 321 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
323 322
324 323 # step 2. wait for all requests to be done
325 324 # timeout 5s, polling every 100ms
326 325 qs = rc.queue_status()
327 326 for i in range(50):
328 327 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
329 328 time.sleep(0.1)
330 329 qs = rc.queue_status()
331 330 else:
332 331 break
333 332
334 333 # ensure Hub up to date:
335 334 self.assertEqual(qs['unassigned'], 0)
336 335 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
337 336 self.assertEqual(qs[eid]['tasks'], 0)
338 337 self.assertEqual(qs[eid]['queue'], 0)
339 338
340 339
341 340 def test_resubmit(self):
342 341 def f():
343 342 import random
344 343 return random.random()
345 344 v = self.client.load_balanced_view()
346 345 ar = v.apply_async(f)
347 346 r1 = ar.get(1)
348 347 # give the Hub a chance to notice:
349 348 self._wait_for_idle()
350 349 ahr = self.client.resubmit(ar.msg_ids)
351 350 r2 = ahr.get(1)
352 351 self.assertFalse(r1 == r2)
353 352
354 353 def test_resubmit_chain(self):
355 354 """resubmit resubmitted tasks"""
356 355 v = self.client.load_balanced_view()
357 356 ar = v.apply_async(lambda x: x, 'x'*1024)
358 357 ar.get()
359 358 self._wait_for_idle()
360 359 ars = [ar]
361 360
362 361 for i in range(10):
363 362 ar = ars[-1]
364 363 ar2 = self.client.resubmit(ar.msg_ids)
365 364
366 365 [ ar.get() for ar in ars ]
367 366
368 367 def test_resubmit_header(self):
369 368 """resubmit shouldn't clobber the whole header"""
370 369 def f():
371 370 import random
372 371 return random.random()
373 372 v = self.client.load_balanced_view()
374 373 v.retries = 1
375 374 ar = v.apply_async(f)
376 375 r1 = ar.get(1)
377 376 # give the Hub a chance to notice:
378 377 self._wait_for_idle()
379 378 ahr = self.client.resubmit(ar.msg_ids)
380 379 ahr.get(1)
381 380 time.sleep(0.5)
382 381 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
383 382 h1,h2 = [ r['header'] for r in records ]
384 383 for key in set(h1.keys()).union(set(h2.keys())):
385 384 if key in ('msg_id', 'date'):
386 385 self.assertNotEqual(h1[key], h2[key])
387 386 else:
388 387 self.assertEqual(h1[key], h2[key])
389 388
390 389 def test_resubmit_aborted(self):
391 390 def f():
392 391 import random
393 392 return random.random()
394 393 v = self.client.load_balanced_view()
395 394 # restrict to one engine, so we can put a sleep
396 395 # ahead of the task, so it will get aborted
397 396 eid = self.client.ids[-1]
398 397 v.targets = [eid]
399 398 sleep = v.apply_async(time.sleep, 0.5)
400 399 ar = v.apply_async(f)
401 400 ar.abort()
402 401 self.assertRaises(error.TaskAborted, ar.get)
403 402 # Give the Hub a chance to get up to date:
404 403 self._wait_for_idle()
405 404 ahr = self.client.resubmit(ar.msg_ids)
406 405 r2 = ahr.get(1)
407 406
408 407 def test_resubmit_inflight(self):
409 408 """resubmit of inflight task"""
410 409 v = self.client.load_balanced_view()
411 410 ar = v.apply_async(time.sleep,1)
412 411 # give the message a chance to arrive
413 412 time.sleep(0.2)
414 413 ahr = self.client.resubmit(ar.msg_ids)
415 414 ar.get(2)
416 415 ahr.get(2)
417 416
418 417 def test_resubmit_badkey(self):
419 418 """ensure KeyError on resubmit of nonexistant task"""
420 419 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
421 420
422 421 def test_purge_hub_results(self):
423 422 # ensure there are some tasks
424 423 for i in range(5):
425 424 self.client[:].apply_sync(lambda : 1)
426 425 # Wait for the Hub to realise the result is done:
427 426 # This prevents a race condition, where we
428 427 # might purge a result the Hub still thinks is pending.
429 428 self._wait_for_idle()
430 429 rc2 = clientmod.Client(profile='iptest')
431 430 hist = self.client.hub_history()
432 431 ahr = rc2.get_result([hist[-1]])
433 432 ahr.wait(10)
434 433 self.client.purge_hub_results(hist[-1])
435 434 newhist = self.client.hub_history()
436 435 self.assertEqual(len(newhist)+1,len(hist))
437 436 rc2.spin()
438 437 rc2.close()
439 438
440 439 def test_purge_local_results(self):
441 440 # ensure there are some tasks
442 441 res = []
443 442 for i in range(5):
444 443 res.append(self.client[:].apply_async(lambda : 1))
445 444 self._wait_for_idle()
446 445 self.client.wait(10) # wait for the results to come back
447 446 before = len(self.client.results)
448 447 self.assertEqual(len(self.client.metadata),before)
449 448 self.client.purge_local_results(res[-1])
450 449 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
451 450 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
452 451
453 452 def test_purge_local_results_outstanding(self):
454 453 v = self.client[-1]
455 454 ar = v.apply_async(lambda : 1)
456 455 msg_id = ar.msg_ids[0]
457 456 ar.get()
458 457 self._wait_for_idle()
459 458 ar2 = v.apply_async(time.sleep, 1)
460 459 self.assertIn(msg_id, self.client.results)
461 460 self.assertIn(msg_id, self.client.metadata)
462 461 self.client.purge_local_results(ar)
463 462 self.assertNotIn(msg_id, self.client.results)
464 463 self.assertNotIn(msg_id, self.client.metadata)
465 464 with self.assertRaises(RuntimeError):
466 465 self.client.purge_local_results(ar2)
467 466 ar2.get()
468 467 self.client.purge_local_results(ar2)
469 468
470 469 def test_purge_all_local_results_outstanding(self):
471 470 v = self.client[-1]
472 471 ar = v.apply_async(time.sleep, 1)
473 472 with self.assertRaises(RuntimeError):
474 473 self.client.purge_local_results('all')
475 474 ar.get()
476 475 self.client.purge_local_results('all')
477 476
478 477 def test_purge_all_hub_results(self):
479 478 self.client.purge_hub_results('all')
480 479 hist = self.client.hub_history()
481 480 self.assertEqual(len(hist), 0)
482 481
483 482 def test_purge_all_local_results(self):
484 483 self.client.purge_local_results('all')
485 484 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
486 485 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
487 486
488 487 def test_purge_all_results(self):
489 488 # ensure there are some tasks
490 489 for i in range(5):
491 490 self.client[:].apply_sync(lambda : 1)
492 491 self.client.wait(10)
493 492 self._wait_for_idle()
494 493 self.client.purge_results('all')
495 494 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
496 495 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
497 496 hist = self.client.hub_history()
498 497 self.assertEqual(len(hist), 0, msg="hub history not empty")
499 498
500 499 def test_purge_everything(self):
501 500 # ensure there are some tasks
502 501 for i in range(5):
503 502 self.client[:].apply_sync(lambda : 1)
504 503 self.client.wait(10)
505 504 self._wait_for_idle()
506 505 self.client.purge_everything()
507 506 # The client results
508 507 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
509 508 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
510 509 # The client "bookkeeping"
511 510 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
512 511 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
513 512 # the hub results
514 513 hist = self.client.hub_history()
515 514 self.assertEqual(len(hist), 0, msg="hub history not empty")
516 515
517 516
518 517 def test_spin_thread(self):
519 518 self.client.spin_thread(0.01)
520 519 ar = self.client[-1].apply_async(lambda : 1)
521 520 time.sleep(0.1)
522 521 self.assertTrue(ar.wall_time < 0.1,
523 522 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
524 523 )
525 524
526 525 def test_stop_spin_thread(self):
527 526 self.client.spin_thread(0.01)
528 527 self.client.stop_spin_thread()
529 528 ar = self.client[-1].apply_async(lambda : 1)
530 529 time.sleep(0.15)
531 530 self.assertTrue(ar.wall_time > 0.1,
532 531 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
533 532 )
534 533
535 534 def test_activate(self):
536 535 ip = get_ipython()
537 536 magics = ip.magics_manager.magics
538 537 self.assertTrue('px' in magics['line'])
539 538 self.assertTrue('px' in magics['cell'])
540 539 v0 = self.client.activate(-1, '0')
541 540 self.assertTrue('px0' in magics['line'])
542 541 self.assertTrue('px0' in magics['cell'])
543 542 self.assertEqual(v0.targets, self.client.ids[-1])
544 543 v0 = self.client.activate('all', 'all')
545 544 self.assertTrue('pxall' in magics['line'])
546 545 self.assertTrue('pxall' in magics['cell'])
547 546 self.assertEqual(v0.targets, 'all')
@@ -1,850 +1,849 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import base64
20 20 import sys
21 21 import platform
22 22 import time
23 23 from collections import namedtuple
24 from tempfile import mktemp
24 from tempfile import NamedTemporaryFile
25 25
26 26 import zmq
27 27 from nose.plugins.attrib import attr
28 28
29 29 from IPython.testing import decorators as dec
30 30 from IPython.utils.io import capture_output
31 31 from IPython.utils.py3compat import unicode_type
32 32
33 33 from IPython import parallel as pmod
34 34 from IPython.parallel import error
35 35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 36 from IPython.parallel.util import interactive
37 37
38 38 from IPython.parallel.tests import add_engines
39 39
40 40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41 41
42 42 def setup():
43 43 add_engines(3, total=True)
44 44
45 45 point = namedtuple("point", "x y")
46 46
47 47 class TestView(ClusterTestCase):
48 48
49 49 def setUp(self):
50 50 # On Win XP, wait for resource cleanup, else parallel test group fails
51 51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 53 time.sleep(2)
54 54 super(TestView, self).setUp()
55 55
56 56 @attr('crash')
57 57 def test_z_crash_mux(self):
58 58 """test graceful handling of engine death (direct)"""
59 59 # self.add_engines(1)
60 60 eid = self.client.ids[-1]
61 61 ar = self.client[eid].apply_async(crash)
62 62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 63 eid = ar.engine_id
64 64 tic = time.time()
65 65 while eid in self.client.ids and time.time()-tic < 5:
66 66 time.sleep(.01)
67 67 self.client.spin()
68 68 self.assertFalse(eid in self.client.ids, "Engine should have died")
69 69
70 70 def test_push_pull(self):
71 71 """test pushing and pulling"""
72 72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 73 t = self.client.ids[-1]
74 74 v = self.client[t]
75 75 push = v.push
76 76 pull = v.pull
77 77 v.block=True
78 78 nengines = len(self.client)
79 79 push({'data':data})
80 80 d = pull('data')
81 81 self.assertEqual(d, data)
82 82 self.client[:].push({'data':data})
83 83 d = self.client[:].pull('data', block=True)
84 84 self.assertEqual(d, nengines*[data])
85 85 ar = push({'data':data}, block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 88 ar = self.client[:].pull('data', block=False)
89 89 self.assertTrue(isinstance(ar, AsyncResult))
90 90 r = ar.get()
91 91 self.assertEqual(r, nengines*[data])
92 92 self.client[:].push(dict(a=10,b=20))
93 93 r = self.client[:].pull(('a','b'), block=True)
94 94 self.assertEqual(r, nengines*[[10,20]])
95 95
96 96 def test_push_pull_function(self):
97 97 "test pushing and pulling functions"
98 98 def testf(x):
99 99 return 2.0*x
100 100
101 101 t = self.client.ids[-1]
102 102 v = self.client[t]
103 103 v.block=True
104 104 push = v.push
105 105 pull = v.pull
106 106 execute = v.execute
107 107 push({'testf':testf})
108 108 r = pull('testf')
109 109 self.assertEqual(r(1.0), testf(1.0))
110 110 execute('r = testf(10)')
111 111 r = pull('r')
112 112 self.assertEqual(r, testf(10))
113 113 ar = self.client[:].push({'testf':testf}, block=False)
114 114 ar.get()
115 115 ar = self.client[:].pull('testf', block=False)
116 116 rlist = ar.get()
117 117 for r in rlist:
118 118 self.assertEqual(r(1.0), testf(1.0))
119 119 execute("def g(x): return x*x")
120 120 r = pull(('testf','g'))
121 121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122 122
123 123 def test_push_function_globals(self):
124 124 """test that pushed functions have access to globals"""
125 125 @interactive
126 126 def geta():
127 127 return a
128 128 # self.add_engines(1)
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = geta
132 132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 133 v.execute('a=5')
134 134 v.execute('b=f()')
135 135 self.assertEqual(v['b'], 5)
136 136
137 137 def test_push_function_defaults(self):
138 138 """test that pushed functions preserve default args"""
139 139 def echo(a=10):
140 140 return a
141 141 v = self.client[-1]
142 142 v.block=True
143 143 v['f'] = echo
144 144 v.execute('b=f()')
145 145 self.assertEqual(v['b'], 10)
146 146
147 147 def test_get_result(self):
148 148 """test getting results from the Hub."""
149 149 c = pmod.Client(profile='iptest')
150 150 # self.add_engines(1)
151 151 t = c.ids[-1]
152 152 v = c[t]
153 153 v2 = self.client[t]
154 154 ar = v.apply_async(wait, 1)
155 155 # give the monitor time to notice the message
156 156 time.sleep(.25)
157 157 ahr = v2.get_result(ar.msg_ids[0])
158 158 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 159 self.assertEqual(ahr.get(), ar.get())
160 160 ar2 = v2.get_result(ar.msg_ids[0])
161 161 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 162 c.spin()
163 163 c.close()
164 164
165 165 def test_run_newline(self):
166 166 """test that run appends newline to files"""
167 tmpfile = mktemp()
168 with open(tmpfile, 'w') as f:
167 with NamedTemporaryFile('w', delete=False) as f:
169 168 f.write("""def g():
170 169 return 5
171 170 """)
172 171 v = self.client[-1]
173 v.run(tmpfile, block=True)
172 v.run(f.name, block=True)
174 173 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175 174
176 175 def test_apply_tracked(self):
177 176 """test tracking for apply"""
178 177 # self.add_engines(1)
179 178 t = self.client.ids[-1]
180 179 v = self.client[t]
181 180 v.block=False
182 181 def echo(n=1024*1024, **kwargs):
183 182 with v.temp_flags(**kwargs):
184 183 return v.apply(lambda x: x, 'x'*n)
185 184 ar = echo(1, track=False)
186 185 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 186 self.assertTrue(ar.sent)
188 187 ar = echo(track=True)
189 188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 189 self.assertEqual(ar.sent, ar._tracker.done)
191 190 ar._tracker.wait()
192 191 self.assertTrue(ar.sent)
193 192
194 193 def test_push_tracked(self):
195 194 t = self.client.ids[-1]
196 195 ns = dict(x='x'*1024*1024)
197 196 v = self.client[t]
198 197 ar = v.push(ns, block=False, track=False)
199 198 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 199 self.assertTrue(ar.sent)
201 200
202 201 ar = v.push(ns, block=False, track=True)
203 202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 203 ar._tracker.wait()
205 204 self.assertEqual(ar.sent, ar._tracker.done)
206 205 self.assertTrue(ar.sent)
207 206 ar.get()
208 207
209 208 def test_scatter_tracked(self):
210 209 t = self.client.ids
211 210 x='x'*1024*1024
212 211 ar = self.client[t].scatter('x', x, block=False, track=False)
213 212 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 213 self.assertTrue(ar.sent)
215 214
216 215 ar = self.client[t].scatter('x', x, block=False, track=True)
217 216 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 217 self.assertEqual(ar.sent, ar._tracker.done)
219 218 ar._tracker.wait()
220 219 self.assertTrue(ar.sent)
221 220 ar.get()
222 221
223 222 def test_remote_reference(self):
224 223 v = self.client[-1]
225 224 v['a'] = 123
226 225 ra = pmod.Reference('a')
227 226 b = v.apply_sync(lambda x: x, ra)
228 227 self.assertEqual(b, 123)
229 228
230 229
231 230 def test_scatter_gather(self):
232 231 view = self.client[:]
233 232 seq1 = list(range(16))
234 233 view.scatter('a', seq1)
235 234 seq2 = view.gather('a', block=True)
236 235 self.assertEqual(seq2, seq1)
237 236 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238 237
239 238 @skip_without('numpy')
240 239 def test_scatter_gather_numpy(self):
241 240 import numpy
242 241 from numpy.testing.utils import assert_array_equal
243 242 view = self.client[:]
244 243 a = numpy.arange(64)
245 244 view.scatter('a', a, block=True)
246 245 b = view.gather('a', block=True)
247 246 assert_array_equal(b, a)
248 247
249 248 def test_scatter_gather_lazy(self):
250 249 """scatter/gather with targets='all'"""
251 250 view = self.client.direct_view(targets='all')
252 251 x = list(range(64))
253 252 view.scatter('x', x)
254 253 gathered = view.gather('x', block=True)
255 254 self.assertEqual(gathered, x)
256 255
257 256
258 257 @dec.known_failure_py3
259 258 @skip_without('numpy')
260 259 def test_push_numpy_nocopy(self):
261 260 import numpy
262 261 view = self.client[:]
263 262 a = numpy.arange(64)
264 263 view['A'] = a
265 264 @interactive
266 265 def check_writeable(x):
267 266 return x.flags.writeable
268 267
269 268 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 269 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 270
272 271 view.push(dict(B=a))
273 272 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 273 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275 274
276 275 @skip_without('numpy')
277 276 def test_apply_numpy(self):
278 277 """view.apply(f, ndarray)"""
279 278 import numpy
280 279 from numpy.testing.utils import assert_array_equal
281 280
282 281 A = numpy.random.random((100,100))
283 282 view = self.client[-1]
284 283 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 284 B = A.astype(dt)
286 285 C = view.apply_sync(lambda x:x, B)
287 286 assert_array_equal(B,C)
288 287
289 288 @skip_without('numpy')
290 289 def test_apply_numpy_object_dtype(self):
291 290 """view.apply(f, ndarray) with dtype=object"""
292 291 import numpy
293 292 from numpy.testing.utils import assert_array_equal
294 293 view = self.client[-1]
295 294
296 295 A = numpy.array([dict(a=5)])
297 296 B = view.apply_sync(lambda x:x, A)
298 297 assert_array_equal(A,B)
299 298
300 299 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
301 300 B = view.apply_sync(lambda x:x, A)
302 301 assert_array_equal(A,B)
303 302
304 303 @skip_without('numpy')
305 304 def test_push_pull_recarray(self):
306 305 """push/pull recarrays"""
307 306 import numpy
308 307 from numpy.testing.utils import assert_array_equal
309 308
310 309 view = self.client[-1]
311 310
312 311 R = numpy.array([
313 312 (1, 'hi', 0.),
314 313 (2**30, 'there', 2.5),
315 314 (-99999, 'world', -12345.6789),
316 315 ], [('n', int), ('s', '|S10'), ('f', float)])
317 316
318 317 view['RR'] = R
319 318 R2 = view['RR']
320 319
321 320 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
322 321 self.assertEqual(r_dtype, R.dtype)
323 322 self.assertEqual(r_shape, R.shape)
324 323 self.assertEqual(R2.dtype, R.dtype)
325 324 self.assertEqual(R2.shape, R.shape)
326 325 assert_array_equal(R2, R)
327 326
328 327 @skip_without('pandas')
329 328 def test_push_pull_timeseries(self):
330 329 """push/pull pandas.TimeSeries"""
331 330 import pandas
332 331
333 332 ts = pandas.TimeSeries(list(range(10)))
334 333
335 334 view = self.client[-1]
336 335
337 336 view.push(dict(ts=ts), block=True)
338 337 rts = view['ts']
339 338
340 339 self.assertEqual(type(rts), type(ts))
341 340 self.assertTrue((ts == rts).all())
342 341
343 342 def test_map(self):
344 343 view = self.client[:]
345 344 def f(x):
346 345 return x**2
347 346 data = list(range(16))
348 347 r = view.map_sync(f, data)
349 348 self.assertEqual(r, list(map(f, data)))
350 349
351 350 def test_map_iterable(self):
352 351 """test map on iterables (direct)"""
353 352 view = self.client[:]
354 353 # 101 is prime, so it won't be evenly distributed
355 354 arr = range(101)
356 355 # ensure it will be an iterator, even in Python 3
357 356 it = iter(arr)
358 357 r = view.map_sync(lambda x: x, it)
359 358 self.assertEqual(r, list(arr))
360 359
361 360 @skip_without('numpy')
362 361 def test_map_numpy(self):
363 362 """test map on numpy arrays (direct)"""
364 363 import numpy
365 364 from numpy.testing.utils import assert_array_equal
366 365
367 366 view = self.client[:]
368 367 # 101 is prime, so it won't be evenly distributed
369 368 arr = numpy.arange(101)
370 369 r = view.map_sync(lambda x: x, arr)
371 370 assert_array_equal(r, arr)
372 371
373 372 def test_scatter_gather_nonblocking(self):
374 373 data = list(range(16))
375 374 view = self.client[:]
376 375 view.scatter('a', data, block=False)
377 376 ar = view.gather('a', block=False)
378 377 self.assertEqual(ar.get(), data)
379 378
380 379 @skip_without('numpy')
381 380 def test_scatter_gather_numpy_nonblocking(self):
382 381 import numpy
383 382 from numpy.testing.utils import assert_array_equal
384 383 a = numpy.arange(64)
385 384 view = self.client[:]
386 385 ar = view.scatter('a', a, block=False)
387 386 self.assertTrue(isinstance(ar, AsyncResult))
388 387 amr = view.gather('a', block=False)
389 388 self.assertTrue(isinstance(amr, AsyncMapResult))
390 389 assert_array_equal(amr.get(), a)
391 390
392 391 def test_execute(self):
393 392 view = self.client[:]
394 393 # self.client.debug=True
395 394 execute = view.execute
396 395 ar = execute('c=30', block=False)
397 396 self.assertTrue(isinstance(ar, AsyncResult))
398 397 ar = execute('d=[0,1,2]', block=False)
399 398 self.client.wait(ar, 1)
400 399 self.assertEqual(len(ar.get()), len(self.client))
401 400 for c in view['c']:
402 401 self.assertEqual(c, 30)
403 402
404 403 def test_abort(self):
405 404 view = self.client[-1]
406 405 ar = view.execute('import time; time.sleep(1)', block=False)
407 406 ar2 = view.apply_async(lambda : 2)
408 407 ar3 = view.apply_async(lambda : 3)
409 408 view.abort(ar2)
410 409 view.abort(ar3.msg_ids)
411 410 self.assertRaises(error.TaskAborted, ar2.get)
412 411 self.assertRaises(error.TaskAborted, ar3.get)
413 412
414 413 def test_abort_all(self):
415 414 """view.abort() aborts all outstanding tasks"""
416 415 view = self.client[-1]
417 416 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
418 417 view.abort()
419 418 view.wait(timeout=5)
420 419 for ar in ars[5:]:
421 420 self.assertRaises(error.TaskAborted, ar.get)
422 421
423 422 def test_temp_flags(self):
424 423 view = self.client[-1]
425 424 view.block=True
426 425 with view.temp_flags(block=False):
427 426 self.assertFalse(view.block)
428 427 self.assertTrue(view.block)
429 428
430 429 @dec.known_failure_py3
431 430 def test_importer(self):
432 431 view = self.client[-1]
433 432 view.clear(block=True)
434 433 with view.importer:
435 434 import re
436 435
437 436 @interactive
438 437 def findall(pat, s):
439 438 # this globals() step isn't necessary in real code
440 439 # only to prevent a closure in the test
441 440 re = globals()['re']
442 441 return re.findall(pat, s)
443 442
444 443 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
445 444
446 445 def test_unicode_execute(self):
447 446 """test executing unicode strings"""
448 447 v = self.client[-1]
449 448 v.block=True
450 449 if sys.version_info[0] >= 3:
451 450 code="a='é'"
452 451 else:
453 452 code=u"a=u'é'"
454 453 v.execute(code)
455 454 self.assertEqual(v['a'], u'é')
456 455
457 456 def test_unicode_apply_result(self):
458 457 """test unicode apply results"""
459 458 v = self.client[-1]
460 459 r = v.apply_sync(lambda : u'é')
461 460 self.assertEqual(r, u'é')
462 461
463 462 def test_unicode_apply_arg(self):
464 463 """test passing unicode arguments to apply"""
465 464 v = self.client[-1]
466 465
467 466 @interactive
468 467 def check_unicode(a, check):
469 468 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
470 469 assert isinstance(check, bytes), "%r is not bytes"%check
471 470 assert a.encode('utf8') == check, "%s != %s"%(a,check)
472 471
473 472 for s in [ u'é', u'ßø®∫',u'asdf' ]:
474 473 try:
475 474 v.apply_sync(check_unicode, s, s.encode('utf8'))
476 475 except error.RemoteError as e:
477 476 if e.ename == 'AssertionError':
478 477 self.fail(e.evalue)
479 478 else:
480 479 raise e
481 480
482 481 def test_map_reference(self):
483 482 """view.map(<Reference>, *seqs) should work"""
484 483 v = self.client[:]
485 484 v.scatter('n', self.client.ids, flatten=True)
486 485 v.execute("f = lambda x,y: x*y")
487 486 rf = pmod.Reference('f')
488 487 nlist = list(range(10))
489 488 mlist = nlist[::-1]
490 489 expected = [ m*n for m,n in zip(mlist, nlist) ]
491 490 result = v.map_sync(rf, mlist, nlist)
492 491 self.assertEqual(result, expected)
493 492
494 493 def test_apply_reference(self):
495 494 """view.apply(<Reference>, *args) should work"""
496 495 v = self.client[:]
497 496 v.scatter('n', self.client.ids, flatten=True)
498 497 v.execute("f = lambda x: n*x")
499 498 rf = pmod.Reference('f')
500 499 result = v.apply_sync(rf, 5)
501 500 expected = [ 5*id for id in self.client.ids ]
502 501 self.assertEqual(result, expected)
503 502
504 503 def test_eval_reference(self):
505 504 v = self.client[self.client.ids[0]]
506 505 v['g'] = list(range(5))
507 506 rg = pmod.Reference('g[0]')
508 507 echo = lambda x:x
509 508 self.assertEqual(v.apply_sync(echo, rg), 0)
510 509
511 510 def test_reference_nameerror(self):
512 511 v = self.client[self.client.ids[0]]
513 512 r = pmod.Reference('elvis_has_left')
514 513 echo = lambda x:x
515 514 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
516 515
517 516 def test_single_engine_map(self):
518 517 e0 = self.client[self.client.ids[0]]
519 518 r = list(range(5))
520 519 check = [ -1*i for i in r ]
521 520 result = e0.map_sync(lambda x: -1*x, r)
522 521 self.assertEqual(result, check)
523 522
524 523 def test_len(self):
525 524 """len(view) makes sense"""
526 525 e0 = self.client[self.client.ids[0]]
527 526 self.assertEqual(len(e0), 1)
528 527 v = self.client[:]
529 528 self.assertEqual(len(v), len(self.client.ids))
530 529 v = self.client.direct_view('all')
531 530 self.assertEqual(len(v), len(self.client.ids))
532 531 v = self.client[:2]
533 532 self.assertEqual(len(v), 2)
534 533 v = self.client[:1]
535 534 self.assertEqual(len(v), 1)
536 535 v = self.client.load_balanced_view()
537 536 self.assertEqual(len(v), len(self.client.ids))
538 537
539 538
540 539 # begin execute tests
541 540
542 541 def test_execute_reply(self):
543 542 e0 = self.client[self.client.ids[0]]
544 543 e0.block = True
545 544 ar = e0.execute("5", silent=False)
546 545 er = ar.get()
547 546 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
548 547 self.assertEqual(er.pyout['data']['text/plain'], '5')
549 548
550 549 def test_execute_reply_rich(self):
551 550 e0 = self.client[self.client.ids[0]]
552 551 e0.block = True
553 552 e0.execute("from IPython.display import Image, HTML")
554 553 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
555 554 er = ar.get()
556 555 b64data = base64.encodestring(b'garbage').decode('ascii')
557 556 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
558 557 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
559 558 er = ar.get()
560 559 self.assertEqual(er._repr_html_(), "<b>bold</b>")
561 560
562 561 def test_execute_reply_stdout(self):
563 562 e0 = self.client[self.client.ids[0]]
564 563 e0.block = True
565 564 ar = e0.execute("print (5)", silent=False)
566 565 er = ar.get()
567 566 self.assertEqual(er.stdout.strip(), '5')
568 567
569 568 def test_execute_pyout(self):
570 569 """execute triggers pyout with silent=False"""
571 570 view = self.client[:]
572 571 ar = view.execute("5", silent=False, block=True)
573 572
574 573 expected = [{'text/plain' : '5'}] * len(view)
575 574 mimes = [ out['data'] for out in ar.pyout ]
576 575 self.assertEqual(mimes, expected)
577 576
578 577 def test_execute_silent(self):
579 578 """execute does not trigger pyout with silent=True"""
580 579 view = self.client[:]
581 580 ar = view.execute("5", block=True)
582 581 expected = [None] * len(view)
583 582 self.assertEqual(ar.pyout, expected)
584 583
585 584 def test_execute_magic(self):
586 585 """execute accepts IPython commands"""
587 586 view = self.client[:]
588 587 view.execute("a = 5")
589 588 ar = view.execute("%whos", block=True)
590 589 # this will raise, if that failed
591 590 ar.get(5)
592 591 for stdout in ar.stdout:
593 592 lines = stdout.splitlines()
594 593 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
595 594 found = False
596 595 for line in lines[2:]:
597 596 split = line.split()
598 597 if split == ['a', 'int', '5']:
599 598 found = True
600 599 break
601 600 self.assertTrue(found, "whos output wrong: %s" % stdout)
602 601
603 602 def test_execute_displaypub(self):
604 603 """execute tracks display_pub output"""
605 604 view = self.client[:]
606 605 view.execute("from IPython.core.display import *")
607 606 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
608 607
609 608 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
610 609 for outputs in ar.outputs:
611 610 mimes = [ out['data'] for out in outputs ]
612 611 self.assertEqual(mimes, expected)
613 612
614 613 def test_apply_displaypub(self):
615 614 """apply tracks display_pub output"""
616 615 view = self.client[:]
617 616 view.execute("from IPython.core.display import *")
618 617
619 618 @interactive
620 619 def publish():
621 620 [ display(i) for i in range(5) ]
622 621
623 622 ar = view.apply_async(publish)
624 623 ar.get(5)
625 624 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
626 625 for outputs in ar.outputs:
627 626 mimes = [ out['data'] for out in outputs ]
628 627 self.assertEqual(mimes, expected)
629 628
630 629 def test_execute_raises(self):
631 630 """exceptions in execute requests raise appropriately"""
632 631 view = self.client[-1]
633 632 ar = view.execute("1/0")
634 633 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
635 634
636 635 def test_remoteerror_render_exception(self):
637 636 """RemoteErrors get nice tracebacks"""
638 637 view = self.client[-1]
639 638 ar = view.execute("1/0")
640 639 ip = get_ipython()
641 640 ip.user_ns['ar'] = ar
642 641 with capture_output() as io:
643 642 ip.run_cell("ar.get(2)")
644 643
645 644 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
646 645
647 646 def test_compositeerror_render_exception(self):
648 647 """CompositeErrors get nice tracebacks"""
649 648 view = self.client[:]
650 649 ar = view.execute("1/0")
651 650 ip = get_ipython()
652 651 ip.user_ns['ar'] = ar
653 652
654 653 with capture_output() as io:
655 654 ip.run_cell("ar.get(2)")
656 655
657 656 count = min(error.CompositeError.tb_limit, len(view))
658 657
659 658 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
660 659 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
661 660 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
662 661
663 662 def test_compositeerror_truncate(self):
664 663 """Truncate CompositeErrors with many exceptions"""
665 664 view = self.client[:]
666 665 msg_ids = []
667 666 for i in range(10):
668 667 ar = view.execute("1/0")
669 668 msg_ids.extend(ar.msg_ids)
670 669
671 670 ar = self.client.get_result(msg_ids)
672 671 try:
673 672 ar.get()
674 673 except error.CompositeError as _e:
675 674 e = _e
676 675 else:
677 676 self.fail("Should have raised CompositeError")
678 677
679 678 lines = e.render_traceback()
680 679 with capture_output() as io:
681 680 e.print_traceback()
682 681
683 682 self.assertTrue("more exceptions" in lines[-1])
684 683 count = e.tb_limit
685 684
686 685 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
687 686 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
688 687 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
689 688
690 689 @dec.skipif_not_matplotlib
691 690 def test_magic_pylab(self):
692 691 """%pylab works on engines"""
693 692 view = self.client[-1]
694 693 ar = view.execute("%pylab inline")
695 694 # at least check if this raised:
696 695 reply = ar.get(5)
697 696 # include imports, in case user config
698 697 ar = view.execute("plot(rand(100))", silent=False)
699 698 reply = ar.get(5)
700 699 self.assertEqual(len(reply.outputs), 1)
701 700 output = reply.outputs[0]
702 701 self.assertTrue("data" in output)
703 702 data = output['data']
704 703 self.assertTrue("image/png" in data)
705 704
706 705 def test_func_default_func(self):
707 706 """interactively defined function as apply func default"""
708 707 def foo():
709 708 return 'foo'
710 709
711 710 def bar(f=foo):
712 711 return f()
713 712
714 713 view = self.client[-1]
715 714 ar = view.apply_async(bar)
716 715 r = ar.get(10)
717 716 self.assertEqual(r, 'foo')
718 717 def test_data_pub_single(self):
719 718 view = self.client[-1]
720 719 ar = view.execute('\n'.join([
721 720 'from IPython.kernel.zmq.datapub import publish_data',
722 721 'for i in range(5):',
723 722 ' publish_data(dict(i=i))'
724 723 ]), block=False)
725 724 self.assertTrue(isinstance(ar.data, dict))
726 725 ar.get(5)
727 726 self.assertEqual(ar.data, dict(i=4))
728 727
729 728 def test_data_pub(self):
730 729 view = self.client[:]
731 730 ar = view.execute('\n'.join([
732 731 'from IPython.kernel.zmq.datapub import publish_data',
733 732 'for i in range(5):',
734 733 ' publish_data(dict(i=i))'
735 734 ]), block=False)
736 735 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
737 736 ar.get(5)
738 737 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
739 738
740 739 def test_can_list_arg(self):
741 740 """args in lists are canned"""
742 741 view = self.client[-1]
743 742 view['a'] = 128
744 743 rA = pmod.Reference('a')
745 744 ar = view.apply_async(lambda x: x, [rA])
746 745 r = ar.get(5)
747 746 self.assertEqual(r, [128])
748 747
749 748 def test_can_dict_arg(self):
750 749 """args in dicts are canned"""
751 750 view = self.client[-1]
752 751 view['a'] = 128
753 752 rA = pmod.Reference('a')
754 753 ar = view.apply_async(lambda x: x, dict(foo=rA))
755 754 r = ar.get(5)
756 755 self.assertEqual(r, dict(foo=128))
757 756
758 757 def test_can_list_kwarg(self):
759 758 """kwargs in lists are canned"""
760 759 view = self.client[-1]
761 760 view['a'] = 128
762 761 rA = pmod.Reference('a')
763 762 ar = view.apply_async(lambda x=5: x, x=[rA])
764 763 r = ar.get(5)
765 764 self.assertEqual(r, [128])
766 765
767 766 def test_can_dict_kwarg(self):
768 767 """kwargs in dicts are canned"""
769 768 view = self.client[-1]
770 769 view['a'] = 128
771 770 rA = pmod.Reference('a')
772 771 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
773 772 r = ar.get(5)
774 773 self.assertEqual(r, dict(foo=128))
775 774
776 775 def test_map_ref(self):
777 776 """view.map works with references"""
778 777 view = self.client[:]
779 778 ranks = sorted(self.client.ids)
780 779 view.scatter('rank', ranks, flatten=True)
781 780 rrank = pmod.Reference('rank')
782 781
783 782 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
784 783 drank = amr.get(5)
785 784 self.assertEqual(drank, [ r*2 for r in ranks ])
786 785
787 786 def test_nested_getitem_setitem(self):
788 787 """get and set with view['a.b']"""
789 788 view = self.client[-1]
790 789 view.execute('\n'.join([
791 790 'class A(object): pass',
792 791 'a = A()',
793 792 'a.b = 128',
794 793 ]), block=True)
795 794 ra = pmod.Reference('a')
796 795
797 796 r = view.apply_sync(lambda x: x.b, ra)
798 797 self.assertEqual(r, 128)
799 798 self.assertEqual(view['a.b'], 128)
800 799
801 800 view['a.b'] = 0
802 801
803 802 r = view.apply_sync(lambda x: x.b, ra)
804 803 self.assertEqual(r, 0)
805 804 self.assertEqual(view['a.b'], 0)
806 805
807 806 def test_return_namedtuple(self):
808 807 def namedtuplify(x, y):
809 808 from IPython.parallel.tests.test_view import point
810 809 return point(x, y)
811 810
812 811 view = self.client[-1]
813 812 p = view.apply_sync(namedtuplify, 1, 2)
814 813 self.assertEqual(p.x, 1)
815 814 self.assertEqual(p.y, 2)
816 815
817 816 def test_apply_namedtuple(self):
818 817 def echoxy(p):
819 818 return p.y, p.x
820 819
821 820 view = self.client[-1]
822 821 tup = view.apply_sync(echoxy, point(1, 2))
823 822 self.assertEqual(tup, (2,1))
824 823
825 824 def test_sync_imports(self):
826 825 view = self.client[-1]
827 826 with capture_output() as io:
828 827 with view.sync_imports():
829 828 import IPython
830 829 self.assertIn("IPython", io.stdout)
831 830
832 831 @interactive
833 832 def find_ipython():
834 833 return 'IPython' in globals()
835 834
836 835 assert view.apply_sync(find_ipython)
837 836
838 837 def test_sync_imports_quiet(self):
839 838 view = self.client[-1]
840 839 with capture_output() as io:
841 840 with view.sync_imports(quiet=True):
842 841 import IPython
843 842 self.assertEqual(io.stdout, '')
844 843
845 844 @interactive
846 845 def find_ipython():
847 846 return 'IPython' in globals()
848 847
849 848 assert view.apply_sync(find_ipython)
850 849
@@ -1,454 +1,456 b''
1 1 """Generic testing tools.
2 2
3 3 Authors
4 4 -------
5 5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 6 """
7 7
8 8 from __future__ import absolute_import
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2009 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20
21 21 import os
22 22 import re
23 23 import sys
24 24 import tempfile
25 25
26 26 from contextlib import contextmanager
27 27 from io import StringIO
28 28 from subprocess import Popen, PIPE
29 29
30 30 try:
31 31 # These tools are used by parts of the runtime, so we make the nose
32 32 # dependency optional at this point. Nose is a hard dependency to run the
33 33 # test suite, but NOT to use ipython itself.
34 34 import nose.tools as nt
35 35 has_nose = True
36 36 except ImportError:
37 37 has_nose = False
38 38
39 39 from IPython.config.loader import Config
40 40 from IPython.utils.process import get_output_error_code
41 41 from IPython.utils.text import list_strings
42 42 from IPython.utils.io import temp_pyfile, Tee
43 43 from IPython.utils import py3compat
44 44 from IPython.utils.encoding import DEFAULT_ENCODING
45 45
46 46 from . import decorators as dec
47 47 from . import skipdoctest
48 48
49 49 #-----------------------------------------------------------------------------
50 50 # Functions and classes
51 51 #-----------------------------------------------------------------------------
52 52
53 53 # The docstring for full_path doctests differently on win32 (different path
54 54 # separator) so just skip the doctest there. The example remains informative.
55 55 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
56 56
57 57 @doctest_deco
58 58 def full_path(startPath,files):
59 59 """Make full paths for all the listed files, based on startPath.
60 60
61 61 Only the base part of startPath is kept, since this routine is typically
62 62 used with a script's ``__file__`` variable as startPath. The base of startPath
63 63 is then prepended to all the listed files, forming the output list.
64 64
65 65 Parameters
66 66 ----------
67 67 startPath : string
68 68 Initial path to use as the base for the results. This path is split
69 69 using os.path.split() and only its first component is kept.
70 70
71 71 files : string or list
72 72 One or more files.
73 73
74 74 Examples
75 75 --------
76 76
77 77 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
78 78 ['/foo/a.txt', '/foo/b.txt']
79 79
80 80 >>> full_path('/foo',['a.txt','b.txt'])
81 81 ['/a.txt', '/b.txt']
82 82
83 83 If a single file is given, the output is still a list::
84 84
85 85 >>> full_path('/foo','a.txt')
86 86 ['/a.txt']
87 87 """
88 88
89 89 files = list_strings(files)
90 90 base = os.path.split(startPath)[0]
91 91 return [ os.path.join(base,f) for f in files ]
92 92
93 93
94 94 def parse_test_output(txt):
95 95 """Parse the output of a test run and return errors, failures.
96 96
97 97 Parameters
98 98 ----------
99 99 txt : str
100 100 Text output of a test run, assumed to contain a line of one of the
101 101 following forms::
102 102
103 103 'FAILED (errors=1)'
104 104 'FAILED (failures=1)'
105 105 'FAILED (errors=1, failures=1)'
106 106
107 107 Returns
108 108 -------
109 109 nerr, nfail
110 110 number of errors and failures.
111 111 """
112 112
113 113 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
114 114 if err_m:
115 115 nerr = int(err_m.group(1))
116 116 nfail = 0
117 117 return nerr, nfail
118 118
119 119 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
120 120 if fail_m:
121 121 nerr = 0
122 122 nfail = int(fail_m.group(1))
123 123 return nerr, nfail
124 124
125 125 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
126 126 re.MULTILINE)
127 127 if both_m:
128 128 nerr = int(both_m.group(1))
129 129 nfail = int(both_m.group(2))
130 130 return nerr, nfail
131 131
132 132 # If the input didn't match any of these forms, assume no error/failures
133 133 return 0, 0
134 134
135 135
136 136 # So nose doesn't think this is a test
137 137 parse_test_output.__test__ = False
138 138
139 139
140 140 def default_argv():
141 141 """Return a valid default argv for creating testing instances of ipython"""
142 142
143 143 return ['--quick', # so no config file is loaded
144 144 # Other defaults to minimize side effects on stdout
145 145 '--colors=NoColor', '--no-term-title','--no-banner',
146 146 '--autocall=0']
147 147
148 148
149 149 def default_config():
150 150 """Return a config object with good defaults for testing."""
151 151 config = Config()
152 152 config.TerminalInteractiveShell.colors = 'NoColor'
153 153 config.TerminalTerminalInteractiveShell.term_title = False,
154 154 config.TerminalInteractiveShell.autocall = 0
155 config.HistoryManager.hist_file = tempfile.mktemp(u'test_hist.sqlite')
155 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
156 config.HistoryManager.hist_file = f.name
157 f.close()
156 158 config.HistoryManager.db_cache_size = 10000
157 159 return config
158 160
159 161
160 162 def get_ipython_cmd(as_string=False):
161 163 """
162 164 Return appropriate IPython command line name. By default, this will return
163 165 a list that can be used with subprocess.Popen, for example, but passing
164 166 `as_string=True` allows for returning the IPython command as a string.
165 167
166 168 Parameters
167 169 ----------
168 170 as_string: bool
169 171 Flag to allow to return the command as a string.
170 172 """
171 173 ipython_cmd = [sys.executable, "-m", "IPython"]
172 174
173 175 if as_string:
174 176 ipython_cmd = " ".join(ipython_cmd)
175 177
176 178 return ipython_cmd
177 179
178 180 def ipexec(fname, options=None):
179 181 """Utility to call 'ipython filename'.
180 182
181 183 Starts IPython with a minimal and safe configuration to make startup as fast
182 184 as possible.
183 185
184 186 Note that this starts IPython in a subprocess!
185 187
186 188 Parameters
187 189 ----------
188 190 fname : str
189 191 Name of file to be executed (should have .py or .ipy extension).
190 192
191 193 options : optional, list
192 194 Extra command-line flags to be passed to IPython.
193 195
194 196 Returns
195 197 -------
196 198 (stdout, stderr) of ipython subprocess.
197 199 """
198 200 if options is None: options = []
199 201
200 202 # For these subprocess calls, eliminate all prompt printing so we only see
201 203 # output from script execution
202 204 prompt_opts = [ '--PromptManager.in_template=""',
203 205 '--PromptManager.in2_template=""',
204 206 '--PromptManager.out_template=""'
205 207 ]
206 208 cmdargs = default_argv() + prompt_opts + options
207 209
208 210 test_dir = os.path.dirname(__file__)
209 211
210 212 ipython_cmd = get_ipython_cmd()
211 213 # Absolute path for filename
212 214 full_fname = os.path.join(test_dir, fname)
213 215 full_cmd = ipython_cmd + cmdargs + [full_fname]
214 216 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE)
215 217 out, err = p.communicate()
216 218 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
217 219 # `import readline` causes 'ESC[?1034h' to be output sometimes,
218 220 # so strip that out before doing comparisons
219 221 if out:
220 222 out = re.sub(r'\x1b\[[^h]+h', '', out)
221 223 return out, err
222 224
223 225
224 226 def ipexec_validate(fname, expected_out, expected_err='',
225 227 options=None):
226 228 """Utility to call 'ipython filename' and validate output/error.
227 229
228 230 This function raises an AssertionError if the validation fails.
229 231
230 232 Note that this starts IPython in a subprocess!
231 233
232 234 Parameters
233 235 ----------
234 236 fname : str
235 237 Name of the file to be executed (should have .py or .ipy extension).
236 238
237 239 expected_out : str
238 240 Expected stdout of the process.
239 241
240 242 expected_err : optional, str
241 243 Expected stderr of the process.
242 244
243 245 options : optional, list
244 246 Extra command-line flags to be passed to IPython.
245 247
246 248 Returns
247 249 -------
248 250 None
249 251 """
250 252
251 253 import nose.tools as nt
252 254
253 255 out, err = ipexec(fname, options)
254 256 #print 'OUT', out # dbg
255 257 #print 'ERR', err # dbg
256 258 # If there are any errors, we must check those befor stdout, as they may be
257 259 # more informative than simply having an empty stdout.
258 260 if err:
259 261 if expected_err:
260 262 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
261 263 else:
262 264 raise ValueError('Running file %r produced error: %r' %
263 265 (fname, err))
264 266 # If no errors or output on stderr was expected, match stdout
265 267 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
266 268
267 269
268 270 class TempFileMixin(object):
269 271 """Utility class to create temporary Python/IPython files.
270 272
271 273 Meant as a mixin class for test cases."""
272 274
273 275 def mktmp(self, src, ext='.py'):
274 276 """Make a valid python temp file."""
275 277 fname, f = temp_pyfile(src, ext)
276 278 self.tmpfile = f
277 279 self.fname = fname
278 280
279 281 def tearDown(self):
280 282 if hasattr(self, 'tmpfile'):
281 283 # If the tmpfile wasn't made because of skipped tests, like in
282 284 # win32, there's nothing to cleanup.
283 285 self.tmpfile.close()
284 286 try:
285 287 os.unlink(self.fname)
286 288 except:
287 289 # On Windows, even though we close the file, we still can't
288 290 # delete it. I have no clue why
289 291 pass
290 292
291 293 pair_fail_msg = ("Testing {0}\n\n"
292 294 "In:\n"
293 295 " {1!r}\n"
294 296 "Expected:\n"
295 297 " {2!r}\n"
296 298 "Got:\n"
297 299 " {3!r}\n")
298 300 def check_pairs(func, pairs):
299 301 """Utility function for the common case of checking a function with a
300 302 sequence of input/output pairs.
301 303
302 304 Parameters
303 305 ----------
304 306 func : callable
305 307 The function to be tested. Should accept a single argument.
306 308 pairs : iterable
307 309 A list of (input, expected_output) tuples.
308 310
309 311 Returns
310 312 -------
311 313 None. Raises an AssertionError if any output does not match the expected
312 314 value.
313 315 """
314 316 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
315 317 for inp, expected in pairs:
316 318 out = func(inp)
317 319 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
318 320
319 321
320 322 if py3compat.PY3:
321 323 MyStringIO = StringIO
322 324 else:
323 325 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
324 326 # so we need a class that can handle both.
325 327 class MyStringIO(StringIO):
326 328 def write(self, s):
327 329 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
328 330 super(MyStringIO, self).write(s)
329 331
330 332 _re_type = type(re.compile(r''))
331 333
332 334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
333 335 -------
334 336 {2!s}
335 337 -------
336 338 """
337 339
338 340 class AssertPrints(object):
339 341 """Context manager for testing that code prints certain text.
340 342
341 343 Examples
342 344 --------
343 345 >>> with AssertPrints("abc", suppress=False):
344 346 ... print("abcd")
345 347 ... print("def")
346 348 ...
347 349 abcd
348 350 def
349 351 """
350 352 def __init__(self, s, channel='stdout', suppress=True):
351 353 self.s = s
352 354 if isinstance(self.s, (py3compat.string_types, _re_type)):
353 355 self.s = [self.s]
354 356 self.channel = channel
355 357 self.suppress = suppress
356 358
357 359 def __enter__(self):
358 360 self.orig_stream = getattr(sys, self.channel)
359 361 self.buffer = MyStringIO()
360 362 self.tee = Tee(self.buffer, channel=self.channel)
361 363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
362 364
363 365 def __exit__(self, etype, value, traceback):
364 366 if value is not None:
365 367 # If an error was raised, don't check anything else
366 368 return False
367 369 self.tee.flush()
368 370 setattr(sys, self.channel, self.orig_stream)
369 371 printed = self.buffer.getvalue()
370 372 for s in self.s:
371 373 if isinstance(s, _re_type):
372 374 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
373 375 else:
374 376 assert s in printed, notprinted_msg.format(s, self.channel, printed)
375 377 return False
376 378
377 379 printed_msg = """Found {0!r} in printed output (on {1}):
378 380 -------
379 381 {2!s}
380 382 -------
381 383 """
382 384
383 385 class AssertNotPrints(AssertPrints):
384 386 """Context manager for checking that certain output *isn't* produced.
385 387
386 388 Counterpart of AssertPrints"""
387 389 def __exit__(self, etype, value, traceback):
388 390 if value is not None:
389 391 # If an error was raised, don't check anything else
390 392 return False
391 393 self.tee.flush()
392 394 setattr(sys, self.channel, self.orig_stream)
393 395 printed = self.buffer.getvalue()
394 396 for s in self.s:
395 397 if isinstance(s, _re_type):
396 398 assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed)
397 399 else:
398 400 assert s not in printed, printed_msg.format(s, self.channel, printed)
399 401 return False
400 402
401 403 @contextmanager
402 404 def mute_warn():
403 405 from IPython.utils import warn
404 406 save_warn = warn.warn
405 407 warn.warn = lambda *a, **kw: None
406 408 try:
407 409 yield
408 410 finally:
409 411 warn.warn = save_warn
410 412
411 413 @contextmanager
412 414 def make_tempfile(name):
413 415 """ Create an empty, named, temporary file for the duration of the context.
414 416 """
415 417 f = open(name, 'w')
416 418 f.close()
417 419 try:
418 420 yield
419 421 finally:
420 422 os.unlink(name)
421 423
422 424
423 425 @contextmanager
424 426 def monkeypatch(obj, name, attr):
425 427 """
426 428 Context manager to replace attribute named `name` in `obj` with `attr`.
427 429 """
428 430 orig = getattr(obj, name)
429 431 setattr(obj, name, attr)
430 432 yield
431 433 setattr(obj, name, orig)
432 434
433 435
434 436 def help_output_test(subcommand=''):
435 437 """test that `ipython [subcommand] -h` works"""
436 438 cmd = get_ipython_cmd() + [subcommand, '-h']
437 439 out, err, rc = get_output_error_code(cmd)
438 440 nt.assert_equal(rc, 0, err)
439 441 nt.assert_not_in("Traceback", err)
440 442 nt.assert_in("Options", out)
441 443 nt.assert_in("--help-all", out)
442 444 return out, err
443 445
444 446
445 447 def help_all_output_test(subcommand=''):
446 448 """test that `ipython [subcommand] --help-all` works"""
447 449 cmd = get_ipython_cmd() + [subcommand, '--help-all']
448 450 out, err, rc = get_output_error_code(cmd)
449 451 nt.assert_equal(rc, 0, err)
450 452 nt.assert_not_in("Traceback", err)
451 453 nt.assert_in("Options", out)
452 454 nt.assert_in("Class parameters", out)
453 455 return out, err
454 456
General Comments 0
You need to be logged in to leave comments. Login now