##// END OF EJS Templates
Various fixes for test_genutils under win32, now all tests pass.
Fernando Perez -
Show More
@@ -1,1873 +1,1879 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """General purpose utilities.
2 """General purpose utilities.
3
3
4 This is a grab-bag of stuff I find useful in most programs I write. Some of
4 This is a grab-bag of stuff I find useful in most programs I write. Some of
5 these things are also convenient when working at the command line.
5 these things are also convenient when working at the command line.
6 """
6 """
7
7
8 #*****************************************************************************
8 #*****************************************************************************
9 # Copyright (C) 2001-2006 Fernando Perez. <fperez@colorado.edu>
9 # Copyright (C) 2001-2006 Fernando Perez. <fperez@colorado.edu>
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #*****************************************************************************
13 #*****************************************************************************
14
14
15 #****************************************************************************
15 #****************************************************************************
16 # required modules from the Python standard library
16 # required modules from the Python standard library
17 import __main__
17 import __main__
18
18
19 import os
19 import os
20 import platform
20 import platform
21 import re
21 import re
22 import shlex
22 import shlex
23 import shutil
23 import shutil
24 import subprocess
24 import subprocess
25 import sys
25 import sys
26 import time
26 import time
27 import types
27 import types
28 import warnings
28 import warnings
29
29
30 # Curses and termios are Unix-only modules
30 # Curses and termios are Unix-only modules
31 try:
31 try:
32 import curses
32 import curses
33 # We need termios as well, so if its import happens to raise, we bail on
33 # We need termios as well, so if its import happens to raise, we bail on
34 # using curses altogether.
34 # using curses altogether.
35 import termios
35 import termios
36 except ImportError:
36 except ImportError:
37 USE_CURSES = False
37 USE_CURSES = False
38 else:
38 else:
39 # Curses on Solaris may not be complete, so we can't use it there
39 # Curses on Solaris may not be complete, so we can't use it there
40 USE_CURSES = hasattr(curses,'initscr')
40 USE_CURSES = hasattr(curses,'initscr')
41
41
42 # Other IPython utilities
42 # Other IPython utilities
43 import IPython
43 import IPython
44 from IPython.external.Itpl import itpl,printpl
44 from IPython.external.Itpl import itpl,printpl
45 from IPython.utils import platutils
45 from IPython.utils import platutils
46 from IPython.utils.generics import result_display
46 from IPython.utils.generics import result_display
47 from IPython.external.path import path
47 from IPython.external.path import path
48
48
49
49
50 #****************************************************************************
50 #****************************************************************************
51 # Exceptions
51 # Exceptions
52 class Error(Exception):
52 class Error(Exception):
53 """Base class for exceptions in this module."""
53 """Base class for exceptions in this module."""
54 pass
54 pass
55
55
56 #----------------------------------------------------------------------------
56 #----------------------------------------------------------------------------
57 class IOStream:
57 class IOStream:
58 def __init__(self,stream,fallback):
58 def __init__(self,stream,fallback):
59 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
59 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
60 stream = fallback
60 stream = fallback
61 self.stream = stream
61 self.stream = stream
62 self._swrite = stream.write
62 self._swrite = stream.write
63 self.flush = stream.flush
63 self.flush = stream.flush
64
64
65 def write(self,data):
65 def write(self,data):
66 try:
66 try:
67 self._swrite(data)
67 self._swrite(data)
68 except:
68 except:
69 try:
69 try:
70 # print handles some unicode issues which may trip a plain
70 # print handles some unicode issues which may trip a plain
71 # write() call. Attempt to emulate write() by using a
71 # write() call. Attempt to emulate write() by using a
72 # trailing comma
72 # trailing comma
73 print >> self.stream, data,
73 print >> self.stream, data,
74 except:
74 except:
75 # if we get here, something is seriously broken.
75 # if we get here, something is seriously broken.
76 print >> sys.stderr, \
76 print >> sys.stderr, \
77 'ERROR - failed to write data to stream:', self.stream
77 'ERROR - failed to write data to stream:', self.stream
78
78
79 def close(self):
79 def close(self):
80 pass
80 pass
81
81
82
82
83 class IOTerm:
83 class IOTerm:
84 """ Term holds the file or file-like objects for handling I/O operations.
84 """ Term holds the file or file-like objects for handling I/O operations.
85
85
86 These are normally just sys.stdin, sys.stdout and sys.stderr but for
86 These are normally just sys.stdin, sys.stdout and sys.stderr but for
87 Windows they can can replaced to allow editing the strings before they are
87 Windows they can can replaced to allow editing the strings before they are
88 displayed."""
88 displayed."""
89
89
90 # In the future, having IPython channel all its I/O operations through
90 # In the future, having IPython channel all its I/O operations through
91 # this class will make it easier to embed it into other environments which
91 # this class will make it easier to embed it into other environments which
92 # are not a normal terminal (such as a GUI-based shell)
92 # are not a normal terminal (such as a GUI-based shell)
93 def __init__(self,cin=None,cout=None,cerr=None):
93 def __init__(self,cin=None,cout=None,cerr=None):
94 self.cin = IOStream(cin,sys.stdin)
94 self.cin = IOStream(cin,sys.stdin)
95 self.cout = IOStream(cout,sys.stdout)
95 self.cout = IOStream(cout,sys.stdout)
96 self.cerr = IOStream(cerr,sys.stderr)
96 self.cerr = IOStream(cerr,sys.stderr)
97
97
98 # Global variable to be used for all I/O
98 # Global variable to be used for all I/O
99 Term = IOTerm()
99 Term = IOTerm()
100
100
101 import IPython.utils.rlineimpl as readline
101 import IPython.utils.rlineimpl as readline
102 # Remake Term to use the readline i/o facilities
102 # Remake Term to use the readline i/o facilities
103 if sys.platform == 'win32' and readline.have_readline:
103 if sys.platform == 'win32' and readline.have_readline:
104
104
105 Term = IOTerm(cout=readline._outputfile,cerr=readline._outputfile)
105 Term = IOTerm(cout=readline._outputfile,cerr=readline._outputfile)
106
106
107
107
108 class Tee(object):
108 class Tee(object):
109 """A class to duplicate an output stream to stdout/err.
109 """A class to duplicate an output stream to stdout/err.
110
110
111 This works in a manner very similar to the Unix 'tee' command.
111 This works in a manner very similar to the Unix 'tee' command.
112
112
113 When the object is closed or deleted, it closes the original file given to
113 When the object is closed or deleted, it closes the original file given to
114 it for duplication.
114 it for duplication.
115 """
115 """
116 # Inspired by:
116 # Inspired by:
117 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
117 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
118
118
119 def __init__(self, file, mode=None, channel='stdout'):
119 def __init__(self, file, mode=None, channel='stdout'):
120 """Construct a new Tee object.
120 """Construct a new Tee object.
121
121
122 Parameters
122 Parameters
123 ----------
123 ----------
124 file : filename or open filehandle (writable)
124 file : filename or open filehandle (writable)
125 File that will be duplicated
125 File that will be duplicated
126
126
127 mode : optional, valid mode for open().
127 mode : optional, valid mode for open().
128 If a filename was give, open with this mode.
128 If a filename was give, open with this mode.
129
129
130 channel : str, one of ['stdout', 'stderr']
130 channel : str, one of ['stdout', 'stderr']
131 """
131 """
132 if channel not in ['stdout', 'stderr']:
132 if channel not in ['stdout', 'stderr']:
133 raise ValueError('Invalid channel spec %s' % channel)
133 raise ValueError('Invalid channel spec %s' % channel)
134
134
135 if hasattr(file, 'write') and hasattr(file, 'seek'):
135 if hasattr(file, 'write') and hasattr(file, 'seek'):
136 self.file = file
136 self.file = file
137 else:
137 else:
138 self.file = open(name, mode)
138 self.file = open(name, mode)
139 self.channel = channel
139 self.channel = channel
140 self.ostream = getattr(sys, channel)
140 self.ostream = getattr(sys, channel)
141 setattr(sys, channel, self)
141 setattr(sys, channel, self)
142 self._closed = False
142 self._closed = False
143
143
144 def close(self):
144 def close(self):
145 """Close the file and restore the channel."""
145 """Close the file and restore the channel."""
146 self.flush()
146 self.flush()
147 setattr(sys, self.channel, self.ostream)
147 setattr(sys, self.channel, self.ostream)
148 self.file.close()
148 self.file.close()
149 self._closed = True
149 self._closed = True
150
150
151 def write(self, data):
151 def write(self, data):
152 """Write data to both channels."""
152 """Write data to both channels."""
153 self.file.write(data)
153 self.file.write(data)
154 self.ostream.write(data)
154 self.ostream.write(data)
155 self.ostream.flush()
155 self.ostream.flush()
156
156
157 def flush(self):
157 def flush(self):
158 """Flush both channels."""
158 """Flush both channels."""
159 self.file.flush()
159 self.file.flush()
160 self.ostream.flush()
160 self.ostream.flush()
161
161
162 def __del__(self):
162 def __del__(self):
163 if not self._closed:
163 if not self._closed:
164 self.close()
164 self.close()
165
165
166
166
167 #****************************************************************************
167 #****************************************************************************
168 # Generic warning/error printer, used by everything else
168 # Generic warning/error printer, used by everything else
169 def warn(msg,level=2,exit_val=1):
169 def warn(msg,level=2,exit_val=1):
170 """Standard warning printer. Gives formatting consistency.
170 """Standard warning printer. Gives formatting consistency.
171
171
172 Output is sent to Term.cerr (sys.stderr by default).
172 Output is sent to Term.cerr (sys.stderr by default).
173
173
174 Options:
174 Options:
175
175
176 -level(2): allows finer control:
176 -level(2): allows finer control:
177 0 -> Do nothing, dummy function.
177 0 -> Do nothing, dummy function.
178 1 -> Print message.
178 1 -> Print message.
179 2 -> Print 'WARNING:' + message. (Default level).
179 2 -> Print 'WARNING:' + message. (Default level).
180 3 -> Print 'ERROR:' + message.
180 3 -> Print 'ERROR:' + message.
181 4 -> Print 'FATAL ERROR:' + message and trigger a sys.exit(exit_val).
181 4 -> Print 'FATAL ERROR:' + message and trigger a sys.exit(exit_val).
182
182
183 -exit_val (1): exit value returned by sys.exit() for a level 4
183 -exit_val (1): exit value returned by sys.exit() for a level 4
184 warning. Ignored for all other levels."""
184 warning. Ignored for all other levels."""
185
185
186 if level>0:
186 if level>0:
187 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
187 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
188 print >> Term.cerr, '%s%s' % (header[level],msg)
188 print >> Term.cerr, '%s%s' % (header[level],msg)
189 if level == 4:
189 if level == 4:
190 print >> Term.cerr,'Exiting.\n'
190 print >> Term.cerr,'Exiting.\n'
191 sys.exit(exit_val)
191 sys.exit(exit_val)
192
192
193 def info(msg):
193 def info(msg):
194 """Equivalent to warn(msg,level=1)."""
194 """Equivalent to warn(msg,level=1)."""
195
195
196 warn(msg,level=1)
196 warn(msg,level=1)
197
197
198 def error(msg):
198 def error(msg):
199 """Equivalent to warn(msg,level=3)."""
199 """Equivalent to warn(msg,level=3)."""
200
200
201 warn(msg,level=3)
201 warn(msg,level=3)
202
202
203 def fatal(msg,exit_val=1):
203 def fatal(msg,exit_val=1):
204 """Equivalent to warn(msg,exit_val=exit_val,level=4)."""
204 """Equivalent to warn(msg,exit_val=exit_val,level=4)."""
205
205
206 warn(msg,exit_val=exit_val,level=4)
206 warn(msg,exit_val=exit_val,level=4)
207
207
208 #---------------------------------------------------------------------------
208 #---------------------------------------------------------------------------
209 # Debugging routines
209 # Debugging routines
210 #
210 #
211 def debugx(expr,pre_msg=''):
211 def debugx(expr,pre_msg=''):
212 """Print the value of an expression from the caller's frame.
212 """Print the value of an expression from the caller's frame.
213
213
214 Takes an expression, evaluates it in the caller's frame and prints both
214 Takes an expression, evaluates it in the caller's frame and prints both
215 the given expression and the resulting value (as well as a debug mark
215 the given expression and the resulting value (as well as a debug mark
216 indicating the name of the calling function. The input must be of a form
216 indicating the name of the calling function. The input must be of a form
217 suitable for eval().
217 suitable for eval().
218
218
219 An optional message can be passed, which will be prepended to the printed
219 An optional message can be passed, which will be prepended to the printed
220 expr->value pair."""
220 expr->value pair."""
221
221
222 cf = sys._getframe(1)
222 cf = sys._getframe(1)
223 print '[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
223 print '[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
224 eval(expr,cf.f_globals,cf.f_locals))
224 eval(expr,cf.f_globals,cf.f_locals))
225
225
226 # deactivate it by uncommenting the following line, which makes it a no-op
226 # deactivate it by uncommenting the following line, which makes it a no-op
227 #def debugx(expr,pre_msg=''): pass
227 #def debugx(expr,pre_msg=''): pass
228
228
229 #----------------------------------------------------------------------------
229 #----------------------------------------------------------------------------
230 StringTypes = types.StringTypes
230 StringTypes = types.StringTypes
231
231
232 # Basic timing functionality
232 # Basic timing functionality
233
233
234 # If possible (Unix), use the resource module instead of time.clock()
234 # If possible (Unix), use the resource module instead of time.clock()
235 try:
235 try:
236 import resource
236 import resource
237 def clocku():
237 def clocku():
238 """clocku() -> floating point number
238 """clocku() -> floating point number
239
239
240 Return the *USER* CPU time in seconds since the start of the process.
240 Return the *USER* CPU time in seconds since the start of the process.
241 This is done via a call to resource.getrusage, so it avoids the
241 This is done via a call to resource.getrusage, so it avoids the
242 wraparound problems in time.clock()."""
242 wraparound problems in time.clock()."""
243
243
244 return resource.getrusage(resource.RUSAGE_SELF)[0]
244 return resource.getrusage(resource.RUSAGE_SELF)[0]
245
245
246 def clocks():
246 def clocks():
247 """clocks() -> floating point number
247 """clocks() -> floating point number
248
248
249 Return the *SYSTEM* CPU time in seconds since the start of the process.
249 Return the *SYSTEM* CPU time in seconds since the start of the process.
250 This is done via a call to resource.getrusage, so it avoids the
250 This is done via a call to resource.getrusage, so it avoids the
251 wraparound problems in time.clock()."""
251 wraparound problems in time.clock()."""
252
252
253 return resource.getrusage(resource.RUSAGE_SELF)[1]
253 return resource.getrusage(resource.RUSAGE_SELF)[1]
254
254
255 def clock():
255 def clock():
256 """clock() -> floating point number
256 """clock() -> floating point number
257
257
258 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
258 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
259 the process. This is done via a call to resource.getrusage, so it
259 the process. This is done via a call to resource.getrusage, so it
260 avoids the wraparound problems in time.clock()."""
260 avoids the wraparound problems in time.clock()."""
261
261
262 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
262 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
263 return u+s
263 return u+s
264
264
265 def clock2():
265 def clock2():
266 """clock2() -> (t_user,t_system)
266 """clock2() -> (t_user,t_system)
267
267
268 Similar to clock(), but return a tuple of user/system times."""
268 Similar to clock(), but return a tuple of user/system times."""
269 return resource.getrusage(resource.RUSAGE_SELF)[:2]
269 return resource.getrusage(resource.RUSAGE_SELF)[:2]
270
270
271 except ImportError:
271 except ImportError:
272 # There is no distinction of user/system time under windows, so we just use
272 # There is no distinction of user/system time under windows, so we just use
273 # time.clock() for everything...
273 # time.clock() for everything...
274 clocku = clocks = clock = time.clock
274 clocku = clocks = clock = time.clock
275 def clock2():
275 def clock2():
276 """Under windows, system CPU time can't be measured.
276 """Under windows, system CPU time can't be measured.
277
277
278 This just returns clock() and zero."""
278 This just returns clock() and zero."""
279 return time.clock(),0.0
279 return time.clock(),0.0
280
280
281 def timings_out(reps,func,*args,**kw):
281 def timings_out(reps,func,*args,**kw):
282 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
282 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
283
283
284 Execute a function reps times, return a tuple with the elapsed total
284 Execute a function reps times, return a tuple with the elapsed total
285 CPU time in seconds, the time per call and the function's output.
285 CPU time in seconds, the time per call and the function's output.
286
286
287 Under Unix, the return value is the sum of user+system time consumed by
287 Under Unix, the return value is the sum of user+system time consumed by
288 the process, computed via the resource module. This prevents problems
288 the process, computed via the resource module. This prevents problems
289 related to the wraparound effect which the time.clock() function has.
289 related to the wraparound effect which the time.clock() function has.
290
290
291 Under Windows the return value is in wall clock seconds. See the
291 Under Windows the return value is in wall clock seconds. See the
292 documentation for the time module for more details."""
292 documentation for the time module for more details."""
293
293
294 reps = int(reps)
294 reps = int(reps)
295 assert reps >=1, 'reps must be >= 1'
295 assert reps >=1, 'reps must be >= 1'
296 if reps==1:
296 if reps==1:
297 start = clock()
297 start = clock()
298 out = func(*args,**kw)
298 out = func(*args,**kw)
299 tot_time = clock()-start
299 tot_time = clock()-start
300 else:
300 else:
301 rng = xrange(reps-1) # the last time is executed separately to store output
301 rng = xrange(reps-1) # the last time is executed separately to store output
302 start = clock()
302 start = clock()
303 for dummy in rng: func(*args,**kw)
303 for dummy in rng: func(*args,**kw)
304 out = func(*args,**kw) # one last time
304 out = func(*args,**kw) # one last time
305 tot_time = clock()-start
305 tot_time = clock()-start
306 av_time = tot_time / reps
306 av_time = tot_time / reps
307 return tot_time,av_time,out
307 return tot_time,av_time,out
308
308
309 def timings(reps,func,*args,**kw):
309 def timings(reps,func,*args,**kw):
310 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
310 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
311
311
312 Execute a function reps times, return a tuple with the elapsed total CPU
312 Execute a function reps times, return a tuple with the elapsed total CPU
313 time in seconds and the time per call. These are just the first two values
313 time in seconds and the time per call. These are just the first two values
314 in timings_out()."""
314 in timings_out()."""
315
315
316 return timings_out(reps,func,*args,**kw)[0:2]
316 return timings_out(reps,func,*args,**kw)[0:2]
317
317
318 def timing(func,*args,**kw):
318 def timing(func,*args,**kw):
319 """timing(func,*args,**kw) -> t_total
319 """timing(func,*args,**kw) -> t_total
320
320
321 Execute a function once, return the elapsed total CPU time in
321 Execute a function once, return the elapsed total CPU time in
322 seconds. This is just the first value in timings_out()."""
322 seconds. This is just the first value in timings_out()."""
323
323
324 return timings_out(1,func,*args,**kw)[0]
324 return timings_out(1,func,*args,**kw)[0]
325
325
326 #****************************************************************************
326 #****************************************************************************
327 # file and system
327 # file and system
328
328
329 def arg_split(s,posix=False):
329 def arg_split(s,posix=False):
330 """Split a command line's arguments in a shell-like manner.
330 """Split a command line's arguments in a shell-like manner.
331
331
332 This is a modified version of the standard library's shlex.split()
332 This is a modified version of the standard library's shlex.split()
333 function, but with a default of posix=False for splitting, so that quotes
333 function, but with a default of posix=False for splitting, so that quotes
334 in inputs are respected."""
334 in inputs are respected."""
335
335
336 # XXX - there may be unicode-related problems here!!! I'm not sure that
336 # XXX - there may be unicode-related problems here!!! I'm not sure that
337 # shlex is truly unicode-safe, so it might be necessary to do
337 # shlex is truly unicode-safe, so it might be necessary to do
338 #
338 #
339 # s = s.encode(sys.stdin.encoding)
339 # s = s.encode(sys.stdin.encoding)
340 #
340 #
341 # first, to ensure that shlex gets a normal string. Input from anyone who
341 # first, to ensure that shlex gets a normal string. Input from anyone who
342 # knows more about unicode and shlex than I would be good to have here...
342 # knows more about unicode and shlex than I would be good to have here...
343 lex = shlex.shlex(s, posix=posix)
343 lex = shlex.shlex(s, posix=posix)
344 lex.whitespace_split = True
344 lex.whitespace_split = True
345 return list(lex)
345 return list(lex)
346
346
347 def system(cmd,verbose=0,debug=0,header=''):
347 def system(cmd,verbose=0,debug=0,header=''):
348 """Execute a system command, return its exit status.
348 """Execute a system command, return its exit status.
349
349
350 Options:
350 Options:
351
351
352 - verbose (0): print the command to be executed.
352 - verbose (0): print the command to be executed.
353
353
354 - debug (0): only print, do not actually execute.
354 - debug (0): only print, do not actually execute.
355
355
356 - header (''): Header to print on screen prior to the executed command (it
356 - header (''): Header to print on screen prior to the executed command (it
357 is only prepended to the command, no newlines are added).
357 is only prepended to the command, no newlines are added).
358
358
359 Note: a stateful version of this function is available through the
359 Note: a stateful version of this function is available through the
360 SystemExec class."""
360 SystemExec class."""
361
361
362 stat = 0
362 stat = 0
363 if verbose or debug: print header+cmd
363 if verbose or debug: print header+cmd
364 sys.stdout.flush()
364 sys.stdout.flush()
365 if not debug: stat = os.system(cmd)
365 if not debug: stat = os.system(cmd)
366 return stat
366 return stat
367
367
368 def abbrev_cwd():
368 def abbrev_cwd():
369 """ Return abbreviated version of cwd, e.g. d:mydir """
369 """ Return abbreviated version of cwd, e.g. d:mydir """
370 cwd = os.getcwd().replace('\\','/')
370 cwd = os.getcwd().replace('\\','/')
371 drivepart = ''
371 drivepart = ''
372 tail = cwd
372 tail = cwd
373 if sys.platform == 'win32':
373 if sys.platform == 'win32':
374 if len(cwd) < 4:
374 if len(cwd) < 4:
375 return cwd
375 return cwd
376 drivepart,tail = os.path.splitdrive(cwd)
376 drivepart,tail = os.path.splitdrive(cwd)
377
377
378
378
379 parts = tail.split('/')
379 parts = tail.split('/')
380 if len(parts) > 2:
380 if len(parts) > 2:
381 tail = '/'.join(parts[-2:])
381 tail = '/'.join(parts[-2:])
382
382
383 return (drivepart + (
383 return (drivepart + (
384 cwd == '/' and '/' or tail))
384 cwd == '/' and '/' or tail))
385
385
386
386
387 # This function is used by ipython in a lot of places to make system calls.
387 # This function is used by ipython in a lot of places to make system calls.
388 # We need it to be slightly different under win32, due to the vagaries of
388 # We need it to be slightly different under win32, due to the vagaries of
389 # 'network shares'. A win32 override is below.
389 # 'network shares'. A win32 override is below.
390
390
391 def shell(cmd,verbose=0,debug=0,header=''):
391 def shell(cmd,verbose=0,debug=0,header=''):
392 """Execute a command in the system shell, always return None.
392 """Execute a command in the system shell, always return None.
393
393
394 Options:
394 Options:
395
395
396 - verbose (0): print the command to be executed.
396 - verbose (0): print the command to be executed.
397
397
398 - debug (0): only print, do not actually execute.
398 - debug (0): only print, do not actually execute.
399
399
400 - header (''): Header to print on screen prior to the executed command (it
400 - header (''): Header to print on screen prior to the executed command (it
401 is only prepended to the command, no newlines are added).
401 is only prepended to the command, no newlines are added).
402
402
403 Note: this is similar to genutils.system(), but it returns None so it can
403 Note: this is similar to genutils.system(), but it returns None so it can
404 be conveniently used in interactive loops without getting the return value
404 be conveniently used in interactive loops without getting the return value
405 (typically 0) printed many times."""
405 (typically 0) printed many times."""
406
406
407 stat = 0
407 stat = 0
408 if verbose or debug: print header+cmd
408 if verbose or debug: print header+cmd
409 # flush stdout so we don't mangle python's buffering
409 # flush stdout so we don't mangle python's buffering
410 sys.stdout.flush()
410 sys.stdout.flush()
411
411
412 if not debug:
412 if not debug:
413 platutils.set_term_title("IPy " + cmd)
413 platutils.set_term_title("IPy " + cmd)
414 os.system(cmd)
414 os.system(cmd)
415 platutils.set_term_title("IPy " + abbrev_cwd())
415 platutils.set_term_title("IPy " + abbrev_cwd())
416
416
417 # override shell() for win32 to deal with network shares
417 # override shell() for win32 to deal with network shares
418 if os.name in ('nt','dos'):
418 if os.name in ('nt','dos'):
419
419
420 shell_ori = shell
420 shell_ori = shell
421
421
422 def shell(cmd,verbose=0,debug=0,header=''):
422 def shell(cmd,verbose=0,debug=0,header=''):
423 if os.getcwd().startswith(r"\\"):
423 if os.getcwd().startswith(r"\\"):
424 path = os.getcwd()
424 path = os.getcwd()
425 # change to c drive (cannot be on UNC-share when issuing os.system,
425 # change to c drive (cannot be on UNC-share when issuing os.system,
426 # as cmd.exe cannot handle UNC addresses)
426 # as cmd.exe cannot handle UNC addresses)
427 os.chdir("c:")
427 os.chdir("c:")
428 # issue pushd to the UNC-share and then run the command
428 # issue pushd to the UNC-share and then run the command
429 try:
429 try:
430 shell_ori('"pushd %s&&"'%path+cmd,verbose,debug,header)
430 shell_ori('"pushd %s&&"'%path+cmd,verbose,debug,header)
431 finally:
431 finally:
432 os.chdir(path)
432 os.chdir(path)
433 else:
433 else:
434 shell_ori(cmd,verbose,debug,header)
434 shell_ori(cmd,verbose,debug,header)
435
435
436 shell.__doc__ = shell_ori.__doc__
436 shell.__doc__ = shell_ori.__doc__
437
437
438 def getoutput(cmd,verbose=0,debug=0,header='',split=0):
438 def getoutput(cmd,verbose=0,debug=0,header='',split=0):
439 """Dummy substitute for perl's backquotes.
439 """Dummy substitute for perl's backquotes.
440
440
441 Executes a command and returns the output.
441 Executes a command and returns the output.
442
442
443 Accepts the same arguments as system(), plus:
443 Accepts the same arguments as system(), plus:
444
444
445 - split(0): if true, the output is returned as a list split on newlines.
445 - split(0): if true, the output is returned as a list split on newlines.
446
446
447 Note: a stateful version of this function is available through the
447 Note: a stateful version of this function is available through the
448 SystemExec class.
448 SystemExec class.
449
449
450 This is pretty much deprecated and rarely used,
450 This is pretty much deprecated and rarely used,
451 genutils.getoutputerror may be what you need.
451 genutils.getoutputerror may be what you need.
452
452
453 """
453 """
454
454
455 if verbose or debug: print header+cmd
455 if verbose or debug: print header+cmd
456 if not debug:
456 if not debug:
457 pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE).stdout
457 pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE).stdout
458 output = pipe.read()
458 output = pipe.read()
459 # stipping last \n is here for backwards compat.
459 # stipping last \n is here for backwards compat.
460 if output.endswith('\n'):
460 if output.endswith('\n'):
461 output = output[:-1]
461 output = output[:-1]
462 if split:
462 if split:
463 return output.split('\n')
463 return output.split('\n')
464 else:
464 else:
465 return output
465 return output
466
466
467 def getoutputerror(cmd,verbose=0,debug=0,header='',split=0):
467 def getoutputerror(cmd,verbose=0,debug=0,header='',split=0):
468 """Return (standard output,standard error) of executing cmd in a shell.
468 """Return (standard output,standard error) of executing cmd in a shell.
469
469
470 Accepts the same arguments as system(), plus:
470 Accepts the same arguments as system(), plus:
471
471
472 - split(0): if true, each of stdout/err is returned as a list split on
472 - split(0): if true, each of stdout/err is returned as a list split on
473 newlines.
473 newlines.
474
474
475 Note: a stateful version of this function is available through the
475 Note: a stateful version of this function is available through the
476 SystemExec class."""
476 SystemExec class."""
477
477
478 if verbose or debug: print header+cmd
478 if verbose or debug: print header+cmd
479 if not cmd:
479 if not cmd:
480 if split:
480 if split:
481 return [],[]
481 return [],[]
482 else:
482 else:
483 return '',''
483 return '',''
484 if not debug:
484 if not debug:
485 p = subprocess.Popen(cmd, shell=True,
485 p = subprocess.Popen(cmd, shell=True,
486 stdin=subprocess.PIPE,
486 stdin=subprocess.PIPE,
487 stdout=subprocess.PIPE,
487 stdout=subprocess.PIPE,
488 stderr=subprocess.PIPE,
488 stderr=subprocess.PIPE,
489 close_fds=True)
489 close_fds=True)
490 pin, pout, perr = (p.stdin, p.stdout, p.stderr)
490 pin, pout, perr = (p.stdin, p.stdout, p.stderr)
491
491
492 tout = pout.read().rstrip()
492 tout = pout.read().rstrip()
493 terr = perr.read().rstrip()
493 terr = perr.read().rstrip()
494 pin.close()
494 pin.close()
495 pout.close()
495 pout.close()
496 perr.close()
496 perr.close()
497 if split:
497 if split:
498 return tout.split('\n'),terr.split('\n')
498 return tout.split('\n'),terr.split('\n')
499 else:
499 else:
500 return tout,terr
500 return tout,terr
501
501
502 # for compatibility with older naming conventions
502 # for compatibility with older naming conventions
503 xsys = system
503 xsys = system
504 bq = getoutput
504 bq = getoutput
505
505
506 class SystemExec:
506 class SystemExec:
507 """Access the system and getoutput functions through a stateful interface.
507 """Access the system and getoutput functions through a stateful interface.
508
508
509 Note: here we refer to the system and getoutput functions from this
509 Note: here we refer to the system and getoutput functions from this
510 library, not the ones from the standard python library.
510 library, not the ones from the standard python library.
511
511
512 This class offers the system and getoutput functions as methods, but the
512 This class offers the system and getoutput functions as methods, but the
513 verbose, debug and header parameters can be set for the instance (at
513 verbose, debug and header parameters can be set for the instance (at
514 creation time or later) so that they don't need to be specified on each
514 creation time or later) so that they don't need to be specified on each
515 call.
515 call.
516
516
517 For efficiency reasons, there's no way to override the parameters on a
517 For efficiency reasons, there's no way to override the parameters on a
518 per-call basis other than by setting instance attributes. If you need
518 per-call basis other than by setting instance attributes. If you need
519 local overrides, it's best to directly call system() or getoutput().
519 local overrides, it's best to directly call system() or getoutput().
520
520
521 The following names are provided as alternate options:
521 The following names are provided as alternate options:
522 - xsys: alias to system
522 - xsys: alias to system
523 - bq: alias to getoutput
523 - bq: alias to getoutput
524
524
525 An instance can then be created as:
525 An instance can then be created as:
526 >>> sysexec = SystemExec(verbose=1,debug=0,header='Calling: ')
526 >>> sysexec = SystemExec(verbose=1,debug=0,header='Calling: ')
527 """
527 """
528
528
529 def __init__(self,verbose=0,debug=0,header='',split=0):
529 def __init__(self,verbose=0,debug=0,header='',split=0):
530 """Specify the instance's values for verbose, debug and header."""
530 """Specify the instance's values for verbose, debug and header."""
531 setattr_list(self,'verbose debug header split')
531 setattr_list(self,'verbose debug header split')
532
532
533 def system(self,cmd):
533 def system(self,cmd):
534 """Stateful interface to system(), with the same keyword parameters."""
534 """Stateful interface to system(), with the same keyword parameters."""
535
535
536 system(cmd,self.verbose,self.debug,self.header)
536 system(cmd,self.verbose,self.debug,self.header)
537
537
538 def shell(self,cmd):
538 def shell(self,cmd):
539 """Stateful interface to shell(), with the same keyword parameters."""
539 """Stateful interface to shell(), with the same keyword parameters."""
540
540
541 shell(cmd,self.verbose,self.debug,self.header)
541 shell(cmd,self.verbose,self.debug,self.header)
542
542
543 xsys = system # alias
543 xsys = system # alias
544
544
545 def getoutput(self,cmd):
545 def getoutput(self,cmd):
546 """Stateful interface to getoutput()."""
546 """Stateful interface to getoutput()."""
547
547
548 return getoutput(cmd,self.verbose,self.debug,self.header,self.split)
548 return getoutput(cmd,self.verbose,self.debug,self.header,self.split)
549
549
550 def getoutputerror(self,cmd):
550 def getoutputerror(self,cmd):
551 """Stateful interface to getoutputerror()."""
551 """Stateful interface to getoutputerror()."""
552
552
553 return getoutputerror(cmd,self.verbose,self.debug,self.header,self.split)
553 return getoutputerror(cmd,self.verbose,self.debug,self.header,self.split)
554
554
555 bq = getoutput # alias
555 bq = getoutput # alias
556
556
557 #-----------------------------------------------------------------------------
557 #-----------------------------------------------------------------------------
558 def mutex_opts(dict,ex_op):
558 def mutex_opts(dict,ex_op):
559 """Check for presence of mutually exclusive keys in a dict.
559 """Check for presence of mutually exclusive keys in a dict.
560
560
561 Call: mutex_opts(dict,[[op1a,op1b],[op2a,op2b]...]"""
561 Call: mutex_opts(dict,[[op1a,op1b],[op2a,op2b]...]"""
562 for op1,op2 in ex_op:
562 for op1,op2 in ex_op:
563 if op1 in dict and op2 in dict:
563 if op1 in dict and op2 in dict:
564 raise ValueError,'\n*** ERROR in Arguments *** '\
564 raise ValueError,'\n*** ERROR in Arguments *** '\
565 'Options '+op1+' and '+op2+' are mutually exclusive.'
565 'Options '+op1+' and '+op2+' are mutually exclusive.'
566
566
567 #-----------------------------------------------------------------------------
567 #-----------------------------------------------------------------------------
568 def get_py_filename(name):
568 def get_py_filename(name):
569 """Return a valid python filename in the current directory.
569 """Return a valid python filename in the current directory.
570
570
571 If the given name is not a file, it adds '.py' and searches again.
571 If the given name is not a file, it adds '.py' and searches again.
572 Raises IOError with an informative message if the file isn't found."""
572 Raises IOError with an informative message if the file isn't found."""
573
573
574 name = os.path.expanduser(name)
574 name = os.path.expanduser(name)
575 if not os.path.isfile(name) and not name.endswith('.py'):
575 if not os.path.isfile(name) and not name.endswith('.py'):
576 name += '.py'
576 name += '.py'
577 if os.path.isfile(name):
577 if os.path.isfile(name):
578 return name
578 return name
579 else:
579 else:
580 raise IOError,'File `%s` not found.' % name
580 raise IOError,'File `%s` not found.' % name
581
581
582 #-----------------------------------------------------------------------------
582 #-----------------------------------------------------------------------------
583
583
584
584
585 def filefind(filename, path_dirs=None):
585 def filefind(filename, path_dirs=None):
586 """Find a file by looking through a sequence of paths.
586 """Find a file by looking through a sequence of paths.
587
587
588 This iterates through a sequence of paths looking for a file and returns
588 This iterates through a sequence of paths looking for a file and returns
589 the full, absolute path of the first occurence of the file. If no set of
589 the full, absolute path of the first occurence of the file. If no set of
590 path dirs is given, the filename is tested as is, after running through
590 path dirs is given, the filename is tested as is, after running through
591 :func:`expandvars` and :func:`expanduser`. Thus a simple call::
591 :func:`expandvars` and :func:`expanduser`. Thus a simple call::
592
592
593 filefind('myfile.txt')
593 filefind('myfile.txt')
594
594
595 will find the file in the current working dir, but::
595 will find the file in the current working dir, but::
596
596
597 filefind('~/myfile.txt')
597 filefind('~/myfile.txt')
598
598
599 Will find the file in the users home directory. This function does not
599 Will find the file in the users home directory. This function does not
600 automatically try any paths, such as the cwd or the user's home directory.
600 automatically try any paths, such as the cwd or the user's home directory.
601
601
602 Parameters
602 Parameters
603 ----------
603 ----------
604 filename : str
604 filename : str
605 The filename to look for.
605 The filename to look for.
606 path_dirs : str, None or sequence of str
606 path_dirs : str, None or sequence of str
607 The sequence of paths to look for the file in. If None, the filename
607 The sequence of paths to look for the file in. If None, the filename
608 need to be absolute or be in the cwd. If a string, the string is
608 need to be absolute or be in the cwd. If a string, the string is
609 put into a sequence and the searched. If a sequence, walk through
609 put into a sequence and the searched. If a sequence, walk through
610 each element and join with ``filename``, calling :func:`expandvars`
610 each element and join with ``filename``, calling :func:`expandvars`
611 and :func:`expanduser` before testing for existence.
611 and :func:`expanduser` before testing for existence.
612
612
613 Returns
613 Returns
614 -------
614 -------
615 Raises :exc:`IOError` or returns absolute path to file.
615 Raises :exc:`IOError` or returns absolute path to file.
616 """
616 """
617 if path_dirs is None:
617 if path_dirs is None:
618 path_dirs = ("",)
618 path_dirs = ("",)
619 elif isinstance(path_dirs, basestring):
619 elif isinstance(path_dirs, basestring):
620 path_dirs = (path_dirs,)
620 path_dirs = (path_dirs,)
621 for path in path_dirs:
621 for path in path_dirs:
622 if path == '.': path = os.getcwd()
622 if path == '.': path = os.getcwd()
623 testname = expand_path(os.path.join(path, filename))
623 testname = expand_path(os.path.join(path, filename))
624 if os.path.isfile(testname):
624 if os.path.isfile(testname):
625 return os.path.abspath(testname)
625 return os.path.abspath(testname)
626 raise IOError("File does not exist in any "
626 raise IOError("File does not exist in any "
627 "of the search paths: %r, %r" % \
627 "of the search paths: %r, %r" % \
628 (filename, path_dirs))
628 (filename, path_dirs))
629
629
630
630
631 #----------------------------------------------------------------------------
631 #----------------------------------------------------------------------------
632 def file_read(filename):
632 def file_read(filename):
633 """Read a file and close it. Returns the file source."""
633 """Read a file and close it. Returns the file source."""
634 fobj = open(filename,'r');
634 fobj = open(filename,'r');
635 source = fobj.read();
635 source = fobj.read();
636 fobj.close()
636 fobj.close()
637 return source
637 return source
638
638
639 def file_readlines(filename):
639 def file_readlines(filename):
640 """Read a file and close it. Returns the file source using readlines()."""
640 """Read a file and close it. Returns the file source using readlines()."""
641 fobj = open(filename,'r');
641 fobj = open(filename,'r');
642 lines = fobj.readlines();
642 lines = fobj.readlines();
643 fobj.close()
643 fobj.close()
644 return lines
644 return lines
645
645
646 #----------------------------------------------------------------------------
646 #----------------------------------------------------------------------------
647 def target_outdated(target,deps):
647 def target_outdated(target,deps):
648 """Determine whether a target is out of date.
648 """Determine whether a target is out of date.
649
649
650 target_outdated(target,deps) -> 1/0
650 target_outdated(target,deps) -> 1/0
651
651
652 deps: list of filenames which MUST exist.
652 deps: list of filenames which MUST exist.
653 target: single filename which may or may not exist.
653 target: single filename which may or may not exist.
654
654
655 If target doesn't exist or is older than any file listed in deps, return
655 If target doesn't exist or is older than any file listed in deps, return
656 true, otherwise return false.
656 true, otherwise return false.
657 """
657 """
658 try:
658 try:
659 target_time = os.path.getmtime(target)
659 target_time = os.path.getmtime(target)
660 except os.error:
660 except os.error:
661 return 1
661 return 1
662 for dep in deps:
662 for dep in deps:
663 dep_time = os.path.getmtime(dep)
663 dep_time = os.path.getmtime(dep)
664 if dep_time > target_time:
664 if dep_time > target_time:
665 #print "For target",target,"Dep failed:",dep # dbg
665 #print "For target",target,"Dep failed:",dep # dbg
666 #print "times (dep,tar):",dep_time,target_time # dbg
666 #print "times (dep,tar):",dep_time,target_time # dbg
667 return 1
667 return 1
668 return 0
668 return 0
669
669
670 #-----------------------------------------------------------------------------
670 #-----------------------------------------------------------------------------
671 def target_update(target,deps,cmd):
671 def target_update(target,deps,cmd):
672 """Update a target with a given command given a list of dependencies.
672 """Update a target with a given command given a list of dependencies.
673
673
674 target_update(target,deps,cmd) -> runs cmd if target is outdated.
674 target_update(target,deps,cmd) -> runs cmd if target is outdated.
675
675
676 This is just a wrapper around target_outdated() which calls the given
676 This is just a wrapper around target_outdated() which calls the given
677 command if target is outdated."""
677 command if target is outdated."""
678
678
679 if target_outdated(target,deps):
679 if target_outdated(target,deps):
680 xsys(cmd)
680 xsys(cmd)
681
681
682 #----------------------------------------------------------------------------
682 #----------------------------------------------------------------------------
683 def unquote_ends(istr):
683 def unquote_ends(istr):
684 """Remove a single pair of quotes from the endpoints of a string."""
684 """Remove a single pair of quotes from the endpoints of a string."""
685
685
686 if not istr:
686 if not istr:
687 return istr
687 return istr
688 if (istr[0]=="'" and istr[-1]=="'") or \
688 if (istr[0]=="'" and istr[-1]=="'") or \
689 (istr[0]=='"' and istr[-1]=='"'):
689 (istr[0]=='"' and istr[-1]=='"'):
690 return istr[1:-1]
690 return istr[1:-1]
691 else:
691 else:
692 return istr
692 return istr
693
693
694 #----------------------------------------------------------------------------
694 #----------------------------------------------------------------------------
695 def flag_calls(func):
695 def flag_calls(func):
696 """Wrap a function to detect and flag when it gets called.
696 """Wrap a function to detect and flag when it gets called.
697
697
698 This is a decorator which takes a function and wraps it in a function with
698 This is a decorator which takes a function and wraps it in a function with
699 a 'called' attribute. wrapper.called is initialized to False.
699 a 'called' attribute. wrapper.called is initialized to False.
700
700
701 The wrapper.called attribute is set to False right before each call to the
701 The wrapper.called attribute is set to False right before each call to the
702 wrapped function, so if the call fails it remains False. After the call
702 wrapped function, so if the call fails it remains False. After the call
703 completes, wrapper.called is set to True and the output is returned.
703 completes, wrapper.called is set to True and the output is returned.
704
704
705 Testing for truth in wrapper.called allows you to determine if a call to
705 Testing for truth in wrapper.called allows you to determine if a call to
706 func() was attempted and succeeded."""
706 func() was attempted and succeeded."""
707
707
708 def wrapper(*args,**kw):
708 def wrapper(*args,**kw):
709 wrapper.called = False
709 wrapper.called = False
710 out = func(*args,**kw)
710 out = func(*args,**kw)
711 wrapper.called = True
711 wrapper.called = True
712 return out
712 return out
713
713
714 wrapper.called = False
714 wrapper.called = False
715 wrapper.__doc__ = func.__doc__
715 wrapper.__doc__ = func.__doc__
716 return wrapper
716 return wrapper
717
717
718 #----------------------------------------------------------------------------
718 #----------------------------------------------------------------------------
719 def dhook_wrap(func,*a,**k):
719 def dhook_wrap(func,*a,**k):
720 """Wrap a function call in a sys.displayhook controller.
720 """Wrap a function call in a sys.displayhook controller.
721
721
722 Returns a wrapper around func which calls func, with all its arguments and
722 Returns a wrapper around func which calls func, with all its arguments and
723 keywords unmodified, using the default sys.displayhook. Since IPython
723 keywords unmodified, using the default sys.displayhook. Since IPython
724 modifies sys.displayhook, it breaks the behavior of certain systems that
724 modifies sys.displayhook, it breaks the behavior of certain systems that
725 rely on the default behavior, notably doctest.
725 rely on the default behavior, notably doctest.
726 """
726 """
727
727
728 def f(*a,**k):
728 def f(*a,**k):
729
729
730 dhook_s = sys.displayhook
730 dhook_s = sys.displayhook
731 sys.displayhook = sys.__displayhook__
731 sys.displayhook = sys.__displayhook__
732 try:
732 try:
733 out = func(*a,**k)
733 out = func(*a,**k)
734 finally:
734 finally:
735 sys.displayhook = dhook_s
735 sys.displayhook = dhook_s
736
736
737 return out
737 return out
738
738
739 f.__doc__ = func.__doc__
739 f.__doc__ = func.__doc__
740 return f
740 return f
741
741
742 #----------------------------------------------------------------------------
742 #----------------------------------------------------------------------------
743 def doctest_reload():
743 def doctest_reload():
744 """Properly reload doctest to reuse it interactively.
744 """Properly reload doctest to reuse it interactively.
745
745
746 This routine:
746 This routine:
747
747
748 - imports doctest but does NOT reload it (see below).
748 - imports doctest but does NOT reload it (see below).
749
749
750 - resets its global 'master' attribute to None, so that multiple uses of
750 - resets its global 'master' attribute to None, so that multiple uses of
751 the module interactively don't produce cumulative reports.
751 the module interactively don't produce cumulative reports.
752
752
753 - Monkeypatches its core test runner method to protect it from IPython's
753 - Monkeypatches its core test runner method to protect it from IPython's
754 modified displayhook. Doctest expects the default displayhook behavior
754 modified displayhook. Doctest expects the default displayhook behavior
755 deep down, so our modification breaks it completely. For this reason, a
755 deep down, so our modification breaks it completely. For this reason, a
756 hard monkeypatch seems like a reasonable solution rather than asking
756 hard monkeypatch seems like a reasonable solution rather than asking
757 users to manually use a different doctest runner when under IPython.
757 users to manually use a different doctest runner when under IPython.
758
758
759 Notes
759 Notes
760 -----
760 -----
761
761
762 This function *used to* reload doctest, but this has been disabled because
762 This function *used to* reload doctest, but this has been disabled because
763 reloading doctest unconditionally can cause massive breakage of other
763 reloading doctest unconditionally can cause massive breakage of other
764 doctest-dependent modules already in memory, such as those for IPython's
764 doctest-dependent modules already in memory, such as those for IPython's
765 own testing system. The name wasn't changed to avoid breaking people's
765 own testing system. The name wasn't changed to avoid breaking people's
766 code, but the reload call isn't actually made anymore."""
766 code, but the reload call isn't actually made anymore."""
767
767
768 import doctest
768 import doctest
769 doctest.master = None
769 doctest.master = None
770 doctest.DocTestRunner.run = dhook_wrap(doctest.DocTestRunner.run)
770 doctest.DocTestRunner.run = dhook_wrap(doctest.DocTestRunner.run)
771
771
772 #----------------------------------------------------------------------------
772 #----------------------------------------------------------------------------
773 class HomeDirError(Error):
773 class HomeDirError(Error):
774 pass
774 pass
775
775
776 def get_home_dir():
776 def get_home_dir():
777 """Return the closest possible equivalent to a 'home' directory.
777 """Return the closest possible equivalent to a 'home' directory.
778
778
779 * On POSIX, we try $HOME.
779 * On POSIX, we try $HOME.
780 * On Windows we try:
780 * On Windows we try:
781 - %HOME%: rare, but some people with unix-like setups may have defined it
781 - %HOMESHARE%
782 - %HOMESHARE%
782 - %HOMEDRIVE\%HOMEPATH%
783 - %HOMEDRIVE\%HOMEPATH%
783 - %USERPROFILE%
784 - %USERPROFILE%
784 - Registry hack
785 - Registry hack
785 * On Dos C:\
786 * On Dos C:\
786
787
787 Currently only Posix and NT are implemented, a HomeDirError exception is
788 Currently only Posix and NT are implemented, a HomeDirError exception is
788 raised for all other OSes.
789 raised for all other OSes.
789 """
790 """
790
791
791 isdir = os.path.isdir
792 isdir = os.path.isdir
792 env = os.environ
793 env = os.environ
793
794
794 # first, check py2exe distribution root directory for _ipython.
795 # first, check py2exe distribution root directory for _ipython.
795 # This overrides all. Normally does not exist.
796 # This overrides all. Normally does not exist.
796
797
797 if hasattr(sys, "frozen"): #Is frozen by py2exe
798 if hasattr(sys, "frozen"): #Is frozen by py2exe
798 if '\\library.zip\\' in IPython.__file__.lower():#libraries compressed to zip-file
799 if '\\library.zip\\' in IPython.__file__.lower():#libraries compressed to zip-file
799 root, rest = IPython.__file__.lower().split('library.zip')
800 root, rest = IPython.__file__.lower().split('library.zip')
800 else:
801 else:
801 root=os.path.join(os.path.split(IPython.__file__)[0],"../../")
802 root=os.path.join(os.path.split(IPython.__file__)[0],"../../")
802 root=os.path.abspath(root).rstrip('\\')
803 root=os.path.abspath(root).rstrip('\\')
803 if isdir(os.path.join(root, '_ipython')):
804 if isdir(os.path.join(root, '_ipython')):
804 os.environ["IPYKITROOT"] = root
805 os.environ["IPYKITROOT"] = root
805 return root.decode(sys.getfilesystemencoding())
806 return root.decode(sys.getfilesystemencoding())
806
807
807 if os.name == 'posix':
808 if os.name == 'posix':
808 # Linux, Unix, AIX, OS X
809 # Linux, Unix, AIX, OS X
809 try:
810 try:
810 homedir = env['HOME']
811 homedir = env['HOME']
811 except KeyError:
812 except KeyError:
812 raise HomeDirError('Undefined $HOME, IPython cannot proceed.')
813 raise HomeDirError('Undefined $HOME, IPython cannot proceed.')
813 else:
814 else:
814 return homedir.decode(sys.getfilesystemencoding())
815 return homedir.decode(sys.getfilesystemencoding())
815 elif os.name == 'nt':
816 elif os.name == 'nt':
816 # Now for win9x, XP, Vista, 7?
817 # Now for win9x, XP, Vista, 7?
817 # For some strange reason all of these return 'nt' for os.name.
818 # For some strange reason all of these return 'nt' for os.name.
818 # First look for a network home directory. This will return the UNC
819 # First look for a network home directory. This will return the UNC
819 # path (\\server\\Users\%username%) not the mapped path (Z:\). This
820 # path (\\server\\Users\%username%) not the mapped path (Z:\). This
820 # is needed when running IPython on cluster where all paths have to
821 # is needed when running IPython on cluster where all paths have to
821 # be UNC.
822 # be UNC.
822 try:
823 try:
823 homedir = env['HOMESHARE']
824 # A user with a lot of unix tools in win32 may have defined $HOME,
825 # honor it if it exists, but otherwise let the more typical
826 # %HOMESHARE% variable be used.
827 homedir = env.get('HOME')
828 if homedir is None:
829 homedir = env['HOMESHARE']
824 except KeyError:
830 except KeyError:
825 pass
831 pass
826 else:
832 else:
827 if isdir(homedir):
833 if isdir(homedir):
828 return homedir.decode(sys.getfilesystemencoding())
834 return homedir.decode(sys.getfilesystemencoding())
829
835
830 # Now look for a local home directory
836 # Now look for a local home directory
831 try:
837 try:
832 homedir = os.path.join(env['HOMEDRIVE'],env['HOMEPATH'])
838 homedir = os.path.join(env['HOMEDRIVE'],env['HOMEPATH'])
833 except KeyError:
839 except KeyError:
834 pass
840 pass
835 else:
841 else:
836 if isdir(homedir):
842 if isdir(homedir):
837 return homedir.decode(sys.getfilesystemencoding())
843 return homedir.decode(sys.getfilesystemencoding())
838
844
839 # Now the users profile directory
845 # Now the users profile directory
840 try:
846 try:
841 homedir = os.path.join(env['USERPROFILE'])
847 homedir = os.path.join(env['USERPROFILE'])
842 except KeyError:
848 except KeyError:
843 pass
849 pass
844 else:
850 else:
845 if isdir(homedir):
851 if isdir(homedir):
846 return homedir.decode(sys.getfilesystemencoding())
852 return homedir.decode(sys.getfilesystemencoding())
847
853
848 # Use the registry to get the 'My Documents' folder.
854 # Use the registry to get the 'My Documents' folder.
849 try:
855 try:
850 import _winreg as wreg
856 import _winreg as wreg
851 key = wreg.OpenKey(
857 key = wreg.OpenKey(
852 wreg.HKEY_CURRENT_USER,
858 wreg.HKEY_CURRENT_USER,
853 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
859 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
854 )
860 )
855 homedir = wreg.QueryValueEx(key,'Personal')[0]
861 homedir = wreg.QueryValueEx(key,'Personal')[0]
856 key.Close()
862 key.Close()
857 except:
863 except:
858 pass
864 pass
859 else:
865 else:
860 if isdir(homedir):
866 if isdir(homedir):
861 return homedir.decode(sys.getfilesystemencoding())
867 return homedir.decode(sys.getfilesystemencoding())
862
868
863 # If all else fails, raise HomeDirError
869 # If all else fails, raise HomeDirError
864 raise HomeDirError('No valid home directory could be found')
870 raise HomeDirError('No valid home directory could be found')
865 elif os.name == 'dos':
871 elif os.name == 'dos':
866 # Desperate, may do absurd things in classic MacOS. May work under DOS.
872 # Desperate, may do absurd things in classic MacOS. May work under DOS.
867 return 'C:\\'.decode(sys.getfilesystemencoding())
873 return 'C:\\'.decode(sys.getfilesystemencoding())
868 else:
874 else:
869 raise HomeDirError('No valid home directory could be found for your OS')
875 raise HomeDirError('No valid home directory could be found for your OS')
870
876
871
877
872 def get_ipython_dir():
878 def get_ipython_dir():
873 """Get the IPython directory for this platform and user.
879 """Get the IPython directory for this platform and user.
874
880
875 This uses the logic in `get_home_dir` to find the home directory
881 This uses the logic in `get_home_dir` to find the home directory
876 and the adds .ipython to the end of the path.
882 and the adds .ipython to the end of the path.
877 """
883 """
878 ipdir_def = '.ipython'
884 ipdir_def = '.ipython'
879 home_dir = get_home_dir()
885 home_dir = get_home_dir()
880 ipdir = os.environ.get(
886 ipdir = os.environ.get(
881 'IPYTHON_DIR', os.environ.get(
887 'IPYTHON_DIR', os.environ.get(
882 'IPYTHONDIR', os.path.join(home_dir, ipdir_def)
888 'IPYTHONDIR', os.path.join(home_dir, ipdir_def)
883 )
889 )
884 )
890 )
885 return ipdir.decode(sys.getfilesystemencoding())
891 return ipdir.decode(sys.getfilesystemencoding())
886
892
887
893
888 def get_ipython_package_dir():
894 def get_ipython_package_dir():
889 """Get the base directory where IPython itself is installed."""
895 """Get the base directory where IPython itself is installed."""
890 ipdir = os.path.dirname(IPython.__file__)
896 ipdir = os.path.dirname(IPython.__file__)
891 return ipdir.decode(sys.getfilesystemencoding())
897 return ipdir.decode(sys.getfilesystemencoding())
892
898
893
899
894 #****************************************************************************
900 #****************************************************************************
895 # strings and text
901 # strings and text
896
902
897 class LSString(str):
903 class LSString(str):
898 """String derivative with a special access attributes.
904 """String derivative with a special access attributes.
899
905
900 These are normal strings, but with the special attributes:
906 These are normal strings, but with the special attributes:
901
907
902 .l (or .list) : value as list (split on newlines).
908 .l (or .list) : value as list (split on newlines).
903 .n (or .nlstr): original value (the string itself).
909 .n (or .nlstr): original value (the string itself).
904 .s (or .spstr): value as whitespace-separated string.
910 .s (or .spstr): value as whitespace-separated string.
905 .p (or .paths): list of path objects
911 .p (or .paths): list of path objects
906
912
907 Any values which require transformations are computed only once and
913 Any values which require transformations are computed only once and
908 cached.
914 cached.
909
915
910 Such strings are very useful to efficiently interact with the shell, which
916 Such strings are very useful to efficiently interact with the shell, which
911 typically only understands whitespace-separated options for commands."""
917 typically only understands whitespace-separated options for commands."""
912
918
913 def get_list(self):
919 def get_list(self):
914 try:
920 try:
915 return self.__list
921 return self.__list
916 except AttributeError:
922 except AttributeError:
917 self.__list = self.split('\n')
923 self.__list = self.split('\n')
918 return self.__list
924 return self.__list
919
925
920 l = list = property(get_list)
926 l = list = property(get_list)
921
927
922 def get_spstr(self):
928 def get_spstr(self):
923 try:
929 try:
924 return self.__spstr
930 return self.__spstr
925 except AttributeError:
931 except AttributeError:
926 self.__spstr = self.replace('\n',' ')
932 self.__spstr = self.replace('\n',' ')
927 return self.__spstr
933 return self.__spstr
928
934
929 s = spstr = property(get_spstr)
935 s = spstr = property(get_spstr)
930
936
931 def get_nlstr(self):
937 def get_nlstr(self):
932 return self
938 return self
933
939
934 n = nlstr = property(get_nlstr)
940 n = nlstr = property(get_nlstr)
935
941
936 def get_paths(self):
942 def get_paths(self):
937 try:
943 try:
938 return self.__paths
944 return self.__paths
939 except AttributeError:
945 except AttributeError:
940 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
946 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
941 return self.__paths
947 return self.__paths
942
948
943 p = paths = property(get_paths)
949 p = paths = property(get_paths)
944
950
945 def print_lsstring(arg):
951 def print_lsstring(arg):
946 """ Prettier (non-repr-like) and more informative printer for LSString """
952 """ Prettier (non-repr-like) and more informative printer for LSString """
947 print "LSString (.p, .n, .l, .s available). Value:"
953 print "LSString (.p, .n, .l, .s available). Value:"
948 print arg
954 print arg
949
955
950 print_lsstring = result_display.when_type(LSString)(print_lsstring)
956 print_lsstring = result_display.when_type(LSString)(print_lsstring)
951
957
952 #----------------------------------------------------------------------------
958 #----------------------------------------------------------------------------
953 class SList(list):
959 class SList(list):
954 """List derivative with a special access attributes.
960 """List derivative with a special access attributes.
955
961
956 These are normal lists, but with the special attributes:
962 These are normal lists, but with the special attributes:
957
963
958 .l (or .list) : value as list (the list itself).
964 .l (or .list) : value as list (the list itself).
959 .n (or .nlstr): value as a string, joined on newlines.
965 .n (or .nlstr): value as a string, joined on newlines.
960 .s (or .spstr): value as a string, joined on spaces.
966 .s (or .spstr): value as a string, joined on spaces.
961 .p (or .paths): list of path objects
967 .p (or .paths): list of path objects
962
968
963 Any values which require transformations are computed only once and
969 Any values which require transformations are computed only once and
964 cached."""
970 cached."""
965
971
966 def get_list(self):
972 def get_list(self):
967 return self
973 return self
968
974
969 l = list = property(get_list)
975 l = list = property(get_list)
970
976
971 def get_spstr(self):
977 def get_spstr(self):
972 try:
978 try:
973 return self.__spstr
979 return self.__spstr
974 except AttributeError:
980 except AttributeError:
975 self.__spstr = ' '.join(self)
981 self.__spstr = ' '.join(self)
976 return self.__spstr
982 return self.__spstr
977
983
978 s = spstr = property(get_spstr)
984 s = spstr = property(get_spstr)
979
985
980 def get_nlstr(self):
986 def get_nlstr(self):
981 try:
987 try:
982 return self.__nlstr
988 return self.__nlstr
983 except AttributeError:
989 except AttributeError:
984 self.__nlstr = '\n'.join(self)
990 self.__nlstr = '\n'.join(self)
985 return self.__nlstr
991 return self.__nlstr
986
992
987 n = nlstr = property(get_nlstr)
993 n = nlstr = property(get_nlstr)
988
994
989 def get_paths(self):
995 def get_paths(self):
990 try:
996 try:
991 return self.__paths
997 return self.__paths
992 except AttributeError:
998 except AttributeError:
993 self.__paths = [path(p) for p in self if os.path.exists(p)]
999 self.__paths = [path(p) for p in self if os.path.exists(p)]
994 return self.__paths
1000 return self.__paths
995
1001
996 p = paths = property(get_paths)
1002 p = paths = property(get_paths)
997
1003
998 def grep(self, pattern, prune = False, field = None):
1004 def grep(self, pattern, prune = False, field = None):
999 """ Return all strings matching 'pattern' (a regex or callable)
1005 """ Return all strings matching 'pattern' (a regex or callable)
1000
1006
1001 This is case-insensitive. If prune is true, return all items
1007 This is case-insensitive. If prune is true, return all items
1002 NOT matching the pattern.
1008 NOT matching the pattern.
1003
1009
1004 If field is specified, the match must occur in the specified
1010 If field is specified, the match must occur in the specified
1005 whitespace-separated field.
1011 whitespace-separated field.
1006
1012
1007 Examples::
1013 Examples::
1008
1014
1009 a.grep( lambda x: x.startswith('C') )
1015 a.grep( lambda x: x.startswith('C') )
1010 a.grep('Cha.*log', prune=1)
1016 a.grep('Cha.*log', prune=1)
1011 a.grep('chm', field=-1)
1017 a.grep('chm', field=-1)
1012 """
1018 """
1013
1019
1014 def match_target(s):
1020 def match_target(s):
1015 if field is None:
1021 if field is None:
1016 return s
1022 return s
1017 parts = s.split()
1023 parts = s.split()
1018 try:
1024 try:
1019 tgt = parts[field]
1025 tgt = parts[field]
1020 return tgt
1026 return tgt
1021 except IndexError:
1027 except IndexError:
1022 return ""
1028 return ""
1023
1029
1024 if isinstance(pattern, basestring):
1030 if isinstance(pattern, basestring):
1025 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
1031 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
1026 else:
1032 else:
1027 pred = pattern
1033 pred = pattern
1028 if not prune:
1034 if not prune:
1029 return SList([el for el in self if pred(match_target(el))])
1035 return SList([el for el in self if pred(match_target(el))])
1030 else:
1036 else:
1031 return SList([el for el in self if not pred(match_target(el))])
1037 return SList([el for el in self if not pred(match_target(el))])
1032 def fields(self, *fields):
1038 def fields(self, *fields):
1033 """ Collect whitespace-separated fields from string list
1039 """ Collect whitespace-separated fields from string list
1034
1040
1035 Allows quick awk-like usage of string lists.
1041 Allows quick awk-like usage of string lists.
1036
1042
1037 Example data (in var a, created by 'a = !ls -l')::
1043 Example data (in var a, created by 'a = !ls -l')::
1038 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
1044 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
1039 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
1045 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
1040
1046
1041 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
1047 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
1042 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
1048 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
1043 (note the joining by space).
1049 (note the joining by space).
1044 a.fields(-1) is ['ChangeLog', 'IPython']
1050 a.fields(-1) is ['ChangeLog', 'IPython']
1045
1051
1046 IndexErrors are ignored.
1052 IndexErrors are ignored.
1047
1053
1048 Without args, fields() just split()'s the strings.
1054 Without args, fields() just split()'s the strings.
1049 """
1055 """
1050 if len(fields) == 0:
1056 if len(fields) == 0:
1051 return [el.split() for el in self]
1057 return [el.split() for el in self]
1052
1058
1053 res = SList()
1059 res = SList()
1054 for el in [f.split() for f in self]:
1060 for el in [f.split() for f in self]:
1055 lineparts = []
1061 lineparts = []
1056
1062
1057 for fd in fields:
1063 for fd in fields:
1058 try:
1064 try:
1059 lineparts.append(el[fd])
1065 lineparts.append(el[fd])
1060 except IndexError:
1066 except IndexError:
1061 pass
1067 pass
1062 if lineparts:
1068 if lineparts:
1063 res.append(" ".join(lineparts))
1069 res.append(" ".join(lineparts))
1064
1070
1065 return res
1071 return res
1066 def sort(self,field= None, nums = False):
1072 def sort(self,field= None, nums = False):
1067 """ sort by specified fields (see fields())
1073 """ sort by specified fields (see fields())
1068
1074
1069 Example::
1075 Example::
1070 a.sort(1, nums = True)
1076 a.sort(1, nums = True)
1071
1077
1072 Sorts a by second field, in numerical order (so that 21 > 3)
1078 Sorts a by second field, in numerical order (so that 21 > 3)
1073
1079
1074 """
1080 """
1075
1081
1076 #decorate, sort, undecorate
1082 #decorate, sort, undecorate
1077 if field is not None:
1083 if field is not None:
1078 dsu = [[SList([line]).fields(field), line] for line in self]
1084 dsu = [[SList([line]).fields(field), line] for line in self]
1079 else:
1085 else:
1080 dsu = [[line, line] for line in self]
1086 dsu = [[line, line] for line in self]
1081 if nums:
1087 if nums:
1082 for i in range(len(dsu)):
1088 for i in range(len(dsu)):
1083 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
1089 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
1084 try:
1090 try:
1085 n = int(numstr)
1091 n = int(numstr)
1086 except ValueError:
1092 except ValueError:
1087 n = 0;
1093 n = 0;
1088 dsu[i][0] = n
1094 dsu[i][0] = n
1089
1095
1090
1096
1091 dsu.sort()
1097 dsu.sort()
1092 return SList([t[1] for t in dsu])
1098 return SList([t[1] for t in dsu])
1093
1099
1094 def print_slist(arg):
1100 def print_slist(arg):
1095 """ Prettier (non-repr-like) and more informative printer for SList """
1101 """ Prettier (non-repr-like) and more informative printer for SList """
1096 print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
1102 print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
1097 if hasattr(arg, 'hideonce') and arg.hideonce:
1103 if hasattr(arg, 'hideonce') and arg.hideonce:
1098 arg.hideonce = False
1104 arg.hideonce = False
1099 return
1105 return
1100
1106
1101 nlprint(arg)
1107 nlprint(arg)
1102
1108
1103 print_slist = result_display.when_type(SList)(print_slist)
1109 print_slist = result_display.when_type(SList)(print_slist)
1104
1110
1105
1111
1106
1112
1107 #----------------------------------------------------------------------------
1113 #----------------------------------------------------------------------------
1108 def esc_quotes(strng):
1114 def esc_quotes(strng):
1109 """Return the input string with single and double quotes escaped out"""
1115 """Return the input string with single and double quotes escaped out"""
1110
1116
1111 return strng.replace('"','\\"').replace("'","\\'")
1117 return strng.replace('"','\\"').replace("'","\\'")
1112
1118
1113 #----------------------------------------------------------------------------
1119 #----------------------------------------------------------------------------
1114 def make_quoted_expr(s):
1120 def make_quoted_expr(s):
1115 """Return string s in appropriate quotes, using raw string if possible.
1121 """Return string s in appropriate quotes, using raw string if possible.
1116
1122
1117 XXX - example removed because it caused encoding errors in documentation
1123 XXX - example removed because it caused encoding errors in documentation
1118 generation. We need a new example that doesn't contain invalid chars.
1124 generation. We need a new example that doesn't contain invalid chars.
1119
1125
1120 Note the use of raw string and padding at the end to allow trailing
1126 Note the use of raw string and padding at the end to allow trailing
1121 backslash.
1127 backslash.
1122 """
1128 """
1123
1129
1124 tail = ''
1130 tail = ''
1125 tailpadding = ''
1131 tailpadding = ''
1126 raw = ''
1132 raw = ''
1127 if "\\" in s:
1133 if "\\" in s:
1128 raw = 'r'
1134 raw = 'r'
1129 if s.endswith('\\'):
1135 if s.endswith('\\'):
1130 tail = '[:-1]'
1136 tail = '[:-1]'
1131 tailpadding = '_'
1137 tailpadding = '_'
1132 if '"' not in s:
1138 if '"' not in s:
1133 quote = '"'
1139 quote = '"'
1134 elif "'" not in s:
1140 elif "'" not in s:
1135 quote = "'"
1141 quote = "'"
1136 elif '"""' not in s and not s.endswith('"'):
1142 elif '"""' not in s and not s.endswith('"'):
1137 quote = '"""'
1143 quote = '"""'
1138 elif "'''" not in s and not s.endswith("'"):
1144 elif "'''" not in s and not s.endswith("'"):
1139 quote = "'''"
1145 quote = "'''"
1140 else:
1146 else:
1141 # give up, backslash-escaped string will do
1147 # give up, backslash-escaped string will do
1142 return '"%s"' % esc_quotes(s)
1148 return '"%s"' % esc_quotes(s)
1143 res = raw + quote + s + tailpadding + quote + tail
1149 res = raw + quote + s + tailpadding + quote + tail
1144 return res
1150 return res
1145
1151
1146
1152
1147 #----------------------------------------------------------------------------
1153 #----------------------------------------------------------------------------
1148 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
1154 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
1149 """Take multiple lines of input.
1155 """Take multiple lines of input.
1150
1156
1151 A list with each line of input as a separate element is returned when a
1157 A list with each line of input as a separate element is returned when a
1152 termination string is entered (defaults to a single '.'). Input can also
1158 termination string is entered (defaults to a single '.'). Input can also
1153 terminate via EOF (^D in Unix, ^Z-RET in Windows).
1159 terminate via EOF (^D in Unix, ^Z-RET in Windows).
1154
1160
1155 Lines of input which end in \\ are joined into single entries (and a
1161 Lines of input which end in \\ are joined into single entries (and a
1156 secondary continuation prompt is issued as long as the user terminates
1162 secondary continuation prompt is issued as long as the user terminates
1157 lines with \\). This allows entering very long strings which are still
1163 lines with \\). This allows entering very long strings which are still
1158 meant to be treated as single entities.
1164 meant to be treated as single entities.
1159 """
1165 """
1160
1166
1161 try:
1167 try:
1162 if header:
1168 if header:
1163 header += '\n'
1169 header += '\n'
1164 lines = [raw_input(header + ps1)]
1170 lines = [raw_input(header + ps1)]
1165 except EOFError:
1171 except EOFError:
1166 return []
1172 return []
1167 terminate = [terminate_str]
1173 terminate = [terminate_str]
1168 try:
1174 try:
1169 while lines[-1:] != terminate:
1175 while lines[-1:] != terminate:
1170 new_line = raw_input(ps1)
1176 new_line = raw_input(ps1)
1171 while new_line.endswith('\\'):
1177 while new_line.endswith('\\'):
1172 new_line = new_line[:-1] + raw_input(ps2)
1178 new_line = new_line[:-1] + raw_input(ps2)
1173 lines.append(new_line)
1179 lines.append(new_line)
1174
1180
1175 return lines[:-1] # don't return the termination command
1181 return lines[:-1] # don't return the termination command
1176 except EOFError:
1182 except EOFError:
1177 print
1183 print
1178 return lines
1184 return lines
1179
1185
1180 #----------------------------------------------------------------------------
1186 #----------------------------------------------------------------------------
1181 def raw_input_ext(prompt='', ps2='... '):
1187 def raw_input_ext(prompt='', ps2='... '):
1182 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
1188 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
1183
1189
1184 line = raw_input(prompt)
1190 line = raw_input(prompt)
1185 while line.endswith('\\'):
1191 while line.endswith('\\'):
1186 line = line[:-1] + raw_input(ps2)
1192 line = line[:-1] + raw_input(ps2)
1187 return line
1193 return line
1188
1194
1189 #----------------------------------------------------------------------------
1195 #----------------------------------------------------------------------------
1190 def ask_yes_no(prompt,default=None):
1196 def ask_yes_no(prompt,default=None):
1191 """Asks a question and returns a boolean (y/n) answer.
1197 """Asks a question and returns a boolean (y/n) answer.
1192
1198
1193 If default is given (one of 'y','n'), it is used if the user input is
1199 If default is given (one of 'y','n'), it is used if the user input is
1194 empty. Otherwise the question is repeated until an answer is given.
1200 empty. Otherwise the question is repeated until an answer is given.
1195
1201
1196 An EOF is treated as the default answer. If there is no default, an
1202 An EOF is treated as the default answer. If there is no default, an
1197 exception is raised to prevent infinite loops.
1203 exception is raised to prevent infinite loops.
1198
1204
1199 Valid answers are: y/yes/n/no (match is not case sensitive)."""
1205 Valid answers are: y/yes/n/no (match is not case sensitive)."""
1200
1206
1201 answers = {'y':True,'n':False,'yes':True,'no':False}
1207 answers = {'y':True,'n':False,'yes':True,'no':False}
1202 ans = None
1208 ans = None
1203 while ans not in answers.keys():
1209 while ans not in answers.keys():
1204 try:
1210 try:
1205 ans = raw_input(prompt+' ').lower()
1211 ans = raw_input(prompt+' ').lower()
1206 if not ans: # response was an empty string
1212 if not ans: # response was an empty string
1207 ans = default
1213 ans = default
1208 except KeyboardInterrupt:
1214 except KeyboardInterrupt:
1209 pass
1215 pass
1210 except EOFError:
1216 except EOFError:
1211 if default in answers.keys():
1217 if default in answers.keys():
1212 ans = default
1218 ans = default
1213 print
1219 print
1214 else:
1220 else:
1215 raise
1221 raise
1216
1222
1217 return answers[ans]
1223 return answers[ans]
1218
1224
1219 #----------------------------------------------------------------------------
1225 #----------------------------------------------------------------------------
1220 class EvalDict:
1226 class EvalDict:
1221 """
1227 """
1222 Emulate a dict which evaluates its contents in the caller's frame.
1228 Emulate a dict which evaluates its contents in the caller's frame.
1223
1229
1224 Usage:
1230 Usage:
1225 >>> number = 19
1231 >>> number = 19
1226
1232
1227 >>> text = "python"
1233 >>> text = "python"
1228
1234
1229 >>> print "%(text.capitalize())s %(number/9.0).1f rules!" % EvalDict()
1235 >>> print "%(text.capitalize())s %(number/9.0).1f rules!" % EvalDict()
1230 Python 2.1 rules!
1236 Python 2.1 rules!
1231 """
1237 """
1232
1238
1233 # This version is due to sismex01@hebmex.com on c.l.py, and is basically a
1239 # This version is due to sismex01@hebmex.com on c.l.py, and is basically a
1234 # modified (shorter) version of:
1240 # modified (shorter) version of:
1235 # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66018 by
1241 # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66018 by
1236 # Skip Montanaro (skip@pobox.com).
1242 # Skip Montanaro (skip@pobox.com).
1237
1243
1238 def __getitem__(self, name):
1244 def __getitem__(self, name):
1239 frame = sys._getframe(1)
1245 frame = sys._getframe(1)
1240 return eval(name, frame.f_globals, frame.f_locals)
1246 return eval(name, frame.f_globals, frame.f_locals)
1241
1247
1242 EvalString = EvalDict # for backwards compatibility
1248 EvalString = EvalDict # for backwards compatibility
1243 #----------------------------------------------------------------------------
1249 #----------------------------------------------------------------------------
1244 def qw(words,flat=0,sep=None,maxsplit=-1):
1250 def qw(words,flat=0,sep=None,maxsplit=-1):
1245 """Similar to Perl's qw() operator, but with some more options.
1251 """Similar to Perl's qw() operator, but with some more options.
1246
1252
1247 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
1253 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
1248
1254
1249 words can also be a list itself, and with flat=1, the output will be
1255 words can also be a list itself, and with flat=1, the output will be
1250 recursively flattened.
1256 recursively flattened.
1251
1257
1252 Examples:
1258 Examples:
1253
1259
1254 >>> qw('1 2')
1260 >>> qw('1 2')
1255 ['1', '2']
1261 ['1', '2']
1256
1262
1257 >>> qw(['a b','1 2',['m n','p q']])
1263 >>> qw(['a b','1 2',['m n','p q']])
1258 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
1264 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
1259
1265
1260 >>> qw(['a b','1 2',['m n','p q']],flat=1)
1266 >>> qw(['a b','1 2',['m n','p q']],flat=1)
1261 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
1267 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
1262 """
1268 """
1263
1269
1264 if type(words) in StringTypes:
1270 if type(words) in StringTypes:
1265 return [word.strip() for word in words.split(sep,maxsplit)
1271 return [word.strip() for word in words.split(sep,maxsplit)
1266 if word and not word.isspace() ]
1272 if word and not word.isspace() ]
1267 if flat:
1273 if flat:
1268 return flatten(map(qw,words,[1]*len(words)))
1274 return flatten(map(qw,words,[1]*len(words)))
1269 return map(qw,words)
1275 return map(qw,words)
1270
1276
1271 #----------------------------------------------------------------------------
1277 #----------------------------------------------------------------------------
1272 def qwflat(words,sep=None,maxsplit=-1):
1278 def qwflat(words,sep=None,maxsplit=-1):
1273 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
1279 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
1274 return qw(words,1,sep,maxsplit)
1280 return qw(words,1,sep,maxsplit)
1275
1281
1276 #----------------------------------------------------------------------------
1282 #----------------------------------------------------------------------------
1277 def qw_lol(indata):
1283 def qw_lol(indata):
1278 """qw_lol('a b') -> [['a','b']],
1284 """qw_lol('a b') -> [['a','b']],
1279 otherwise it's just a call to qw().
1285 otherwise it's just a call to qw().
1280
1286
1281 We need this to make sure the modules_some keys *always* end up as a
1287 We need this to make sure the modules_some keys *always* end up as a
1282 list of lists."""
1288 list of lists."""
1283
1289
1284 if type(indata) in StringTypes:
1290 if type(indata) in StringTypes:
1285 return [qw(indata)]
1291 return [qw(indata)]
1286 else:
1292 else:
1287 return qw(indata)
1293 return qw(indata)
1288
1294
1289 #----------------------------------------------------------------------------
1295 #----------------------------------------------------------------------------
1290 def grep(pat,list,case=1):
1296 def grep(pat,list,case=1):
1291 """Simple minded grep-like function.
1297 """Simple minded grep-like function.
1292 grep(pat,list) returns occurrences of pat in list, None on failure.
1298 grep(pat,list) returns occurrences of pat in list, None on failure.
1293
1299
1294 It only does simple string matching, with no support for regexps. Use the
1300 It only does simple string matching, with no support for regexps. Use the
1295 option case=0 for case-insensitive matching."""
1301 option case=0 for case-insensitive matching."""
1296
1302
1297 # This is pretty crude. At least it should implement copying only references
1303 # This is pretty crude. At least it should implement copying only references
1298 # to the original data in case it's big. Now it copies the data for output.
1304 # to the original data in case it's big. Now it copies the data for output.
1299 out=[]
1305 out=[]
1300 if case:
1306 if case:
1301 for term in list:
1307 for term in list:
1302 if term.find(pat)>-1: out.append(term)
1308 if term.find(pat)>-1: out.append(term)
1303 else:
1309 else:
1304 lpat=pat.lower()
1310 lpat=pat.lower()
1305 for term in list:
1311 for term in list:
1306 if term.lower().find(lpat)>-1: out.append(term)
1312 if term.lower().find(lpat)>-1: out.append(term)
1307
1313
1308 if len(out): return out
1314 if len(out): return out
1309 else: return None
1315 else: return None
1310
1316
1311 #----------------------------------------------------------------------------
1317 #----------------------------------------------------------------------------
1312 def dgrep(pat,*opts):
1318 def dgrep(pat,*opts):
1313 """Return grep() on dir()+dir(__builtins__).
1319 """Return grep() on dir()+dir(__builtins__).
1314
1320
1315 A very common use of grep() when working interactively."""
1321 A very common use of grep() when working interactively."""
1316
1322
1317 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
1323 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
1318
1324
1319 #----------------------------------------------------------------------------
1325 #----------------------------------------------------------------------------
1320 def idgrep(pat):
1326 def idgrep(pat):
1321 """Case-insensitive dgrep()"""
1327 """Case-insensitive dgrep()"""
1322
1328
1323 return dgrep(pat,0)
1329 return dgrep(pat,0)
1324
1330
1325 #----------------------------------------------------------------------------
1331 #----------------------------------------------------------------------------
1326 def igrep(pat,list):
1332 def igrep(pat,list):
1327 """Synonym for case-insensitive grep."""
1333 """Synonym for case-insensitive grep."""
1328
1334
1329 return grep(pat,list,case=0)
1335 return grep(pat,list,case=0)
1330
1336
1331 #----------------------------------------------------------------------------
1337 #----------------------------------------------------------------------------
1332 def indent(str,nspaces=4,ntabs=0):
1338 def indent(str,nspaces=4,ntabs=0):
1333 """Indent a string a given number of spaces or tabstops.
1339 """Indent a string a given number of spaces or tabstops.
1334
1340
1335 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
1341 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
1336 """
1342 """
1337 if str is None:
1343 if str is None:
1338 return
1344 return
1339 ind = '\t'*ntabs+' '*nspaces
1345 ind = '\t'*ntabs+' '*nspaces
1340 outstr = '%s%s' % (ind,str.replace(os.linesep,os.linesep+ind))
1346 outstr = '%s%s' % (ind,str.replace(os.linesep,os.linesep+ind))
1341 if outstr.endswith(os.linesep+ind):
1347 if outstr.endswith(os.linesep+ind):
1342 return outstr[:-len(ind)]
1348 return outstr[:-len(ind)]
1343 else:
1349 else:
1344 return outstr
1350 return outstr
1345
1351
1346 #-----------------------------------------------------------------------------
1352 #-----------------------------------------------------------------------------
1347 def native_line_ends(filename,backup=1):
1353 def native_line_ends(filename,backup=1):
1348 """Convert (in-place) a file to line-ends native to the current OS.
1354 """Convert (in-place) a file to line-ends native to the current OS.
1349
1355
1350 If the optional backup argument is given as false, no backup of the
1356 If the optional backup argument is given as false, no backup of the
1351 original file is left. """
1357 original file is left. """
1352
1358
1353 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
1359 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
1354
1360
1355 bak_filename = filename + backup_suffixes[os.name]
1361 bak_filename = filename + backup_suffixes[os.name]
1356
1362
1357 original = open(filename).read()
1363 original = open(filename).read()
1358 shutil.copy2(filename,bak_filename)
1364 shutil.copy2(filename,bak_filename)
1359 try:
1365 try:
1360 new = open(filename,'wb')
1366 new = open(filename,'wb')
1361 new.write(os.linesep.join(original.splitlines()))
1367 new.write(os.linesep.join(original.splitlines()))
1362 new.write(os.linesep) # ALWAYS put an eol at the end of the file
1368 new.write(os.linesep) # ALWAYS put an eol at the end of the file
1363 new.close()
1369 new.close()
1364 except:
1370 except:
1365 os.rename(bak_filename,filename)
1371 os.rename(bak_filename,filename)
1366 if not backup:
1372 if not backup:
1367 try:
1373 try:
1368 os.remove(bak_filename)
1374 os.remove(bak_filename)
1369 except:
1375 except:
1370 pass
1376 pass
1371
1377
1372 #****************************************************************************
1378 #****************************************************************************
1373 # lists, dicts and structures
1379 # lists, dicts and structures
1374
1380
1375 def belong(candidates,checklist):
1381 def belong(candidates,checklist):
1376 """Check whether a list of items appear in a given list of options.
1382 """Check whether a list of items appear in a given list of options.
1377
1383
1378 Returns a list of 1 and 0, one for each candidate given."""
1384 Returns a list of 1 and 0, one for each candidate given."""
1379
1385
1380 return [x in checklist for x in candidates]
1386 return [x in checklist for x in candidates]
1381
1387
1382 #----------------------------------------------------------------------------
1388 #----------------------------------------------------------------------------
1383 def uniq_stable(elems):
1389 def uniq_stable(elems):
1384 """uniq_stable(elems) -> list
1390 """uniq_stable(elems) -> list
1385
1391
1386 Return from an iterable, a list of all the unique elements in the input,
1392 Return from an iterable, a list of all the unique elements in the input,
1387 but maintaining the order in which they first appear.
1393 but maintaining the order in which they first appear.
1388
1394
1389 A naive solution to this problem which just makes a dictionary with the
1395 A naive solution to this problem which just makes a dictionary with the
1390 elements as keys fails to respect the stability condition, since
1396 elements as keys fails to respect the stability condition, since
1391 dictionaries are unsorted by nature.
1397 dictionaries are unsorted by nature.
1392
1398
1393 Note: All elements in the input must be valid dictionary keys for this
1399 Note: All elements in the input must be valid dictionary keys for this
1394 routine to work, as it internally uses a dictionary for efficiency
1400 routine to work, as it internally uses a dictionary for efficiency
1395 reasons."""
1401 reasons."""
1396
1402
1397 unique = []
1403 unique = []
1398 unique_dict = {}
1404 unique_dict = {}
1399 for nn in elems:
1405 for nn in elems:
1400 if nn not in unique_dict:
1406 if nn not in unique_dict:
1401 unique.append(nn)
1407 unique.append(nn)
1402 unique_dict[nn] = None
1408 unique_dict[nn] = None
1403 return unique
1409 return unique
1404
1410
1405 #----------------------------------------------------------------------------
1411 #----------------------------------------------------------------------------
1406 class NLprinter:
1412 class NLprinter:
1407 """Print an arbitrarily nested list, indicating index numbers.
1413 """Print an arbitrarily nested list, indicating index numbers.
1408
1414
1409 An instance of this class called nlprint is available and callable as a
1415 An instance of this class called nlprint is available and callable as a
1410 function.
1416 function.
1411
1417
1412 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
1418 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
1413 and using 'sep' to separate the index from the value. """
1419 and using 'sep' to separate the index from the value. """
1414
1420
1415 def __init__(self):
1421 def __init__(self):
1416 self.depth = 0
1422 self.depth = 0
1417
1423
1418 def __call__(self,lst,pos='',**kw):
1424 def __call__(self,lst,pos='',**kw):
1419 """Prints the nested list numbering levels."""
1425 """Prints the nested list numbering levels."""
1420 kw.setdefault('indent',' ')
1426 kw.setdefault('indent',' ')
1421 kw.setdefault('sep',': ')
1427 kw.setdefault('sep',': ')
1422 kw.setdefault('start',0)
1428 kw.setdefault('start',0)
1423 kw.setdefault('stop',len(lst))
1429 kw.setdefault('stop',len(lst))
1424 # we need to remove start and stop from kw so they don't propagate
1430 # we need to remove start and stop from kw so they don't propagate
1425 # into a recursive call for a nested list.
1431 # into a recursive call for a nested list.
1426 start = kw['start']; del kw['start']
1432 start = kw['start']; del kw['start']
1427 stop = kw['stop']; del kw['stop']
1433 stop = kw['stop']; del kw['stop']
1428 if self.depth == 0 and 'header' in kw.keys():
1434 if self.depth == 0 and 'header' in kw.keys():
1429 print kw['header']
1435 print kw['header']
1430
1436
1431 for idx in range(start,stop):
1437 for idx in range(start,stop):
1432 elem = lst[idx]
1438 elem = lst[idx]
1433 if type(elem)==type([]):
1439 if type(elem)==type([]):
1434 self.depth += 1
1440 self.depth += 1
1435 self.__call__(elem,itpl('$pos$idx,'),**kw)
1441 self.__call__(elem,itpl('$pos$idx,'),**kw)
1436 self.depth -= 1
1442 self.depth -= 1
1437 else:
1443 else:
1438 printpl(kw['indent']*self.depth+'$pos$idx$kw["sep"]$elem')
1444 printpl(kw['indent']*self.depth+'$pos$idx$kw["sep"]$elem')
1439
1445
1440 nlprint = NLprinter()
1446 nlprint = NLprinter()
1441 #----------------------------------------------------------------------------
1447 #----------------------------------------------------------------------------
1442 def all_belong(candidates,checklist):
1448 def all_belong(candidates,checklist):
1443 """Check whether a list of items ALL appear in a given list of options.
1449 """Check whether a list of items ALL appear in a given list of options.
1444
1450
1445 Returns a single 1 or 0 value."""
1451 Returns a single 1 or 0 value."""
1446
1452
1447 return 1-(0 in [x in checklist for x in candidates])
1453 return 1-(0 in [x in checklist for x in candidates])
1448
1454
1449 #----------------------------------------------------------------------------
1455 #----------------------------------------------------------------------------
1450 def sort_compare(lst1,lst2,inplace = 1):
1456 def sort_compare(lst1,lst2,inplace = 1):
1451 """Sort and compare two lists.
1457 """Sort and compare two lists.
1452
1458
1453 By default it does it in place, thus modifying the lists. Use inplace = 0
1459 By default it does it in place, thus modifying the lists. Use inplace = 0
1454 to avoid that (at the cost of temporary copy creation)."""
1460 to avoid that (at the cost of temporary copy creation)."""
1455 if not inplace:
1461 if not inplace:
1456 lst1 = lst1[:]
1462 lst1 = lst1[:]
1457 lst2 = lst2[:]
1463 lst2 = lst2[:]
1458 lst1.sort(); lst2.sort()
1464 lst1.sort(); lst2.sort()
1459 return lst1 == lst2
1465 return lst1 == lst2
1460
1466
1461 #----------------------------------------------------------------------------
1467 #----------------------------------------------------------------------------
1462 def list2dict(lst):
1468 def list2dict(lst):
1463 """Takes a list of (key,value) pairs and turns it into a dict."""
1469 """Takes a list of (key,value) pairs and turns it into a dict."""
1464
1470
1465 dic = {}
1471 dic = {}
1466 for k,v in lst: dic[k] = v
1472 for k,v in lst: dic[k] = v
1467 return dic
1473 return dic
1468
1474
1469 #----------------------------------------------------------------------------
1475 #----------------------------------------------------------------------------
1470 def list2dict2(lst,default=''):
1476 def list2dict2(lst,default=''):
1471 """Takes a list and turns it into a dict.
1477 """Takes a list and turns it into a dict.
1472 Much slower than list2dict, but more versatile. This version can take
1478 Much slower than list2dict, but more versatile. This version can take
1473 lists with sublists of arbitrary length (including sclars)."""
1479 lists with sublists of arbitrary length (including sclars)."""
1474
1480
1475 dic = {}
1481 dic = {}
1476 for elem in lst:
1482 for elem in lst:
1477 if type(elem) in (types.ListType,types.TupleType):
1483 if type(elem) in (types.ListType,types.TupleType):
1478 size = len(elem)
1484 size = len(elem)
1479 if size == 0:
1485 if size == 0:
1480 pass
1486 pass
1481 elif size == 1:
1487 elif size == 1:
1482 dic[elem] = default
1488 dic[elem] = default
1483 else:
1489 else:
1484 k,v = elem[0], elem[1:]
1490 k,v = elem[0], elem[1:]
1485 if len(v) == 1: v = v[0]
1491 if len(v) == 1: v = v[0]
1486 dic[k] = v
1492 dic[k] = v
1487 else:
1493 else:
1488 dic[elem] = default
1494 dic[elem] = default
1489 return dic
1495 return dic
1490
1496
1491 #----------------------------------------------------------------------------
1497 #----------------------------------------------------------------------------
1492 def flatten(seq):
1498 def flatten(seq):
1493 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
1499 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
1494
1500
1495 return [x for subseq in seq for x in subseq]
1501 return [x for subseq in seq for x in subseq]
1496
1502
1497 #----------------------------------------------------------------------------
1503 #----------------------------------------------------------------------------
1498 def get_slice(seq,start=0,stop=None,step=1):
1504 def get_slice(seq,start=0,stop=None,step=1):
1499 """Get a slice of a sequence with variable step. Specify start,stop,step."""
1505 """Get a slice of a sequence with variable step. Specify start,stop,step."""
1500 if stop == None:
1506 if stop == None:
1501 stop = len(seq)
1507 stop = len(seq)
1502 item = lambda i: seq[i]
1508 item = lambda i: seq[i]
1503 return map(item,xrange(start,stop,step))
1509 return map(item,xrange(start,stop,step))
1504
1510
1505 #----------------------------------------------------------------------------
1511 #----------------------------------------------------------------------------
1506 def chop(seq,size):
1512 def chop(seq,size):
1507 """Chop a sequence into chunks of the given size."""
1513 """Chop a sequence into chunks of the given size."""
1508 chunk = lambda i: seq[i:i+size]
1514 chunk = lambda i: seq[i:i+size]
1509 return map(chunk,xrange(0,len(seq),size))
1515 return map(chunk,xrange(0,len(seq),size))
1510
1516
1511 #----------------------------------------------------------------------------
1517 #----------------------------------------------------------------------------
1512 # with is a keyword as of python 2.5, so this function is renamed to withobj
1518 # with is a keyword as of python 2.5, so this function is renamed to withobj
1513 # from its old 'with' name.
1519 # from its old 'with' name.
1514 def with_obj(object, **args):
1520 def with_obj(object, **args):
1515 """Set multiple attributes for an object, similar to Pascal's with.
1521 """Set multiple attributes for an object, similar to Pascal's with.
1516
1522
1517 Example:
1523 Example:
1518 with_obj(jim,
1524 with_obj(jim,
1519 born = 1960,
1525 born = 1960,
1520 haircolour = 'Brown',
1526 haircolour = 'Brown',
1521 eyecolour = 'Green')
1527 eyecolour = 'Green')
1522
1528
1523 Credit: Greg Ewing, in
1529 Credit: Greg Ewing, in
1524 http://mail.python.org/pipermail/python-list/2001-May/040703.html.
1530 http://mail.python.org/pipermail/python-list/2001-May/040703.html.
1525
1531
1526 NOTE: up until IPython 0.7.2, this was called simply 'with', but 'with'
1532 NOTE: up until IPython 0.7.2, this was called simply 'with', but 'with'
1527 has become a keyword for Python 2.5, so we had to rename it."""
1533 has become a keyword for Python 2.5, so we had to rename it."""
1528
1534
1529 object.__dict__.update(args)
1535 object.__dict__.update(args)
1530
1536
1531 #----------------------------------------------------------------------------
1537 #----------------------------------------------------------------------------
1532 def setattr_list(obj,alist,nspace = None):
1538 def setattr_list(obj,alist,nspace = None):
1533 """Set a list of attributes for an object taken from a namespace.
1539 """Set a list of attributes for an object taken from a namespace.
1534
1540
1535 setattr_list(obj,alist,nspace) -> sets in obj all the attributes listed in
1541 setattr_list(obj,alist,nspace) -> sets in obj all the attributes listed in
1536 alist with their values taken from nspace, which must be a dict (something
1542 alist with their values taken from nspace, which must be a dict (something
1537 like locals() will often do) If nspace isn't given, locals() of the
1543 like locals() will often do) If nspace isn't given, locals() of the
1538 *caller* is used, so in most cases you can omit it.
1544 *caller* is used, so in most cases you can omit it.
1539
1545
1540 Note that alist can be given as a string, which will be automatically
1546 Note that alist can be given as a string, which will be automatically
1541 split into a list on whitespace. If given as a list, it must be a list of
1547 split into a list on whitespace. If given as a list, it must be a list of
1542 *strings* (the variable names themselves), not of variables."""
1548 *strings* (the variable names themselves), not of variables."""
1543
1549
1544 # this grabs the local variables from the *previous* call frame -- that is
1550 # this grabs the local variables from the *previous* call frame -- that is
1545 # the locals from the function that called setattr_list().
1551 # the locals from the function that called setattr_list().
1546 # - snipped from weave.inline()
1552 # - snipped from weave.inline()
1547 if nspace is None:
1553 if nspace is None:
1548 call_frame = sys._getframe().f_back
1554 call_frame = sys._getframe().f_back
1549 nspace = call_frame.f_locals
1555 nspace = call_frame.f_locals
1550
1556
1551 if type(alist) in StringTypes:
1557 if type(alist) in StringTypes:
1552 alist = alist.split()
1558 alist = alist.split()
1553 for attr in alist:
1559 for attr in alist:
1554 val = eval(attr,nspace)
1560 val = eval(attr,nspace)
1555 setattr(obj,attr,val)
1561 setattr(obj,attr,val)
1556
1562
1557 #----------------------------------------------------------------------------
1563 #----------------------------------------------------------------------------
1558 def getattr_list(obj,alist,*args):
1564 def getattr_list(obj,alist,*args):
1559 """getattr_list(obj,alist[, default]) -> attribute list.
1565 """getattr_list(obj,alist[, default]) -> attribute list.
1560
1566
1561 Get a list of named attributes for an object. When a default argument is
1567 Get a list of named attributes for an object. When a default argument is
1562 given, it is returned when the attribute doesn't exist; without it, an
1568 given, it is returned when the attribute doesn't exist; without it, an
1563 exception is raised in that case.
1569 exception is raised in that case.
1564
1570
1565 Note that alist can be given as a string, which will be automatically
1571 Note that alist can be given as a string, which will be automatically
1566 split into a list on whitespace. If given as a list, it must be a list of
1572 split into a list on whitespace. If given as a list, it must be a list of
1567 *strings* (the variable names themselves), not of variables."""
1573 *strings* (the variable names themselves), not of variables."""
1568
1574
1569 if type(alist) in StringTypes:
1575 if type(alist) in StringTypes:
1570 alist = alist.split()
1576 alist = alist.split()
1571 if args:
1577 if args:
1572 if len(args)==1:
1578 if len(args)==1:
1573 default = args[0]
1579 default = args[0]
1574 return map(lambda attr: getattr(obj,attr,default),alist)
1580 return map(lambda attr: getattr(obj,attr,default),alist)
1575 else:
1581 else:
1576 raise ValueError,'getattr_list() takes only one optional argument'
1582 raise ValueError,'getattr_list() takes only one optional argument'
1577 else:
1583 else:
1578 return map(lambda attr: getattr(obj,attr),alist)
1584 return map(lambda attr: getattr(obj,attr),alist)
1579
1585
1580 #----------------------------------------------------------------------------
1586 #----------------------------------------------------------------------------
1581 def map_method(method,object_list,*argseq,**kw):
1587 def map_method(method,object_list,*argseq,**kw):
1582 """map_method(method,object_list,*args,**kw) -> list
1588 """map_method(method,object_list,*args,**kw) -> list
1583
1589
1584 Return a list of the results of applying the methods to the items of the
1590 Return a list of the results of applying the methods to the items of the
1585 argument sequence(s). If more than one sequence is given, the method is
1591 argument sequence(s). If more than one sequence is given, the method is
1586 called with an argument list consisting of the corresponding item of each
1592 called with an argument list consisting of the corresponding item of each
1587 sequence. All sequences must be of the same length.
1593 sequence. All sequences must be of the same length.
1588
1594
1589 Keyword arguments are passed verbatim to all objects called.
1595 Keyword arguments are passed verbatim to all objects called.
1590
1596
1591 This is Python code, so it's not nearly as fast as the builtin map()."""
1597 This is Python code, so it's not nearly as fast as the builtin map()."""
1592
1598
1593 out_list = []
1599 out_list = []
1594 idx = 0
1600 idx = 0
1595 for object in object_list:
1601 for object in object_list:
1596 try:
1602 try:
1597 handler = getattr(object, method)
1603 handler = getattr(object, method)
1598 except AttributeError:
1604 except AttributeError:
1599 out_list.append(None)
1605 out_list.append(None)
1600 else:
1606 else:
1601 if argseq:
1607 if argseq:
1602 args = map(lambda lst:lst[idx],argseq)
1608 args = map(lambda lst:lst[idx],argseq)
1603 #print 'ob',object,'hand',handler,'ar',args # dbg
1609 #print 'ob',object,'hand',handler,'ar',args # dbg
1604 out_list.append(handler(args,**kw))
1610 out_list.append(handler(args,**kw))
1605 else:
1611 else:
1606 out_list.append(handler(**kw))
1612 out_list.append(handler(**kw))
1607 idx += 1
1613 idx += 1
1608 return out_list
1614 return out_list
1609
1615
1610 #----------------------------------------------------------------------------
1616 #----------------------------------------------------------------------------
1611 def get_class_members(cls):
1617 def get_class_members(cls):
1612 ret = dir(cls)
1618 ret = dir(cls)
1613 if hasattr(cls,'__bases__'):
1619 if hasattr(cls,'__bases__'):
1614 for base in cls.__bases__:
1620 for base in cls.__bases__:
1615 ret.extend(get_class_members(base))
1621 ret.extend(get_class_members(base))
1616 return ret
1622 return ret
1617
1623
1618 #----------------------------------------------------------------------------
1624 #----------------------------------------------------------------------------
1619 def dir2(obj):
1625 def dir2(obj):
1620 """dir2(obj) -> list of strings
1626 """dir2(obj) -> list of strings
1621
1627
1622 Extended version of the Python builtin dir(), which does a few extra
1628 Extended version of the Python builtin dir(), which does a few extra
1623 checks, and supports common objects with unusual internals that confuse
1629 checks, and supports common objects with unusual internals that confuse
1624 dir(), such as Traits and PyCrust.
1630 dir(), such as Traits and PyCrust.
1625
1631
1626 This version is guaranteed to return only a list of true strings, whereas
1632 This version is guaranteed to return only a list of true strings, whereas
1627 dir() returns anything that objects inject into themselves, even if they
1633 dir() returns anything that objects inject into themselves, even if they
1628 are later not really valid for attribute access (many extension libraries
1634 are later not really valid for attribute access (many extension libraries
1629 have such bugs).
1635 have such bugs).
1630 """
1636 """
1631
1637
1632 # Start building the attribute list via dir(), and then complete it
1638 # Start building the attribute list via dir(), and then complete it
1633 # with a few extra special-purpose calls.
1639 # with a few extra special-purpose calls.
1634 words = dir(obj)
1640 words = dir(obj)
1635
1641
1636 if hasattr(obj,'__class__'):
1642 if hasattr(obj,'__class__'):
1637 words.append('__class__')
1643 words.append('__class__')
1638 words.extend(get_class_members(obj.__class__))
1644 words.extend(get_class_members(obj.__class__))
1639 #if '__base__' in words: 1/0
1645 #if '__base__' in words: 1/0
1640
1646
1641 # Some libraries (such as traits) may introduce duplicates, we want to
1647 # Some libraries (such as traits) may introduce duplicates, we want to
1642 # track and clean this up if it happens
1648 # track and clean this up if it happens
1643 may_have_dupes = False
1649 may_have_dupes = False
1644
1650
1645 # this is the 'dir' function for objects with Enthought's traits
1651 # this is the 'dir' function for objects with Enthought's traits
1646 if hasattr(obj, 'trait_names'):
1652 if hasattr(obj, 'trait_names'):
1647 try:
1653 try:
1648 words.extend(obj.trait_names())
1654 words.extend(obj.trait_names())
1649 may_have_dupes = True
1655 may_have_dupes = True
1650 except TypeError:
1656 except TypeError:
1651 # This will happen if `obj` is a class and not an instance.
1657 # This will happen if `obj` is a class and not an instance.
1652 pass
1658 pass
1653
1659
1654 # Support for PyCrust-style _getAttributeNames magic method.
1660 # Support for PyCrust-style _getAttributeNames magic method.
1655 if hasattr(obj, '_getAttributeNames'):
1661 if hasattr(obj, '_getAttributeNames'):
1656 try:
1662 try:
1657 words.extend(obj._getAttributeNames())
1663 words.extend(obj._getAttributeNames())
1658 may_have_dupes = True
1664 may_have_dupes = True
1659 except TypeError:
1665 except TypeError:
1660 # `obj` is a class and not an instance. Ignore
1666 # `obj` is a class and not an instance. Ignore
1661 # this error.
1667 # this error.
1662 pass
1668 pass
1663
1669
1664 if may_have_dupes:
1670 if may_have_dupes:
1665 # eliminate possible duplicates, as some traits may also
1671 # eliminate possible duplicates, as some traits may also
1666 # appear as normal attributes in the dir() call.
1672 # appear as normal attributes in the dir() call.
1667 words = list(set(words))
1673 words = list(set(words))
1668 words.sort()
1674 words.sort()
1669
1675
1670 # filter out non-string attributes which may be stuffed by dir() calls
1676 # filter out non-string attributes which may be stuffed by dir() calls
1671 # and poor coding in third-party modules
1677 # and poor coding in third-party modules
1672 return [w for w in words if isinstance(w, basestring)]
1678 return [w for w in words if isinstance(w, basestring)]
1673
1679
1674 #----------------------------------------------------------------------------
1680 #----------------------------------------------------------------------------
1675 def import_fail_info(mod_name,fns=None):
1681 def import_fail_info(mod_name,fns=None):
1676 """Inform load failure for a module."""
1682 """Inform load failure for a module."""
1677
1683
1678 if fns == None:
1684 if fns == None:
1679 warn("Loading of %s failed.\n" % (mod_name,))
1685 warn("Loading of %s failed.\n" % (mod_name,))
1680 else:
1686 else:
1681 warn("Loading of %s from %s failed.\n" % (fns,mod_name))
1687 warn("Loading of %s from %s failed.\n" % (fns,mod_name))
1682
1688
1683 #----------------------------------------------------------------------------
1689 #----------------------------------------------------------------------------
1684 # Proposed popitem() extension, written as a method
1690 # Proposed popitem() extension, written as a method
1685
1691
1686
1692
1687 class NotGiven: pass
1693 class NotGiven: pass
1688
1694
1689 def popkey(dct,key,default=NotGiven):
1695 def popkey(dct,key,default=NotGiven):
1690 """Return dct[key] and delete dct[key].
1696 """Return dct[key] and delete dct[key].
1691
1697
1692 If default is given, return it if dct[key] doesn't exist, otherwise raise
1698 If default is given, return it if dct[key] doesn't exist, otherwise raise
1693 KeyError. """
1699 KeyError. """
1694
1700
1695 try:
1701 try:
1696 val = dct[key]
1702 val = dct[key]
1697 except KeyError:
1703 except KeyError:
1698 if default is NotGiven:
1704 if default is NotGiven:
1699 raise
1705 raise
1700 else:
1706 else:
1701 return default
1707 return default
1702 else:
1708 else:
1703 del dct[key]
1709 del dct[key]
1704 return val
1710 return val
1705
1711
1706 def wrap_deprecated(func, suggest = '<nothing>'):
1712 def wrap_deprecated(func, suggest = '<nothing>'):
1707 def newFunc(*args, **kwargs):
1713 def newFunc(*args, **kwargs):
1708 warnings.warn("Call to deprecated function %s, use %s instead" %
1714 warnings.warn("Call to deprecated function %s, use %s instead" %
1709 ( func.__name__, suggest),
1715 ( func.__name__, suggest),
1710 category=DeprecationWarning,
1716 category=DeprecationWarning,
1711 stacklevel = 2)
1717 stacklevel = 2)
1712 return func(*args, **kwargs)
1718 return func(*args, **kwargs)
1713 return newFunc
1719 return newFunc
1714
1720
1715
1721
1716 def _num_cpus_unix():
1722 def _num_cpus_unix():
1717 """Return the number of active CPUs on a Unix system."""
1723 """Return the number of active CPUs on a Unix system."""
1718 return os.sysconf("SC_NPROCESSORS_ONLN")
1724 return os.sysconf("SC_NPROCESSORS_ONLN")
1719
1725
1720
1726
1721 def _num_cpus_darwin():
1727 def _num_cpus_darwin():
1722 """Return the number of active CPUs on a Darwin system."""
1728 """Return the number of active CPUs on a Darwin system."""
1723 p = subprocess.Popen(['sysctl','-n','hw.ncpu'],stdout=subprocess.PIPE)
1729 p = subprocess.Popen(['sysctl','-n','hw.ncpu'],stdout=subprocess.PIPE)
1724 return p.stdout.read()
1730 return p.stdout.read()
1725
1731
1726
1732
1727 def _num_cpus_windows():
1733 def _num_cpus_windows():
1728 """Return the number of active CPUs on a Windows system."""
1734 """Return the number of active CPUs on a Windows system."""
1729 return os.environ.get("NUMBER_OF_PROCESSORS")
1735 return os.environ.get("NUMBER_OF_PROCESSORS")
1730
1736
1731
1737
1732 def num_cpus():
1738 def num_cpus():
1733 """Return the effective number of CPUs in the system as an integer.
1739 """Return the effective number of CPUs in the system as an integer.
1734
1740
1735 This cross-platform function makes an attempt at finding the total number of
1741 This cross-platform function makes an attempt at finding the total number of
1736 available CPUs in the system, as returned by various underlying system and
1742 available CPUs in the system, as returned by various underlying system and
1737 python calls.
1743 python calls.
1738
1744
1739 If it can't find a sensible answer, it returns 1 (though an error *may* make
1745 If it can't find a sensible answer, it returns 1 (though an error *may* make
1740 it return a large positive number that's actually incorrect).
1746 it return a large positive number that's actually incorrect).
1741 """
1747 """
1742
1748
1743 # Many thanks to the Parallel Python project (http://www.parallelpython.com)
1749 # Many thanks to the Parallel Python project (http://www.parallelpython.com)
1744 # for the names of the keys we needed to look up for this function. This
1750 # for the names of the keys we needed to look up for this function. This
1745 # code was inspired by their equivalent function.
1751 # code was inspired by their equivalent function.
1746
1752
1747 ncpufuncs = {'Linux':_num_cpus_unix,
1753 ncpufuncs = {'Linux':_num_cpus_unix,
1748 'Darwin':_num_cpus_darwin,
1754 'Darwin':_num_cpus_darwin,
1749 'Windows':_num_cpus_windows,
1755 'Windows':_num_cpus_windows,
1750 # On Vista, python < 2.5.2 has a bug and returns 'Microsoft'
1756 # On Vista, python < 2.5.2 has a bug and returns 'Microsoft'
1751 # See http://bugs.python.org/issue1082 for details.
1757 # See http://bugs.python.org/issue1082 for details.
1752 'Microsoft':_num_cpus_windows,
1758 'Microsoft':_num_cpus_windows,
1753 }
1759 }
1754
1760
1755 ncpufunc = ncpufuncs.get(platform.system(),
1761 ncpufunc = ncpufuncs.get(platform.system(),
1756 # default to unix version (Solaris, AIX, etc)
1762 # default to unix version (Solaris, AIX, etc)
1757 _num_cpus_unix)
1763 _num_cpus_unix)
1758
1764
1759 try:
1765 try:
1760 ncpus = max(1,int(ncpufunc()))
1766 ncpus = max(1,int(ncpufunc()))
1761 except:
1767 except:
1762 ncpus = 1
1768 ncpus = 1
1763 return ncpus
1769 return ncpus
1764
1770
1765 def extract_vars(*names,**kw):
1771 def extract_vars(*names,**kw):
1766 """Extract a set of variables by name from another frame.
1772 """Extract a set of variables by name from another frame.
1767
1773
1768 :Parameters:
1774 :Parameters:
1769 - `*names`: strings
1775 - `*names`: strings
1770 One or more variable names which will be extracted from the caller's
1776 One or more variable names which will be extracted from the caller's
1771 frame.
1777 frame.
1772
1778
1773 :Keywords:
1779 :Keywords:
1774 - `depth`: integer (0)
1780 - `depth`: integer (0)
1775 How many frames in the stack to walk when looking for your variables.
1781 How many frames in the stack to walk when looking for your variables.
1776
1782
1777
1783
1778 Examples:
1784 Examples:
1779
1785
1780 In [2]: def func(x):
1786 In [2]: def func(x):
1781 ...: y = 1
1787 ...: y = 1
1782 ...: print extract_vars('x','y')
1788 ...: print extract_vars('x','y')
1783 ...:
1789 ...:
1784
1790
1785 In [3]: func('hello')
1791 In [3]: func('hello')
1786 {'y': 1, 'x': 'hello'}
1792 {'y': 1, 'x': 'hello'}
1787 """
1793 """
1788
1794
1789 depth = kw.get('depth',0)
1795 depth = kw.get('depth',0)
1790
1796
1791 callerNS = sys._getframe(depth+1).f_locals
1797 callerNS = sys._getframe(depth+1).f_locals
1792 return dict((k,callerNS[k]) for k in names)
1798 return dict((k,callerNS[k]) for k in names)
1793
1799
1794
1800
1795 def extract_vars_above(*names):
1801 def extract_vars_above(*names):
1796 """Extract a set of variables by name from another frame.
1802 """Extract a set of variables by name from another frame.
1797
1803
1798 Similar to extractVars(), but with a specified depth of 1, so that names
1804 Similar to extractVars(), but with a specified depth of 1, so that names
1799 are exctracted exactly from above the caller.
1805 are exctracted exactly from above the caller.
1800
1806
1801 This is simply a convenience function so that the very common case (for us)
1807 This is simply a convenience function so that the very common case (for us)
1802 of skipping exactly 1 frame doesn't have to construct a special dict for
1808 of skipping exactly 1 frame doesn't have to construct a special dict for
1803 keyword passing."""
1809 keyword passing."""
1804
1810
1805 callerNS = sys._getframe(2).f_locals
1811 callerNS = sys._getframe(2).f_locals
1806 return dict((k,callerNS[k]) for k in names)
1812 return dict((k,callerNS[k]) for k in names)
1807
1813
1808 def expand_path(s):
1814 def expand_path(s):
1809 """Expand $VARS and ~names in a string, like a shell
1815 """Expand $VARS and ~names in a string, like a shell
1810
1816
1811 :Examples:
1817 :Examples:
1812
1818
1813 In [2]: os.environ['FOO']='test'
1819 In [2]: os.environ['FOO']='test'
1814
1820
1815 In [3]: expand_path('variable FOO is $FOO')
1821 In [3]: expand_path('variable FOO is $FOO')
1816 Out[3]: 'variable FOO is test'
1822 Out[3]: 'variable FOO is test'
1817 """
1823 """
1818 # This is a pretty subtle hack. When expand user is given a UNC path
1824 # This is a pretty subtle hack. When expand user is given a UNC path
1819 # on Windows (\\server\share$\%username%), os.path.expandvars, removes
1825 # on Windows (\\server\share$\%username%), os.path.expandvars, removes
1820 # the $ to get (\\server\share\%username%). I think it considered $
1826 # the $ to get (\\server\share\%username%). I think it considered $
1821 # alone an empty var. But, we need the $ to remains there (it indicates
1827 # alone an empty var. But, we need the $ to remains there (it indicates
1822 # a hidden share).
1828 # a hidden share).
1823 if os.name=='nt':
1829 if os.name=='nt':
1824 s = s.replace('$\\', 'IPYTHON_TEMP')
1830 s = s.replace('$\\', 'IPYTHON_TEMP')
1825 s = os.path.expandvars(os.path.expanduser(s))
1831 s = os.path.expandvars(os.path.expanduser(s))
1826 if os.name=='nt':
1832 if os.name=='nt':
1827 s = s.replace('IPYTHON_TEMP', '$\\')
1833 s = s.replace('IPYTHON_TEMP', '$\\')
1828 return s
1834 return s
1829
1835
1830 def list_strings(arg):
1836 def list_strings(arg):
1831 """Always return a list of strings, given a string or list of strings
1837 """Always return a list of strings, given a string or list of strings
1832 as input.
1838 as input.
1833
1839
1834 :Examples:
1840 :Examples:
1835
1841
1836 In [7]: list_strings('A single string')
1842 In [7]: list_strings('A single string')
1837 Out[7]: ['A single string']
1843 Out[7]: ['A single string']
1838
1844
1839 In [8]: list_strings(['A single string in a list'])
1845 In [8]: list_strings(['A single string in a list'])
1840 Out[8]: ['A single string in a list']
1846 Out[8]: ['A single string in a list']
1841
1847
1842 In [9]: list_strings(['A','list','of','strings'])
1848 In [9]: list_strings(['A','list','of','strings'])
1843 Out[9]: ['A', 'list', 'of', 'strings']
1849 Out[9]: ['A', 'list', 'of', 'strings']
1844 """
1850 """
1845
1851
1846 if isinstance(arg,basestring): return [arg]
1852 if isinstance(arg,basestring): return [arg]
1847 else: return arg
1853 else: return arg
1848
1854
1849
1855
1850 #----------------------------------------------------------------------------
1856 #----------------------------------------------------------------------------
1851 def marquee(txt='',width=78,mark='*'):
1857 def marquee(txt='',width=78,mark='*'):
1852 """Return the input string centered in a 'marquee'.
1858 """Return the input string centered in a 'marquee'.
1853
1859
1854 :Examples:
1860 :Examples:
1855
1861
1856 In [16]: marquee('A test',40)
1862 In [16]: marquee('A test',40)
1857 Out[16]: '**************** A test ****************'
1863 Out[16]: '**************** A test ****************'
1858
1864
1859 In [17]: marquee('A test',40,'-')
1865 In [17]: marquee('A test',40,'-')
1860 Out[17]: '---------------- A test ----------------'
1866 Out[17]: '---------------- A test ----------------'
1861
1867
1862 In [18]: marquee('A test',40,' ')
1868 In [18]: marquee('A test',40,' ')
1863 Out[18]: ' A test '
1869 Out[18]: ' A test '
1864
1870
1865 """
1871 """
1866 if not txt:
1872 if not txt:
1867 return (mark*width)[:width]
1873 return (mark*width)[:width]
1868 nmark = (width-len(txt)-2)/len(mark)/2
1874 nmark = (width-len(txt)-2)/len(mark)/2
1869 if nmark < 0: nmark =0
1875 if nmark < 0: nmark =0
1870 marks = mark*nmark
1876 marks = mark*nmark
1871 return '%s %s %s' % (marks,txt,marks)
1877 return '%s %s %s' % (marks,txt,marks)
1872
1878
1873 #*************************** end of file <genutils.py> **********************
1879 #*************************** end of file <genutils.py> **********************
@@ -1,323 +1,326 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Tests for genutils.py"""
3 """Tests for genutils.py"""
4
4
5 __docformat__ = "restructuredtext en"
5 __docformat__ = "restructuredtext en"
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008 The IPython Development Team
8 # Copyright (C) 2008 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # stdlib
18 # stdlib
19 import os
19 import os
20 import shutil
20 import shutil
21 import sys
21 import sys
22 import tempfile
22 import tempfile
23 import unittest
23 import unittest
24
24
25 from cStringIO import StringIO
25 from cStringIO import StringIO
26 from os.path import join, abspath, split
26 from os.path import join, abspath, split
27
27
28 # third-party
28 # third-party
29 import nose.tools as nt
29 import nose.tools as nt
30
30
31 from nose import with_setup
31 from nose import with_setup
32 from nose.tools import raises
32 from nose.tools import raises
33
33
34 # Our own
34 # Our own
35 import IPython
35 import IPython
36 from IPython.utils import genutils
36 from IPython.utils import genutils
37 from IPython.testing import decorators as dec
37 from IPython.testing import decorators as dec
38 from IPython.testing.decorators import skipif, skip_if_not_win32
38 from IPython.testing.decorators import skipif, skip_if_not_win32
39
39
40 # Platform-dependent imports
40 # Platform-dependent imports
41 try:
41 try:
42 import _winreg as wreg
42 import _winreg as wreg
43 except ImportError:
43 except ImportError:
44 #Fake _winreg module on none windows platforms
44 #Fake _winreg module on none windows platforms
45 import new
45 import new
46 sys.modules["_winreg"] = new.module("_winreg")
46 sys.modules["_winreg"] = new.module("_winreg")
47 import _winreg as wreg
47 import _winreg as wreg
48 #Add entries that needs to be stubbed by the testing code
48 #Add entries that needs to be stubbed by the testing code
49 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
49 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
50
50
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52 # Globals
52 # Globals
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 env = os.environ
54 env = os.environ
55 TEST_FILE_PATH = split(abspath(__file__))[0]
55 TEST_FILE_PATH = split(abspath(__file__))[0]
56 TMP_TEST_DIR = tempfile.mkdtemp()
56 TMP_TEST_DIR = tempfile.mkdtemp()
57 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
57 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
58 IP_TEST_DIR = join(HOME_TEST_DIR,'_ipython')
58 IP_TEST_DIR = join(HOME_TEST_DIR,'_ipython')
59 #
59 #
60 # Setup/teardown functions/decorators
60 # Setup/teardown functions/decorators
61 #
61 #
62
62
63 def setup():
63 def setup():
64 """Setup testenvironment for the module:
64 """Setup testenvironment for the module:
65
65
66 - Adds dummy home dir tree
66 - Adds dummy home dir tree
67 """
67 """
68 # Do not mask exceptions here. In particular, catching WindowsError is a
68 # Do not mask exceptions here. In particular, catching WindowsError is a
69 # problem because that exception is only defined on Windows...
69 # problem because that exception is only defined on Windows...
70 os.makedirs(IP_TEST_DIR)
70 os.makedirs(IP_TEST_DIR)
71
71
72 def teardown():
72 def teardown():
73 """Teardown testenvironment for the module:
73 """Teardown testenvironment for the module:
74
74
75 - Remove dummy home dir tree
75 - Remove dummy home dir tree
76 """
76 """
77 # Note: we remove the parent test dir, which is the root of all test
77 # Note: we remove the parent test dir, which is the root of all test
78 # subdirs we may have created. Use shutil instead of os.removedirs, so
78 # subdirs we may have created. Use shutil instead of os.removedirs, so
79 # that non-empty directories are all recursively removed.
79 # that non-empty directories are all recursively removed.
80 shutil.rmtree(TMP_TEST_DIR)
80 shutil.rmtree(TMP_TEST_DIR)
81
81
82
82
83 def setup_environment():
83 def setup_environment():
84 """Setup testenvironment for some functions that are tested
84 """Setup testenvironment for some functions that are tested
85 in this module. In particular this functions stores attributes
85 in this module. In particular this functions stores attributes
86 and other things that we need to stub in some test functions.
86 and other things that we need to stub in some test functions.
87 This needs to be done on a function level and not module level because
87 This needs to be done on a function level and not module level because
88 each testfunction needs a pristine environment.
88 each testfunction needs a pristine environment.
89 """
89 """
90 global oldstuff, platformstuff
90 global oldstuff, platformstuff
91 oldstuff = (env.copy(), os.name, genutils.get_home_dir, IPython.__file__,)
91 oldstuff = (env.copy(), os.name, genutils.get_home_dir, IPython.__file__,)
92
92
93 if os.name == 'nt':
93 if os.name == 'nt':
94 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
94 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
95
95
96 if 'IPYTHONDIR' in env:
96 # Remove both spellings of env variables if present
97 del env['IPYTHONDIR']
97 env.pop('IPYTHON_DIR', None)
98 env.pop('IPYTHONDIR', None)
98
99
99 def teardown_environment():
100 def teardown_environment():
100 """Restore things that were remebered by the setup_environment function
101 """Restore things that were remebered by the setup_environment function
101 """
102 """
102 (oldenv, os.name, genutils.get_home_dir, IPython.__file__,) = oldstuff
103 (oldenv, os.name, genutils.get_home_dir, IPython.__file__,) = oldstuff
103 for key in env.keys():
104 for key in env.keys():
104 if key not in oldenv:
105 if key not in oldenv:
105 del env[key]
106 del env[key]
106 env.update(oldenv)
107 env.update(oldenv)
107 if hasattr(sys, 'frozen'):
108 if hasattr(sys, 'frozen'):
108 del sys.frozen
109 del sys.frozen
109 if os.name == 'nt':
110 if os.name == 'nt':
110 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
111 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
111
112
112 # Build decorator that uses the setup_environment/setup_environment
113 # Build decorator that uses the setup_environment/setup_environment
113 with_environment = with_setup(setup_environment, teardown_environment)
114 with_environment = with_setup(setup_environment, teardown_environment)
114
115
115
116
116 #
117 #
117 # Tests for get_home_dir
118 # Tests for get_home_dir
118 #
119 #
119
120
120 @skip_if_not_win32
121 @skip_if_not_win32
121 @with_environment
122 @with_environment
122 def test_get_home_dir_1():
123 def test_get_home_dir_1():
123 """Testcase for py2exe logic, un-compressed lib
124 """Testcase for py2exe logic, un-compressed lib
124 """
125 """
125 sys.frozen = True
126 sys.frozen = True
126
127
127 #fake filename for IPython.__init__
128 #fake filename for IPython.__init__
128 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
129 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
129
130
130 home_dir = genutils.get_home_dir()
131 home_dir = genutils.get_home_dir()
131 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
132 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
132
133
133 @skip_if_not_win32
134 @skip_if_not_win32
134 @with_environment
135 @with_environment
135 def test_get_home_dir_2():
136 def test_get_home_dir_2():
136 """Testcase for py2exe logic, compressed lib
137 """Testcase for py2exe logic, compressed lib
137 """
138 """
138 sys.frozen = True
139 sys.frozen = True
139 #fake filename for IPython.__init__
140 #fake filename for IPython.__init__
140 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
141 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
141
142
142 home_dir = genutils.get_home_dir()
143 home_dir = genutils.get_home_dir()
143 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
144 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
144
145
145 @with_environment
146 @with_environment
146 def test_get_home_dir_3():
147 def test_get_home_dir_3():
147 """Testcase $HOME is set, then use its value as home directory."""
148 """Testcase $HOME is set, then use its value as home directory."""
148 env["HOME"] = HOME_TEST_DIR
149 env["HOME"] = HOME_TEST_DIR
149 home_dir = genutils.get_home_dir()
150 home_dir = genutils.get_home_dir()
150 nt.assert_equal(home_dir, env["HOME"])
151 nt.assert_equal(home_dir, env["HOME"])
151
152
152 @with_environment
153 @with_environment
153 def test_get_home_dir_4():
154 def test_get_home_dir_4():
154 """Testcase $HOME is not set, os=='poix'.
155 """Testcase $HOME is not set, os=='poix'.
155 This should fail with HomeDirError"""
156 This should fail with HomeDirError"""
156
157
157 os.name = 'posix'
158 os.name = 'posix'
158 if 'HOME' in env: del env['HOME']
159 if 'HOME' in env: del env['HOME']
159 nt.assert_raises(genutils.HomeDirError, genutils.get_home_dir)
160 nt.assert_raises(genutils.HomeDirError, genutils.get_home_dir)
160
161
161 @skip_if_not_win32
162 @skip_if_not_win32
162 @with_environment
163 @with_environment
163 def test_get_home_dir_5():
164 def test_get_home_dir_5():
164 """Testcase $HOME is not set, os=='nt'
165 """Testcase $HOME is not set, os=='nt'
165 env['HOMEDRIVE'],env['HOMEPATH'] points to path."""
166 env['HOMEDRIVE'],env['HOMEPATH'] points to path."""
166
167
167 os.name = 'nt'
168 os.name = 'nt'
168 if 'HOME' in env: del env['HOME']
169 if 'HOME' in env: del env['HOME']
169 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
170 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
170
171
171 home_dir = genutils.get_home_dir()
172 home_dir = genutils.get_home_dir()
172 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
173 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
173
174
174 @skip_if_not_win32
175 @skip_if_not_win32
175 @with_environment
176 @with_environment
176 def test_get_home_dir_6():
177 def test_get_home_dir_6():
177 """Testcase $HOME is not set, os=='nt'
178 """Testcase $HOME is not set, os=='nt'
178 env['HOMEDRIVE'],env['HOMEPATH'] do not point to path.
179 env['HOMEDRIVE'],env['HOMEPATH'] do not point to path.
179 env['USERPROFILE'] points to path
180 env['USERPROFILE'] points to path
180 """
181 """
181
182
182 os.name = 'nt'
183 os.name = 'nt'
183 if 'HOME' in env: del env['HOME']
184 if 'HOME' in env: del env['HOME']
184 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.abspath(TEST_FILE_PATH), "DOES NOT EXIST"
185 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.abspath(TEST_FILE_PATH), "DOES NOT EXIST"
185 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
186 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
186
187
187 home_dir = genutils.get_home_dir()
188 home_dir = genutils.get_home_dir()
188 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
189 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
189
190
190 # Should we stub wreg fully so we can run the test on all platforms?
191 # Should we stub wreg fully so we can run the test on all platforms?
191 @skip_if_not_win32
192 @skip_if_not_win32
192 @with_environment
193 @with_environment
193 def test_get_home_dir_7():
194 def test_get_home_dir_7():
194 """Testcase $HOME is not set, os=='nt'
195 """Testcase $HOME is not set, os=='nt'
195 env['HOMEDRIVE'],env['HOMEPATH'], env['USERPROFILE'] missing
196
197 env['HOMEDRIVE'],env['HOMEPATH'], env['USERPROFILE'] and others missing
196 """
198 """
197 os.name = 'nt'
199 os.name = 'nt'
198 if 'HOME' in env: del env['HOME']
200 # Remove from stub environment all keys that may be set
199 if 'HOMEDRIVE' in env: del env['HOMEDRIVE']
201 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
202 env.pop(key, None)
200
203
201 #Stub windows registry functions
204 #Stub windows registry functions
202 def OpenKey(x, y):
205 def OpenKey(x, y):
203 class key:
206 class key:
204 def Close(self):
207 def Close(self):
205 pass
208 pass
206 return key()
209 return key()
207 def QueryValueEx(x, y):
210 def QueryValueEx(x, y):
208 return [abspath(HOME_TEST_DIR)]
211 return [abspath(HOME_TEST_DIR)]
209
212
210 wreg.OpenKey = OpenKey
213 wreg.OpenKey = OpenKey
211 wreg.QueryValueEx = QueryValueEx
214 wreg.QueryValueEx = QueryValueEx
212
215
213 home_dir = genutils.get_home_dir()
216 home_dir = genutils.get_home_dir()
214 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
217 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
215
218
216 #
219 #
217 # Tests for get_ipython_dir
220 # Tests for get_ipython_dir
218 #
221 #
219
222
220 @with_environment
223 @with_environment
221 def test_get_ipython_dir_1():
224 def test_get_ipython_dir_1():
222 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
225 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
223 env['IPYTHONDIR'] = "someplace/.ipython"
226 env['IPYTHON_DIR'] = "someplace/.ipython"
224 ipdir = genutils.get_ipython_dir()
227 ipdir = genutils.get_ipython_dir()
225 nt.assert_equal(ipdir, "someplace/.ipython")
228 nt.assert_equal(ipdir, "someplace/.ipython")
226
229
227
230
228 @with_environment
231 @with_environment
229 def test_get_ipython_dir_2():
232 def test_get_ipython_dir_2():
230 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
233 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
231 genutils.get_home_dir = lambda : "someplace"
234 genutils.get_home_dir = lambda : "someplace"
232 os.name = "posix"
235 os.name = "posix"
233 ipdir = genutils.get_ipython_dir()
236 ipdir = genutils.get_ipython_dir()
234 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
237 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
235
238
236 #
239 #
237 # Tests for popkey
240 # Tests for popkey
238 #
241 #
239
242
240 def test_popkey_1():
243 def test_popkey_1():
241 """test_popkey_1, Basic usage test of popkey
244 """test_popkey_1, Basic usage test of popkey
242 """
245 """
243 dct = dict(a=1, b=2, c=3)
246 dct = dict(a=1, b=2, c=3)
244 nt.assert_equal(genutils.popkey(dct, "a"), 1)
247 nt.assert_equal(genutils.popkey(dct, "a"), 1)
245 nt.assert_equal(dct, dict(b=2, c=3))
248 nt.assert_equal(dct, dict(b=2, c=3))
246 nt.assert_equal(genutils.popkey(dct, "b"), 2)
249 nt.assert_equal(genutils.popkey(dct, "b"), 2)
247 nt.assert_equal(dct, dict(c=3))
250 nt.assert_equal(dct, dict(c=3))
248 nt.assert_equal(genutils.popkey(dct, "c"), 3)
251 nt.assert_equal(genutils.popkey(dct, "c"), 3)
249 nt.assert_equal(dct, dict())
252 nt.assert_equal(dct, dict())
250
253
251 def test_popkey_2():
254 def test_popkey_2():
252 """test_popkey_2, Test to see that popkey of non occuring keys
255 """test_popkey_2, Test to see that popkey of non occuring keys
253 generates a KeyError exception
256 generates a KeyError exception
254 """
257 """
255 dct = dict(a=1, b=2, c=3)
258 dct = dict(a=1, b=2, c=3)
256 nt.assert_raises(KeyError, genutils.popkey, dct, "d")
259 nt.assert_raises(KeyError, genutils.popkey, dct, "d")
257
260
258 def test_popkey_3():
261 def test_popkey_3():
259 """test_popkey_3, Tests to see that popkey calls returns the correct value
262 """test_popkey_3, Tests to see that popkey calls returns the correct value
260 and that the key/value was removed from the dict.
263 and that the key/value was removed from the dict.
261 """
264 """
262 dct = dict(a=1, b=2, c=3)
265 dct = dict(a=1, b=2, c=3)
263 nt.assert_equal(genutils.popkey(dct, "A", 13), 13)
266 nt.assert_equal(genutils.popkey(dct, "A", 13), 13)
264 nt.assert_equal(dct, dict(a=1, b=2, c=3))
267 nt.assert_equal(dct, dict(a=1, b=2, c=3))
265 nt.assert_equal(genutils.popkey(dct, "B", 14), 14)
268 nt.assert_equal(genutils.popkey(dct, "B", 14), 14)
266 nt.assert_equal(dct, dict(a=1, b=2, c=3))
269 nt.assert_equal(dct, dict(a=1, b=2, c=3))
267 nt.assert_equal(genutils.popkey(dct, "C", 15), 15)
270 nt.assert_equal(genutils.popkey(dct, "C", 15), 15)
268 nt.assert_equal(dct, dict(a=1, b=2, c=3))
271 nt.assert_equal(dct, dict(a=1, b=2, c=3))
269 nt.assert_equal(genutils.popkey(dct, "a"), 1)
272 nt.assert_equal(genutils.popkey(dct, "a"), 1)
270 nt.assert_equal(dct, dict(b=2, c=3))
273 nt.assert_equal(dct, dict(b=2, c=3))
271 nt.assert_equal(genutils.popkey(dct, "b"), 2)
274 nt.assert_equal(genutils.popkey(dct, "b"), 2)
272 nt.assert_equal(dct, dict(c=3))
275 nt.assert_equal(dct, dict(c=3))
273 nt.assert_equal(genutils.popkey(dct, "c"), 3)
276 nt.assert_equal(genutils.popkey(dct, "c"), 3)
274 nt.assert_equal(dct, dict())
277 nt.assert_equal(dct, dict())
275
278
276
279
277 def test_filefind():
280 def test_filefind():
278 """Various tests for filefind"""
281 """Various tests for filefind"""
279 f = tempfile.NamedTemporaryFile()
282 f = tempfile.NamedTemporaryFile()
280 print 'fname:',f.name
283 print 'fname:',f.name
281 alt_dirs = genutils.get_ipython_dir()
284 alt_dirs = genutils.get_ipython_dir()
282 t = genutils.filefind(f.name,alt_dirs)
285 t = genutils.filefind(f.name,alt_dirs)
283 print 'found:',t
286 print 'found:',t
284
287
285
288
286 def test_get_ipython_package_dir():
289 def test_get_ipython_package_dir():
287 ipdir = genutils.get_ipython_package_dir()
290 ipdir = genutils.get_ipython_package_dir()
288 nt.assert_true(os.path.isdir(ipdir))
291 nt.assert_true(os.path.isdir(ipdir))
289
292
290
293
291 def test_tee_simple():
294 def test_tee_simple():
292 "Very simple check with stdout only"
295 "Very simple check with stdout only"
293 chan = StringIO()
296 chan = StringIO()
294 text = 'Hello'
297 text = 'Hello'
295 tee = genutils.Tee(chan, channel='stdout')
298 tee = genutils.Tee(chan, channel='stdout')
296 print >> chan, text,
299 print >> chan, text,
297 nt.assert_equal(chan.getvalue(), text)
300 nt.assert_equal(chan.getvalue(), text)
298
301
299
302
300 class TeeTestCase(dec.ParametricTestCase):
303 class TeeTestCase(dec.ParametricTestCase):
301
304
302 def tchan(self, channel, check='close'):
305 def tchan(self, channel, check='close'):
303 trap = StringIO()
306 trap = StringIO()
304 chan = StringIO()
307 chan = StringIO()
305 text = 'Hello'
308 text = 'Hello'
306
309
307 std_ori = getattr(sys, channel)
310 std_ori = getattr(sys, channel)
308 setattr(sys, channel, trap)
311 setattr(sys, channel, trap)
309
312
310 tee = genutils.Tee(chan, channel=channel)
313 tee = genutils.Tee(chan, channel=channel)
311 print >> chan, text,
314 print >> chan, text,
312 setattr(sys, channel, std_ori)
315 setattr(sys, channel, std_ori)
313 trap_val = trap.getvalue()
316 trap_val = trap.getvalue()
314 nt.assert_equals(chan.getvalue(), text)
317 nt.assert_equals(chan.getvalue(), text)
315 if check=='close':
318 if check=='close':
316 tee.close()
319 tee.close()
317 else:
320 else:
318 del tee
321 del tee
319
322
320 def test(self):
323 def test(self):
321 for chan in ['stdout', 'stderr']:
324 for chan in ['stdout', 'stderr']:
322 for check in ['close', 'del']:
325 for check in ['close', 'del']:
323 yield self.tchan(chan, check)
326 yield self.tchan(chan, check)
General Comments 0
You need to be logged in to leave comments. Login now