##// END OF EJS Templates
Remove unused imports in IPython.utils
Thomas Kluyver -
Show More
@@ -1,188 +1,187 b''
1 """Windows-specific implementation of process utilities.
1 """Windows-specific implementation of process utilities.
2
2
3 This file is only meant to be imported by process.py, not by end-users.
3 This file is only meant to be imported by process.py, not by end-users.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
7 # Copyright (C) 2010-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 # stdlib
18 # stdlib
19 import os
19 import os
20 import sys
20 import sys
21 import ctypes
21 import ctypes
22 import msvcrt
23
22
24 from ctypes import c_int, POINTER
23 from ctypes import c_int, POINTER
25 from ctypes.wintypes import LPCWSTR, HLOCAL
24 from ctypes.wintypes import LPCWSTR, HLOCAL
26 from subprocess import STDOUT
25 from subprocess import STDOUT
27
26
28 # our own imports
27 # our own imports
29 from ._process_common import read_no_interrupt, process_handler, arg_split as py_arg_split
28 from ._process_common import read_no_interrupt, process_handler, arg_split as py_arg_split
30 from . import py3compat
29 from . import py3compat
31 from .encoding import DEFAULT_ENCODING
30 from .encoding import DEFAULT_ENCODING
32
31
33 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
34 # Function definitions
33 # Function definitions
35 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
36
35
37 class AvoidUNCPath(object):
36 class AvoidUNCPath(object):
38 """A context manager to protect command execution from UNC paths.
37 """A context manager to protect command execution from UNC paths.
39
38
40 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
39 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
41 This context manager temporarily changes directory to the 'C:' drive on
40 This context manager temporarily changes directory to the 'C:' drive on
42 entering, and restores the original working directory on exit.
41 entering, and restores the original working directory on exit.
43
42
44 The context manager returns the starting working directory *if* it made a
43 The context manager returns the starting working directory *if* it made a
45 change and None otherwise, so that users can apply the necessary adjustment
44 change and None otherwise, so that users can apply the necessary adjustment
46 to their system calls in the event of a change.
45 to their system calls in the event of a change.
47
46
48 Example
47 Example
49 -------
48 -------
50 ::
49 ::
51 cmd = 'dir'
50 cmd = 'dir'
52 with AvoidUNCPath() as path:
51 with AvoidUNCPath() as path:
53 if path is not None:
52 if path is not None:
54 cmd = '"pushd %s &&"%s' % (path, cmd)
53 cmd = '"pushd %s &&"%s' % (path, cmd)
55 os.system(cmd)
54 os.system(cmd)
56 """
55 """
57 def __enter__(self):
56 def __enter__(self):
58 self.path = os.getcwdu()
57 self.path = os.getcwdu()
59 self.is_unc_path = self.path.startswith(r"\\")
58 self.is_unc_path = self.path.startswith(r"\\")
60 if self.is_unc_path:
59 if self.is_unc_path:
61 # change to c drive (as cmd.exe cannot handle UNC addresses)
60 # change to c drive (as cmd.exe cannot handle UNC addresses)
62 os.chdir("C:")
61 os.chdir("C:")
63 return self.path
62 return self.path
64 else:
63 else:
65 # We return None to signal that there was no change in the working
64 # We return None to signal that there was no change in the working
66 # directory
65 # directory
67 return None
66 return None
68
67
69 def __exit__(self, exc_type, exc_value, traceback):
68 def __exit__(self, exc_type, exc_value, traceback):
70 if self.is_unc_path:
69 if self.is_unc_path:
71 os.chdir(self.path)
70 os.chdir(self.path)
72
71
73
72
74 def _find_cmd(cmd):
73 def _find_cmd(cmd):
75 """Find the full path to a .bat or .exe using the win32api module."""
74 """Find the full path to a .bat or .exe using the win32api module."""
76 try:
75 try:
77 from win32api import SearchPath
76 from win32api import SearchPath
78 except ImportError:
77 except ImportError:
79 raise ImportError('you need to have pywin32 installed for this to work')
78 raise ImportError('you need to have pywin32 installed for this to work')
80 else:
79 else:
81 PATH = os.environ['PATH']
80 PATH = os.environ['PATH']
82 extensions = ['.exe', '.com', '.bat', '.py']
81 extensions = ['.exe', '.com', '.bat', '.py']
83 path = None
82 path = None
84 for ext in extensions:
83 for ext in extensions:
85 try:
84 try:
86 path = SearchPath(PATH, cmd + ext)[0]
85 path = SearchPath(PATH, cmd + ext)[0]
87 except:
86 except:
88 pass
87 pass
89 if path is None:
88 if path is None:
90 raise OSError("command %r not found" % cmd)
89 raise OSError("command %r not found" % cmd)
91 else:
90 else:
92 return path
91 return path
93
92
94
93
95 def _system_body(p):
94 def _system_body(p):
96 """Callback for _system."""
95 """Callback for _system."""
97 enc = DEFAULT_ENCODING
96 enc = DEFAULT_ENCODING
98 for line in read_no_interrupt(p.stdout).splitlines():
97 for line in read_no_interrupt(p.stdout).splitlines():
99 line = line.decode(enc, 'replace')
98 line = line.decode(enc, 'replace')
100 print(line, file=sys.stdout)
99 print(line, file=sys.stdout)
101 for line in read_no_interrupt(p.stderr).splitlines():
100 for line in read_no_interrupt(p.stderr).splitlines():
102 line = line.decode(enc, 'replace')
101 line = line.decode(enc, 'replace')
103 print(line, file=sys.stderr)
102 print(line, file=sys.stderr)
104
103
105 # Wait to finish for returncode
104 # Wait to finish for returncode
106 return p.wait()
105 return p.wait()
107
106
108
107
109 def system(cmd):
108 def system(cmd):
110 """Win32 version of os.system() that works with network shares.
109 """Win32 version of os.system() that works with network shares.
111
110
112 Note that this implementation returns None, as meant for use in IPython.
111 Note that this implementation returns None, as meant for use in IPython.
113
112
114 Parameters
113 Parameters
115 ----------
114 ----------
116 cmd : str
115 cmd : str
117 A command to be executed in the system shell.
116 A command to be executed in the system shell.
118
117
119 Returns
118 Returns
120 -------
119 -------
121 None : we explicitly do NOT return the subprocess status code, as this
120 None : we explicitly do NOT return the subprocess status code, as this
122 utility is meant to be used extensively in IPython, where any return value
121 utility is meant to be used extensively in IPython, where any return value
123 would trigger :func:`sys.displayhook` calls.
122 would trigger :func:`sys.displayhook` calls.
124 """
123 """
125 # The controller provides interactivity with both
124 # The controller provides interactivity with both
126 # stdin and stdout
125 # stdin and stdout
127 #import _process_win32_controller
126 #import _process_win32_controller
128 #_process_win32_controller.system(cmd)
127 #_process_win32_controller.system(cmd)
129
128
130 with AvoidUNCPath() as path:
129 with AvoidUNCPath() as path:
131 if path is not None:
130 if path is not None:
132 cmd = '"pushd %s &&"%s' % (path, cmd)
131 cmd = '"pushd %s &&"%s' % (path, cmd)
133 return process_handler(cmd, _system_body)
132 return process_handler(cmd, _system_body)
134
133
135 def getoutput(cmd):
134 def getoutput(cmd):
136 """Return standard output of executing cmd in a shell.
135 """Return standard output of executing cmd in a shell.
137
136
138 Accepts the same arguments as os.system().
137 Accepts the same arguments as os.system().
139
138
140 Parameters
139 Parameters
141 ----------
140 ----------
142 cmd : str
141 cmd : str
143 A command to be executed in the system shell.
142 A command to be executed in the system shell.
144
143
145 Returns
144 Returns
146 -------
145 -------
147 stdout : str
146 stdout : str
148 """
147 """
149
148
150 with AvoidUNCPath() as path:
149 with AvoidUNCPath() as path:
151 if path is not None:
150 if path is not None:
152 cmd = '"pushd %s &&"%s' % (path, cmd)
151 cmd = '"pushd %s &&"%s' % (path, cmd)
153 out = process_handler(cmd, lambda p: p.communicate()[0], STDOUT)
152 out = process_handler(cmd, lambda p: p.communicate()[0], STDOUT)
154
153
155 if out is None:
154 if out is None:
156 out = b''
155 out = b''
157 return py3compat.bytes_to_str(out)
156 return py3compat.bytes_to_str(out)
158
157
159 try:
158 try:
160 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
159 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
161 CommandLineToArgvW.arg_types = [LPCWSTR, POINTER(c_int)]
160 CommandLineToArgvW.arg_types = [LPCWSTR, POINTER(c_int)]
162 CommandLineToArgvW.restype = POINTER(LPCWSTR)
161 CommandLineToArgvW.restype = POINTER(LPCWSTR)
163 LocalFree = ctypes.windll.kernel32.LocalFree
162 LocalFree = ctypes.windll.kernel32.LocalFree
164 LocalFree.res_type = HLOCAL
163 LocalFree.res_type = HLOCAL
165 LocalFree.arg_types = [HLOCAL]
164 LocalFree.arg_types = [HLOCAL]
166
165
167 def arg_split(commandline, posix=False, strict=True):
166 def arg_split(commandline, posix=False, strict=True):
168 """Split a command line's arguments in a shell-like manner.
167 """Split a command line's arguments in a shell-like manner.
169
168
170 This is a special version for windows that use a ctypes call to CommandLineToArgvW
169 This is a special version for windows that use a ctypes call to CommandLineToArgvW
171 to do the argv splitting. The posix paramter is ignored.
170 to do the argv splitting. The posix paramter is ignored.
172
171
173 If strict=False, process_common.arg_split(...strict=False) is used instead.
172 If strict=False, process_common.arg_split(...strict=False) is used instead.
174 """
173 """
175 #CommandLineToArgvW returns path to executable if called with empty string.
174 #CommandLineToArgvW returns path to executable if called with empty string.
176 if commandline.strip() == "":
175 if commandline.strip() == "":
177 return []
176 return []
178 if not strict:
177 if not strict:
179 # not really a cl-arg, fallback on _process_common
178 # not really a cl-arg, fallback on _process_common
180 return py_arg_split(commandline, posix=posix, strict=strict)
179 return py_arg_split(commandline, posix=posix, strict=strict)
181 argvn = c_int()
180 argvn = c_int()
182 result_pointer = CommandLineToArgvW(py3compat.cast_unicode(commandline.lstrip()), ctypes.byref(argvn))
181 result_pointer = CommandLineToArgvW(py3compat.cast_unicode(commandline.lstrip()), ctypes.byref(argvn))
183 result_array_type = LPCWSTR * argvn.value
182 result_array_type = LPCWSTR * argvn.value
184 result = [arg for arg in result_array_type.from_address(ctypes.addressof(result_pointer.contents))]
183 result = [arg for arg in result_array_type.from_address(ctypes.addressof(result_pointer.contents))]
185 retval = LocalFree(result_pointer)
184 retval = LocalFree(result_pointer)
186 return result
185 return result
187 except AttributeError:
186 except AttributeError:
188 arg_split = py_arg_split
187 arg_split = py_arg_split
@@ -1,574 +1,574 b''
1 """Windows-specific implementation of process utilities with direct WinAPI.
1 """Windows-specific implementation of process utilities with direct WinAPI.
2
2
3 This file is meant to be used by process.py
3 This file is meant to be used by process.py
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
7 # Copyright (C) 2010-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 from __future__ import print_function
13 from __future__ import print_function
14
14
15 # stdlib
15 # stdlib
16 import os, sys, time, threading
16 import os, sys, threading
17 import ctypes, msvcrt
17 import ctypes, msvcrt
18
18
19 # Win32 API types needed for the API calls
19 # Win32 API types needed for the API calls
20 from ctypes import POINTER
20 from ctypes import POINTER
21 from ctypes.wintypes import HANDLE, HLOCAL, LPVOID, WORD, DWORD, BOOL, \
21 from ctypes.wintypes import HANDLE, HLOCAL, LPVOID, WORD, DWORD, BOOL, \
22 ULONG, LPCWSTR
22 ULONG, LPCWSTR
23 LPDWORD = POINTER(DWORD)
23 LPDWORD = POINTER(DWORD)
24 LPHANDLE = POINTER(HANDLE)
24 LPHANDLE = POINTER(HANDLE)
25 ULONG_PTR = POINTER(ULONG)
25 ULONG_PTR = POINTER(ULONG)
26 class SECURITY_ATTRIBUTES(ctypes.Structure):
26 class SECURITY_ATTRIBUTES(ctypes.Structure):
27 _fields_ = [("nLength", DWORD),
27 _fields_ = [("nLength", DWORD),
28 ("lpSecurityDescriptor", LPVOID),
28 ("lpSecurityDescriptor", LPVOID),
29 ("bInheritHandle", BOOL)]
29 ("bInheritHandle", BOOL)]
30 LPSECURITY_ATTRIBUTES = POINTER(SECURITY_ATTRIBUTES)
30 LPSECURITY_ATTRIBUTES = POINTER(SECURITY_ATTRIBUTES)
31 class STARTUPINFO(ctypes.Structure):
31 class STARTUPINFO(ctypes.Structure):
32 _fields_ = [("cb", DWORD),
32 _fields_ = [("cb", DWORD),
33 ("lpReserved", LPCWSTR),
33 ("lpReserved", LPCWSTR),
34 ("lpDesktop", LPCWSTR),
34 ("lpDesktop", LPCWSTR),
35 ("lpTitle", LPCWSTR),
35 ("lpTitle", LPCWSTR),
36 ("dwX", DWORD),
36 ("dwX", DWORD),
37 ("dwY", DWORD),
37 ("dwY", DWORD),
38 ("dwXSize", DWORD),
38 ("dwXSize", DWORD),
39 ("dwYSize", DWORD),
39 ("dwYSize", DWORD),
40 ("dwXCountChars", DWORD),
40 ("dwXCountChars", DWORD),
41 ("dwYCountChars", DWORD),
41 ("dwYCountChars", DWORD),
42 ("dwFillAttribute", DWORD),
42 ("dwFillAttribute", DWORD),
43 ("dwFlags", DWORD),
43 ("dwFlags", DWORD),
44 ("wShowWindow", WORD),
44 ("wShowWindow", WORD),
45 ("cbReserved2", WORD),
45 ("cbReserved2", WORD),
46 ("lpReserved2", LPVOID),
46 ("lpReserved2", LPVOID),
47 ("hStdInput", HANDLE),
47 ("hStdInput", HANDLE),
48 ("hStdOutput", HANDLE),
48 ("hStdOutput", HANDLE),
49 ("hStdError", HANDLE)]
49 ("hStdError", HANDLE)]
50 LPSTARTUPINFO = POINTER(STARTUPINFO)
50 LPSTARTUPINFO = POINTER(STARTUPINFO)
51 class PROCESS_INFORMATION(ctypes.Structure):
51 class PROCESS_INFORMATION(ctypes.Structure):
52 _fields_ = [("hProcess", HANDLE),
52 _fields_ = [("hProcess", HANDLE),
53 ("hThread", HANDLE),
53 ("hThread", HANDLE),
54 ("dwProcessId", DWORD),
54 ("dwProcessId", DWORD),
55 ("dwThreadId", DWORD)]
55 ("dwThreadId", DWORD)]
56 LPPROCESS_INFORMATION = POINTER(PROCESS_INFORMATION)
56 LPPROCESS_INFORMATION = POINTER(PROCESS_INFORMATION)
57
57
58 # Win32 API constants needed
58 # Win32 API constants needed
59 ERROR_HANDLE_EOF = 38
59 ERROR_HANDLE_EOF = 38
60 ERROR_BROKEN_PIPE = 109
60 ERROR_BROKEN_PIPE = 109
61 ERROR_NO_DATA = 232
61 ERROR_NO_DATA = 232
62 HANDLE_FLAG_INHERIT = 0x0001
62 HANDLE_FLAG_INHERIT = 0x0001
63 STARTF_USESTDHANDLES = 0x0100
63 STARTF_USESTDHANDLES = 0x0100
64 CREATE_SUSPENDED = 0x0004
64 CREATE_SUSPENDED = 0x0004
65 CREATE_NEW_CONSOLE = 0x0010
65 CREATE_NEW_CONSOLE = 0x0010
66 CREATE_NO_WINDOW = 0x08000000
66 CREATE_NO_WINDOW = 0x08000000
67 STILL_ACTIVE = 259
67 STILL_ACTIVE = 259
68 WAIT_TIMEOUT = 0x0102
68 WAIT_TIMEOUT = 0x0102
69 WAIT_FAILED = 0xFFFFFFFF
69 WAIT_FAILED = 0xFFFFFFFF
70 INFINITE = 0xFFFFFFFF
70 INFINITE = 0xFFFFFFFF
71 DUPLICATE_SAME_ACCESS = 0x00000002
71 DUPLICATE_SAME_ACCESS = 0x00000002
72 ENABLE_ECHO_INPUT = 0x0004
72 ENABLE_ECHO_INPUT = 0x0004
73 ENABLE_LINE_INPUT = 0x0002
73 ENABLE_LINE_INPUT = 0x0002
74 ENABLE_PROCESSED_INPUT = 0x0001
74 ENABLE_PROCESSED_INPUT = 0x0001
75
75
76 # Win32 API functions needed
76 # Win32 API functions needed
77 GetLastError = ctypes.windll.kernel32.GetLastError
77 GetLastError = ctypes.windll.kernel32.GetLastError
78 GetLastError.argtypes = []
78 GetLastError.argtypes = []
79 GetLastError.restype = DWORD
79 GetLastError.restype = DWORD
80
80
81 CreateFile = ctypes.windll.kernel32.CreateFileW
81 CreateFile = ctypes.windll.kernel32.CreateFileW
82 CreateFile.argtypes = [LPCWSTR, DWORD, DWORD, LPVOID, DWORD, DWORD, HANDLE]
82 CreateFile.argtypes = [LPCWSTR, DWORD, DWORD, LPVOID, DWORD, DWORD, HANDLE]
83 CreateFile.restype = HANDLE
83 CreateFile.restype = HANDLE
84
84
85 CreatePipe = ctypes.windll.kernel32.CreatePipe
85 CreatePipe = ctypes.windll.kernel32.CreatePipe
86 CreatePipe.argtypes = [POINTER(HANDLE), POINTER(HANDLE),
86 CreatePipe.argtypes = [POINTER(HANDLE), POINTER(HANDLE),
87 LPSECURITY_ATTRIBUTES, DWORD]
87 LPSECURITY_ATTRIBUTES, DWORD]
88 CreatePipe.restype = BOOL
88 CreatePipe.restype = BOOL
89
89
90 CreateProcess = ctypes.windll.kernel32.CreateProcessW
90 CreateProcess = ctypes.windll.kernel32.CreateProcessW
91 CreateProcess.argtypes = [LPCWSTR, LPCWSTR, LPSECURITY_ATTRIBUTES,
91 CreateProcess.argtypes = [LPCWSTR, LPCWSTR, LPSECURITY_ATTRIBUTES,
92 LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID, LPCWSTR, LPSTARTUPINFO,
92 LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID, LPCWSTR, LPSTARTUPINFO,
93 LPPROCESS_INFORMATION]
93 LPPROCESS_INFORMATION]
94 CreateProcess.restype = BOOL
94 CreateProcess.restype = BOOL
95
95
96 GetExitCodeProcess = ctypes.windll.kernel32.GetExitCodeProcess
96 GetExitCodeProcess = ctypes.windll.kernel32.GetExitCodeProcess
97 GetExitCodeProcess.argtypes = [HANDLE, LPDWORD]
97 GetExitCodeProcess.argtypes = [HANDLE, LPDWORD]
98 GetExitCodeProcess.restype = BOOL
98 GetExitCodeProcess.restype = BOOL
99
99
100 GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
100 GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
101 GetCurrentProcess.argtypes = []
101 GetCurrentProcess.argtypes = []
102 GetCurrentProcess.restype = HANDLE
102 GetCurrentProcess.restype = HANDLE
103
103
104 ResumeThread = ctypes.windll.kernel32.ResumeThread
104 ResumeThread = ctypes.windll.kernel32.ResumeThread
105 ResumeThread.argtypes = [HANDLE]
105 ResumeThread.argtypes = [HANDLE]
106 ResumeThread.restype = DWORD
106 ResumeThread.restype = DWORD
107
107
108 ReadFile = ctypes.windll.kernel32.ReadFile
108 ReadFile = ctypes.windll.kernel32.ReadFile
109 ReadFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
109 ReadFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
110 ReadFile.restype = BOOL
110 ReadFile.restype = BOOL
111
111
112 WriteFile = ctypes.windll.kernel32.WriteFile
112 WriteFile = ctypes.windll.kernel32.WriteFile
113 WriteFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
113 WriteFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
114 WriteFile.restype = BOOL
114 WriteFile.restype = BOOL
115
115
116 GetConsoleMode = ctypes.windll.kernel32.GetConsoleMode
116 GetConsoleMode = ctypes.windll.kernel32.GetConsoleMode
117 GetConsoleMode.argtypes = [HANDLE, LPDWORD]
117 GetConsoleMode.argtypes = [HANDLE, LPDWORD]
118 GetConsoleMode.restype = BOOL
118 GetConsoleMode.restype = BOOL
119
119
120 SetConsoleMode = ctypes.windll.kernel32.SetConsoleMode
120 SetConsoleMode = ctypes.windll.kernel32.SetConsoleMode
121 SetConsoleMode.argtypes = [HANDLE, DWORD]
121 SetConsoleMode.argtypes = [HANDLE, DWORD]
122 SetConsoleMode.restype = BOOL
122 SetConsoleMode.restype = BOOL
123
123
124 FlushConsoleInputBuffer = ctypes.windll.kernel32.FlushConsoleInputBuffer
124 FlushConsoleInputBuffer = ctypes.windll.kernel32.FlushConsoleInputBuffer
125 FlushConsoleInputBuffer.argtypes = [HANDLE]
125 FlushConsoleInputBuffer.argtypes = [HANDLE]
126 FlushConsoleInputBuffer.restype = BOOL
126 FlushConsoleInputBuffer.restype = BOOL
127
127
128 WaitForSingleObject = ctypes.windll.kernel32.WaitForSingleObject
128 WaitForSingleObject = ctypes.windll.kernel32.WaitForSingleObject
129 WaitForSingleObject.argtypes = [HANDLE, DWORD]
129 WaitForSingleObject.argtypes = [HANDLE, DWORD]
130 WaitForSingleObject.restype = DWORD
130 WaitForSingleObject.restype = DWORD
131
131
132 DuplicateHandle = ctypes.windll.kernel32.DuplicateHandle
132 DuplicateHandle = ctypes.windll.kernel32.DuplicateHandle
133 DuplicateHandle.argtypes = [HANDLE, HANDLE, HANDLE, LPHANDLE,
133 DuplicateHandle.argtypes = [HANDLE, HANDLE, HANDLE, LPHANDLE,
134 DWORD, BOOL, DWORD]
134 DWORD, BOOL, DWORD]
135 DuplicateHandle.restype = BOOL
135 DuplicateHandle.restype = BOOL
136
136
137 SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
137 SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
138 SetHandleInformation.argtypes = [HANDLE, DWORD, DWORD]
138 SetHandleInformation.argtypes = [HANDLE, DWORD, DWORD]
139 SetHandleInformation.restype = BOOL
139 SetHandleInformation.restype = BOOL
140
140
141 CloseHandle = ctypes.windll.kernel32.CloseHandle
141 CloseHandle = ctypes.windll.kernel32.CloseHandle
142 CloseHandle.argtypes = [HANDLE]
142 CloseHandle.argtypes = [HANDLE]
143 CloseHandle.restype = BOOL
143 CloseHandle.restype = BOOL
144
144
145 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
145 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
146 CommandLineToArgvW.argtypes = [LPCWSTR, POINTER(ctypes.c_int)]
146 CommandLineToArgvW.argtypes = [LPCWSTR, POINTER(ctypes.c_int)]
147 CommandLineToArgvW.restype = POINTER(LPCWSTR)
147 CommandLineToArgvW.restype = POINTER(LPCWSTR)
148
148
149 LocalFree = ctypes.windll.kernel32.LocalFree
149 LocalFree = ctypes.windll.kernel32.LocalFree
150 LocalFree.argtypes = [HLOCAL]
150 LocalFree.argtypes = [HLOCAL]
151 LocalFree.restype = HLOCAL
151 LocalFree.restype = HLOCAL
152
152
153 class AvoidUNCPath(object):
153 class AvoidUNCPath(object):
154 """A context manager to protect command execution from UNC paths.
154 """A context manager to protect command execution from UNC paths.
155
155
156 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
156 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
157 This context manager temporarily changes directory to the 'C:' drive on
157 This context manager temporarily changes directory to the 'C:' drive on
158 entering, and restores the original working directory on exit.
158 entering, and restores the original working directory on exit.
159
159
160 The context manager returns the starting working directory *if* it made a
160 The context manager returns the starting working directory *if* it made a
161 change and None otherwise, so that users can apply the necessary adjustment
161 change and None otherwise, so that users can apply the necessary adjustment
162 to their system calls in the event of a change.
162 to their system calls in the event of a change.
163
163
164 Example
164 Example
165 -------
165 -------
166 ::
166 ::
167 cmd = 'dir'
167 cmd = 'dir'
168 with AvoidUNCPath() as path:
168 with AvoidUNCPath() as path:
169 if path is not None:
169 if path is not None:
170 cmd = '"pushd %s &&"%s' % (path, cmd)
170 cmd = '"pushd %s &&"%s' % (path, cmd)
171 os.system(cmd)
171 os.system(cmd)
172 """
172 """
173 def __enter__(self):
173 def __enter__(self):
174 self.path = os.getcwdu()
174 self.path = os.getcwdu()
175 self.is_unc_path = self.path.startswith(r"\\")
175 self.is_unc_path = self.path.startswith(r"\\")
176 if self.is_unc_path:
176 if self.is_unc_path:
177 # change to c drive (as cmd.exe cannot handle UNC addresses)
177 # change to c drive (as cmd.exe cannot handle UNC addresses)
178 os.chdir("C:")
178 os.chdir("C:")
179 return self.path
179 return self.path
180 else:
180 else:
181 # We return None to signal that there was no change in the working
181 # We return None to signal that there was no change in the working
182 # directory
182 # directory
183 return None
183 return None
184
184
185 def __exit__(self, exc_type, exc_value, traceback):
185 def __exit__(self, exc_type, exc_value, traceback):
186 if self.is_unc_path:
186 if self.is_unc_path:
187 os.chdir(self.path)
187 os.chdir(self.path)
188
188
189
189
190 class Win32ShellCommandController(object):
190 class Win32ShellCommandController(object):
191 """Runs a shell command in a 'with' context.
191 """Runs a shell command in a 'with' context.
192
192
193 This implementation is Win32-specific.
193 This implementation is Win32-specific.
194
194
195 Example:
195 Example:
196 # Runs the command interactively with default console stdin/stdout
196 # Runs the command interactively with default console stdin/stdout
197 with ShellCommandController('python -i') as scc:
197 with ShellCommandController('python -i') as scc:
198 scc.run()
198 scc.run()
199
199
200 # Runs the command using the provided functions for stdin/stdout
200 # Runs the command using the provided functions for stdin/stdout
201 def my_stdout_func(s):
201 def my_stdout_func(s):
202 # print or save the string 's'
202 # print or save the string 's'
203 write_to_stdout(s)
203 write_to_stdout(s)
204 def my_stdin_func():
204 def my_stdin_func():
205 # If input is available, return it as a string.
205 # If input is available, return it as a string.
206 if input_available():
206 if input_available():
207 return get_input()
207 return get_input()
208 # If no input available, return None after a short delay to
208 # If no input available, return None after a short delay to
209 # keep from blocking.
209 # keep from blocking.
210 else:
210 else:
211 time.sleep(0.01)
211 time.sleep(0.01)
212 return None
212 return None
213
213
214 with ShellCommandController('python -i') as scc:
214 with ShellCommandController('python -i') as scc:
215 scc.run(my_stdout_func, my_stdin_func)
215 scc.run(my_stdout_func, my_stdin_func)
216 """
216 """
217
217
218 def __init__(self, cmd, mergeout = True):
218 def __init__(self, cmd, mergeout = True):
219 """Initializes the shell command controller.
219 """Initializes the shell command controller.
220
220
221 The cmd is the program to execute, and mergeout is
221 The cmd is the program to execute, and mergeout is
222 whether to blend stdout and stderr into one output
222 whether to blend stdout and stderr into one output
223 in stdout. Merging them together in this fashion more
223 in stdout. Merging them together in this fashion more
224 reliably keeps stdout and stderr in the correct order
224 reliably keeps stdout and stderr in the correct order
225 especially for interactive shell usage.
225 especially for interactive shell usage.
226 """
226 """
227 self.cmd = cmd
227 self.cmd = cmd
228 self.mergeout = mergeout
228 self.mergeout = mergeout
229
229
230 def __enter__(self):
230 def __enter__(self):
231 cmd = self.cmd
231 cmd = self.cmd
232 mergeout = self.mergeout
232 mergeout = self.mergeout
233
233
234 self.hstdout, self.hstdin, self.hstderr = None, None, None
234 self.hstdout, self.hstdin, self.hstderr = None, None, None
235 self.piProcInfo = None
235 self.piProcInfo = None
236 try:
236 try:
237 p_hstdout, c_hstdout, p_hstderr, \
237 p_hstdout, c_hstdout, p_hstderr, \
238 c_hstderr, p_hstdin, c_hstdin = [None]*6
238 c_hstderr, p_hstdin, c_hstdin = [None]*6
239
239
240 # SECURITY_ATTRIBUTES with inherit handle set to True
240 # SECURITY_ATTRIBUTES with inherit handle set to True
241 saAttr = SECURITY_ATTRIBUTES()
241 saAttr = SECURITY_ATTRIBUTES()
242 saAttr.nLength = ctypes.sizeof(saAttr)
242 saAttr.nLength = ctypes.sizeof(saAttr)
243 saAttr.bInheritHandle = True
243 saAttr.bInheritHandle = True
244 saAttr.lpSecurityDescriptor = None
244 saAttr.lpSecurityDescriptor = None
245
245
246 def create_pipe(uninherit):
246 def create_pipe(uninherit):
247 """Creates a Windows pipe, which consists of two handles.
247 """Creates a Windows pipe, which consists of two handles.
248
248
249 The 'uninherit' parameter controls which handle is not
249 The 'uninherit' parameter controls which handle is not
250 inherited by the child process.
250 inherited by the child process.
251 """
251 """
252 handles = HANDLE(), HANDLE()
252 handles = HANDLE(), HANDLE()
253 if not CreatePipe(ctypes.byref(handles[0]),
253 if not CreatePipe(ctypes.byref(handles[0]),
254 ctypes.byref(handles[1]), ctypes.byref(saAttr), 0):
254 ctypes.byref(handles[1]), ctypes.byref(saAttr), 0):
255 raise ctypes.WinError()
255 raise ctypes.WinError()
256 if not SetHandleInformation(handles[uninherit],
256 if not SetHandleInformation(handles[uninherit],
257 HANDLE_FLAG_INHERIT, 0):
257 HANDLE_FLAG_INHERIT, 0):
258 raise ctypes.WinError()
258 raise ctypes.WinError()
259 return handles[0].value, handles[1].value
259 return handles[0].value, handles[1].value
260
260
261 p_hstdout, c_hstdout = create_pipe(uninherit=0)
261 p_hstdout, c_hstdout = create_pipe(uninherit=0)
262 # 'mergeout' signals that stdout and stderr should be merged.
262 # 'mergeout' signals that stdout and stderr should be merged.
263 # We do that by using one pipe for both of them.
263 # We do that by using one pipe for both of them.
264 if mergeout:
264 if mergeout:
265 c_hstderr = HANDLE()
265 c_hstderr = HANDLE()
266 if not DuplicateHandle(GetCurrentProcess(), c_hstdout,
266 if not DuplicateHandle(GetCurrentProcess(), c_hstdout,
267 GetCurrentProcess(), ctypes.byref(c_hstderr),
267 GetCurrentProcess(), ctypes.byref(c_hstderr),
268 0, True, DUPLICATE_SAME_ACCESS):
268 0, True, DUPLICATE_SAME_ACCESS):
269 raise ctypes.WinError()
269 raise ctypes.WinError()
270 else:
270 else:
271 p_hstderr, c_hstderr = create_pipe(uninherit=0)
271 p_hstderr, c_hstderr = create_pipe(uninherit=0)
272 c_hstdin, p_hstdin = create_pipe(uninherit=1)
272 c_hstdin, p_hstdin = create_pipe(uninherit=1)
273
273
274 # Create the process object
274 # Create the process object
275 piProcInfo = PROCESS_INFORMATION()
275 piProcInfo = PROCESS_INFORMATION()
276 siStartInfo = STARTUPINFO()
276 siStartInfo = STARTUPINFO()
277 siStartInfo.cb = ctypes.sizeof(siStartInfo)
277 siStartInfo.cb = ctypes.sizeof(siStartInfo)
278 siStartInfo.hStdInput = c_hstdin
278 siStartInfo.hStdInput = c_hstdin
279 siStartInfo.hStdOutput = c_hstdout
279 siStartInfo.hStdOutput = c_hstdout
280 siStartInfo.hStdError = c_hstderr
280 siStartInfo.hStdError = c_hstderr
281 siStartInfo.dwFlags = STARTF_USESTDHANDLES
281 siStartInfo.dwFlags = STARTF_USESTDHANDLES
282 dwCreationFlags = CREATE_SUSPENDED | CREATE_NO_WINDOW # | CREATE_NEW_CONSOLE
282 dwCreationFlags = CREATE_SUSPENDED | CREATE_NO_WINDOW # | CREATE_NEW_CONSOLE
283
283
284 if not CreateProcess(None,
284 if not CreateProcess(None,
285 u"cmd.exe /c " + cmd,
285 u"cmd.exe /c " + cmd,
286 None, None, True, dwCreationFlags,
286 None, None, True, dwCreationFlags,
287 None, None, ctypes.byref(siStartInfo),
287 None, None, ctypes.byref(siStartInfo),
288 ctypes.byref(piProcInfo)):
288 ctypes.byref(piProcInfo)):
289 raise ctypes.WinError()
289 raise ctypes.WinError()
290
290
291 # Close this process's versions of the child handles
291 # Close this process's versions of the child handles
292 CloseHandle(c_hstdin)
292 CloseHandle(c_hstdin)
293 c_hstdin = None
293 c_hstdin = None
294 CloseHandle(c_hstdout)
294 CloseHandle(c_hstdout)
295 c_hstdout = None
295 c_hstdout = None
296 if c_hstderr != None:
296 if c_hstderr != None:
297 CloseHandle(c_hstderr)
297 CloseHandle(c_hstderr)
298 c_hstderr = None
298 c_hstderr = None
299
299
300 # Transfer ownership of the parent handles to the object
300 # Transfer ownership of the parent handles to the object
301 self.hstdin = p_hstdin
301 self.hstdin = p_hstdin
302 p_hstdin = None
302 p_hstdin = None
303 self.hstdout = p_hstdout
303 self.hstdout = p_hstdout
304 p_hstdout = None
304 p_hstdout = None
305 if not mergeout:
305 if not mergeout:
306 self.hstderr = p_hstderr
306 self.hstderr = p_hstderr
307 p_hstderr = None
307 p_hstderr = None
308 self.piProcInfo = piProcInfo
308 self.piProcInfo = piProcInfo
309
309
310 finally:
310 finally:
311 if p_hstdin:
311 if p_hstdin:
312 CloseHandle(p_hstdin)
312 CloseHandle(p_hstdin)
313 if c_hstdin:
313 if c_hstdin:
314 CloseHandle(c_hstdin)
314 CloseHandle(c_hstdin)
315 if p_hstdout:
315 if p_hstdout:
316 CloseHandle(p_hstdout)
316 CloseHandle(p_hstdout)
317 if c_hstdout:
317 if c_hstdout:
318 CloseHandle(c_hstdout)
318 CloseHandle(c_hstdout)
319 if p_hstderr:
319 if p_hstderr:
320 CloseHandle(p_hstderr)
320 CloseHandle(p_hstderr)
321 if c_hstderr:
321 if c_hstderr:
322 CloseHandle(c_hstderr)
322 CloseHandle(c_hstderr)
323
323
324 return self
324 return self
325
325
326 def _stdin_thread(self, handle, hprocess, func, stdout_func):
326 def _stdin_thread(self, handle, hprocess, func, stdout_func):
327 exitCode = DWORD()
327 exitCode = DWORD()
328 bytesWritten = DWORD(0)
328 bytesWritten = DWORD(0)
329 while True:
329 while True:
330 #print("stdin thread loop start")
330 #print("stdin thread loop start")
331 # Get the input string (may be bytes or unicode)
331 # Get the input string (may be bytes or unicode)
332 data = func()
332 data = func()
333
333
334 # None signals to poll whether the process has exited
334 # None signals to poll whether the process has exited
335 if data is None:
335 if data is None:
336 #print("checking for process completion")
336 #print("checking for process completion")
337 if not GetExitCodeProcess(hprocess, ctypes.byref(exitCode)):
337 if not GetExitCodeProcess(hprocess, ctypes.byref(exitCode)):
338 raise ctypes.WinError()
338 raise ctypes.WinError()
339 if exitCode.value != STILL_ACTIVE:
339 if exitCode.value != STILL_ACTIVE:
340 return
340 return
341 # TESTING: Does zero-sized writefile help?
341 # TESTING: Does zero-sized writefile help?
342 if not WriteFile(handle, "", 0,
342 if not WriteFile(handle, "", 0,
343 ctypes.byref(bytesWritten), None):
343 ctypes.byref(bytesWritten), None):
344 raise ctypes.WinError()
344 raise ctypes.WinError()
345 continue
345 continue
346 #print("\nGot str %s\n" % repr(data), file=sys.stderr)
346 #print("\nGot str %s\n" % repr(data), file=sys.stderr)
347
347
348 # Encode the string to the console encoding
348 # Encode the string to the console encoding
349 if isinstance(data, unicode): #FIXME: Python3
349 if isinstance(data, unicode): #FIXME: Python3
350 data = data.encode('utf_8')
350 data = data.encode('utf_8')
351
351
352 # What we have now must be a string of bytes
352 # What we have now must be a string of bytes
353 if not isinstance(data, str): #FIXME: Python3
353 if not isinstance(data, str): #FIXME: Python3
354 raise RuntimeError("internal stdin function string error")
354 raise RuntimeError("internal stdin function string error")
355
355
356 # An empty string signals EOF
356 # An empty string signals EOF
357 if len(data) == 0:
357 if len(data) == 0:
358 return
358 return
359
359
360 # In a windows console, sometimes the input is echoed,
360 # In a windows console, sometimes the input is echoed,
361 # but sometimes not. How do we determine when to do this?
361 # but sometimes not. How do we determine when to do this?
362 stdout_func(data)
362 stdout_func(data)
363 # WriteFile may not accept all the data at once.
363 # WriteFile may not accept all the data at once.
364 # Loop until everything is processed
364 # Loop until everything is processed
365 while len(data) != 0:
365 while len(data) != 0:
366 #print("Calling writefile")
366 #print("Calling writefile")
367 if not WriteFile(handle, data, len(data),
367 if not WriteFile(handle, data, len(data),
368 ctypes.byref(bytesWritten), None):
368 ctypes.byref(bytesWritten), None):
369 # This occurs at exit
369 # This occurs at exit
370 if GetLastError() == ERROR_NO_DATA:
370 if GetLastError() == ERROR_NO_DATA:
371 return
371 return
372 raise ctypes.WinError()
372 raise ctypes.WinError()
373 #print("Called writefile")
373 #print("Called writefile")
374 data = data[bytesWritten.value:]
374 data = data[bytesWritten.value:]
375
375
376 def _stdout_thread(self, handle, func):
376 def _stdout_thread(self, handle, func):
377 # Allocate the output buffer
377 # Allocate the output buffer
378 data = ctypes.create_string_buffer(4096)
378 data = ctypes.create_string_buffer(4096)
379 while True:
379 while True:
380 bytesRead = DWORD(0)
380 bytesRead = DWORD(0)
381 if not ReadFile(handle, data, 4096,
381 if not ReadFile(handle, data, 4096,
382 ctypes.byref(bytesRead), None):
382 ctypes.byref(bytesRead), None):
383 le = GetLastError()
383 le = GetLastError()
384 if le == ERROR_BROKEN_PIPE:
384 if le == ERROR_BROKEN_PIPE:
385 return
385 return
386 else:
386 else:
387 raise ctypes.WinError()
387 raise ctypes.WinError()
388 # FIXME: Python3
388 # FIXME: Python3
389 s = data.value[0:bytesRead.value]
389 s = data.value[0:bytesRead.value]
390 #print("\nv: %s" % repr(s), file=sys.stderr)
390 #print("\nv: %s" % repr(s), file=sys.stderr)
391 func(s.decode('utf_8', 'replace'))
391 func(s.decode('utf_8', 'replace'))
392
392
393 def run(self, stdout_func = None, stdin_func = None, stderr_func = None):
393 def run(self, stdout_func = None, stdin_func = None, stderr_func = None):
394 """Runs the process, using the provided functions for I/O.
394 """Runs the process, using the provided functions for I/O.
395
395
396 The function stdin_func should return strings whenever a
396 The function stdin_func should return strings whenever a
397 character or characters become available.
397 character or characters become available.
398 The functions stdout_func and stderr_func are called whenever
398 The functions stdout_func and stderr_func are called whenever
399 something is printed to stdout or stderr, respectively.
399 something is printed to stdout or stderr, respectively.
400 These functions are called from different threads (but not
400 These functions are called from different threads (but not
401 concurrently, because of the GIL).
401 concurrently, because of the GIL).
402 """
402 """
403 if stdout_func == None and stdin_func == None and stderr_func == None:
403 if stdout_func == None and stdin_func == None and stderr_func == None:
404 return self._run_stdio()
404 return self._run_stdio()
405
405
406 if stderr_func != None and self.mergeout:
406 if stderr_func != None and self.mergeout:
407 raise RuntimeError("Shell command was initiated with "
407 raise RuntimeError("Shell command was initiated with "
408 "merged stdin/stdout, but a separate stderr_func "
408 "merged stdin/stdout, but a separate stderr_func "
409 "was provided to the run() method")
409 "was provided to the run() method")
410
410
411 # Create a thread for each input/output handle
411 # Create a thread for each input/output handle
412 stdin_thread = None
412 stdin_thread = None
413 threads = []
413 threads = []
414 if stdin_func:
414 if stdin_func:
415 stdin_thread = threading.Thread(target=self._stdin_thread,
415 stdin_thread = threading.Thread(target=self._stdin_thread,
416 args=(self.hstdin, self.piProcInfo.hProcess,
416 args=(self.hstdin, self.piProcInfo.hProcess,
417 stdin_func, stdout_func))
417 stdin_func, stdout_func))
418 threads.append(threading.Thread(target=self._stdout_thread,
418 threads.append(threading.Thread(target=self._stdout_thread,
419 args=(self.hstdout, stdout_func)))
419 args=(self.hstdout, stdout_func)))
420 if not self.mergeout:
420 if not self.mergeout:
421 if stderr_func == None:
421 if stderr_func == None:
422 stderr_func = stdout_func
422 stderr_func = stdout_func
423 threads.append(threading.Thread(target=self._stdout_thread,
423 threads.append(threading.Thread(target=self._stdout_thread,
424 args=(self.hstderr, stderr_func)))
424 args=(self.hstderr, stderr_func)))
425 # Start the I/O threads and the process
425 # Start the I/O threads and the process
426 if ResumeThread(self.piProcInfo.hThread) == 0xFFFFFFFF:
426 if ResumeThread(self.piProcInfo.hThread) == 0xFFFFFFFF:
427 raise ctypes.WinError()
427 raise ctypes.WinError()
428 if stdin_thread is not None:
428 if stdin_thread is not None:
429 stdin_thread.start()
429 stdin_thread.start()
430 for thread in threads:
430 for thread in threads:
431 thread.start()
431 thread.start()
432 # Wait for the process to complete
432 # Wait for the process to complete
433 if WaitForSingleObject(self.piProcInfo.hProcess, INFINITE) == \
433 if WaitForSingleObject(self.piProcInfo.hProcess, INFINITE) == \
434 WAIT_FAILED:
434 WAIT_FAILED:
435 raise ctypes.WinError()
435 raise ctypes.WinError()
436 # Wait for the I/O threads to complete
436 # Wait for the I/O threads to complete
437 for thread in threads:
437 for thread in threads:
438 thread.join()
438 thread.join()
439
439
440 # Wait for the stdin thread to complete
440 # Wait for the stdin thread to complete
441 if stdin_thread is not None:
441 if stdin_thread is not None:
442 stdin_thread.join()
442 stdin_thread.join()
443
443
444 def _stdin_raw_nonblock(self):
444 def _stdin_raw_nonblock(self):
445 """Use the raw Win32 handle of sys.stdin to do non-blocking reads"""
445 """Use the raw Win32 handle of sys.stdin to do non-blocking reads"""
446 # WARNING: This is experimental, and produces inconsistent results.
446 # WARNING: This is experimental, and produces inconsistent results.
447 # It's possible for the handle not to be appropriate for use
447 # It's possible for the handle not to be appropriate for use
448 # with WaitForSingleObject, among other things.
448 # with WaitForSingleObject, among other things.
449 handle = msvcrt.get_osfhandle(sys.stdin.fileno())
449 handle = msvcrt.get_osfhandle(sys.stdin.fileno())
450 result = WaitForSingleObject(handle, 100)
450 result = WaitForSingleObject(handle, 100)
451 if result == WAIT_FAILED:
451 if result == WAIT_FAILED:
452 raise ctypes.WinError()
452 raise ctypes.WinError()
453 elif result == WAIT_TIMEOUT:
453 elif result == WAIT_TIMEOUT:
454 print(".", end='')
454 print(".", end='')
455 return None
455 return None
456 else:
456 else:
457 data = ctypes.create_string_buffer(256)
457 data = ctypes.create_string_buffer(256)
458 bytesRead = DWORD(0)
458 bytesRead = DWORD(0)
459 print('?', end='')
459 print('?', end='')
460
460
461 if not ReadFile(handle, data, 256,
461 if not ReadFile(handle, data, 256,
462 ctypes.byref(bytesRead), None):
462 ctypes.byref(bytesRead), None):
463 raise ctypes.WinError()
463 raise ctypes.WinError()
464 # This ensures the non-blocking works with an actual console
464 # This ensures the non-blocking works with an actual console
465 # Not checking the error, so the processing will still work with
465 # Not checking the error, so the processing will still work with
466 # other handle types
466 # other handle types
467 FlushConsoleInputBuffer(handle)
467 FlushConsoleInputBuffer(handle)
468
468
469 data = data.value
469 data = data.value
470 data = data.replace('\r\n', '\n')
470 data = data.replace('\r\n', '\n')
471 data = data.replace('\r', '\n')
471 data = data.replace('\r', '\n')
472 print(repr(data) + " ", end='')
472 print(repr(data) + " ", end='')
473 return data
473 return data
474
474
475 def _stdin_raw_block(self):
475 def _stdin_raw_block(self):
476 """Use a blocking stdin read"""
476 """Use a blocking stdin read"""
477 # The big problem with the blocking read is that it doesn't
477 # The big problem with the blocking read is that it doesn't
478 # exit when it's supposed to in all contexts. An extra
478 # exit when it's supposed to in all contexts. An extra
479 # key-press may be required to trigger the exit.
479 # key-press may be required to trigger the exit.
480 try:
480 try:
481 data = sys.stdin.read(1)
481 data = sys.stdin.read(1)
482 data = data.replace('\r', '\n')
482 data = data.replace('\r', '\n')
483 return data
483 return data
484 except WindowsError as we:
484 except WindowsError as we:
485 if we.winerror == ERROR_NO_DATA:
485 if we.winerror == ERROR_NO_DATA:
486 # This error occurs when the pipe is closed
486 # This error occurs when the pipe is closed
487 return None
487 return None
488 else:
488 else:
489 # Otherwise let the error propagate
489 # Otherwise let the error propagate
490 raise we
490 raise we
491
491
492 def _stdout_raw(self, s):
492 def _stdout_raw(self, s):
493 """Writes the string to stdout"""
493 """Writes the string to stdout"""
494 print(s, end='', file=sys.stdout)
494 print(s, end='', file=sys.stdout)
495 sys.stdout.flush()
495 sys.stdout.flush()
496
496
497 def _stderr_raw(self, s):
497 def _stderr_raw(self, s):
498 """Writes the string to stdout"""
498 """Writes the string to stdout"""
499 print(s, end='', file=sys.stderr)
499 print(s, end='', file=sys.stderr)
500 sys.stderr.flush()
500 sys.stderr.flush()
501
501
502 def _run_stdio(self):
502 def _run_stdio(self):
503 """Runs the process using the system standard I/O.
503 """Runs the process using the system standard I/O.
504
504
505 IMPORTANT: stdin needs to be asynchronous, so the Python
505 IMPORTANT: stdin needs to be asynchronous, so the Python
506 sys.stdin object is not used. Instead,
506 sys.stdin object is not used. Instead,
507 msvcrt.kbhit/getwch are used asynchronously.
507 msvcrt.kbhit/getwch are used asynchronously.
508 """
508 """
509 # Disable Line and Echo mode
509 # Disable Line and Echo mode
510 #lpMode = DWORD()
510 #lpMode = DWORD()
511 #handle = msvcrt.get_osfhandle(sys.stdin.fileno())
511 #handle = msvcrt.get_osfhandle(sys.stdin.fileno())
512 #if GetConsoleMode(handle, ctypes.byref(lpMode)):
512 #if GetConsoleMode(handle, ctypes.byref(lpMode)):
513 # set_console_mode = True
513 # set_console_mode = True
514 # if not SetConsoleMode(handle, lpMode.value &
514 # if not SetConsoleMode(handle, lpMode.value &
515 # ~(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)):
515 # ~(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)):
516 # raise ctypes.WinError()
516 # raise ctypes.WinError()
517
517
518 if self.mergeout:
518 if self.mergeout:
519 return self.run(stdout_func = self._stdout_raw,
519 return self.run(stdout_func = self._stdout_raw,
520 stdin_func = self._stdin_raw_block)
520 stdin_func = self._stdin_raw_block)
521 else:
521 else:
522 return self.run(stdout_func = self._stdout_raw,
522 return self.run(stdout_func = self._stdout_raw,
523 stdin_func = self._stdin_raw_block,
523 stdin_func = self._stdin_raw_block,
524 stderr_func = self._stderr_raw)
524 stderr_func = self._stderr_raw)
525
525
526 # Restore the previous console mode
526 # Restore the previous console mode
527 #if set_console_mode:
527 #if set_console_mode:
528 # if not SetConsoleMode(handle, lpMode.value):
528 # if not SetConsoleMode(handle, lpMode.value):
529 # raise ctypes.WinError()
529 # raise ctypes.WinError()
530
530
531 def __exit__(self, exc_type, exc_value, traceback):
531 def __exit__(self, exc_type, exc_value, traceback):
532 if self.hstdin:
532 if self.hstdin:
533 CloseHandle(self.hstdin)
533 CloseHandle(self.hstdin)
534 self.hstdin = None
534 self.hstdin = None
535 if self.hstdout:
535 if self.hstdout:
536 CloseHandle(self.hstdout)
536 CloseHandle(self.hstdout)
537 self.hstdout = None
537 self.hstdout = None
538 if self.hstderr:
538 if self.hstderr:
539 CloseHandle(self.hstderr)
539 CloseHandle(self.hstderr)
540 self.hstderr = None
540 self.hstderr = None
541 if self.piProcInfo != None:
541 if self.piProcInfo != None:
542 CloseHandle(self.piProcInfo.hProcess)
542 CloseHandle(self.piProcInfo.hProcess)
543 CloseHandle(self.piProcInfo.hThread)
543 CloseHandle(self.piProcInfo.hThread)
544 self.piProcInfo = None
544 self.piProcInfo = None
545
545
546
546
547 def system(cmd):
547 def system(cmd):
548 """Win32 version of os.system() that works with network shares.
548 """Win32 version of os.system() that works with network shares.
549
549
550 Note that this implementation returns None, as meant for use in IPython.
550 Note that this implementation returns None, as meant for use in IPython.
551
551
552 Parameters
552 Parameters
553 ----------
553 ----------
554 cmd : str
554 cmd : str
555 A command to be executed in the system shell.
555 A command to be executed in the system shell.
556
556
557 Returns
557 Returns
558 -------
558 -------
559 None : we explicitly do NOT return the subprocess status code, as this
559 None : we explicitly do NOT return the subprocess status code, as this
560 utility is meant to be used extensively in IPython, where any return value
560 utility is meant to be used extensively in IPython, where any return value
561 would trigger :func:`sys.displayhook` calls.
561 would trigger :func:`sys.displayhook` calls.
562 """
562 """
563 with AvoidUNCPath() as path:
563 with AvoidUNCPath() as path:
564 if path is not None:
564 if path is not None:
565 cmd = '"pushd %s &&"%s' % (path, cmd)
565 cmd = '"pushd %s &&"%s' % (path, cmd)
566 with Win32ShellCommandController(cmd) as scc:
566 with Win32ShellCommandController(cmd) as scc:
567 scc.run()
567 scc.run()
568
568
569
569
570 if __name__ == "__main__":
570 if __name__ == "__main__":
571 print("Test starting!")
571 print("Test starting!")
572 #system("cmd")
572 #system("cmd")
573 system("python -i")
573 system("python -i")
574 print("Test finished!")
574 print("Test finished!")
@@ -1,354 +1,353 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Pickle related utilities. Perhaps this should be called 'can'."""
3 """Pickle related utilities. Perhaps this should be called 'can'."""
4
4
5 __docformat__ = "restructuredtext en"
5 __docformat__ = "restructuredtext en"
6
6
7 #-------------------------------------------------------------------------------
7 #-------------------------------------------------------------------------------
8 # Copyright (C) 2008-2011 The IPython Development Team
8 # Copyright (C) 2008-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 #-------------------------------------------------------------------------------
14 #-------------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 import copy
18 import copy
19 import logging
19 import logging
20 import sys
20 import sys
21 from types import FunctionType
21 from types import FunctionType
22
22
23 try:
23 try:
24 import cPickle as pickle
24 import cPickle as pickle
25 except ImportError:
25 except ImportError:
26 import pickle
26 import pickle
27
27
28 try:
28 try:
29 import numpy
29 import numpy
30 except:
30 except:
31 numpy = None
31 numpy = None
32
32
33 import codeutil
34 import py3compat
33 import py3compat
35 from importstring import import_item
34 from importstring import import_item
36
35
37 from IPython.config import Application
36 from IPython.config import Application
38
37
39 if py3compat.PY3:
38 if py3compat.PY3:
40 buffer = memoryview
39 buffer = memoryview
41 class_type = type
40 class_type = type
42 else:
41 else:
43 from types import ClassType
42 from types import ClassType
44 class_type = (type, ClassType)
43 class_type = (type, ClassType)
45
44
46 #-------------------------------------------------------------------------------
45 #-------------------------------------------------------------------------------
47 # Classes
46 # Classes
48 #-------------------------------------------------------------------------------
47 #-------------------------------------------------------------------------------
49
48
50
49
51 class CannedObject(object):
50 class CannedObject(object):
52 def __init__(self, obj, keys=[], hook=None):
51 def __init__(self, obj, keys=[], hook=None):
53 """can an object for safe pickling
52 """can an object for safe pickling
54
53
55 Parameters
54 Parameters
56 ==========
55 ==========
57
56
58 obj:
57 obj:
59 The object to be canned
58 The object to be canned
60 keys: list (optional)
59 keys: list (optional)
61 list of attribute names that will be explicitly canned / uncanned
60 list of attribute names that will be explicitly canned / uncanned
62 hook: callable (optional)
61 hook: callable (optional)
63 An optional extra callable,
62 An optional extra callable,
64 which can do additional processing of the uncanned object.
63 which can do additional processing of the uncanned object.
65
64
66 large data may be offloaded into the buffers list,
65 large data may be offloaded into the buffers list,
67 used for zero-copy transfers.
66 used for zero-copy transfers.
68 """
67 """
69 self.keys = keys
68 self.keys = keys
70 self.obj = copy.copy(obj)
69 self.obj = copy.copy(obj)
71 self.hook = can(hook)
70 self.hook = can(hook)
72 for key in keys:
71 for key in keys:
73 setattr(self.obj, key, can(getattr(obj, key)))
72 setattr(self.obj, key, can(getattr(obj, key)))
74
73
75 self.buffers = []
74 self.buffers = []
76
75
77 def get_object(self, g=None):
76 def get_object(self, g=None):
78 if g is None:
77 if g is None:
79 g = {}
78 g = {}
80 obj = self.obj
79 obj = self.obj
81 for key in self.keys:
80 for key in self.keys:
82 setattr(obj, key, uncan(getattr(obj, key), g))
81 setattr(obj, key, uncan(getattr(obj, key), g))
83
82
84 if self.hook:
83 if self.hook:
85 self.hook = uncan(self.hook, g)
84 self.hook = uncan(self.hook, g)
86 self.hook(obj, g)
85 self.hook(obj, g)
87 return self.obj
86 return self.obj
88
87
89
88
90 class Reference(CannedObject):
89 class Reference(CannedObject):
91 """object for wrapping a remote reference by name."""
90 """object for wrapping a remote reference by name."""
92 def __init__(self, name):
91 def __init__(self, name):
93 if not isinstance(name, basestring):
92 if not isinstance(name, basestring):
94 raise TypeError("illegal name: %r"%name)
93 raise TypeError("illegal name: %r"%name)
95 self.name = name
94 self.name = name
96 self.buffers = []
95 self.buffers = []
97
96
98 def __repr__(self):
97 def __repr__(self):
99 return "<Reference: %r>"%self.name
98 return "<Reference: %r>"%self.name
100
99
101 def get_object(self, g=None):
100 def get_object(self, g=None):
102 if g is None:
101 if g is None:
103 g = {}
102 g = {}
104
103
105 return eval(self.name, g)
104 return eval(self.name, g)
106
105
107
106
108 class CannedFunction(CannedObject):
107 class CannedFunction(CannedObject):
109
108
110 def __init__(self, f):
109 def __init__(self, f):
111 self._check_type(f)
110 self._check_type(f)
112 self.code = f.func_code
111 self.code = f.func_code
113 if f.func_defaults:
112 if f.func_defaults:
114 self.defaults = [ can(fd) for fd in f.func_defaults ]
113 self.defaults = [ can(fd) for fd in f.func_defaults ]
115 else:
114 else:
116 self.defaults = None
115 self.defaults = None
117 self.module = f.__module__ or '__main__'
116 self.module = f.__module__ or '__main__'
118 self.__name__ = f.__name__
117 self.__name__ = f.__name__
119 self.buffers = []
118 self.buffers = []
120
119
121 def _check_type(self, obj):
120 def _check_type(self, obj):
122 assert isinstance(obj, FunctionType), "Not a function type"
121 assert isinstance(obj, FunctionType), "Not a function type"
123
122
124 def get_object(self, g=None):
123 def get_object(self, g=None):
125 # try to load function back into its module:
124 # try to load function back into its module:
126 if not self.module.startswith('__'):
125 if not self.module.startswith('__'):
127 __import__(self.module)
126 __import__(self.module)
128 g = sys.modules[self.module].__dict__
127 g = sys.modules[self.module].__dict__
129
128
130 if g is None:
129 if g is None:
131 g = {}
130 g = {}
132 if self.defaults:
131 if self.defaults:
133 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
132 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
134 else:
133 else:
135 defaults = None
134 defaults = None
136 newFunc = FunctionType(self.code, g, self.__name__, defaults)
135 newFunc = FunctionType(self.code, g, self.__name__, defaults)
137 return newFunc
136 return newFunc
138
137
139 class CannedClass(CannedObject):
138 class CannedClass(CannedObject):
140
139
141 def __init__(self, cls):
140 def __init__(self, cls):
142 self._check_type(cls)
141 self._check_type(cls)
143 self.name = cls.__name__
142 self.name = cls.__name__
144 self.old_style = not isinstance(cls, type)
143 self.old_style = not isinstance(cls, type)
145 self._canned_dict = {}
144 self._canned_dict = {}
146 for k,v in cls.__dict__.items():
145 for k,v in cls.__dict__.items():
147 if k not in ('__weakref__', '__dict__'):
146 if k not in ('__weakref__', '__dict__'):
148 self._canned_dict[k] = can(v)
147 self._canned_dict[k] = can(v)
149 if self.old_style:
148 if self.old_style:
150 mro = []
149 mro = []
151 else:
150 else:
152 mro = cls.mro()
151 mro = cls.mro()
153
152
154 self.parents = [ can(c) for c in mro[1:] ]
153 self.parents = [ can(c) for c in mro[1:] ]
155 self.buffers = []
154 self.buffers = []
156
155
157 def _check_type(self, obj):
156 def _check_type(self, obj):
158 assert isinstance(obj, class_type), "Not a class type"
157 assert isinstance(obj, class_type), "Not a class type"
159
158
160 def get_object(self, g=None):
159 def get_object(self, g=None):
161 parents = tuple(uncan(p, g) for p in self.parents)
160 parents = tuple(uncan(p, g) for p in self.parents)
162 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
161 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
163
162
164 class CannedArray(CannedObject):
163 class CannedArray(CannedObject):
165 def __init__(self, obj):
164 def __init__(self, obj):
166 self.shape = obj.shape
165 self.shape = obj.shape
167 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
166 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
168 if sum(obj.shape) == 0:
167 if sum(obj.shape) == 0:
169 # just pickle it
168 # just pickle it
170 self.buffers = [pickle.dumps(obj, -1)]
169 self.buffers = [pickle.dumps(obj, -1)]
171 else:
170 else:
172 # ensure contiguous
171 # ensure contiguous
173 obj = numpy.ascontiguousarray(obj, dtype=None)
172 obj = numpy.ascontiguousarray(obj, dtype=None)
174 self.buffers = [buffer(obj)]
173 self.buffers = [buffer(obj)]
175
174
176 def get_object(self, g=None):
175 def get_object(self, g=None):
177 data = self.buffers[0]
176 data = self.buffers[0]
178 if sum(self.shape) == 0:
177 if sum(self.shape) == 0:
179 # no shape, we just pickled it
178 # no shape, we just pickled it
180 return pickle.loads(data)
179 return pickle.loads(data)
181 else:
180 else:
182 return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape)
181 return numpy.frombuffer(data, dtype=self.dtype).reshape(self.shape)
183
182
184
183
185 class CannedBytes(CannedObject):
184 class CannedBytes(CannedObject):
186 wrap = bytes
185 wrap = bytes
187 def __init__(self, obj):
186 def __init__(self, obj):
188 self.buffers = [obj]
187 self.buffers = [obj]
189
188
190 def get_object(self, g=None):
189 def get_object(self, g=None):
191 data = self.buffers[0]
190 data = self.buffers[0]
192 return self.wrap(data)
191 return self.wrap(data)
193
192
194 def CannedBuffer(CannedBytes):
193 def CannedBuffer(CannedBytes):
195 wrap = buffer
194 wrap = buffer
196
195
197 #-------------------------------------------------------------------------------
196 #-------------------------------------------------------------------------------
198 # Functions
197 # Functions
199 #-------------------------------------------------------------------------------
198 #-------------------------------------------------------------------------------
200
199
201 def _logger():
200 def _logger():
202 """get the logger for the current Application
201 """get the logger for the current Application
203
202
204 the root logger will be used if no Application is running
203 the root logger will be used if no Application is running
205 """
204 """
206 if Application.initialized():
205 if Application.initialized():
207 logger = Application.instance().log
206 logger = Application.instance().log
208 else:
207 else:
209 logger = logging.getLogger()
208 logger = logging.getLogger()
210 if not logger.handlers:
209 if not logger.handlers:
211 logging.basicConfig()
210 logging.basicConfig()
212
211
213 return logger
212 return logger
214
213
215 def _import_mapping(mapping, original=None):
214 def _import_mapping(mapping, original=None):
216 """import any string-keys in a type mapping
215 """import any string-keys in a type mapping
217
216
218 """
217 """
219 log = _logger()
218 log = _logger()
220 log.debug("Importing canning map")
219 log.debug("Importing canning map")
221 for key,value in mapping.items():
220 for key,value in mapping.items():
222 if isinstance(key, basestring):
221 if isinstance(key, basestring):
223 try:
222 try:
224 cls = import_item(key)
223 cls = import_item(key)
225 except Exception:
224 except Exception:
226 if original and key not in original:
225 if original and key not in original:
227 # only message on user-added classes
226 # only message on user-added classes
228 log.error("cannning class not importable: %r", key, exc_info=True)
227 log.error("cannning class not importable: %r", key, exc_info=True)
229 mapping.pop(key)
228 mapping.pop(key)
230 else:
229 else:
231 mapping[cls] = mapping.pop(key)
230 mapping[cls] = mapping.pop(key)
232
231
233 def istype(obj, check):
232 def istype(obj, check):
234 """like isinstance(obj, check), but strict
233 """like isinstance(obj, check), but strict
235
234
236 This won't catch subclasses.
235 This won't catch subclasses.
237 """
236 """
238 if isinstance(check, tuple):
237 if isinstance(check, tuple):
239 for cls in check:
238 for cls in check:
240 if type(obj) is cls:
239 if type(obj) is cls:
241 return True
240 return True
242 return False
241 return False
243 else:
242 else:
244 return type(obj) is check
243 return type(obj) is check
245
244
246 def can(obj):
245 def can(obj):
247 """prepare an object for pickling"""
246 """prepare an object for pickling"""
248
247
249 import_needed = False
248 import_needed = False
250
249
251 for cls,canner in can_map.iteritems():
250 for cls,canner in can_map.iteritems():
252 if isinstance(cls, basestring):
251 if isinstance(cls, basestring):
253 import_needed = True
252 import_needed = True
254 break
253 break
255 elif istype(obj, cls):
254 elif istype(obj, cls):
256 return canner(obj)
255 return canner(obj)
257
256
258 if import_needed:
257 if import_needed:
259 # perform can_map imports, then try again
258 # perform can_map imports, then try again
260 # this will usually only happen once
259 # this will usually only happen once
261 _import_mapping(can_map, _original_can_map)
260 _import_mapping(can_map, _original_can_map)
262 return can(obj)
261 return can(obj)
263
262
264 return obj
263 return obj
265
264
266 def can_class(obj):
265 def can_class(obj):
267 if isinstance(obj, class_type) and obj.__module__ == '__main__':
266 if isinstance(obj, class_type) and obj.__module__ == '__main__':
268 return CannedClass(obj)
267 return CannedClass(obj)
269 else:
268 else:
270 return obj
269 return obj
271
270
272 def can_dict(obj):
271 def can_dict(obj):
273 """can the *values* of a dict"""
272 """can the *values* of a dict"""
274 if istype(obj, dict):
273 if istype(obj, dict):
275 newobj = {}
274 newobj = {}
276 for k, v in obj.iteritems():
275 for k, v in obj.iteritems():
277 newobj[k] = can(v)
276 newobj[k] = can(v)
278 return newobj
277 return newobj
279 else:
278 else:
280 return obj
279 return obj
281
280
282 sequence_types = (list, tuple, set)
281 sequence_types = (list, tuple, set)
283
282
284 def can_sequence(obj):
283 def can_sequence(obj):
285 """can the elements of a sequence"""
284 """can the elements of a sequence"""
286 if istype(obj, sequence_types):
285 if istype(obj, sequence_types):
287 t = type(obj)
286 t = type(obj)
288 return t([can(i) for i in obj])
287 return t([can(i) for i in obj])
289 else:
288 else:
290 return obj
289 return obj
291
290
292 def uncan(obj, g=None):
291 def uncan(obj, g=None):
293 """invert canning"""
292 """invert canning"""
294
293
295 import_needed = False
294 import_needed = False
296 for cls,uncanner in uncan_map.iteritems():
295 for cls,uncanner in uncan_map.iteritems():
297 if isinstance(cls, basestring):
296 if isinstance(cls, basestring):
298 import_needed = True
297 import_needed = True
299 break
298 break
300 elif isinstance(obj, cls):
299 elif isinstance(obj, cls):
301 return uncanner(obj, g)
300 return uncanner(obj, g)
302
301
303 if import_needed:
302 if import_needed:
304 # perform uncan_map imports, then try again
303 # perform uncan_map imports, then try again
305 # this will usually only happen once
304 # this will usually only happen once
306 _import_mapping(uncan_map, _original_uncan_map)
305 _import_mapping(uncan_map, _original_uncan_map)
307 return uncan(obj, g)
306 return uncan(obj, g)
308
307
309 return obj
308 return obj
310
309
311 def uncan_dict(obj, g=None):
310 def uncan_dict(obj, g=None):
312 if istype(obj, dict):
311 if istype(obj, dict):
313 newobj = {}
312 newobj = {}
314 for k, v in obj.iteritems():
313 for k, v in obj.iteritems():
315 newobj[k] = uncan(v,g)
314 newobj[k] = uncan(v,g)
316 return newobj
315 return newobj
317 else:
316 else:
318 return obj
317 return obj
319
318
320 def uncan_sequence(obj, g=None):
319 def uncan_sequence(obj, g=None):
321 if istype(obj, sequence_types):
320 if istype(obj, sequence_types):
322 t = type(obj)
321 t = type(obj)
323 return t([uncan(i,g) for i in obj])
322 return t([uncan(i,g) for i in obj])
324 else:
323 else:
325 return obj
324 return obj
326
325
327 def _uncan_dependent_hook(dep, g=None):
326 def _uncan_dependent_hook(dep, g=None):
328 dep.check_dependency()
327 dep.check_dependency()
329
328
330 def can_dependent(obj):
329 def can_dependent(obj):
331 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
330 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
332
331
333 #-------------------------------------------------------------------------------
332 #-------------------------------------------------------------------------------
334 # API dictionaries
333 # API dictionaries
335 #-------------------------------------------------------------------------------
334 #-------------------------------------------------------------------------------
336
335
337 # These dicts can be extended for custom serialization of new objects
336 # These dicts can be extended for custom serialization of new objects
338
337
339 can_map = {
338 can_map = {
340 'IPython.parallel.dependent' : can_dependent,
339 'IPython.parallel.dependent' : can_dependent,
341 'numpy.ndarray' : CannedArray,
340 'numpy.ndarray' : CannedArray,
342 FunctionType : CannedFunction,
341 FunctionType : CannedFunction,
343 bytes : CannedBytes,
342 bytes : CannedBytes,
344 buffer : CannedBuffer,
343 buffer : CannedBuffer,
345 class_type : can_class,
344 class_type : can_class,
346 }
345 }
347
346
348 uncan_map = {
347 uncan_map = {
349 CannedObject : lambda obj, g: obj.get_object(g),
348 CannedObject : lambda obj, g: obj.get_object(g),
350 }
349 }
351
350
352 # for use in _import_mapping:
351 # for use in _import_mapping:
353 _original_can_map = can_map.copy()
352 _original_can_map = can_map.copy()
354 _original_uncan_map = uncan_map.copy()
353 _original_uncan_map = uncan_map.copy()
@@ -1,135 +1,125 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.path.py"""
2 """Tests for IPython.utils.module_paths.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008-2011 The IPython Development Team
5 # Copyright (C) 2008-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 from __future__ import with_statement
15 from __future__ import with_statement
16
16
17 import os
17 import os
18 import shutil
18 import shutil
19 import sys
19 import sys
20 import tempfile
20 import tempfile
21 import StringIO
22
21
23 from os.path import join, abspath, split
22 from os.path import join, abspath, split
24
23
25 import nose.tools as nt
26
27 from nose import with_setup
28
29 import IPython
30 from IPython.testing import decorators as dec
31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
32 from IPython.testing.tools import make_tempfile
24 from IPython.testing.tools import make_tempfile
33 from IPython.utils import path, io
34 from IPython.utils import py3compat
35
25
36 import IPython.utils.module_paths as mp
26 import IPython.utils.module_paths as mp
37
27
38 env = os.environ
28 env = os.environ
39 TEST_FILE_PATH = split(abspath(__file__))[0]
29 TEST_FILE_PATH = split(abspath(__file__))[0]
40 TMP_TEST_DIR = tempfile.mkdtemp()
30 TMP_TEST_DIR = tempfile.mkdtemp()
41 #
31 #
42 # Setup/teardown functions/decorators
32 # Setup/teardown functions/decorators
43 #
33 #
44
34
45 old_syspath = sys.path
35 old_syspath = sys.path
46
36
47 def make_empty_file(fname):
37 def make_empty_file(fname):
48 f = open(fname, 'w')
38 f = open(fname, 'w')
49 f.close()
39 f.close()
50
40
51
41
52 def setup():
42 def setup():
53 """Setup testenvironment for the module:
43 """Setup testenvironment for the module:
54
44
55 """
45 """
56 # Do not mask exceptions here. In particular, catching WindowsError is a
46 # Do not mask exceptions here. In particular, catching WindowsError is a
57 # problem because that exception is only defined on Windows...
47 # problem because that exception is only defined on Windows...
58 os.makedirs(join(TMP_TEST_DIR, "xmod"))
48 os.makedirs(join(TMP_TEST_DIR, "xmod"))
59 os.makedirs(join(TMP_TEST_DIR, "nomod"))
49 os.makedirs(join(TMP_TEST_DIR, "nomod"))
60 make_empty_file(join(TMP_TEST_DIR, "xmod/__init__.py"))
50 make_empty_file(join(TMP_TEST_DIR, "xmod/__init__.py"))
61 make_empty_file(join(TMP_TEST_DIR, "xmod/sub.py"))
51 make_empty_file(join(TMP_TEST_DIR, "xmod/sub.py"))
62 make_empty_file(join(TMP_TEST_DIR, "pack.py"))
52 make_empty_file(join(TMP_TEST_DIR, "pack.py"))
63 make_empty_file(join(TMP_TEST_DIR, "packpyc.pyc"))
53 make_empty_file(join(TMP_TEST_DIR, "packpyc.pyc"))
64 sys.path = [TMP_TEST_DIR]
54 sys.path = [TMP_TEST_DIR]
65
55
66 def teardown():
56 def teardown():
67 """Teardown testenvironment for the module:
57 """Teardown testenvironment for the module:
68
58
69 - Remove tempdir
59 - Remove tempdir
70 - restore sys.path
60 - restore sys.path
71 """
61 """
72 # Note: we remove the parent test dir, which is the root of all test
62 # Note: we remove the parent test dir, which is the root of all test
73 # subdirs we may have created. Use shutil instead of os.removedirs, so
63 # subdirs we may have created. Use shutil instead of os.removedirs, so
74 # that non-empty directories are all recursively removed.
64 # that non-empty directories are all recursively removed.
75 shutil.rmtree(TMP_TEST_DIR)
65 shutil.rmtree(TMP_TEST_DIR)
76 sys.path = old_syspath
66 sys.path = old_syspath
77
67
78
68
79 def test_get_init_1():
69 def test_get_init_1():
80 """See if get_init can find __init__.py in this testdir"""
70 """See if get_init can find __init__.py in this testdir"""
81 with make_tempfile(join(TMP_TEST_DIR, "__init__.py")):
71 with make_tempfile(join(TMP_TEST_DIR, "__init__.py")):
82 assert mp.get_init(TMP_TEST_DIR)
72 assert mp.get_init(TMP_TEST_DIR)
83
73
84 def test_get_init_2():
74 def test_get_init_2():
85 """See if get_init can find __init__.pyw in this testdir"""
75 """See if get_init can find __init__.pyw in this testdir"""
86 with make_tempfile(join(TMP_TEST_DIR, "__init__.pyw")):
76 with make_tempfile(join(TMP_TEST_DIR, "__init__.pyw")):
87 assert mp.get_init(TMP_TEST_DIR)
77 assert mp.get_init(TMP_TEST_DIR)
88
78
89 def test_get_init_3():
79 def test_get_init_3():
90 """get_init can't find __init__.pyc in this testdir"""
80 """get_init can't find __init__.pyc in this testdir"""
91 with make_tempfile(join(TMP_TEST_DIR, "__init__.pyc")):
81 with make_tempfile(join(TMP_TEST_DIR, "__init__.pyc")):
92 assert mp.get_init(TMP_TEST_DIR) is None
82 assert mp.get_init(TMP_TEST_DIR) is None
93
83
94 def test_get_init_3():
84 def test_get_init_4():
95 """get_init can't find __init__ in empty testdir"""
85 """get_init can't find __init__ in empty testdir"""
96 assert mp.get_init(TMP_TEST_DIR) is None
86 assert mp.get_init(TMP_TEST_DIR) is None
97
87
98
88
99 def test_find_mod_1():
89 def test_find_mod_1():
100 modpath = join(TMP_TEST_DIR, "xmod", "__init__.py")
90 modpath = join(TMP_TEST_DIR, "xmod", "__init__.py")
101 assert mp.find_mod("xmod") == modpath
91 assert mp.find_mod("xmod") == modpath
102
92
103 def test_find_mod_2():
93 def test_find_mod_2():
104 modpath = join(TMP_TEST_DIR, "xmod", "__init__.py")
94 modpath = join(TMP_TEST_DIR, "xmod", "__init__.py")
105 assert mp.find_mod("xmod") == modpath
95 assert mp.find_mod("xmod") == modpath
106
96
107 def test_find_mod_3():
97 def test_find_mod_3():
108 modpath = join(TMP_TEST_DIR, "xmod", "sub.py")
98 modpath = join(TMP_TEST_DIR, "xmod", "sub.py")
109 assert mp.find_mod("xmod.sub") == modpath
99 assert mp.find_mod("xmod.sub") == modpath
110
100
111 def test_find_mod_4():
101 def test_find_mod_4():
112 modpath = join(TMP_TEST_DIR, "pack.py")
102 modpath = join(TMP_TEST_DIR, "pack.py")
113 assert mp.find_mod("pack") == modpath
103 assert mp.find_mod("pack") == modpath
114
104
115 def test_find_mod_5():
105 def test_find_mod_5():
116 assert mp.find_mod("packpyc") is None
106 assert mp.find_mod("packpyc") is None
117
107
118 def test_find_module_1():
108 def test_find_module_1():
119 modpath = join(TMP_TEST_DIR, "xmod")
109 modpath = join(TMP_TEST_DIR, "xmod")
120 assert mp.find_module("xmod") == modpath
110 assert mp.find_module("xmod") == modpath
121
111
122 def test_find_module_2():
112 def test_find_module_2():
123 """Testing sys.path that is empty"""
113 """Testing sys.path that is empty"""
124 assert mp.find_module("xmod", []) is None
114 assert mp.find_module("xmod", []) is None
125
115
126 def test_find_module_3():
116 def test_find_module_3():
127 """Testing sys.path that is empty"""
117 """Testing sys.path that is empty"""
128 assert mp.find_module(None, None) is None
118 assert mp.find_module(None, None) is None
129
119
130 def test_find_module_4():
120 def test_find_module_4():
131 """Testing sys.path that is empty"""
121 """Testing sys.path that is empty"""
132 assert mp.find_module(None) is None
122 assert mp.find_module(None) is None
133
123
134 def test_find_module_5():
124 def test_find_module_5():
135 assert mp.find_module("xmod.nopack") is None
125 assert mp.find_module("xmod.nopack") is None
@@ -1,560 +1,559 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.path.py"""
2 """Tests for IPython.utils.path.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008-2011 The IPython Development Team
5 # Copyright (C) 2008-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 from __future__ import with_statement
15 from __future__ import with_statement
16
16
17 import os
17 import os
18 import shutil
18 import shutil
19 import sys
19 import sys
20 import tempfile
20 import tempfile
21 from io import StringIO
22 from contextlib import contextmanager
21 from contextlib import contextmanager
23
22
24 from os.path import join, abspath, split
23 from os.path import join, abspath, split
25
24
26 import nose.tools as nt
25 import nose.tools as nt
27
26
28 from nose import with_setup
27 from nose import with_setup
29
28
30 import IPython
29 import IPython
31 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
32 from IPython.testing.decorators import skip_if_not_win32, skip_win32
31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
33 from IPython.testing.tools import make_tempfile, AssertPrints
32 from IPython.testing.tools import make_tempfile, AssertPrints
34 from IPython.utils import path, io
33 from IPython.utils import path
35 from IPython.utils import py3compat
34 from IPython.utils import py3compat
36 from IPython.utils.tempdir import TemporaryDirectory
35 from IPython.utils.tempdir import TemporaryDirectory
37
36
38 # Platform-dependent imports
37 # Platform-dependent imports
39 try:
38 try:
40 import _winreg as wreg
39 import _winreg as wreg
41 except ImportError:
40 except ImportError:
42 #Fake _winreg module on none windows platforms
41 #Fake _winreg module on none windows platforms
43 import types
42 import types
44 wr_name = "winreg" if py3compat.PY3 else "_winreg"
43 wr_name = "winreg" if py3compat.PY3 else "_winreg"
45 sys.modules[wr_name] = types.ModuleType(wr_name)
44 sys.modules[wr_name] = types.ModuleType(wr_name)
46 import _winreg as wreg
45 import _winreg as wreg
47 #Add entries that needs to be stubbed by the testing code
46 #Add entries that needs to be stubbed by the testing code
48 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
47 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
49
48
50 try:
49 try:
51 reload
50 reload
52 except NameError: # Python 3
51 except NameError: # Python 3
53 from imp import reload
52 from imp import reload
54
53
55 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
56 # Globals
55 # Globals
57 #-----------------------------------------------------------------------------
56 #-----------------------------------------------------------------------------
58 env = os.environ
57 env = os.environ
59 TEST_FILE_PATH = split(abspath(__file__))[0]
58 TEST_FILE_PATH = split(abspath(__file__))[0]
60 TMP_TEST_DIR = tempfile.mkdtemp()
59 TMP_TEST_DIR = tempfile.mkdtemp()
61 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
60 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
62 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
61 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
63 XDG_CACHE_DIR = join(HOME_TEST_DIR, "xdg_cache_dir")
62 XDG_CACHE_DIR = join(HOME_TEST_DIR, "xdg_cache_dir")
64 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
63 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
65 #
64 #
66 # Setup/teardown functions/decorators
65 # Setup/teardown functions/decorators
67 #
66 #
68
67
69 def setup():
68 def setup():
70 """Setup testenvironment for the module:
69 """Setup testenvironment for the module:
71
70
72 - Adds dummy home dir tree
71 - Adds dummy home dir tree
73 """
72 """
74 # Do not mask exceptions here. In particular, catching WindowsError is a
73 # Do not mask exceptions here. In particular, catching WindowsError is a
75 # problem because that exception is only defined on Windows...
74 # problem because that exception is only defined on Windows...
76 os.makedirs(IP_TEST_DIR)
75 os.makedirs(IP_TEST_DIR)
77 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
76 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
78 os.makedirs(os.path.join(XDG_CACHE_DIR, 'ipython'))
77 os.makedirs(os.path.join(XDG_CACHE_DIR, 'ipython'))
79
78
80
79
81 def teardown():
80 def teardown():
82 """Teardown testenvironment for the module:
81 """Teardown testenvironment for the module:
83
82
84 - Remove dummy home dir tree
83 - Remove dummy home dir tree
85 """
84 """
86 # Note: we remove the parent test dir, which is the root of all test
85 # Note: we remove the parent test dir, which is the root of all test
87 # subdirs we may have created. Use shutil instead of os.removedirs, so
86 # subdirs we may have created. Use shutil instead of os.removedirs, so
88 # that non-empty directories are all recursively removed.
87 # that non-empty directories are all recursively removed.
89 shutil.rmtree(TMP_TEST_DIR)
88 shutil.rmtree(TMP_TEST_DIR)
90
89
91
90
92 def setup_environment():
91 def setup_environment():
93 """Setup testenvironment for some functions that are tested
92 """Setup testenvironment for some functions that are tested
94 in this module. In particular this functions stores attributes
93 in this module. In particular this functions stores attributes
95 and other things that we need to stub in some test functions.
94 and other things that we need to stub in some test functions.
96 This needs to be done on a function level and not module level because
95 This needs to be done on a function level and not module level because
97 each testfunction needs a pristine environment.
96 each testfunction needs a pristine environment.
98 """
97 """
99 global oldstuff, platformstuff
98 global oldstuff, platformstuff
100 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
99 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
101
100
102 if os.name == 'nt':
101 if os.name == 'nt':
103 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
102 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
104
103
105
104
106 def teardown_environment():
105 def teardown_environment():
107 """Restore things that were remebered by the setup_environment function
106 """Restore things that were remebered by the setup_environment function
108 """
107 """
109 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
108 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
110 os.chdir(old_wd)
109 os.chdir(old_wd)
111 reload(path)
110 reload(path)
112
111
113 for key in env.keys():
112 for key in env.keys():
114 if key not in oldenv:
113 if key not in oldenv:
115 del env[key]
114 del env[key]
116 env.update(oldenv)
115 env.update(oldenv)
117 if hasattr(sys, 'frozen'):
116 if hasattr(sys, 'frozen'):
118 del sys.frozen
117 del sys.frozen
119 if os.name == 'nt':
118 if os.name == 'nt':
120 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
119 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
121
120
122 # Build decorator that uses the setup_environment/setup_environment
121 # Build decorator that uses the setup_environment/setup_environment
123 with_environment = with_setup(setup_environment, teardown_environment)
122 with_environment = with_setup(setup_environment, teardown_environment)
124
123
125 @skip_if_not_win32
124 @skip_if_not_win32
126 @with_environment
125 @with_environment
127 def test_get_home_dir_1():
126 def test_get_home_dir_1():
128 """Testcase for py2exe logic, un-compressed lib
127 """Testcase for py2exe logic, un-compressed lib
129 """
128 """
130 sys.frozen = True
129 sys.frozen = True
131
130
132 #fake filename for IPython.__init__
131 #fake filename for IPython.__init__
133 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
132 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
134
133
135 home_dir = path.get_home_dir()
134 home_dir = path.get_home_dir()
136 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
135 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
137
136
138
137
139 @skip_if_not_win32
138 @skip_if_not_win32
140 @with_environment
139 @with_environment
141 def test_get_home_dir_2():
140 def test_get_home_dir_2():
142 """Testcase for py2exe logic, compressed lib
141 """Testcase for py2exe logic, compressed lib
143 """
142 """
144 sys.frozen = True
143 sys.frozen = True
145 #fake filename for IPython.__init__
144 #fake filename for IPython.__init__
146 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
145 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
147
146
148 home_dir = path.get_home_dir(True)
147 home_dir = path.get_home_dir(True)
149 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
148 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
150
149
151
150
152 @with_environment
151 @with_environment
153 def test_get_home_dir_3():
152 def test_get_home_dir_3():
154 """get_home_dir() uses $HOME if set"""
153 """get_home_dir() uses $HOME if set"""
155 env["HOME"] = HOME_TEST_DIR
154 env["HOME"] = HOME_TEST_DIR
156 home_dir = path.get_home_dir(True)
155 home_dir = path.get_home_dir(True)
157 # get_home_dir expands symlinks
156 # get_home_dir expands symlinks
158 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
157 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
159
158
160
159
161 @with_environment
160 @with_environment
162 def test_get_home_dir_4():
161 def test_get_home_dir_4():
163 """get_home_dir() still works if $HOME is not set"""
162 """get_home_dir() still works if $HOME is not set"""
164
163
165 if 'HOME' in env: del env['HOME']
164 if 'HOME' in env: del env['HOME']
166 # this should still succeed, but we don't care what the answer is
165 # this should still succeed, but we don't care what the answer is
167 home = path.get_home_dir(False)
166 home = path.get_home_dir(False)
168
167
169 @with_environment
168 @with_environment
170 def test_get_home_dir_5():
169 def test_get_home_dir_5():
171 """raise HomeDirError if $HOME is specified, but not a writable dir"""
170 """raise HomeDirError if $HOME is specified, but not a writable dir"""
172 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
171 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
173 # set os.name = posix, to prevent My Documents fallback on Windows
172 # set os.name = posix, to prevent My Documents fallback on Windows
174 os.name = 'posix'
173 os.name = 'posix'
175 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
174 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
176
175
177
176
178 # Should we stub wreg fully so we can run the test on all platforms?
177 # Should we stub wreg fully so we can run the test on all platforms?
179 @skip_if_not_win32
178 @skip_if_not_win32
180 @with_environment
179 @with_environment
181 def test_get_home_dir_8():
180 def test_get_home_dir_8():
182 """Using registry hack for 'My Documents', os=='nt'
181 """Using registry hack for 'My Documents', os=='nt'
183
182
184 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
183 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
185 """
184 """
186 os.name = 'nt'
185 os.name = 'nt'
187 # Remove from stub environment all keys that may be set
186 # Remove from stub environment all keys that may be set
188 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
187 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
189 env.pop(key, None)
188 env.pop(key, None)
190
189
191 #Stub windows registry functions
190 #Stub windows registry functions
192 def OpenKey(x, y):
191 def OpenKey(x, y):
193 class key:
192 class key:
194 def Close(self):
193 def Close(self):
195 pass
194 pass
196 return key()
195 return key()
197 def QueryValueEx(x, y):
196 def QueryValueEx(x, y):
198 return [abspath(HOME_TEST_DIR)]
197 return [abspath(HOME_TEST_DIR)]
199
198
200 wreg.OpenKey = OpenKey
199 wreg.OpenKey = OpenKey
201 wreg.QueryValueEx = QueryValueEx
200 wreg.QueryValueEx = QueryValueEx
202
201
203 home_dir = path.get_home_dir()
202 home_dir = path.get_home_dir()
204 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
203 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
205
204
206
205
207 @with_environment
206 @with_environment
208 def test_get_ipython_dir_1():
207 def test_get_ipython_dir_1():
209 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
208 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
210 env_ipdir = os.path.join("someplace", ".ipython")
209 env_ipdir = os.path.join("someplace", ".ipython")
211 path._writable_dir = lambda path: True
210 path._writable_dir = lambda path: True
212 env['IPYTHONDIR'] = env_ipdir
211 env['IPYTHONDIR'] = env_ipdir
213 ipdir = path.get_ipython_dir()
212 ipdir = path.get_ipython_dir()
214 nt.assert_equal(ipdir, env_ipdir)
213 nt.assert_equal(ipdir, env_ipdir)
215
214
216
215
217 @with_environment
216 @with_environment
218 def test_get_ipython_dir_2():
217 def test_get_ipython_dir_2():
219 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
218 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
220 path.get_home_dir = lambda : "someplace"
219 path.get_home_dir = lambda : "someplace"
221 path.get_xdg_dir = lambda : None
220 path.get_xdg_dir = lambda : None
222 path._writable_dir = lambda path: True
221 path._writable_dir = lambda path: True
223 os.name = "posix"
222 os.name = "posix"
224 env.pop('IPYTHON_DIR', None)
223 env.pop('IPYTHON_DIR', None)
225 env.pop('IPYTHONDIR', None)
224 env.pop('IPYTHONDIR', None)
226 env.pop('XDG_CONFIG_HOME', None)
225 env.pop('XDG_CONFIG_HOME', None)
227 ipdir = path.get_ipython_dir()
226 ipdir = path.get_ipython_dir()
228 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
227 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
229
228
230 @with_environment
229 @with_environment
231 def test_get_ipython_dir_3():
230 def test_get_ipython_dir_3():
232 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
231 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
233 path.get_home_dir = lambda : "someplace"
232 path.get_home_dir = lambda : "someplace"
234 path._writable_dir = lambda path: True
233 path._writable_dir = lambda path: True
235 os.name = "posix"
234 os.name = "posix"
236 env.pop('IPYTHON_DIR', None)
235 env.pop('IPYTHON_DIR', None)
237 env.pop('IPYTHONDIR', None)
236 env.pop('IPYTHONDIR', None)
238 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
237 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
239 ipdir = path.get_ipython_dir()
238 ipdir = path.get_ipython_dir()
240 if sys.platform == "darwin":
239 if sys.platform == "darwin":
241 expected = os.path.join("someplace", ".ipython")
240 expected = os.path.join("someplace", ".ipython")
242 else:
241 else:
243 expected = os.path.join(XDG_TEST_DIR, "ipython")
242 expected = os.path.join(XDG_TEST_DIR, "ipython")
244 nt.assert_equal(ipdir, expected)
243 nt.assert_equal(ipdir, expected)
245
244
246 @with_environment
245 @with_environment
247 def test_get_ipython_dir_4():
246 def test_get_ipython_dir_4():
248 """test_get_ipython_dir_4, use XDG if both exist."""
247 """test_get_ipython_dir_4, use XDG if both exist."""
249 path.get_home_dir = lambda : HOME_TEST_DIR
248 path.get_home_dir = lambda : HOME_TEST_DIR
250 os.name = "posix"
249 os.name = "posix"
251 env.pop('IPYTHON_DIR', None)
250 env.pop('IPYTHON_DIR', None)
252 env.pop('IPYTHONDIR', None)
251 env.pop('IPYTHONDIR', None)
253 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
252 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
254 ipdir = path.get_ipython_dir()
253 ipdir = path.get_ipython_dir()
255 if sys.platform == "darwin":
254 if sys.platform == "darwin":
256 expected = os.path.join(HOME_TEST_DIR, ".ipython")
255 expected = os.path.join(HOME_TEST_DIR, ".ipython")
257 else:
256 else:
258 expected = os.path.join(XDG_TEST_DIR, "ipython")
257 expected = os.path.join(XDG_TEST_DIR, "ipython")
259 nt.assert_equal(ipdir, expected)
258 nt.assert_equal(ipdir, expected)
260
259
261 @with_environment
260 @with_environment
262 def test_get_ipython_dir_5():
261 def test_get_ipython_dir_5():
263 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
262 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
264 path.get_home_dir = lambda : HOME_TEST_DIR
263 path.get_home_dir = lambda : HOME_TEST_DIR
265 os.name = "posix"
264 os.name = "posix"
266 env.pop('IPYTHON_DIR', None)
265 env.pop('IPYTHON_DIR', None)
267 env.pop('IPYTHONDIR', None)
266 env.pop('IPYTHONDIR', None)
268 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
267 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
269 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
268 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
270 ipdir = path.get_ipython_dir()
269 ipdir = path.get_ipython_dir()
271 nt.assert_equal(ipdir, IP_TEST_DIR)
270 nt.assert_equal(ipdir, IP_TEST_DIR)
272
271
273 @with_environment
272 @with_environment
274 def test_get_ipython_dir_6():
273 def test_get_ipython_dir_6():
275 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
274 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
276 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
275 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
277 os.mkdir(xdg)
276 os.mkdir(xdg)
278 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
277 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
279 path.get_home_dir = lambda : HOME_TEST_DIR
278 path.get_home_dir = lambda : HOME_TEST_DIR
280 path.get_xdg_dir = lambda : xdg
279 path.get_xdg_dir = lambda : xdg
281 os.name = "posix"
280 os.name = "posix"
282 env.pop('IPYTHON_DIR', None)
281 env.pop('IPYTHON_DIR', None)
283 env.pop('IPYTHONDIR', None)
282 env.pop('IPYTHONDIR', None)
284 env.pop('XDG_CONFIG_HOME', None)
283 env.pop('XDG_CONFIG_HOME', None)
285 xdg_ipdir = os.path.join(xdg, "ipython")
284 xdg_ipdir = os.path.join(xdg, "ipython")
286 ipdir = path.get_ipython_dir()
285 ipdir = path.get_ipython_dir()
287 nt.assert_equal(ipdir, xdg_ipdir)
286 nt.assert_equal(ipdir, xdg_ipdir)
288
287
289 @with_environment
288 @with_environment
290 def test_get_ipython_dir_7():
289 def test_get_ipython_dir_7():
291 """test_get_ipython_dir_7, test home directory expansion on IPYTHONDIR"""
290 """test_get_ipython_dir_7, test home directory expansion on IPYTHONDIR"""
292 path._writable_dir = lambda path: True
291 path._writable_dir = lambda path: True
293 home_dir = os.path.normpath(os.path.expanduser('~'))
292 home_dir = os.path.normpath(os.path.expanduser('~'))
294 env['IPYTHONDIR'] = os.path.join('~', 'somewhere')
293 env['IPYTHONDIR'] = os.path.join('~', 'somewhere')
295 ipdir = path.get_ipython_dir()
294 ipdir = path.get_ipython_dir()
296 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
295 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
297
296
298 @skip_win32
297 @skip_win32
299 @with_environment
298 @with_environment
300 def test_get_ipython_dir_8():
299 def test_get_ipython_dir_8():
301 """test_get_ipython_dir_8, test / home directory"""
300 """test_get_ipython_dir_8, test / home directory"""
302 old = path._writable_dir, path.get_xdg_dir
301 old = path._writable_dir, path.get_xdg_dir
303 try:
302 try:
304 path._writable_dir = lambda path: bool(path)
303 path._writable_dir = lambda path: bool(path)
305 path.get_xdg_dir = lambda: None
304 path.get_xdg_dir = lambda: None
306 env.pop('IPYTHON_DIR', None)
305 env.pop('IPYTHON_DIR', None)
307 env.pop('IPYTHONDIR', None)
306 env.pop('IPYTHONDIR', None)
308 env['HOME'] = '/'
307 env['HOME'] = '/'
309 nt.assert_equal(path.get_ipython_dir(), '/.ipython')
308 nt.assert_equal(path.get_ipython_dir(), '/.ipython')
310 finally:
309 finally:
311 path._writable_dir, path.get_xdg_dir = old
310 path._writable_dir, path.get_xdg_dir = old
312
311
313 @with_environment
312 @with_environment
314 def test_get_xdg_dir_0():
313 def test_get_xdg_dir_0():
315 """test_get_xdg_dir_0, check xdg_dir"""
314 """test_get_xdg_dir_0, check xdg_dir"""
316 reload(path)
315 reload(path)
317 path._writable_dir = lambda path: True
316 path._writable_dir = lambda path: True
318 path.get_home_dir = lambda : 'somewhere'
317 path.get_home_dir = lambda : 'somewhere'
319 os.name = "posix"
318 os.name = "posix"
320 sys.platform = "linux2"
319 sys.platform = "linux2"
321 env.pop('IPYTHON_DIR', None)
320 env.pop('IPYTHON_DIR', None)
322 env.pop('IPYTHONDIR', None)
321 env.pop('IPYTHONDIR', None)
323 env.pop('XDG_CONFIG_HOME', None)
322 env.pop('XDG_CONFIG_HOME', None)
324
323
325 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
324 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
326
325
327
326
328 @with_environment
327 @with_environment
329 def test_get_xdg_dir_1():
328 def test_get_xdg_dir_1():
330 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
329 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
331 reload(path)
330 reload(path)
332 path.get_home_dir = lambda : HOME_TEST_DIR
331 path.get_home_dir = lambda : HOME_TEST_DIR
333 os.name = "posix"
332 os.name = "posix"
334 sys.platform = "linux2"
333 sys.platform = "linux2"
335 env.pop('IPYTHON_DIR', None)
334 env.pop('IPYTHON_DIR', None)
336 env.pop('IPYTHONDIR', None)
335 env.pop('IPYTHONDIR', None)
337 env.pop('XDG_CONFIG_HOME', None)
336 env.pop('XDG_CONFIG_HOME', None)
338 nt.assert_equal(path.get_xdg_dir(), None)
337 nt.assert_equal(path.get_xdg_dir(), None)
339
338
340 @with_environment
339 @with_environment
341 def test_get_xdg_dir_2():
340 def test_get_xdg_dir_2():
342 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
341 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
343 reload(path)
342 reload(path)
344 path.get_home_dir = lambda : HOME_TEST_DIR
343 path.get_home_dir = lambda : HOME_TEST_DIR
345 os.name = "posix"
344 os.name = "posix"
346 sys.platform = "linux2"
345 sys.platform = "linux2"
347 env.pop('IPYTHON_DIR', None)
346 env.pop('IPYTHON_DIR', None)
348 env.pop('IPYTHONDIR', None)
347 env.pop('IPYTHONDIR', None)
349 env.pop('XDG_CONFIG_HOME', None)
348 env.pop('XDG_CONFIG_HOME', None)
350 cfgdir=os.path.join(path.get_home_dir(), '.config')
349 cfgdir=os.path.join(path.get_home_dir(), '.config')
351 if not os.path.exists(cfgdir):
350 if not os.path.exists(cfgdir):
352 os.makedirs(cfgdir)
351 os.makedirs(cfgdir)
353
352
354 nt.assert_equal(path.get_xdg_dir(), cfgdir)
353 nt.assert_equal(path.get_xdg_dir(), cfgdir)
355
354
356 @with_environment
355 @with_environment
357 def test_get_xdg_dir_3():
356 def test_get_xdg_dir_3():
358 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
357 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
359 reload(path)
358 reload(path)
360 path.get_home_dir = lambda : HOME_TEST_DIR
359 path.get_home_dir = lambda : HOME_TEST_DIR
361 os.name = "posix"
360 os.name = "posix"
362 sys.platform = "darwin"
361 sys.platform = "darwin"
363 env.pop('IPYTHON_DIR', None)
362 env.pop('IPYTHON_DIR', None)
364 env.pop('IPYTHONDIR', None)
363 env.pop('IPYTHONDIR', None)
365 env.pop('XDG_CONFIG_HOME', None)
364 env.pop('XDG_CONFIG_HOME', None)
366 cfgdir=os.path.join(path.get_home_dir(), '.config')
365 cfgdir=os.path.join(path.get_home_dir(), '.config')
367 if not os.path.exists(cfgdir):
366 if not os.path.exists(cfgdir):
368 os.makedirs(cfgdir)
367 os.makedirs(cfgdir)
369
368
370 nt.assert_equal(path.get_xdg_dir(), None)
369 nt.assert_equal(path.get_xdg_dir(), None)
371
370
372 def test_filefind():
371 def test_filefind():
373 """Various tests for filefind"""
372 """Various tests for filefind"""
374 f = tempfile.NamedTemporaryFile()
373 f = tempfile.NamedTemporaryFile()
375 # print 'fname:',f.name
374 # print 'fname:',f.name
376 alt_dirs = path.get_ipython_dir()
375 alt_dirs = path.get_ipython_dir()
377 t = path.filefind(f.name, alt_dirs)
376 t = path.filefind(f.name, alt_dirs)
378 # print 'found:',t
377 # print 'found:',t
379
378
380 @with_environment
379 @with_environment
381 def test_get_ipython_cache_dir():
380 def test_get_ipython_cache_dir():
382 os.environ["HOME"] = HOME_TEST_DIR
381 os.environ["HOME"] = HOME_TEST_DIR
383 if os.name == 'posix' and sys.platform != 'darwin':
382 if os.name == 'posix' and sys.platform != 'darwin':
384 # test default
383 # test default
385 os.makedirs(os.path.join(HOME_TEST_DIR, ".cache"))
384 os.makedirs(os.path.join(HOME_TEST_DIR, ".cache"))
386 os.environ.pop("XDG_CACHE_HOME", None)
385 os.environ.pop("XDG_CACHE_HOME", None)
387 ipdir = path.get_ipython_cache_dir()
386 ipdir = path.get_ipython_cache_dir()
388 nt.assert_equal(os.path.join(HOME_TEST_DIR, ".cache", "ipython"),
387 nt.assert_equal(os.path.join(HOME_TEST_DIR, ".cache", "ipython"),
389 ipdir)
388 ipdir)
390 nt.assert_true(os.path.isdir(ipdir))
389 nt.assert_true(os.path.isdir(ipdir))
391
390
392 # test env override
391 # test env override
393 os.environ["XDG_CACHE_HOME"] = XDG_CACHE_DIR
392 os.environ["XDG_CACHE_HOME"] = XDG_CACHE_DIR
394 ipdir = path.get_ipython_cache_dir()
393 ipdir = path.get_ipython_cache_dir()
395 nt.assert_true(os.path.isdir(ipdir))
394 nt.assert_true(os.path.isdir(ipdir))
396 nt.assert_equal(ipdir, os.path.join(XDG_CACHE_DIR, "ipython"))
395 nt.assert_equal(ipdir, os.path.join(XDG_CACHE_DIR, "ipython"))
397 else:
396 else:
398 nt.assert_equal(path.get_ipython_cache_dir(),
397 nt.assert_equal(path.get_ipython_cache_dir(),
399 path.get_ipython_dir())
398 path.get_ipython_dir())
400
399
401 def test_get_ipython_package_dir():
400 def test_get_ipython_package_dir():
402 ipdir = path.get_ipython_package_dir()
401 ipdir = path.get_ipython_package_dir()
403 nt.assert_true(os.path.isdir(ipdir))
402 nt.assert_true(os.path.isdir(ipdir))
404
403
405
404
406 def test_get_ipython_module_path():
405 def test_get_ipython_module_path():
407 ipapp_path = path.get_ipython_module_path('IPython.terminal.ipapp')
406 ipapp_path = path.get_ipython_module_path('IPython.terminal.ipapp')
408 nt.assert_true(os.path.isfile(ipapp_path))
407 nt.assert_true(os.path.isfile(ipapp_path))
409
408
410
409
411 @dec.skip_if_not_win32
410 @dec.skip_if_not_win32
412 def test_get_long_path_name_win32():
411 def test_get_long_path_name_win32():
413 p = path.get_long_path_name('c:\\docume~1')
412 p = path.get_long_path_name('c:\\docume~1')
414 nt.assert_equal(p,u'c:\\Documents and Settings')
413 nt.assert_equal(p,u'c:\\Documents and Settings')
415
414
416
415
417 @dec.skip_win32
416 @dec.skip_win32
418 def test_get_long_path_name():
417 def test_get_long_path_name():
419 p = path.get_long_path_name('/usr/local')
418 p = path.get_long_path_name('/usr/local')
420 nt.assert_equal(p,'/usr/local')
419 nt.assert_equal(p,'/usr/local')
421
420
422 @dec.skip_win32 # can't create not-user-writable dir on win
421 @dec.skip_win32 # can't create not-user-writable dir on win
423 @with_environment
422 @with_environment
424 def test_not_writable_ipdir():
423 def test_not_writable_ipdir():
425 tmpdir = tempfile.mkdtemp()
424 tmpdir = tempfile.mkdtemp()
426 os.name = "posix"
425 os.name = "posix"
427 env.pop('IPYTHON_DIR', None)
426 env.pop('IPYTHON_DIR', None)
428 env.pop('IPYTHONDIR', None)
427 env.pop('IPYTHONDIR', None)
429 env.pop('XDG_CONFIG_HOME', None)
428 env.pop('XDG_CONFIG_HOME', None)
430 env['HOME'] = tmpdir
429 env['HOME'] = tmpdir
431 ipdir = os.path.join(tmpdir, '.ipython')
430 ipdir = os.path.join(tmpdir, '.ipython')
432 os.mkdir(ipdir)
431 os.mkdir(ipdir)
433 os.chmod(ipdir, 600)
432 os.chmod(ipdir, 600)
434 with AssertPrints('is not a writable location', channel='stderr'):
433 with AssertPrints('is not a writable location', channel='stderr'):
435 ipdir = path.get_ipython_dir()
434 ipdir = path.get_ipython_dir()
436 env.pop('IPYTHON_DIR', None)
435 env.pop('IPYTHON_DIR', None)
437
436
438 def test_unquote_filename():
437 def test_unquote_filename():
439 for win32 in (True, False):
438 for win32 in (True, False):
440 nt.assert_equal(path.unquote_filename('foo.py', win32=win32), 'foo.py')
439 nt.assert_equal(path.unquote_filename('foo.py', win32=win32), 'foo.py')
441 nt.assert_equal(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
440 nt.assert_equal(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
442 nt.assert_equal(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
441 nt.assert_equal(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
443 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
442 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
444 nt.assert_equal(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
443 nt.assert_equal(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
445 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
444 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
446 nt.assert_equal(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
445 nt.assert_equal(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
447 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
446 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
448 nt.assert_equal(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
447 nt.assert_equal(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
449 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
448 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
450
449
451 @with_environment
450 @with_environment
452 def test_get_py_filename():
451 def test_get_py_filename():
453 os.chdir(TMP_TEST_DIR)
452 os.chdir(TMP_TEST_DIR)
454 for win32 in (True, False):
453 for win32 in (True, False):
455 with make_tempfile('foo.py'):
454 with make_tempfile('foo.py'):
456 nt.assert_equal(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
455 nt.assert_equal(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
457 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo.py')
456 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo.py')
458 with make_tempfile('foo'):
457 with make_tempfile('foo'):
459 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo')
458 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo')
460 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
459 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
461 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
460 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
462 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
461 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
463 true_fn = 'foo with spaces.py'
462 true_fn = 'foo with spaces.py'
464 with make_tempfile(true_fn):
463 with make_tempfile(true_fn):
465 nt.assert_equal(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
464 nt.assert_equal(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
466 nt.assert_equal(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
465 nt.assert_equal(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
467 if win32:
466 if win32:
468 nt.assert_equal(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
467 nt.assert_equal(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
469 nt.assert_equal(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
468 nt.assert_equal(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
470 else:
469 else:
471 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
470 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
472 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
471 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
473
472
474 def test_unicode_in_filename():
473 def test_unicode_in_filename():
475 """When a file doesn't exist, the exception raised should be safe to call
474 """When a file doesn't exist, the exception raised should be safe to call
476 str() on - i.e. in Python 2 it must only have ASCII characters.
475 str() on - i.e. in Python 2 it must only have ASCII characters.
477
476
478 https://github.com/ipython/ipython/issues/875
477 https://github.com/ipython/ipython/issues/875
479 """
478 """
480 try:
479 try:
481 # these calls should not throw unicode encode exceptions
480 # these calls should not throw unicode encode exceptions
482 path.get_py_filename(u'fooéè.py', force_win32=False)
481 path.get_py_filename(u'fooéè.py', force_win32=False)
483 except IOError as ex:
482 except IOError as ex:
484 str(ex)
483 str(ex)
485
484
486
485
487 class TestShellGlob(object):
486 class TestShellGlob(object):
488
487
489 @classmethod
488 @classmethod
490 def setUpClass(cls):
489 def setUpClass(cls):
491 cls.filenames_start_with_a = map('a{0}'.format, range(3))
490 cls.filenames_start_with_a = map('a{0}'.format, range(3))
492 cls.filenames_end_with_b = map('{0}b'.format, range(3))
491 cls.filenames_end_with_b = map('{0}b'.format, range(3))
493 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
492 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
494 cls.tempdir = TemporaryDirectory()
493 cls.tempdir = TemporaryDirectory()
495 td = cls.tempdir.name
494 td = cls.tempdir.name
496
495
497 with cls.in_tempdir():
496 with cls.in_tempdir():
498 # Create empty files
497 # Create empty files
499 for fname in cls.filenames:
498 for fname in cls.filenames:
500 open(os.path.join(td, fname), 'w').close()
499 open(os.path.join(td, fname), 'w').close()
501
500
502 @classmethod
501 @classmethod
503 def tearDownClass(cls):
502 def tearDownClass(cls):
504 cls.tempdir.cleanup()
503 cls.tempdir.cleanup()
505
504
506 @classmethod
505 @classmethod
507 @contextmanager
506 @contextmanager
508 def in_tempdir(cls):
507 def in_tempdir(cls):
509 save = os.getcwdu()
508 save = os.getcwdu()
510 try:
509 try:
511 os.chdir(cls.tempdir.name)
510 os.chdir(cls.tempdir.name)
512 yield
511 yield
513 finally:
512 finally:
514 os.chdir(save)
513 os.chdir(save)
515
514
516 def check_match(self, patterns, matches):
515 def check_match(self, patterns, matches):
517 with self.in_tempdir():
516 with self.in_tempdir():
518 # glob returns unordered list. that's why sorted is required.
517 # glob returns unordered list. that's why sorted is required.
519 nt.assert_equals(sorted(path.shellglob(patterns)),
518 nt.assert_equals(sorted(path.shellglob(patterns)),
520 sorted(matches))
519 sorted(matches))
521
520
522 def common_cases(self):
521 def common_cases(self):
523 return [
522 return [
524 (['*'], self.filenames),
523 (['*'], self.filenames),
525 (['a*'], self.filenames_start_with_a),
524 (['a*'], self.filenames_start_with_a),
526 (['*c'], ['*c']),
525 (['*c'], ['*c']),
527 (['*', 'a*', '*b', '*c'], self.filenames
526 (['*', 'a*', '*b', '*c'], self.filenames
528 + self.filenames_start_with_a
527 + self.filenames_start_with_a
529 + self.filenames_end_with_b
528 + self.filenames_end_with_b
530 + ['*c']),
529 + ['*c']),
531 (['a[012]'], self.filenames_start_with_a),
530 (['a[012]'], self.filenames_start_with_a),
532 ]
531 ]
533
532
534 @skip_win32
533 @skip_win32
535 def test_match_posix(self):
534 def test_match_posix(self):
536 for (patterns, matches) in self.common_cases() + [
535 for (patterns, matches) in self.common_cases() + [
537 ([r'\*'], ['*']),
536 ([r'\*'], ['*']),
538 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
537 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
539 ([r'a\[012]'], ['a[012]']),
538 ([r'a\[012]'], ['a[012]']),
540 ]:
539 ]:
541 yield (self.check_match, patterns, matches)
540 yield (self.check_match, patterns, matches)
542
541
543 @skip_if_not_win32
542 @skip_if_not_win32
544 def test_match_windows(self):
543 def test_match_windows(self):
545 for (patterns, matches) in self.common_cases() + [
544 for (patterns, matches) in self.common_cases() + [
546 # In windows, backslash is interpreted as path
545 # In windows, backslash is interpreted as path
547 # separator. Therefore, you can't escape glob
546 # separator. Therefore, you can't escape glob
548 # using it.
547 # using it.
549 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
548 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
550 ([r'a\[012]'], [r'a\[012]']),
549 ([r'a\[012]'], [r'a\[012]']),
551 ]:
550 ]:
552 yield (self.check_match, patterns, matches)
551 yield (self.check_match, patterns, matches)
553
552
554
553
555 def test_unescape_glob():
554 def test_unescape_glob():
556 nt.assert_equals(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
555 nt.assert_equals(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
557 nt.assert_equals(path.unescape_glob(r'\\*'), r'\*')
556 nt.assert_equals(path.unescape_glob(r'\\*'), r'\*')
558 nt.assert_equals(path.unescape_glob(r'\\\*'), r'\*')
557 nt.assert_equals(path.unescape_glob(r'\\\*'), r'\*')
559 nt.assert_equals(path.unescape_glob(r'\\a'), r'\a')
558 nt.assert_equals(path.unescape_glob(r'\\a'), r'\a')
560 nt.assert_equals(path.unescape_glob(r'\a'), r'\a')
559 nt.assert_equals(path.unescape_glob(r'\a'), r'\a')
@@ -1,173 +1,170 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.text"""
2 """Tests for IPython.utils.text"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2011 The IPython Development Team
5 # Copyright (C) 2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import math
16 import math
17 import random
17 import random
18
18
19 import nose.tools as nt
19 import nose.tools as nt
20
20
21 from nose import with_setup
22
23 from IPython.testing import decorators as dec
24 from IPython.utils import text
21 from IPython.utils import text
25
22
26 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
27 # Globals
24 # Globals
28 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
29
26
30 def test_columnize():
27 def test_columnize():
31 """Basic columnize tests."""
28 """Basic columnize tests."""
32 size = 5
29 size = 5
33 items = [l*size for l in 'abc']
30 items = [l*size for l in 'abc']
34 out = text.columnize(items, displaywidth=80)
31 out = text.columnize(items, displaywidth=80)
35 nt.assert_equal(out, 'aaaaa bbbbb ccccc\n')
32 nt.assert_equal(out, 'aaaaa bbbbb ccccc\n')
36 out = text.columnize(items, displaywidth=12)
33 out = text.columnize(items, displaywidth=12)
37 nt.assert_equal(out, 'aaaaa ccccc\nbbbbb\n')
34 nt.assert_equal(out, 'aaaaa ccccc\nbbbbb\n')
38 out = text.columnize(items, displaywidth=10)
35 out = text.columnize(items, displaywidth=10)
39 nt.assert_equal(out, 'aaaaa\nbbbbb\nccccc\n')
36 nt.assert_equal(out, 'aaaaa\nbbbbb\nccccc\n')
40
37
41 def test_columnize_random():
38 def test_columnize_random():
42 """Test with random input to hopfully catch edge case """
39 """Test with random input to hopfully catch edge case """
43 for nitems in [random.randint(2,70) for i in range(2,20)]:
40 for nitems in [random.randint(2,70) for i in range(2,20)]:
44 displaywidth = random.randint(20,200)
41 displaywidth = random.randint(20,200)
45 rand_len = [random.randint(2,displaywidth) for i in range(nitems)]
42 rand_len = [random.randint(2,displaywidth) for i in range(nitems)]
46 items = ['x'*l for l in rand_len]
43 items = ['x'*l for l in rand_len]
47 out = text.columnize(items, displaywidth=displaywidth)
44 out = text.columnize(items, displaywidth=displaywidth)
48 longer_line = max([len(x) for x in out.split('\n')])
45 longer_line = max([len(x) for x in out.split('\n')])
49 longer_element = max(rand_len)
46 longer_element = max(rand_len)
50 if longer_line > displaywidth:
47 if longer_line > displaywidth:
51 print "Columnize displayed something lager than displaywidth : %s " % longer_line
48 print "Columnize displayed something lager than displaywidth : %s " % longer_line
52 print "longer element : %s " % longer_element
49 print "longer element : %s " % longer_element
53 print "displaywidth : %s " % displaywidth
50 print "displaywidth : %s " % displaywidth
54 print "number of element : %s " % nitems
51 print "number of element : %s " % nitems
55 print "size of each element :\n %s" % rand_len
52 print "size of each element :\n %s" % rand_len
56 assert False
53 assert False
57
54
58 def test_columnize_medium():
55 def test_columnize_medium():
59 """Test with inputs than shouldn't be wider tahn 80 """
56 """Test with inputs than shouldn't be wider tahn 80 """
60 size = 40
57 size = 40
61 items = [l*size for l in 'abc']
58 items = [l*size for l in 'abc']
62 out = text.columnize(items, displaywidth=80)
59 out = text.columnize(items, displaywidth=80)
63 nt.assert_equal(out, '\n'.join(items+['']))
60 nt.assert_equal(out, '\n'.join(items+['']))
64
61
65 def test_columnize_long():
62 def test_columnize_long():
66 """Test columnize with inputs longer than the display window"""
63 """Test columnize with inputs longer than the display window"""
67 size = 11
64 size = 11
68 items = [l*size for l in 'abc']
65 items = [l*size for l in 'abc']
69 out = text.columnize(items, displaywidth=size-1)
66 out = text.columnize(items, displaywidth=size-1)
70 nt.assert_equal(out, '\n'.join(items+['']))
67 nt.assert_equal(out, '\n'.join(items+['']))
71
68
72 def eval_formatter_check(f):
69 def eval_formatter_check(f):
73 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os, u=u"cafΓ©", b="cafΓ©")
70 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os, u=u"cafΓ©", b="cafΓ©")
74 s = f.format("{n} {n//4} {stuff.split()[0]}", **ns)
71 s = f.format("{n} {n//4} {stuff.split()[0]}", **ns)
75 nt.assert_equal(s, "12 3 hello")
72 nt.assert_equal(s, "12 3 hello")
76 s = f.format(' '.join(['{n//%i}'%i for i in range(1,8)]), **ns)
73 s = f.format(' '.join(['{n//%i}'%i for i in range(1,8)]), **ns)
77 nt.assert_equal(s, "12 6 4 3 2 2 1")
74 nt.assert_equal(s, "12 6 4 3 2 2 1")
78 s = f.format('{[n//i for i in range(1,8)]}', **ns)
75 s = f.format('{[n//i for i in range(1,8)]}', **ns)
79 nt.assert_equal(s, "[12, 6, 4, 3, 2, 2, 1]")
76 nt.assert_equal(s, "[12, 6, 4, 3, 2, 2, 1]")
80 s = f.format("{stuff!s}", **ns)
77 s = f.format("{stuff!s}", **ns)
81 nt.assert_equal(s, ns['stuff'])
78 nt.assert_equal(s, ns['stuff'])
82 s = f.format("{stuff!r}", **ns)
79 s = f.format("{stuff!r}", **ns)
83 nt.assert_equal(s, repr(ns['stuff']))
80 nt.assert_equal(s, repr(ns['stuff']))
84
81
85 # Check with unicode:
82 # Check with unicode:
86 s = f.format("{u}", **ns)
83 s = f.format("{u}", **ns)
87 nt.assert_equal(s, ns['u'])
84 nt.assert_equal(s, ns['u'])
88 # This decodes in a platform dependent manner, but it shouldn't error out
85 # This decodes in a platform dependent manner, but it shouldn't error out
89 s = f.format("{b}", **ns)
86 s = f.format("{b}", **ns)
90
87
91 nt.assert_raises(NameError, f.format, '{dne}', **ns)
88 nt.assert_raises(NameError, f.format, '{dne}', **ns)
92
89
93 def eval_formatter_slicing_check(f):
90 def eval_formatter_slicing_check(f):
94 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
91 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
95 s = f.format(" {stuff.split()[:]} ", **ns)
92 s = f.format(" {stuff.split()[:]} ", **ns)
96 nt.assert_equal(s, " ['hello', 'there'] ")
93 nt.assert_equal(s, " ['hello', 'there'] ")
97 s = f.format(" {stuff.split()[::-1]} ", **ns)
94 s = f.format(" {stuff.split()[::-1]} ", **ns)
98 nt.assert_equal(s, " ['there', 'hello'] ")
95 nt.assert_equal(s, " ['there', 'hello'] ")
99 s = f.format("{stuff[::2]}", **ns)
96 s = f.format("{stuff[::2]}", **ns)
100 nt.assert_equal(s, ns['stuff'][::2])
97 nt.assert_equal(s, ns['stuff'][::2])
101
98
102 nt.assert_raises(SyntaxError, f.format, "{n:x}", **ns)
99 nt.assert_raises(SyntaxError, f.format, "{n:x}", **ns)
103
100
104 def eval_formatter_no_slicing_check(f):
101 def eval_formatter_no_slicing_check(f):
105 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
102 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
106
103
107 s = f.format('{n:x} {pi**2:+f}', **ns)
104 s = f.format('{n:x} {pi**2:+f}', **ns)
108 nt.assert_equal(s, "c +9.869604")
105 nt.assert_equal(s, "c +9.869604")
109
106
110 s = f.format('{stuff[slice(1,4)]}', **ns)
107 s = f.format('{stuff[slice(1,4)]}', **ns)
111 nt.assert_equal(s, 'ell')
108 nt.assert_equal(s, 'ell')
112
109
113 nt.assert_raises(SyntaxError, f.format, "{a[:]}")
110 nt.assert_raises(SyntaxError, f.format, "{a[:]}")
114
111
115 def test_eval_formatter():
112 def test_eval_formatter():
116 f = text.EvalFormatter()
113 f = text.EvalFormatter()
117 eval_formatter_check(f)
114 eval_formatter_check(f)
118 eval_formatter_no_slicing_check(f)
115 eval_formatter_no_slicing_check(f)
119
116
120 def test_full_eval_formatter():
117 def test_full_eval_formatter():
121 f = text.FullEvalFormatter()
118 f = text.FullEvalFormatter()
122 eval_formatter_check(f)
119 eval_formatter_check(f)
123 eval_formatter_slicing_check(f)
120 eval_formatter_slicing_check(f)
124
121
125 def test_dollar_formatter():
122 def test_dollar_formatter():
126 f = text.DollarFormatter()
123 f = text.DollarFormatter()
127 eval_formatter_check(f)
124 eval_formatter_check(f)
128 eval_formatter_slicing_check(f)
125 eval_formatter_slicing_check(f)
129
126
130 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
127 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
131 s = f.format("$n", **ns)
128 s = f.format("$n", **ns)
132 nt.assert_equal(s, "12")
129 nt.assert_equal(s, "12")
133 s = f.format("$n.real", **ns)
130 s = f.format("$n.real", **ns)
134 nt.assert_equal(s, "12")
131 nt.assert_equal(s, "12")
135 s = f.format("$n/{stuff[:5]}", **ns)
132 s = f.format("$n/{stuff[:5]}", **ns)
136 nt.assert_equal(s, "12/hello")
133 nt.assert_equal(s, "12/hello")
137 s = f.format("$n $$HOME", **ns)
134 s = f.format("$n $$HOME", **ns)
138 nt.assert_equal(s, "12 $HOME")
135 nt.assert_equal(s, "12 $HOME")
139 s = f.format("${foo}", foo="HOME")
136 s = f.format("${foo}", foo="HOME")
140 nt.assert_equal(s, "$HOME")
137 nt.assert_equal(s, "$HOME")
141
138
142
139
143 def test_long_substr():
140 def test_long_substr():
144 data = ['hi']
141 data = ['hi']
145 nt.assert_equal(text.long_substr(data), 'hi')
142 nt.assert_equal(text.long_substr(data), 'hi')
146
143
147
144
148 def test_long_substr2():
145 def test_long_substr2():
149 data = ['abc', 'abd', 'abf', 'ab']
146 data = ['abc', 'abd', 'abf', 'ab']
150 nt.assert_equal(text.long_substr(data), 'ab')
147 nt.assert_equal(text.long_substr(data), 'ab')
151
148
152 def test_long_substr_empty():
149 def test_long_substr_empty():
153 data = []
150 data = []
154 nt.assert_equal(text.long_substr(data), '')
151 nt.assert_equal(text.long_substr(data), '')
155
152
156 def test_strip_email():
153 def test_strip_email():
157 src = """\
154 src = """\
158 >> >>> def f(x):
155 >> >>> def f(x):
159 >> ... return x+1
156 >> ... return x+1
160 >> ...
157 >> ...
161 >> >>> zz = f(2.5)"""
158 >> >>> zz = f(2.5)"""
162 cln = """\
159 cln = """\
163 >>> def f(x):
160 >>> def f(x):
164 ... return x+1
161 ... return x+1
165 ...
162 ...
166 >>> zz = f(2.5)"""
163 >>> zz = f(2.5)"""
167 nt.assert_equal(text.strip_email_quotes(src), cln)
164 nt.assert_equal(text.strip_email_quotes(src), cln)
168
165
169
166
170 def test_strip_email2():
167 def test_strip_email2():
171 src = '> > > list()'
168 src = '> > > list()'
172 cln = 'list()'
169 cln = 'list()'
173 nt.assert_equal(text.strip_email_quotes(src), cln)
170 nt.assert_equal(text.strip_email_quotes(src), cln)
@@ -1,148 +1,147 b''
1 """Some tests for the wildcard utilities."""
1 """Some tests for the wildcard utilities."""
2
2
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Library imports
4 # Library imports
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Stdlib
6 # Stdlib
7 import sys
8 import unittest
7 import unittest
9
8
10 # Our own
9 # Our own
11 from IPython.utils import wildcard
10 from IPython.utils import wildcard
12
11
13 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
14 # Globals for test
13 # Globals for test
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16
15
17 class obj_t(object):
16 class obj_t(object):
18 pass
17 pass
19
18
20 root = obj_t()
19 root = obj_t()
21 l = ["arna","abel","ABEL","active","bob","bark","abbot"]
20 l = ["arna","abel","ABEL","active","bob","bark","abbot"]
22 q = ["kate","loop","arne","vito","lucifer","koppel"]
21 q = ["kate","loop","arne","vito","lucifer","koppel"]
23 for x in l:
22 for x in l:
24 o = obj_t()
23 o = obj_t()
25 setattr(root,x,o)
24 setattr(root,x,o)
26 for y in q:
25 for y in q:
27 p = obj_t()
26 p = obj_t()
28 setattr(o,y,p)
27 setattr(o,y,p)
29 root._apan = obj_t()
28 root._apan = obj_t()
30 root._apan.a = 10
29 root._apan.a = 10
31 root._apan._a = 20
30 root._apan._a = 20
32 root._apan.__a = 20
31 root._apan.__a = 20
33 root.__anka = obj_t()
32 root.__anka = obj_t()
34 root.__anka.a = 10
33 root.__anka.a = 10
35 root.__anka._a = 20
34 root.__anka._a = 20
36 root.__anka.__a = 20
35 root.__anka.__a = 20
37
36
38 root._APAN = obj_t()
37 root._APAN = obj_t()
39 root._APAN.a = 10
38 root._APAN.a = 10
40 root._APAN._a = 20
39 root._APAN._a = 20
41 root._APAN.__a = 20
40 root._APAN.__a = 20
42 root.__ANKA = obj_t()
41 root.__ANKA = obj_t()
43 root.__ANKA.a = 10
42 root.__ANKA.a = 10
44 root.__ANKA._a = 20
43 root.__ANKA._a = 20
45 root.__ANKA.__a = 20
44 root.__ANKA.__a = 20
46
45
47 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
48 # Test cases
47 # Test cases
49 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
50
49
51 class Tests (unittest.TestCase):
50 class Tests (unittest.TestCase):
52 def test_case(self):
51 def test_case(self):
53 ns=root.__dict__
52 ns=root.__dict__
54 tests=[
53 tests=[
55 ("a*", ["abbot","abel","active","arna",]),
54 ("a*", ["abbot","abel","active","arna",]),
56 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",]),
55 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",]),
57 ("_a*", []),
56 ("_a*", []),
58 ("_*anka", ["__anka",]),
57 ("_*anka", ["__anka",]),
59 ("_*a*", ["__anka",]),
58 ("_*a*", ["__anka",]),
60 ]
59 ]
61 for pat,res in tests:
60 for pat,res in tests:
62 res.sort()
61 res.sort()
63 a=wildcard.list_namespace(ns,"all",pat,ignore_case=False,
62 a=wildcard.list_namespace(ns,"all",pat,ignore_case=False,
64 show_all=False).keys()
63 show_all=False).keys()
65 a.sort()
64 a.sort()
66 self.assertEqual(a,res)
65 self.assertEqual(a,res)
67
66
68 def test_case_showall(self):
67 def test_case_showall(self):
69 ns=root.__dict__
68 ns=root.__dict__
70 tests=[
69 tests=[
71 ("a*", ["abbot","abel","active","arna",]),
70 ("a*", ["abbot","abel","active","arna",]),
72 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",]),
71 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",]),
73 ("_a*", ["_apan"]),
72 ("_a*", ["_apan"]),
74 ("_*anka", ["__anka",]),
73 ("_*anka", ["__anka",]),
75 ("_*a*", ["__anka","_apan",]),
74 ("_*a*", ["__anka","_apan",]),
76 ]
75 ]
77 for pat,res in tests:
76 for pat,res in tests:
78 res.sort()
77 res.sort()
79 a=wildcard.list_namespace(ns,"all",pat,ignore_case=False,
78 a=wildcard.list_namespace(ns,"all",pat,ignore_case=False,
80 show_all=True).keys()
79 show_all=True).keys()
81 a.sort()
80 a.sort()
82 self.assertEqual(a,res)
81 self.assertEqual(a,res)
83
82
84
83
85 def test_nocase(self):
84 def test_nocase(self):
86 ns=root.__dict__
85 ns=root.__dict__
87 tests=[
86 tests=[
88 ("a*", ["abbot","abel","ABEL","active","arna",]),
87 ("a*", ["abbot","abel","ABEL","active","arna",]),
89 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",
88 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",
90 "ABEL.koppel","ABEL.loop",]),
89 "ABEL.koppel","ABEL.loop",]),
91 ("_a*", []),
90 ("_a*", []),
92 ("_*anka", ["__anka","__ANKA",]),
91 ("_*anka", ["__anka","__ANKA",]),
93 ("_*a*", ["__anka","__ANKA",]),
92 ("_*a*", ["__anka","__ANKA",]),
94 ]
93 ]
95 for pat,res in tests:
94 for pat,res in tests:
96 res.sort()
95 res.sort()
97 a=wildcard.list_namespace(ns,"all",pat,ignore_case=True,
96 a=wildcard.list_namespace(ns,"all",pat,ignore_case=True,
98 show_all=False).keys()
97 show_all=False).keys()
99 a.sort()
98 a.sort()
100 self.assertEqual(a,res)
99 self.assertEqual(a,res)
101
100
102 def test_nocase_showall(self):
101 def test_nocase_showall(self):
103 ns=root.__dict__
102 ns=root.__dict__
104 tests=[
103 tests=[
105 ("a*", ["abbot","abel","ABEL","active","arna",]),
104 ("a*", ["abbot","abel","ABEL","active","arna",]),
106 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",
105 ("?b*.?o*",["abbot.koppel","abbot.loop","abel.koppel","abel.loop",
107 "ABEL.koppel","ABEL.loop",]),
106 "ABEL.koppel","ABEL.loop",]),
108 ("_a*", ["_apan","_APAN"]),
107 ("_a*", ["_apan","_APAN"]),
109 ("_*anka", ["__anka","__ANKA",]),
108 ("_*anka", ["__anka","__ANKA",]),
110 ("_*a*", ["__anka","__ANKA","_apan","_APAN"]),
109 ("_*a*", ["__anka","__ANKA","_apan","_APAN"]),
111 ]
110 ]
112 for pat,res in tests:
111 for pat,res in tests:
113 res.sort()
112 res.sort()
114 a=wildcard.list_namespace(ns,"all",pat,ignore_case=True,
113 a=wildcard.list_namespace(ns,"all",pat,ignore_case=True,
115 show_all=True).keys()
114 show_all=True).keys()
116 a.sort()
115 a.sort()
117 self.assertEqual(a,res)
116 self.assertEqual(a,res)
118
117
119 def test_dict_attributes(self):
118 def test_dict_attributes(self):
120 """Dictionaries should be indexed by attributes, not by keys. This was
119 """Dictionaries should be indexed by attributes, not by keys. This was
121 causing Github issue 129."""
120 causing Github issue 129."""
122 ns = {"az":{"king":55}, "pq":{1:0}}
121 ns = {"az":{"king":55}, "pq":{1:0}}
123 tests = [
122 tests = [
124 ("a*", ["az"]),
123 ("a*", ["az"]),
125 ("az.k*", ["az.keys"]),
124 ("az.k*", ["az.keys"]),
126 ("pq.k*", ["pq.keys"])
125 ("pq.k*", ["pq.keys"])
127 ]
126 ]
128 for pat, res in tests:
127 for pat, res in tests:
129 res.sort()
128 res.sort()
130 a = wildcard.list_namespace(ns, "all", pat, ignore_case=False,
129 a = wildcard.list_namespace(ns, "all", pat, ignore_case=False,
131 show_all=True).keys()
130 show_all=True).keys()
132 a.sort()
131 a.sort()
133 self.assertEqual(a, res)
132 self.assertEqual(a, res)
134
133
135 def test_dict_dir(self):
134 def test_dict_dir(self):
136 class A(object):
135 class A(object):
137 def __init__(self):
136 def __init__(self):
138 self.a = 1
137 self.a = 1
139 self.b = 2
138 self.b = 2
140 def __getattribute__(self, name):
139 def __getattribute__(self, name):
141 if name=="a":
140 if name=="a":
142 raise AttributeError
141 raise AttributeError
143 return object.__getattribute__(self, name)
142 return object.__getattribute__(self, name)
144
143
145 a = A()
144 a = A()
146 adict = wildcard.dict_dir(a)
145 adict = wildcard.dict_dir(a)
147 assert "a" not in adict # change to assertNotIn method in >= 2.7
146 assert "a" not in adict # change to assertNotIn method in >= 2.7
148 self.assertEqual(adict["b"], 2)
147 self.assertEqual(adict["b"], 2)
@@ -1,717 +1,713 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with strings and text.
3 Utilities for working with strings and text.
4
4
5 Inheritance diagram:
5 Inheritance diagram:
6
6
7 .. inheritance-diagram:: IPython.utils.text
7 .. inheritance-diagram:: IPython.utils.text
8 :parts: 3
8 :parts: 3
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import __main__
23
24 import os
22 import os
25 import re
23 import re
26 import sys
27 import textwrap
24 import textwrap
28 from string import Formatter
25 from string import Formatter
29
26
30 from IPython.external.path import path
27 from IPython.external.path import path
31 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
28 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
32 from IPython.utils import py3compat
29 from IPython.utils import py3compat
33 from IPython.utils.data import flatten
34
30
35 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
36 # Code
32 # Code
37 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
38
34
39 class LSString(str):
35 class LSString(str):
40 """String derivative with a special access attributes.
36 """String derivative with a special access attributes.
41
37
42 These are normal strings, but with the special attributes:
38 These are normal strings, but with the special attributes:
43
39
44 .l (or .list) : value as list (split on newlines).
40 .l (or .list) : value as list (split on newlines).
45 .n (or .nlstr): original value (the string itself).
41 .n (or .nlstr): original value (the string itself).
46 .s (or .spstr): value as whitespace-separated string.
42 .s (or .spstr): value as whitespace-separated string.
47 .p (or .paths): list of path objects
43 .p (or .paths): list of path objects
48
44
49 Any values which require transformations are computed only once and
45 Any values which require transformations are computed only once and
50 cached.
46 cached.
51
47
52 Such strings are very useful to efficiently interact with the shell, which
48 Such strings are very useful to efficiently interact with the shell, which
53 typically only understands whitespace-separated options for commands."""
49 typically only understands whitespace-separated options for commands."""
54
50
55 def get_list(self):
51 def get_list(self):
56 try:
52 try:
57 return self.__list
53 return self.__list
58 except AttributeError:
54 except AttributeError:
59 self.__list = self.split('\n')
55 self.__list = self.split('\n')
60 return self.__list
56 return self.__list
61
57
62 l = list = property(get_list)
58 l = list = property(get_list)
63
59
64 def get_spstr(self):
60 def get_spstr(self):
65 try:
61 try:
66 return self.__spstr
62 return self.__spstr
67 except AttributeError:
63 except AttributeError:
68 self.__spstr = self.replace('\n',' ')
64 self.__spstr = self.replace('\n',' ')
69 return self.__spstr
65 return self.__spstr
70
66
71 s = spstr = property(get_spstr)
67 s = spstr = property(get_spstr)
72
68
73 def get_nlstr(self):
69 def get_nlstr(self):
74 return self
70 return self
75
71
76 n = nlstr = property(get_nlstr)
72 n = nlstr = property(get_nlstr)
77
73
78 def get_paths(self):
74 def get_paths(self):
79 try:
75 try:
80 return self.__paths
76 return self.__paths
81 except AttributeError:
77 except AttributeError:
82 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
78 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
83 return self.__paths
79 return self.__paths
84
80
85 p = paths = property(get_paths)
81 p = paths = property(get_paths)
86
82
87 # FIXME: We need to reimplement type specific displayhook and then add this
83 # FIXME: We need to reimplement type specific displayhook and then add this
88 # back as a custom printer. This should also be moved outside utils into the
84 # back as a custom printer. This should also be moved outside utils into the
89 # core.
85 # core.
90
86
91 # def print_lsstring(arg):
87 # def print_lsstring(arg):
92 # """ Prettier (non-repr-like) and more informative printer for LSString """
88 # """ Prettier (non-repr-like) and more informative printer for LSString """
93 # print "LSString (.p, .n, .l, .s available). Value:"
89 # print "LSString (.p, .n, .l, .s available). Value:"
94 # print arg
90 # print arg
95 #
91 #
96 #
92 #
97 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
93 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
98
94
99
95
100 class SList(list):
96 class SList(list):
101 """List derivative with a special access attributes.
97 """List derivative with a special access attributes.
102
98
103 These are normal lists, but with the special attributes:
99 These are normal lists, but with the special attributes:
104
100
105 .l (or .list) : value as list (the list itself).
101 .l (or .list) : value as list (the list itself).
106 .n (or .nlstr): value as a string, joined on newlines.
102 .n (or .nlstr): value as a string, joined on newlines.
107 .s (or .spstr): value as a string, joined on spaces.
103 .s (or .spstr): value as a string, joined on spaces.
108 .p (or .paths): list of path objects
104 .p (or .paths): list of path objects
109
105
110 Any values which require transformations are computed only once and
106 Any values which require transformations are computed only once and
111 cached."""
107 cached."""
112
108
113 def get_list(self):
109 def get_list(self):
114 return self
110 return self
115
111
116 l = list = property(get_list)
112 l = list = property(get_list)
117
113
118 def get_spstr(self):
114 def get_spstr(self):
119 try:
115 try:
120 return self.__spstr
116 return self.__spstr
121 except AttributeError:
117 except AttributeError:
122 self.__spstr = ' '.join(self)
118 self.__spstr = ' '.join(self)
123 return self.__spstr
119 return self.__spstr
124
120
125 s = spstr = property(get_spstr)
121 s = spstr = property(get_spstr)
126
122
127 def get_nlstr(self):
123 def get_nlstr(self):
128 try:
124 try:
129 return self.__nlstr
125 return self.__nlstr
130 except AttributeError:
126 except AttributeError:
131 self.__nlstr = '\n'.join(self)
127 self.__nlstr = '\n'.join(self)
132 return self.__nlstr
128 return self.__nlstr
133
129
134 n = nlstr = property(get_nlstr)
130 n = nlstr = property(get_nlstr)
135
131
136 def get_paths(self):
132 def get_paths(self):
137 try:
133 try:
138 return self.__paths
134 return self.__paths
139 except AttributeError:
135 except AttributeError:
140 self.__paths = [path(p) for p in self if os.path.exists(p)]
136 self.__paths = [path(p) for p in self if os.path.exists(p)]
141 return self.__paths
137 return self.__paths
142
138
143 p = paths = property(get_paths)
139 p = paths = property(get_paths)
144
140
145 def grep(self, pattern, prune = False, field = None):
141 def grep(self, pattern, prune = False, field = None):
146 """ Return all strings matching 'pattern' (a regex or callable)
142 """ Return all strings matching 'pattern' (a regex or callable)
147
143
148 This is case-insensitive. If prune is true, return all items
144 This is case-insensitive. If prune is true, return all items
149 NOT matching the pattern.
145 NOT matching the pattern.
150
146
151 If field is specified, the match must occur in the specified
147 If field is specified, the match must occur in the specified
152 whitespace-separated field.
148 whitespace-separated field.
153
149
154 Examples::
150 Examples::
155
151
156 a.grep( lambda x: x.startswith('C') )
152 a.grep( lambda x: x.startswith('C') )
157 a.grep('Cha.*log', prune=1)
153 a.grep('Cha.*log', prune=1)
158 a.grep('chm', field=-1)
154 a.grep('chm', field=-1)
159 """
155 """
160
156
161 def match_target(s):
157 def match_target(s):
162 if field is None:
158 if field is None:
163 return s
159 return s
164 parts = s.split()
160 parts = s.split()
165 try:
161 try:
166 tgt = parts[field]
162 tgt = parts[field]
167 return tgt
163 return tgt
168 except IndexError:
164 except IndexError:
169 return ""
165 return ""
170
166
171 if isinstance(pattern, basestring):
167 if isinstance(pattern, basestring):
172 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
168 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
173 else:
169 else:
174 pred = pattern
170 pred = pattern
175 if not prune:
171 if not prune:
176 return SList([el for el in self if pred(match_target(el))])
172 return SList([el for el in self if pred(match_target(el))])
177 else:
173 else:
178 return SList([el for el in self if not pred(match_target(el))])
174 return SList([el for el in self if not pred(match_target(el))])
179
175
180 def fields(self, *fields):
176 def fields(self, *fields):
181 """ Collect whitespace-separated fields from string list
177 """ Collect whitespace-separated fields from string list
182
178
183 Allows quick awk-like usage of string lists.
179 Allows quick awk-like usage of string lists.
184
180
185 Example data (in var a, created by 'a = !ls -l')::
181 Example data (in var a, created by 'a = !ls -l')::
186 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
182 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
187 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
183 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
188
184
189 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
185 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
190 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
186 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
191 (note the joining by space).
187 (note the joining by space).
192 a.fields(-1) is ['ChangeLog', 'IPython']
188 a.fields(-1) is ['ChangeLog', 'IPython']
193
189
194 IndexErrors are ignored.
190 IndexErrors are ignored.
195
191
196 Without args, fields() just split()'s the strings.
192 Without args, fields() just split()'s the strings.
197 """
193 """
198 if len(fields) == 0:
194 if len(fields) == 0:
199 return [el.split() for el in self]
195 return [el.split() for el in self]
200
196
201 res = SList()
197 res = SList()
202 for el in [f.split() for f in self]:
198 for el in [f.split() for f in self]:
203 lineparts = []
199 lineparts = []
204
200
205 for fd in fields:
201 for fd in fields:
206 try:
202 try:
207 lineparts.append(el[fd])
203 lineparts.append(el[fd])
208 except IndexError:
204 except IndexError:
209 pass
205 pass
210 if lineparts:
206 if lineparts:
211 res.append(" ".join(lineparts))
207 res.append(" ".join(lineparts))
212
208
213 return res
209 return res
214
210
215 def sort(self,field= None, nums = False):
211 def sort(self,field= None, nums = False):
216 """ sort by specified fields (see fields())
212 """ sort by specified fields (see fields())
217
213
218 Example::
214 Example::
219 a.sort(1, nums = True)
215 a.sort(1, nums = True)
220
216
221 Sorts a by second field, in numerical order (so that 21 > 3)
217 Sorts a by second field, in numerical order (so that 21 > 3)
222
218
223 """
219 """
224
220
225 #decorate, sort, undecorate
221 #decorate, sort, undecorate
226 if field is not None:
222 if field is not None:
227 dsu = [[SList([line]).fields(field), line] for line in self]
223 dsu = [[SList([line]).fields(field), line] for line in self]
228 else:
224 else:
229 dsu = [[line, line] for line in self]
225 dsu = [[line, line] for line in self]
230 if nums:
226 if nums:
231 for i in range(len(dsu)):
227 for i in range(len(dsu)):
232 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
228 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
233 try:
229 try:
234 n = int(numstr)
230 n = int(numstr)
235 except ValueError:
231 except ValueError:
236 n = 0;
232 n = 0;
237 dsu[i][0] = n
233 dsu[i][0] = n
238
234
239
235
240 dsu.sort()
236 dsu.sort()
241 return SList([t[1] for t in dsu])
237 return SList([t[1] for t in dsu])
242
238
243
239
244 # FIXME: We need to reimplement type specific displayhook and then add this
240 # FIXME: We need to reimplement type specific displayhook and then add this
245 # back as a custom printer. This should also be moved outside utils into the
241 # back as a custom printer. This should also be moved outside utils into the
246 # core.
242 # core.
247
243
248 # def print_slist(arg):
244 # def print_slist(arg):
249 # """ Prettier (non-repr-like) and more informative printer for SList """
245 # """ Prettier (non-repr-like) and more informative printer for SList """
250 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
246 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
251 # if hasattr(arg, 'hideonce') and arg.hideonce:
247 # if hasattr(arg, 'hideonce') and arg.hideonce:
252 # arg.hideonce = False
248 # arg.hideonce = False
253 # return
249 # return
254 #
250 #
255 # nlprint(arg) # This was a nested list printer, now removed.
251 # nlprint(arg) # This was a nested list printer, now removed.
256 #
252 #
257 # print_slist = result_display.when_type(SList)(print_slist)
253 # print_slist = result_display.when_type(SList)(print_slist)
258
254
259
255
260 def indent(instr,nspaces=4, ntabs=0, flatten=False):
256 def indent(instr,nspaces=4, ntabs=0, flatten=False):
261 """Indent a string a given number of spaces or tabstops.
257 """Indent a string a given number of spaces or tabstops.
262
258
263 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
259 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
264
260
265 Parameters
261 Parameters
266 ----------
262 ----------
267
263
268 instr : basestring
264 instr : basestring
269 The string to be indented.
265 The string to be indented.
270 nspaces : int (default: 4)
266 nspaces : int (default: 4)
271 The number of spaces to be indented.
267 The number of spaces to be indented.
272 ntabs : int (default: 0)
268 ntabs : int (default: 0)
273 The number of tabs to be indented.
269 The number of tabs to be indented.
274 flatten : bool (default: False)
270 flatten : bool (default: False)
275 Whether to scrub existing indentation. If True, all lines will be
271 Whether to scrub existing indentation. If True, all lines will be
276 aligned to the same indentation. If False, existing indentation will
272 aligned to the same indentation. If False, existing indentation will
277 be strictly increased.
273 be strictly increased.
278
274
279 Returns
275 Returns
280 -------
276 -------
281
277
282 str|unicode : string indented by ntabs and nspaces.
278 str|unicode : string indented by ntabs and nspaces.
283
279
284 """
280 """
285 if instr is None:
281 if instr is None:
286 return
282 return
287 ind = '\t'*ntabs+' '*nspaces
283 ind = '\t'*ntabs+' '*nspaces
288 if flatten:
284 if flatten:
289 pat = re.compile(r'^\s*', re.MULTILINE)
285 pat = re.compile(r'^\s*', re.MULTILINE)
290 else:
286 else:
291 pat = re.compile(r'^', re.MULTILINE)
287 pat = re.compile(r'^', re.MULTILINE)
292 outstr = re.sub(pat, ind, instr)
288 outstr = re.sub(pat, ind, instr)
293 if outstr.endswith(os.linesep+ind):
289 if outstr.endswith(os.linesep+ind):
294 return outstr[:-len(ind)]
290 return outstr[:-len(ind)]
295 else:
291 else:
296 return outstr
292 return outstr
297
293
298
294
299 def list_strings(arg):
295 def list_strings(arg):
300 """Always return a list of strings, given a string or list of strings
296 """Always return a list of strings, given a string or list of strings
301 as input.
297 as input.
302
298
303 :Examples:
299 :Examples:
304
300
305 In [7]: list_strings('A single string')
301 In [7]: list_strings('A single string')
306 Out[7]: ['A single string']
302 Out[7]: ['A single string']
307
303
308 In [8]: list_strings(['A single string in a list'])
304 In [8]: list_strings(['A single string in a list'])
309 Out[8]: ['A single string in a list']
305 Out[8]: ['A single string in a list']
310
306
311 In [9]: list_strings(['A','list','of','strings'])
307 In [9]: list_strings(['A','list','of','strings'])
312 Out[9]: ['A', 'list', 'of', 'strings']
308 Out[9]: ['A', 'list', 'of', 'strings']
313 """
309 """
314
310
315 if isinstance(arg,basestring): return [arg]
311 if isinstance(arg,basestring): return [arg]
316 else: return arg
312 else: return arg
317
313
318
314
319 def marquee(txt='',width=78,mark='*'):
315 def marquee(txt='',width=78,mark='*'):
320 """Return the input string centered in a 'marquee'.
316 """Return the input string centered in a 'marquee'.
321
317
322 :Examples:
318 :Examples:
323
319
324 In [16]: marquee('A test',40)
320 In [16]: marquee('A test',40)
325 Out[16]: '**************** A test ****************'
321 Out[16]: '**************** A test ****************'
326
322
327 In [17]: marquee('A test',40,'-')
323 In [17]: marquee('A test',40,'-')
328 Out[17]: '---------------- A test ----------------'
324 Out[17]: '---------------- A test ----------------'
329
325
330 In [18]: marquee('A test',40,' ')
326 In [18]: marquee('A test',40,' ')
331 Out[18]: ' A test '
327 Out[18]: ' A test '
332
328
333 """
329 """
334 if not txt:
330 if not txt:
335 return (mark*width)[:width]
331 return (mark*width)[:width]
336 nmark = (width-len(txt)-2)//len(mark)//2
332 nmark = (width-len(txt)-2)//len(mark)//2
337 if nmark < 0: nmark =0
333 if nmark < 0: nmark =0
338 marks = mark*nmark
334 marks = mark*nmark
339 return '%s %s %s' % (marks,txt,marks)
335 return '%s %s %s' % (marks,txt,marks)
340
336
341
337
342 ini_spaces_re = re.compile(r'^(\s+)')
338 ini_spaces_re = re.compile(r'^(\s+)')
343
339
344 def num_ini_spaces(strng):
340 def num_ini_spaces(strng):
345 """Return the number of initial spaces in a string"""
341 """Return the number of initial spaces in a string"""
346
342
347 ini_spaces = ini_spaces_re.match(strng)
343 ini_spaces = ini_spaces_re.match(strng)
348 if ini_spaces:
344 if ini_spaces:
349 return ini_spaces.end()
345 return ini_spaces.end()
350 else:
346 else:
351 return 0
347 return 0
352
348
353
349
354 def format_screen(strng):
350 def format_screen(strng):
355 """Format a string for screen printing.
351 """Format a string for screen printing.
356
352
357 This removes some latex-type format codes."""
353 This removes some latex-type format codes."""
358 # Paragraph continue
354 # Paragraph continue
359 par_re = re.compile(r'\\$',re.MULTILINE)
355 par_re = re.compile(r'\\$',re.MULTILINE)
360 strng = par_re.sub('',strng)
356 strng = par_re.sub('',strng)
361 return strng
357 return strng
362
358
363
359
364 def dedent(text):
360 def dedent(text):
365 """Equivalent of textwrap.dedent that ignores unindented first line.
361 """Equivalent of textwrap.dedent that ignores unindented first line.
366
362
367 This means it will still dedent strings like:
363 This means it will still dedent strings like:
368 '''foo
364 '''foo
369 is a bar
365 is a bar
370 '''
366 '''
371
367
372 For use in wrap_paragraphs.
368 For use in wrap_paragraphs.
373 """
369 """
374
370
375 if text.startswith('\n'):
371 if text.startswith('\n'):
376 # text starts with blank line, don't ignore the first line
372 # text starts with blank line, don't ignore the first line
377 return textwrap.dedent(text)
373 return textwrap.dedent(text)
378
374
379 # split first line
375 # split first line
380 splits = text.split('\n',1)
376 splits = text.split('\n',1)
381 if len(splits) == 1:
377 if len(splits) == 1:
382 # only one line
378 # only one line
383 return textwrap.dedent(text)
379 return textwrap.dedent(text)
384
380
385 first, rest = splits
381 first, rest = splits
386 # dedent everything but the first line
382 # dedent everything but the first line
387 rest = textwrap.dedent(rest)
383 rest = textwrap.dedent(rest)
388 return '\n'.join([first, rest])
384 return '\n'.join([first, rest])
389
385
390
386
391 def wrap_paragraphs(text, ncols=80):
387 def wrap_paragraphs(text, ncols=80):
392 """Wrap multiple paragraphs to fit a specified width.
388 """Wrap multiple paragraphs to fit a specified width.
393
389
394 This is equivalent to textwrap.wrap, but with support for multiple
390 This is equivalent to textwrap.wrap, but with support for multiple
395 paragraphs, as separated by empty lines.
391 paragraphs, as separated by empty lines.
396
392
397 Returns
393 Returns
398 -------
394 -------
399
395
400 list of complete paragraphs, wrapped to fill `ncols` columns.
396 list of complete paragraphs, wrapped to fill `ncols` columns.
401 """
397 """
402 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
398 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
403 text = dedent(text).strip()
399 text = dedent(text).strip()
404 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
400 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
405 out_ps = []
401 out_ps = []
406 indent_re = re.compile(r'\n\s+', re.MULTILINE)
402 indent_re = re.compile(r'\n\s+', re.MULTILINE)
407 for p in paragraphs:
403 for p in paragraphs:
408 # presume indentation that survives dedent is meaningful formatting,
404 # presume indentation that survives dedent is meaningful formatting,
409 # so don't fill unless text is flush.
405 # so don't fill unless text is flush.
410 if indent_re.search(p) is None:
406 if indent_re.search(p) is None:
411 # wrap paragraph
407 # wrap paragraph
412 p = textwrap.fill(p, ncols)
408 p = textwrap.fill(p, ncols)
413 out_ps.append(p)
409 out_ps.append(p)
414 return out_ps
410 return out_ps
415
411
416
412
417 def long_substr(data):
413 def long_substr(data):
418 """Return the longest common substring in a list of strings.
414 """Return the longest common substring in a list of strings.
419
415
420 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
416 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
421 """
417 """
422 substr = ''
418 substr = ''
423 if len(data) > 1 and len(data[0]) > 0:
419 if len(data) > 1 and len(data[0]) > 0:
424 for i in range(len(data[0])):
420 for i in range(len(data[0])):
425 for j in range(len(data[0])-i+1):
421 for j in range(len(data[0])-i+1):
426 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
422 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
427 substr = data[0][i:i+j]
423 substr = data[0][i:i+j]
428 elif len(data) == 1:
424 elif len(data) == 1:
429 substr = data[0]
425 substr = data[0]
430 return substr
426 return substr
431
427
432
428
433 def strip_email_quotes(text):
429 def strip_email_quotes(text):
434 """Strip leading email quotation characters ('>').
430 """Strip leading email quotation characters ('>').
435
431
436 Removes any combination of leading '>' interspersed with whitespace that
432 Removes any combination of leading '>' interspersed with whitespace that
437 appears *identically* in all lines of the input text.
433 appears *identically* in all lines of the input text.
438
434
439 Parameters
435 Parameters
440 ----------
436 ----------
441 text : str
437 text : str
442
438
443 Examples
439 Examples
444 --------
440 --------
445
441
446 Simple uses::
442 Simple uses::
447
443
448 In [2]: strip_email_quotes('> > text')
444 In [2]: strip_email_quotes('> > text')
449 Out[2]: 'text'
445 Out[2]: 'text'
450
446
451 In [3]: strip_email_quotes('> > text\\n> > more')
447 In [3]: strip_email_quotes('> > text\\n> > more')
452 Out[3]: 'text\\nmore'
448 Out[3]: 'text\\nmore'
453
449
454 Note how only the common prefix that appears in all lines is stripped::
450 Note how only the common prefix that appears in all lines is stripped::
455
451
456 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
452 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
457 Out[4]: '> text\\n> more\\nmore...'
453 Out[4]: '> text\\n> more\\nmore...'
458
454
459 So if any line has no quote marks ('>') , then none are stripped from any
455 So if any line has no quote marks ('>') , then none are stripped from any
460 of them ::
456 of them ::
461
457
462 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
458 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
463 Out[5]: '> > text\\n> > more\\nlast different'
459 Out[5]: '> > text\\n> > more\\nlast different'
464 """
460 """
465 lines = text.splitlines()
461 lines = text.splitlines()
466 matches = set()
462 matches = set()
467 for line in lines:
463 for line in lines:
468 prefix = re.match(r'^(\s*>[ >]*)', line)
464 prefix = re.match(r'^(\s*>[ >]*)', line)
469 if prefix:
465 if prefix:
470 matches.add(prefix.group(1))
466 matches.add(prefix.group(1))
471 else:
467 else:
472 break
468 break
473 else:
469 else:
474 prefix = long_substr(list(matches))
470 prefix = long_substr(list(matches))
475 if prefix:
471 if prefix:
476 strip = len(prefix)
472 strip = len(prefix)
477 text = '\n'.join([ ln[strip:] for ln in lines])
473 text = '\n'.join([ ln[strip:] for ln in lines])
478 return text
474 return text
479
475
480
476
481 class EvalFormatter(Formatter):
477 class EvalFormatter(Formatter):
482 """A String Formatter that allows evaluation of simple expressions.
478 """A String Formatter that allows evaluation of simple expressions.
483
479
484 Note that this version interprets a : as specifying a format string (as per
480 Note that this version interprets a : as specifying a format string (as per
485 standard string formatting), so if slicing is required, you must explicitly
481 standard string formatting), so if slicing is required, you must explicitly
486 create a slice.
482 create a slice.
487
483
488 This is to be used in templating cases, such as the parallel batch
484 This is to be used in templating cases, such as the parallel batch
489 script templates, where simple arithmetic on arguments is useful.
485 script templates, where simple arithmetic on arguments is useful.
490
486
491 Examples
487 Examples
492 --------
488 --------
493
489
494 In [1]: f = EvalFormatter()
490 In [1]: f = EvalFormatter()
495 In [2]: f.format('{n//4}', n=8)
491 In [2]: f.format('{n//4}', n=8)
496 Out [2]: '2'
492 Out [2]: '2'
497
493
498 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
494 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
499 Out [3]: 'll'
495 Out [3]: 'll'
500 """
496 """
501 def get_field(self, name, args, kwargs):
497 def get_field(self, name, args, kwargs):
502 v = eval(name, kwargs)
498 v = eval(name, kwargs)
503 return v, name
499 return v, name
504
500
505
501
506 @skip_doctest_py3
502 @skip_doctest_py3
507 class FullEvalFormatter(Formatter):
503 class FullEvalFormatter(Formatter):
508 """A String Formatter that allows evaluation of simple expressions.
504 """A String Formatter that allows evaluation of simple expressions.
509
505
510 Any time a format key is not found in the kwargs,
506 Any time a format key is not found in the kwargs,
511 it will be tried as an expression in the kwargs namespace.
507 it will be tried as an expression in the kwargs namespace.
512
508
513 Note that this version allows slicing using [1:2], so you cannot specify
509 Note that this version allows slicing using [1:2], so you cannot specify
514 a format string. Use :class:`EvalFormatter` to permit format strings.
510 a format string. Use :class:`EvalFormatter` to permit format strings.
515
511
516 Examples
512 Examples
517 --------
513 --------
518
514
519 In [1]: f = FullEvalFormatter()
515 In [1]: f = FullEvalFormatter()
520 In [2]: f.format('{n//4}', n=8)
516 In [2]: f.format('{n//4}', n=8)
521 Out[2]: u'2'
517 Out[2]: u'2'
522
518
523 In [3]: f.format('{list(range(5))[2:4]}')
519 In [3]: f.format('{list(range(5))[2:4]}')
524 Out[3]: u'[2, 3]'
520 Out[3]: u'[2, 3]'
525
521
526 In [4]: f.format('{3*2}')
522 In [4]: f.format('{3*2}')
527 Out[4]: u'6'
523 Out[4]: u'6'
528 """
524 """
529 # copied from Formatter._vformat with minor changes to allow eval
525 # copied from Formatter._vformat with minor changes to allow eval
530 # and replace the format_spec code with slicing
526 # and replace the format_spec code with slicing
531 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
527 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
532 if recursion_depth < 0:
528 if recursion_depth < 0:
533 raise ValueError('Max string recursion exceeded')
529 raise ValueError('Max string recursion exceeded')
534 result = []
530 result = []
535 for literal_text, field_name, format_spec, conversion in \
531 for literal_text, field_name, format_spec, conversion in \
536 self.parse(format_string):
532 self.parse(format_string):
537
533
538 # output the literal text
534 # output the literal text
539 if literal_text:
535 if literal_text:
540 result.append(literal_text)
536 result.append(literal_text)
541
537
542 # if there's a field, output it
538 # if there's a field, output it
543 if field_name is not None:
539 if field_name is not None:
544 # this is some markup, find the object and do
540 # this is some markup, find the object and do
545 # the formatting
541 # the formatting
546
542
547 if format_spec:
543 if format_spec:
548 # override format spec, to allow slicing:
544 # override format spec, to allow slicing:
549 field_name = ':'.join([field_name, format_spec])
545 field_name = ':'.join([field_name, format_spec])
550
546
551 # eval the contents of the field for the object
547 # eval the contents of the field for the object
552 # to be formatted
548 # to be formatted
553 obj = eval(field_name, kwargs)
549 obj = eval(field_name, kwargs)
554
550
555 # do any conversion on the resulting object
551 # do any conversion on the resulting object
556 obj = self.convert_field(obj, conversion)
552 obj = self.convert_field(obj, conversion)
557
553
558 # format the object and append to the result
554 # format the object and append to the result
559 result.append(self.format_field(obj, ''))
555 result.append(self.format_field(obj, ''))
560
556
561 return u''.join(py3compat.cast_unicode(s) for s in result)
557 return u''.join(py3compat.cast_unicode(s) for s in result)
562
558
563
559
564 @skip_doctest_py3
560 @skip_doctest_py3
565 class DollarFormatter(FullEvalFormatter):
561 class DollarFormatter(FullEvalFormatter):
566 """Formatter allowing Itpl style $foo replacement, for names and attribute
562 """Formatter allowing Itpl style $foo replacement, for names and attribute
567 access only. Standard {foo} replacement also works, and allows full
563 access only. Standard {foo} replacement also works, and allows full
568 evaluation of its arguments.
564 evaluation of its arguments.
569
565
570 Examples
566 Examples
571 --------
567 --------
572 In [1]: f = DollarFormatter()
568 In [1]: f = DollarFormatter()
573 In [2]: f.format('{n//4}', n=8)
569 In [2]: f.format('{n//4}', n=8)
574 Out[2]: u'2'
570 Out[2]: u'2'
575
571
576 In [3]: f.format('23 * 76 is $result', result=23*76)
572 In [3]: f.format('23 * 76 is $result', result=23*76)
577 Out[3]: u'23 * 76 is 1748'
573 Out[3]: u'23 * 76 is 1748'
578
574
579 In [4]: f.format('$a or {b}', a=1, b=2)
575 In [4]: f.format('$a or {b}', a=1, b=2)
580 Out[4]: u'1 or 2'
576 Out[4]: u'1 or 2'
581 """
577 """
582 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
578 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
583 def parse(self, fmt_string):
579 def parse(self, fmt_string):
584 for literal_txt, field_name, format_spec, conversion \
580 for literal_txt, field_name, format_spec, conversion \
585 in Formatter.parse(self, fmt_string):
581 in Formatter.parse(self, fmt_string):
586
582
587 # Find $foo patterns in the literal text.
583 # Find $foo patterns in the literal text.
588 continue_from = 0
584 continue_from = 0
589 txt = ""
585 txt = ""
590 for m in self._dollar_pattern.finditer(literal_txt):
586 for m in self._dollar_pattern.finditer(literal_txt):
591 new_txt, new_field = m.group(1,2)
587 new_txt, new_field = m.group(1,2)
592 # $$foo --> $foo
588 # $$foo --> $foo
593 if new_field.startswith("$"):
589 if new_field.startswith("$"):
594 txt += new_txt + new_field
590 txt += new_txt + new_field
595 else:
591 else:
596 yield (txt + new_txt, new_field, "", None)
592 yield (txt + new_txt, new_field, "", None)
597 txt = ""
593 txt = ""
598 continue_from = m.end()
594 continue_from = m.end()
599
595
600 # Re-yield the {foo} style pattern
596 # Re-yield the {foo} style pattern
601 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
597 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
602
598
603 #-----------------------------------------------------------------------------
599 #-----------------------------------------------------------------------------
604 # Utils to columnize a list of string
600 # Utils to columnize a list of string
605 #-----------------------------------------------------------------------------
601 #-----------------------------------------------------------------------------
606
602
607 def _chunks(l, n):
603 def _chunks(l, n):
608 """Yield successive n-sized chunks from l."""
604 """Yield successive n-sized chunks from l."""
609 for i in xrange(0, len(l), n):
605 for i in xrange(0, len(l), n):
610 yield l[i:i+n]
606 yield l[i:i+n]
611
607
612
608
613 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
609 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
614 """Calculate optimal info to columnize a list of string"""
610 """Calculate optimal info to columnize a list of string"""
615 for nrow in range(1, len(rlist)+1) :
611 for nrow in range(1, len(rlist)+1) :
616 chk = map(max,_chunks(rlist, nrow))
612 chk = map(max,_chunks(rlist, nrow))
617 sumlength = sum(chk)
613 sumlength = sum(chk)
618 ncols = len(chk)
614 ncols = len(chk)
619 if sumlength+separator_size*(ncols-1) <= displaywidth :
615 if sumlength+separator_size*(ncols-1) <= displaywidth :
620 break;
616 break;
621 return {'columns_numbers' : ncols,
617 return {'columns_numbers' : ncols,
622 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
618 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
623 'rows_numbers' : nrow,
619 'rows_numbers' : nrow,
624 'columns_width' : chk
620 'columns_width' : chk
625 }
621 }
626
622
627
623
628 def _get_or_default(mylist, i, default=None):
624 def _get_or_default(mylist, i, default=None):
629 """return list item number, or default if don't exist"""
625 """return list item number, or default if don't exist"""
630 if i >= len(mylist):
626 if i >= len(mylist):
631 return default
627 return default
632 else :
628 else :
633 return mylist[i]
629 return mylist[i]
634
630
635
631
636 @skip_doctest
632 @skip_doctest
637 def compute_item_matrix(items, empty=None, *args, **kwargs) :
633 def compute_item_matrix(items, empty=None, *args, **kwargs) :
638 """Returns a nested list, and info to columnize items
634 """Returns a nested list, and info to columnize items
639
635
640 Parameters
636 Parameters
641 ----------
637 ----------
642
638
643 items :
639 items :
644 list of strings to columize
640 list of strings to columize
645 empty : (default None)
641 empty : (default None)
646 default value to fill list if needed
642 default value to fill list if needed
647 separator_size : int (default=2)
643 separator_size : int (default=2)
648 How much caracters will be used as a separation between each columns.
644 How much caracters will be used as a separation between each columns.
649 displaywidth : int (default=80)
645 displaywidth : int (default=80)
650 The width of the area onto wich the columns should enter
646 The width of the area onto wich the columns should enter
651
647
652 Returns
648 Returns
653 -------
649 -------
654
650
655 Returns a tuple of (strings_matrix, dict_info)
651 Returns a tuple of (strings_matrix, dict_info)
656
652
657 strings_matrix :
653 strings_matrix :
658
654
659 nested list of string, the outer most list contains as many list as
655 nested list of string, the outer most list contains as many list as
660 rows, the innermost lists have each as many element as colums. If the
656 rows, the innermost lists have each as many element as colums. If the
661 total number of elements in `items` does not equal the product of
657 total number of elements in `items` does not equal the product of
662 rows*columns, the last element of some lists are filled with `None`.
658 rows*columns, the last element of some lists are filled with `None`.
663
659
664 dict_info :
660 dict_info :
665 some info to make columnize easier:
661 some info to make columnize easier:
666
662
667 columns_numbers : number of columns
663 columns_numbers : number of columns
668 rows_numbers : number of rows
664 rows_numbers : number of rows
669 columns_width : list of with of each columns
665 columns_width : list of with of each columns
670 optimal_separator_width : best separator width between columns
666 optimal_separator_width : best separator width between columns
671
667
672 Examples
668 Examples
673 --------
669 --------
674
670
675 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
671 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
676 ...: compute_item_matrix(l,displaywidth=12)
672 ...: compute_item_matrix(l,displaywidth=12)
677 Out[1]:
673 Out[1]:
678 ([['aaa', 'f', 'k'],
674 ([['aaa', 'f', 'k'],
679 ['b', 'g', 'l'],
675 ['b', 'g', 'l'],
680 ['cc', 'h', None],
676 ['cc', 'h', None],
681 ['d', 'i', None],
677 ['d', 'i', None],
682 ['eeeee', 'j', None]],
678 ['eeeee', 'j', None]],
683 {'columns_numbers': 3,
679 {'columns_numbers': 3,
684 'columns_width': [5, 1, 1],
680 'columns_width': [5, 1, 1],
685 'optimal_separator_width': 2,
681 'optimal_separator_width': 2,
686 'rows_numbers': 5})
682 'rows_numbers': 5})
687
683
688 """
684 """
689 info = _find_optimal(map(len, items), *args, **kwargs)
685 info = _find_optimal(map(len, items), *args, **kwargs)
690 nrow, ncol = info['rows_numbers'], info['columns_numbers']
686 nrow, ncol = info['rows_numbers'], info['columns_numbers']
691 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
687 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
692
688
693
689
694 def columnize(items, separator=' ', displaywidth=80):
690 def columnize(items, separator=' ', displaywidth=80):
695 """ Transform a list of strings into a single string with columns.
691 """ Transform a list of strings into a single string with columns.
696
692
697 Parameters
693 Parameters
698 ----------
694 ----------
699 items : sequence of strings
695 items : sequence of strings
700 The strings to process.
696 The strings to process.
701
697
702 separator : str, optional [default is two spaces]
698 separator : str, optional [default is two spaces]
703 The string that separates columns.
699 The string that separates columns.
704
700
705 displaywidth : int, optional [default is 80]
701 displaywidth : int, optional [default is 80]
706 Width of the display in number of characters.
702 Width of the display in number of characters.
707
703
708 Returns
704 Returns
709 -------
705 -------
710 The formatted string.
706 The formatted string.
711 """
707 """
712 if not items :
708 if not items :
713 return '\n'
709 return '\n'
714 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
710 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
715 fmatrix = [filter(None, x) for x in matrix]
711 fmatrix = [filter(None, x) for x in matrix]
716 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
712 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
717 return '\n'.join(map(sjoin, fmatrix))+'\n'
713 return '\n'.join(map(sjoin, fmatrix))+'\n'
@@ -1,47 +1,46 b''
1 """utilities for checking zmq versions"""
1 """utilities for checking zmq versions"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2013 The IPython Development Team
3 # Copyright (C) 2013 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING.txt, distributed as part of this software.
6 # the file COPYING.txt, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Verify zmq version dependency >= 2.1.11
10 # Verify zmq version dependency >= 2.1.11
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import warnings
14 from IPython.utils.version import check_version
13 from IPython.utils.version import check_version
15
14
16
15
17 def patch_pyzmq():
16 def patch_pyzmq():
18 """backport a few patches from newer pyzmq
17 """backport a few patches from newer pyzmq
19
18
20 These can be removed as we bump our minimum pyzmq version
19 These can be removed as we bump our minimum pyzmq version
21 """
20 """
22
21
23 import zmq
22 import zmq
24
23
25 # fallback on stdlib json if jsonlib is selected, because jsonlib breaks things.
24 # fallback on stdlib json if jsonlib is selected, because jsonlib breaks things.
26 # jsonlib support is removed from pyzmq >= 2.2.0
25 # jsonlib support is removed from pyzmq >= 2.2.0
27
26
28 from zmq.utils import jsonapi
27 from zmq.utils import jsonapi
29 if jsonapi.jsonmod.__name__ == 'jsonlib':
28 if jsonapi.jsonmod.__name__ == 'jsonlib':
30 import json
29 import json
31 jsonapi.jsonmod = json
30 jsonapi.jsonmod = json
32
31
33
32
34 def check_for_zmq(minimum_version, required_by='Someone'):
33 def check_for_zmq(minimum_version, required_by='Someone'):
35 try:
34 try:
36 import zmq
35 import zmq
37 except ImportError:
36 except ImportError:
38 raise ImportError("%s requires pyzmq >= %s"%(required_by, minimum_version))
37 raise ImportError("%s requires pyzmq >= %s"%(required_by, minimum_version))
39
38
40 patch_pyzmq()
39 patch_pyzmq()
41
40
42 pyzmq_version = zmq.__version__
41 pyzmq_version = zmq.__version__
43
42
44 if not check_version(pyzmq_version, minimum_version):
43 if not check_version(pyzmq_version, minimum_version):
45 raise ImportError("%s requires pyzmq >= %s, but you have %s"%(
44 raise ImportError("%s requires pyzmq >= %s, but you have %s"%(
46 required_by, minimum_version, pyzmq_version))
45 required_by, minimum_version, pyzmq_version))
47
46
General Comments 0
You need to be logged in to leave comments. Login now