##// END OF EJS Templates
Various fixes for test_genutils under win32, now all tests pass.
Fernando Perez -
Show More
@@ -1,1873 +1,1879 b''
1 1 # -*- coding: utf-8 -*-
2 2 """General purpose utilities.
3 3
4 4 This is a grab-bag of stuff I find useful in most programs I write. Some of
5 5 these things are also convenient when working at the command line.
6 6 """
7 7
8 8 #*****************************************************************************
9 9 # Copyright (C) 2001-2006 Fernando Perez. <fperez@colorado.edu>
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #*****************************************************************************
14 14
15 15 #****************************************************************************
16 16 # required modules from the Python standard library
17 17 import __main__
18 18
19 19 import os
20 20 import platform
21 21 import re
22 22 import shlex
23 23 import shutil
24 24 import subprocess
25 25 import sys
26 26 import time
27 27 import types
28 28 import warnings
29 29
30 30 # Curses and termios are Unix-only modules
31 31 try:
32 32 import curses
33 33 # We need termios as well, so if its import happens to raise, we bail on
34 34 # using curses altogether.
35 35 import termios
36 36 except ImportError:
37 37 USE_CURSES = False
38 38 else:
39 39 # Curses on Solaris may not be complete, so we can't use it there
40 40 USE_CURSES = hasattr(curses,'initscr')
41 41
42 42 # Other IPython utilities
43 43 import IPython
44 44 from IPython.external.Itpl import itpl,printpl
45 45 from IPython.utils import platutils
46 46 from IPython.utils.generics import result_display
47 47 from IPython.external.path import path
48 48
49 49
50 50 #****************************************************************************
51 51 # Exceptions
52 52 class Error(Exception):
53 53 """Base class for exceptions in this module."""
54 54 pass
55 55
56 56 #----------------------------------------------------------------------------
57 57 class IOStream:
58 58 def __init__(self,stream,fallback):
59 59 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
60 60 stream = fallback
61 61 self.stream = stream
62 62 self._swrite = stream.write
63 63 self.flush = stream.flush
64 64
65 65 def write(self,data):
66 66 try:
67 67 self._swrite(data)
68 68 except:
69 69 try:
70 70 # print handles some unicode issues which may trip a plain
71 71 # write() call. Attempt to emulate write() by using a
72 72 # trailing comma
73 73 print >> self.stream, data,
74 74 except:
75 75 # if we get here, something is seriously broken.
76 76 print >> sys.stderr, \
77 77 'ERROR - failed to write data to stream:', self.stream
78 78
79 79 def close(self):
80 80 pass
81 81
82 82
83 83 class IOTerm:
84 84 """ Term holds the file or file-like objects for handling I/O operations.
85 85
86 86 These are normally just sys.stdin, sys.stdout and sys.stderr but for
87 87 Windows they can can replaced to allow editing the strings before they are
88 88 displayed."""
89 89
90 90 # In the future, having IPython channel all its I/O operations through
91 91 # this class will make it easier to embed it into other environments which
92 92 # are not a normal terminal (such as a GUI-based shell)
93 93 def __init__(self,cin=None,cout=None,cerr=None):
94 94 self.cin = IOStream(cin,sys.stdin)
95 95 self.cout = IOStream(cout,sys.stdout)
96 96 self.cerr = IOStream(cerr,sys.stderr)
97 97
98 98 # Global variable to be used for all I/O
99 99 Term = IOTerm()
100 100
101 101 import IPython.utils.rlineimpl as readline
102 102 # Remake Term to use the readline i/o facilities
103 103 if sys.platform == 'win32' and readline.have_readline:
104 104
105 105 Term = IOTerm(cout=readline._outputfile,cerr=readline._outputfile)
106 106
107 107
108 108 class Tee(object):
109 109 """A class to duplicate an output stream to stdout/err.
110 110
111 111 This works in a manner very similar to the Unix 'tee' command.
112 112
113 113 When the object is closed or deleted, it closes the original file given to
114 114 it for duplication.
115 115 """
116 116 # Inspired by:
117 117 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
118 118
119 119 def __init__(self, file, mode=None, channel='stdout'):
120 120 """Construct a new Tee object.
121 121
122 122 Parameters
123 123 ----------
124 124 file : filename or open filehandle (writable)
125 125 File that will be duplicated
126 126
127 127 mode : optional, valid mode for open().
128 128 If a filename was give, open with this mode.
129 129
130 130 channel : str, one of ['stdout', 'stderr']
131 131 """
132 132 if channel not in ['stdout', 'stderr']:
133 133 raise ValueError('Invalid channel spec %s' % channel)
134 134
135 135 if hasattr(file, 'write') and hasattr(file, 'seek'):
136 136 self.file = file
137 137 else:
138 138 self.file = open(name, mode)
139 139 self.channel = channel
140 140 self.ostream = getattr(sys, channel)
141 141 setattr(sys, channel, self)
142 142 self._closed = False
143 143
144 144 def close(self):
145 145 """Close the file and restore the channel."""
146 146 self.flush()
147 147 setattr(sys, self.channel, self.ostream)
148 148 self.file.close()
149 149 self._closed = True
150 150
151 151 def write(self, data):
152 152 """Write data to both channels."""
153 153 self.file.write(data)
154 154 self.ostream.write(data)
155 155 self.ostream.flush()
156 156
157 157 def flush(self):
158 158 """Flush both channels."""
159 159 self.file.flush()
160 160 self.ostream.flush()
161 161
162 162 def __del__(self):
163 163 if not self._closed:
164 164 self.close()
165 165
166 166
167 167 #****************************************************************************
168 168 # Generic warning/error printer, used by everything else
169 169 def warn(msg,level=2,exit_val=1):
170 170 """Standard warning printer. Gives formatting consistency.
171 171
172 172 Output is sent to Term.cerr (sys.stderr by default).
173 173
174 174 Options:
175 175
176 176 -level(2): allows finer control:
177 177 0 -> Do nothing, dummy function.
178 178 1 -> Print message.
179 179 2 -> Print 'WARNING:' + message. (Default level).
180 180 3 -> Print 'ERROR:' + message.
181 181 4 -> Print 'FATAL ERROR:' + message and trigger a sys.exit(exit_val).
182 182
183 183 -exit_val (1): exit value returned by sys.exit() for a level 4
184 184 warning. Ignored for all other levels."""
185 185
186 186 if level>0:
187 187 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
188 188 print >> Term.cerr, '%s%s' % (header[level],msg)
189 189 if level == 4:
190 190 print >> Term.cerr,'Exiting.\n'
191 191 sys.exit(exit_val)
192 192
193 193 def info(msg):
194 194 """Equivalent to warn(msg,level=1)."""
195 195
196 196 warn(msg,level=1)
197 197
198 198 def error(msg):
199 199 """Equivalent to warn(msg,level=3)."""
200 200
201 201 warn(msg,level=3)
202 202
203 203 def fatal(msg,exit_val=1):
204 204 """Equivalent to warn(msg,exit_val=exit_val,level=4)."""
205 205
206 206 warn(msg,exit_val=exit_val,level=4)
207 207
208 208 #---------------------------------------------------------------------------
209 209 # Debugging routines
210 210 #
211 211 def debugx(expr,pre_msg=''):
212 212 """Print the value of an expression from the caller's frame.
213 213
214 214 Takes an expression, evaluates it in the caller's frame and prints both
215 215 the given expression and the resulting value (as well as a debug mark
216 216 indicating the name of the calling function. The input must be of a form
217 217 suitable for eval().
218 218
219 219 An optional message can be passed, which will be prepended to the printed
220 220 expr->value pair."""
221 221
222 222 cf = sys._getframe(1)
223 223 print '[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
224 224 eval(expr,cf.f_globals,cf.f_locals))
225 225
226 226 # deactivate it by uncommenting the following line, which makes it a no-op
227 227 #def debugx(expr,pre_msg=''): pass
228 228
229 229 #----------------------------------------------------------------------------
230 230 StringTypes = types.StringTypes
231 231
232 232 # Basic timing functionality
233 233
234 234 # If possible (Unix), use the resource module instead of time.clock()
235 235 try:
236 236 import resource
237 237 def clocku():
238 238 """clocku() -> floating point number
239 239
240 240 Return the *USER* CPU time in seconds since the start of the process.
241 241 This is done via a call to resource.getrusage, so it avoids the
242 242 wraparound problems in time.clock()."""
243 243
244 244 return resource.getrusage(resource.RUSAGE_SELF)[0]
245 245
246 246 def clocks():
247 247 """clocks() -> floating point number
248 248
249 249 Return the *SYSTEM* CPU time in seconds since the start of the process.
250 250 This is done via a call to resource.getrusage, so it avoids the
251 251 wraparound problems in time.clock()."""
252 252
253 253 return resource.getrusage(resource.RUSAGE_SELF)[1]
254 254
255 255 def clock():
256 256 """clock() -> floating point number
257 257
258 258 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
259 259 the process. This is done via a call to resource.getrusage, so it
260 260 avoids the wraparound problems in time.clock()."""
261 261
262 262 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
263 263 return u+s
264 264
265 265 def clock2():
266 266 """clock2() -> (t_user,t_system)
267 267
268 268 Similar to clock(), but return a tuple of user/system times."""
269 269 return resource.getrusage(resource.RUSAGE_SELF)[:2]
270 270
271 271 except ImportError:
272 272 # There is no distinction of user/system time under windows, so we just use
273 273 # time.clock() for everything...
274 274 clocku = clocks = clock = time.clock
275 275 def clock2():
276 276 """Under windows, system CPU time can't be measured.
277 277
278 278 This just returns clock() and zero."""
279 279 return time.clock(),0.0
280 280
281 281 def timings_out(reps,func,*args,**kw):
282 282 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
283 283
284 284 Execute a function reps times, return a tuple with the elapsed total
285 285 CPU time in seconds, the time per call and the function's output.
286 286
287 287 Under Unix, the return value is the sum of user+system time consumed by
288 288 the process, computed via the resource module. This prevents problems
289 289 related to the wraparound effect which the time.clock() function has.
290 290
291 291 Under Windows the return value is in wall clock seconds. See the
292 292 documentation for the time module for more details."""
293 293
294 294 reps = int(reps)
295 295 assert reps >=1, 'reps must be >= 1'
296 296 if reps==1:
297 297 start = clock()
298 298 out = func(*args,**kw)
299 299 tot_time = clock()-start
300 300 else:
301 301 rng = xrange(reps-1) # the last time is executed separately to store output
302 302 start = clock()
303 303 for dummy in rng: func(*args,**kw)
304 304 out = func(*args,**kw) # one last time
305 305 tot_time = clock()-start
306 306 av_time = tot_time / reps
307 307 return tot_time,av_time,out
308 308
309 309 def timings(reps,func,*args,**kw):
310 310 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
311 311
312 312 Execute a function reps times, return a tuple with the elapsed total CPU
313 313 time in seconds and the time per call. These are just the first two values
314 314 in timings_out()."""
315 315
316 316 return timings_out(reps,func,*args,**kw)[0:2]
317 317
318 318 def timing(func,*args,**kw):
319 319 """timing(func,*args,**kw) -> t_total
320 320
321 321 Execute a function once, return the elapsed total CPU time in
322 322 seconds. This is just the first value in timings_out()."""
323 323
324 324 return timings_out(1,func,*args,**kw)[0]
325 325
326 326 #****************************************************************************
327 327 # file and system
328 328
329 329 def arg_split(s,posix=False):
330 330 """Split a command line's arguments in a shell-like manner.
331 331
332 332 This is a modified version of the standard library's shlex.split()
333 333 function, but with a default of posix=False for splitting, so that quotes
334 334 in inputs are respected."""
335 335
336 336 # XXX - there may be unicode-related problems here!!! I'm not sure that
337 337 # shlex is truly unicode-safe, so it might be necessary to do
338 338 #
339 339 # s = s.encode(sys.stdin.encoding)
340 340 #
341 341 # first, to ensure that shlex gets a normal string. Input from anyone who
342 342 # knows more about unicode and shlex than I would be good to have here...
343 343 lex = shlex.shlex(s, posix=posix)
344 344 lex.whitespace_split = True
345 345 return list(lex)
346 346
347 347 def system(cmd,verbose=0,debug=0,header=''):
348 348 """Execute a system command, return its exit status.
349 349
350 350 Options:
351 351
352 352 - verbose (0): print the command to be executed.
353 353
354 354 - debug (0): only print, do not actually execute.
355 355
356 356 - header (''): Header to print on screen prior to the executed command (it
357 357 is only prepended to the command, no newlines are added).
358 358
359 359 Note: a stateful version of this function is available through the
360 360 SystemExec class."""
361 361
362 362 stat = 0
363 363 if verbose or debug: print header+cmd
364 364 sys.stdout.flush()
365 365 if not debug: stat = os.system(cmd)
366 366 return stat
367 367
368 368 def abbrev_cwd():
369 369 """ Return abbreviated version of cwd, e.g. d:mydir """
370 370 cwd = os.getcwd().replace('\\','/')
371 371 drivepart = ''
372 372 tail = cwd
373 373 if sys.platform == 'win32':
374 374 if len(cwd) < 4:
375 375 return cwd
376 376 drivepart,tail = os.path.splitdrive(cwd)
377 377
378 378
379 379 parts = tail.split('/')
380 380 if len(parts) > 2:
381 381 tail = '/'.join(parts[-2:])
382 382
383 383 return (drivepart + (
384 384 cwd == '/' and '/' or tail))
385 385
386 386
387 387 # This function is used by ipython in a lot of places to make system calls.
388 388 # We need it to be slightly different under win32, due to the vagaries of
389 389 # 'network shares'. A win32 override is below.
390 390
391 391 def shell(cmd,verbose=0,debug=0,header=''):
392 392 """Execute a command in the system shell, always return None.
393 393
394 394 Options:
395 395
396 396 - verbose (0): print the command to be executed.
397 397
398 398 - debug (0): only print, do not actually execute.
399 399
400 400 - header (''): Header to print on screen prior to the executed command (it
401 401 is only prepended to the command, no newlines are added).
402 402
403 403 Note: this is similar to genutils.system(), but it returns None so it can
404 404 be conveniently used in interactive loops without getting the return value
405 405 (typically 0) printed many times."""
406 406
407 407 stat = 0
408 408 if verbose or debug: print header+cmd
409 409 # flush stdout so we don't mangle python's buffering
410 410 sys.stdout.flush()
411 411
412 412 if not debug:
413 413 platutils.set_term_title("IPy " + cmd)
414 414 os.system(cmd)
415 415 platutils.set_term_title("IPy " + abbrev_cwd())
416 416
417 417 # override shell() for win32 to deal with network shares
418 418 if os.name in ('nt','dos'):
419 419
420 420 shell_ori = shell
421 421
422 422 def shell(cmd,verbose=0,debug=0,header=''):
423 423 if os.getcwd().startswith(r"\\"):
424 424 path = os.getcwd()
425 425 # change to c drive (cannot be on UNC-share when issuing os.system,
426 426 # as cmd.exe cannot handle UNC addresses)
427 427 os.chdir("c:")
428 428 # issue pushd to the UNC-share and then run the command
429 429 try:
430 430 shell_ori('"pushd %s&&"'%path+cmd,verbose,debug,header)
431 431 finally:
432 432 os.chdir(path)
433 433 else:
434 434 shell_ori(cmd,verbose,debug,header)
435 435
436 436 shell.__doc__ = shell_ori.__doc__
437 437
438 438 def getoutput(cmd,verbose=0,debug=0,header='',split=0):
439 439 """Dummy substitute for perl's backquotes.
440 440
441 441 Executes a command and returns the output.
442 442
443 443 Accepts the same arguments as system(), plus:
444 444
445 445 - split(0): if true, the output is returned as a list split on newlines.
446 446
447 447 Note: a stateful version of this function is available through the
448 448 SystemExec class.
449 449
450 450 This is pretty much deprecated and rarely used,
451 451 genutils.getoutputerror may be what you need.
452 452
453 453 """
454 454
455 455 if verbose or debug: print header+cmd
456 456 if not debug:
457 457 pipe = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE).stdout
458 458 output = pipe.read()
459 459 # stipping last \n is here for backwards compat.
460 460 if output.endswith('\n'):
461 461 output = output[:-1]
462 462 if split:
463 463 return output.split('\n')
464 464 else:
465 465 return output
466 466
467 467 def getoutputerror(cmd,verbose=0,debug=0,header='',split=0):
468 468 """Return (standard output,standard error) of executing cmd in a shell.
469 469
470 470 Accepts the same arguments as system(), plus:
471 471
472 472 - split(0): if true, each of stdout/err is returned as a list split on
473 473 newlines.
474 474
475 475 Note: a stateful version of this function is available through the
476 476 SystemExec class."""
477 477
478 478 if verbose or debug: print header+cmd
479 479 if not cmd:
480 480 if split:
481 481 return [],[]
482 482 else:
483 483 return '',''
484 484 if not debug:
485 485 p = subprocess.Popen(cmd, shell=True,
486 486 stdin=subprocess.PIPE,
487 487 stdout=subprocess.PIPE,
488 488 stderr=subprocess.PIPE,
489 489 close_fds=True)
490 490 pin, pout, perr = (p.stdin, p.stdout, p.stderr)
491 491
492 492 tout = pout.read().rstrip()
493 493 terr = perr.read().rstrip()
494 494 pin.close()
495 495 pout.close()
496 496 perr.close()
497 497 if split:
498 498 return tout.split('\n'),terr.split('\n')
499 499 else:
500 500 return tout,terr
501 501
502 502 # for compatibility with older naming conventions
503 503 xsys = system
504 504 bq = getoutput
505 505
506 506 class SystemExec:
507 507 """Access the system and getoutput functions through a stateful interface.
508 508
509 509 Note: here we refer to the system and getoutput functions from this
510 510 library, not the ones from the standard python library.
511 511
512 512 This class offers the system and getoutput functions as methods, but the
513 513 verbose, debug and header parameters can be set for the instance (at
514 514 creation time or later) so that they don't need to be specified on each
515 515 call.
516 516
517 517 For efficiency reasons, there's no way to override the parameters on a
518 518 per-call basis other than by setting instance attributes. If you need
519 519 local overrides, it's best to directly call system() or getoutput().
520 520
521 521 The following names are provided as alternate options:
522 522 - xsys: alias to system
523 523 - bq: alias to getoutput
524 524
525 525 An instance can then be created as:
526 526 >>> sysexec = SystemExec(verbose=1,debug=0,header='Calling: ')
527 527 """
528 528
529 529 def __init__(self,verbose=0,debug=0,header='',split=0):
530 530 """Specify the instance's values for verbose, debug and header."""
531 531 setattr_list(self,'verbose debug header split')
532 532
533 533 def system(self,cmd):
534 534 """Stateful interface to system(), with the same keyword parameters."""
535 535
536 536 system(cmd,self.verbose,self.debug,self.header)
537 537
538 538 def shell(self,cmd):
539 539 """Stateful interface to shell(), with the same keyword parameters."""
540 540
541 541 shell(cmd,self.verbose,self.debug,self.header)
542 542
543 543 xsys = system # alias
544 544
545 545 def getoutput(self,cmd):
546 546 """Stateful interface to getoutput()."""
547 547
548 548 return getoutput(cmd,self.verbose,self.debug,self.header,self.split)
549 549
550 550 def getoutputerror(self,cmd):
551 551 """Stateful interface to getoutputerror()."""
552 552
553 553 return getoutputerror(cmd,self.verbose,self.debug,self.header,self.split)
554 554
555 555 bq = getoutput # alias
556 556
557 557 #-----------------------------------------------------------------------------
558 558 def mutex_opts(dict,ex_op):
559 559 """Check for presence of mutually exclusive keys in a dict.
560 560
561 561 Call: mutex_opts(dict,[[op1a,op1b],[op2a,op2b]...]"""
562 562 for op1,op2 in ex_op:
563 563 if op1 in dict and op2 in dict:
564 564 raise ValueError,'\n*** ERROR in Arguments *** '\
565 565 'Options '+op1+' and '+op2+' are mutually exclusive.'
566 566
567 567 #-----------------------------------------------------------------------------
568 568 def get_py_filename(name):
569 569 """Return a valid python filename in the current directory.
570 570
571 571 If the given name is not a file, it adds '.py' and searches again.
572 572 Raises IOError with an informative message if the file isn't found."""
573 573
574 574 name = os.path.expanduser(name)
575 575 if not os.path.isfile(name) and not name.endswith('.py'):
576 576 name += '.py'
577 577 if os.path.isfile(name):
578 578 return name
579 579 else:
580 580 raise IOError,'File `%s` not found.' % name
581 581
582 582 #-----------------------------------------------------------------------------
583 583
584 584
585 585 def filefind(filename, path_dirs=None):
586 586 """Find a file by looking through a sequence of paths.
587 587
588 588 This iterates through a sequence of paths looking for a file and returns
589 589 the full, absolute path of the first occurence of the file. If no set of
590 590 path dirs is given, the filename is tested as is, after running through
591 591 :func:`expandvars` and :func:`expanduser`. Thus a simple call::
592 592
593 593 filefind('myfile.txt')
594 594
595 595 will find the file in the current working dir, but::
596 596
597 597 filefind('~/myfile.txt')
598 598
599 599 Will find the file in the users home directory. This function does not
600 600 automatically try any paths, such as the cwd or the user's home directory.
601 601
602 602 Parameters
603 603 ----------
604 604 filename : str
605 605 The filename to look for.
606 606 path_dirs : str, None or sequence of str
607 607 The sequence of paths to look for the file in. If None, the filename
608 608 need to be absolute or be in the cwd. If a string, the string is
609 609 put into a sequence and the searched. If a sequence, walk through
610 610 each element and join with ``filename``, calling :func:`expandvars`
611 611 and :func:`expanduser` before testing for existence.
612 612
613 613 Returns
614 614 -------
615 615 Raises :exc:`IOError` or returns absolute path to file.
616 616 """
617 617 if path_dirs is None:
618 618 path_dirs = ("",)
619 619 elif isinstance(path_dirs, basestring):
620 620 path_dirs = (path_dirs,)
621 621 for path in path_dirs:
622 622 if path == '.': path = os.getcwd()
623 623 testname = expand_path(os.path.join(path, filename))
624 624 if os.path.isfile(testname):
625 625 return os.path.abspath(testname)
626 626 raise IOError("File does not exist in any "
627 627 "of the search paths: %r, %r" % \
628 628 (filename, path_dirs))
629 629
630 630
631 631 #----------------------------------------------------------------------------
632 632 def file_read(filename):
633 633 """Read a file and close it. Returns the file source."""
634 634 fobj = open(filename,'r');
635 635 source = fobj.read();
636 636 fobj.close()
637 637 return source
638 638
639 639 def file_readlines(filename):
640 640 """Read a file and close it. Returns the file source using readlines()."""
641 641 fobj = open(filename,'r');
642 642 lines = fobj.readlines();
643 643 fobj.close()
644 644 return lines
645 645
646 646 #----------------------------------------------------------------------------
647 647 def target_outdated(target,deps):
648 648 """Determine whether a target is out of date.
649 649
650 650 target_outdated(target,deps) -> 1/0
651 651
652 652 deps: list of filenames which MUST exist.
653 653 target: single filename which may or may not exist.
654 654
655 655 If target doesn't exist or is older than any file listed in deps, return
656 656 true, otherwise return false.
657 657 """
658 658 try:
659 659 target_time = os.path.getmtime(target)
660 660 except os.error:
661 661 return 1
662 662 for dep in deps:
663 663 dep_time = os.path.getmtime(dep)
664 664 if dep_time > target_time:
665 665 #print "For target",target,"Dep failed:",dep # dbg
666 666 #print "times (dep,tar):",dep_time,target_time # dbg
667 667 return 1
668 668 return 0
669 669
670 670 #-----------------------------------------------------------------------------
671 671 def target_update(target,deps,cmd):
672 672 """Update a target with a given command given a list of dependencies.
673 673
674 674 target_update(target,deps,cmd) -> runs cmd if target is outdated.
675 675
676 676 This is just a wrapper around target_outdated() which calls the given
677 677 command if target is outdated."""
678 678
679 679 if target_outdated(target,deps):
680 680 xsys(cmd)
681 681
682 682 #----------------------------------------------------------------------------
683 683 def unquote_ends(istr):
684 684 """Remove a single pair of quotes from the endpoints of a string."""
685 685
686 686 if not istr:
687 687 return istr
688 688 if (istr[0]=="'" and istr[-1]=="'") or \
689 689 (istr[0]=='"' and istr[-1]=='"'):
690 690 return istr[1:-1]
691 691 else:
692 692 return istr
693 693
694 694 #----------------------------------------------------------------------------
695 695 def flag_calls(func):
696 696 """Wrap a function to detect and flag when it gets called.
697 697
698 698 This is a decorator which takes a function and wraps it in a function with
699 699 a 'called' attribute. wrapper.called is initialized to False.
700 700
701 701 The wrapper.called attribute is set to False right before each call to the
702 702 wrapped function, so if the call fails it remains False. After the call
703 703 completes, wrapper.called is set to True and the output is returned.
704 704
705 705 Testing for truth in wrapper.called allows you to determine if a call to
706 706 func() was attempted and succeeded."""
707 707
708 708 def wrapper(*args,**kw):
709 709 wrapper.called = False
710 710 out = func(*args,**kw)
711 711 wrapper.called = True
712 712 return out
713 713
714 714 wrapper.called = False
715 715 wrapper.__doc__ = func.__doc__
716 716 return wrapper
717 717
718 718 #----------------------------------------------------------------------------
719 719 def dhook_wrap(func,*a,**k):
720 720 """Wrap a function call in a sys.displayhook controller.
721 721
722 722 Returns a wrapper around func which calls func, with all its arguments and
723 723 keywords unmodified, using the default sys.displayhook. Since IPython
724 724 modifies sys.displayhook, it breaks the behavior of certain systems that
725 725 rely on the default behavior, notably doctest.
726 726 """
727 727
728 728 def f(*a,**k):
729 729
730 730 dhook_s = sys.displayhook
731 731 sys.displayhook = sys.__displayhook__
732 732 try:
733 733 out = func(*a,**k)
734 734 finally:
735 735 sys.displayhook = dhook_s
736 736
737 737 return out
738 738
739 739 f.__doc__ = func.__doc__
740 740 return f
741 741
742 742 #----------------------------------------------------------------------------
743 743 def doctest_reload():
744 744 """Properly reload doctest to reuse it interactively.
745 745
746 746 This routine:
747 747
748 748 - imports doctest but does NOT reload it (see below).
749 749
750 750 - resets its global 'master' attribute to None, so that multiple uses of
751 751 the module interactively don't produce cumulative reports.
752 752
753 753 - Monkeypatches its core test runner method to protect it from IPython's
754 754 modified displayhook. Doctest expects the default displayhook behavior
755 755 deep down, so our modification breaks it completely. For this reason, a
756 756 hard monkeypatch seems like a reasonable solution rather than asking
757 757 users to manually use a different doctest runner when under IPython.
758 758
759 759 Notes
760 760 -----
761 761
762 762 This function *used to* reload doctest, but this has been disabled because
763 763 reloading doctest unconditionally can cause massive breakage of other
764 764 doctest-dependent modules already in memory, such as those for IPython's
765 765 own testing system. The name wasn't changed to avoid breaking people's
766 766 code, but the reload call isn't actually made anymore."""
767 767
768 768 import doctest
769 769 doctest.master = None
770 770 doctest.DocTestRunner.run = dhook_wrap(doctest.DocTestRunner.run)
771 771
772 772 #----------------------------------------------------------------------------
773 773 class HomeDirError(Error):
774 774 pass
775 775
776 776 def get_home_dir():
777 777 """Return the closest possible equivalent to a 'home' directory.
778 778
779 779 * On POSIX, we try $HOME.
780 780 * On Windows we try:
781 - %HOME%: rare, but some people with unix-like setups may have defined it
781 782 - %HOMESHARE%
782 783 - %HOMEDRIVE\%HOMEPATH%
783 784 - %USERPROFILE%
784 785 - Registry hack
785 786 * On Dos C:\
786 787
787 788 Currently only Posix and NT are implemented, a HomeDirError exception is
788 789 raised for all other OSes.
789 790 """
790 791
791 792 isdir = os.path.isdir
792 793 env = os.environ
793 794
794 795 # first, check py2exe distribution root directory for _ipython.
795 796 # This overrides all. Normally does not exist.
796 797
797 798 if hasattr(sys, "frozen"): #Is frozen by py2exe
798 799 if '\\library.zip\\' in IPython.__file__.lower():#libraries compressed to zip-file
799 800 root, rest = IPython.__file__.lower().split('library.zip')
800 801 else:
801 802 root=os.path.join(os.path.split(IPython.__file__)[0],"../../")
802 803 root=os.path.abspath(root).rstrip('\\')
803 804 if isdir(os.path.join(root, '_ipython')):
804 805 os.environ["IPYKITROOT"] = root
805 806 return root.decode(sys.getfilesystemencoding())
806 807
807 808 if os.name == 'posix':
808 809 # Linux, Unix, AIX, OS X
809 810 try:
810 811 homedir = env['HOME']
811 812 except KeyError:
812 813 raise HomeDirError('Undefined $HOME, IPython cannot proceed.')
813 814 else:
814 815 return homedir.decode(sys.getfilesystemencoding())
815 816 elif os.name == 'nt':
816 817 # Now for win9x, XP, Vista, 7?
817 818 # For some strange reason all of these return 'nt' for os.name.
818 819 # First look for a network home directory. This will return the UNC
819 820 # path (\\server\\Users\%username%) not the mapped path (Z:\). This
820 821 # is needed when running IPython on cluster where all paths have to
821 822 # be UNC.
822 823 try:
823 homedir = env['HOMESHARE']
824 # A user with a lot of unix tools in win32 may have defined $HOME,
825 # honor it if it exists, but otherwise let the more typical
826 # %HOMESHARE% variable be used.
827 homedir = env.get('HOME')
828 if homedir is None:
829 homedir = env['HOMESHARE']
824 830 except KeyError:
825 831 pass
826 832 else:
827 833 if isdir(homedir):
828 834 return homedir.decode(sys.getfilesystemencoding())
829 835
830 836 # Now look for a local home directory
831 837 try:
832 838 homedir = os.path.join(env['HOMEDRIVE'],env['HOMEPATH'])
833 839 except KeyError:
834 840 pass
835 841 else:
836 842 if isdir(homedir):
837 843 return homedir.decode(sys.getfilesystemencoding())
838 844
839 845 # Now the users profile directory
840 846 try:
841 847 homedir = os.path.join(env['USERPROFILE'])
842 848 except KeyError:
843 849 pass
844 850 else:
845 851 if isdir(homedir):
846 852 return homedir.decode(sys.getfilesystemencoding())
847 853
848 854 # Use the registry to get the 'My Documents' folder.
849 855 try:
850 856 import _winreg as wreg
851 857 key = wreg.OpenKey(
852 858 wreg.HKEY_CURRENT_USER,
853 859 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
854 860 )
855 861 homedir = wreg.QueryValueEx(key,'Personal')[0]
856 862 key.Close()
857 863 except:
858 864 pass
859 865 else:
860 866 if isdir(homedir):
861 867 return homedir.decode(sys.getfilesystemencoding())
862 868
863 869 # If all else fails, raise HomeDirError
864 870 raise HomeDirError('No valid home directory could be found')
865 871 elif os.name == 'dos':
866 872 # Desperate, may do absurd things in classic MacOS. May work under DOS.
867 873 return 'C:\\'.decode(sys.getfilesystemencoding())
868 874 else:
869 875 raise HomeDirError('No valid home directory could be found for your OS')
870 876
871 877
872 878 def get_ipython_dir():
873 879 """Get the IPython directory for this platform and user.
874 880
875 881 This uses the logic in `get_home_dir` to find the home directory
876 882 and the adds .ipython to the end of the path.
877 883 """
878 884 ipdir_def = '.ipython'
879 885 home_dir = get_home_dir()
880 886 ipdir = os.environ.get(
881 887 'IPYTHON_DIR', os.environ.get(
882 888 'IPYTHONDIR', os.path.join(home_dir, ipdir_def)
883 889 )
884 890 )
885 891 return ipdir.decode(sys.getfilesystemencoding())
886 892
887 893
888 894 def get_ipython_package_dir():
889 895 """Get the base directory where IPython itself is installed."""
890 896 ipdir = os.path.dirname(IPython.__file__)
891 897 return ipdir.decode(sys.getfilesystemencoding())
892 898
893 899
894 900 #****************************************************************************
895 901 # strings and text
896 902
897 903 class LSString(str):
898 904 """String derivative with a special access attributes.
899 905
900 906 These are normal strings, but with the special attributes:
901 907
902 908 .l (or .list) : value as list (split on newlines).
903 909 .n (or .nlstr): original value (the string itself).
904 910 .s (or .spstr): value as whitespace-separated string.
905 911 .p (or .paths): list of path objects
906 912
907 913 Any values which require transformations are computed only once and
908 914 cached.
909 915
910 916 Such strings are very useful to efficiently interact with the shell, which
911 917 typically only understands whitespace-separated options for commands."""
912 918
913 919 def get_list(self):
914 920 try:
915 921 return self.__list
916 922 except AttributeError:
917 923 self.__list = self.split('\n')
918 924 return self.__list
919 925
920 926 l = list = property(get_list)
921 927
922 928 def get_spstr(self):
923 929 try:
924 930 return self.__spstr
925 931 except AttributeError:
926 932 self.__spstr = self.replace('\n',' ')
927 933 return self.__spstr
928 934
929 935 s = spstr = property(get_spstr)
930 936
931 937 def get_nlstr(self):
932 938 return self
933 939
934 940 n = nlstr = property(get_nlstr)
935 941
936 942 def get_paths(self):
937 943 try:
938 944 return self.__paths
939 945 except AttributeError:
940 946 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
941 947 return self.__paths
942 948
943 949 p = paths = property(get_paths)
944 950
945 951 def print_lsstring(arg):
946 952 """ Prettier (non-repr-like) and more informative printer for LSString """
947 953 print "LSString (.p, .n, .l, .s available). Value:"
948 954 print arg
949 955
950 956 print_lsstring = result_display.when_type(LSString)(print_lsstring)
951 957
952 958 #----------------------------------------------------------------------------
953 959 class SList(list):
954 960 """List derivative with a special access attributes.
955 961
956 962 These are normal lists, but with the special attributes:
957 963
958 964 .l (or .list) : value as list (the list itself).
959 965 .n (or .nlstr): value as a string, joined on newlines.
960 966 .s (or .spstr): value as a string, joined on spaces.
961 967 .p (or .paths): list of path objects
962 968
963 969 Any values which require transformations are computed only once and
964 970 cached."""
965 971
966 972 def get_list(self):
967 973 return self
968 974
969 975 l = list = property(get_list)
970 976
971 977 def get_spstr(self):
972 978 try:
973 979 return self.__spstr
974 980 except AttributeError:
975 981 self.__spstr = ' '.join(self)
976 982 return self.__spstr
977 983
978 984 s = spstr = property(get_spstr)
979 985
980 986 def get_nlstr(self):
981 987 try:
982 988 return self.__nlstr
983 989 except AttributeError:
984 990 self.__nlstr = '\n'.join(self)
985 991 return self.__nlstr
986 992
987 993 n = nlstr = property(get_nlstr)
988 994
989 995 def get_paths(self):
990 996 try:
991 997 return self.__paths
992 998 except AttributeError:
993 999 self.__paths = [path(p) for p in self if os.path.exists(p)]
994 1000 return self.__paths
995 1001
996 1002 p = paths = property(get_paths)
997 1003
998 1004 def grep(self, pattern, prune = False, field = None):
999 1005 """ Return all strings matching 'pattern' (a regex or callable)
1000 1006
1001 1007 This is case-insensitive. If prune is true, return all items
1002 1008 NOT matching the pattern.
1003 1009
1004 1010 If field is specified, the match must occur in the specified
1005 1011 whitespace-separated field.
1006 1012
1007 1013 Examples::
1008 1014
1009 1015 a.grep( lambda x: x.startswith('C') )
1010 1016 a.grep('Cha.*log', prune=1)
1011 1017 a.grep('chm', field=-1)
1012 1018 """
1013 1019
1014 1020 def match_target(s):
1015 1021 if field is None:
1016 1022 return s
1017 1023 parts = s.split()
1018 1024 try:
1019 1025 tgt = parts[field]
1020 1026 return tgt
1021 1027 except IndexError:
1022 1028 return ""
1023 1029
1024 1030 if isinstance(pattern, basestring):
1025 1031 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
1026 1032 else:
1027 1033 pred = pattern
1028 1034 if not prune:
1029 1035 return SList([el for el in self if pred(match_target(el))])
1030 1036 else:
1031 1037 return SList([el for el in self if not pred(match_target(el))])
1032 1038 def fields(self, *fields):
1033 1039 """ Collect whitespace-separated fields from string list
1034 1040
1035 1041 Allows quick awk-like usage of string lists.
1036 1042
1037 1043 Example data (in var a, created by 'a = !ls -l')::
1038 1044 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
1039 1045 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
1040 1046
1041 1047 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
1042 1048 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
1043 1049 (note the joining by space).
1044 1050 a.fields(-1) is ['ChangeLog', 'IPython']
1045 1051
1046 1052 IndexErrors are ignored.
1047 1053
1048 1054 Without args, fields() just split()'s the strings.
1049 1055 """
1050 1056 if len(fields) == 0:
1051 1057 return [el.split() for el in self]
1052 1058
1053 1059 res = SList()
1054 1060 for el in [f.split() for f in self]:
1055 1061 lineparts = []
1056 1062
1057 1063 for fd in fields:
1058 1064 try:
1059 1065 lineparts.append(el[fd])
1060 1066 except IndexError:
1061 1067 pass
1062 1068 if lineparts:
1063 1069 res.append(" ".join(lineparts))
1064 1070
1065 1071 return res
1066 1072 def sort(self,field= None, nums = False):
1067 1073 """ sort by specified fields (see fields())
1068 1074
1069 1075 Example::
1070 1076 a.sort(1, nums = True)
1071 1077
1072 1078 Sorts a by second field, in numerical order (so that 21 > 3)
1073 1079
1074 1080 """
1075 1081
1076 1082 #decorate, sort, undecorate
1077 1083 if field is not None:
1078 1084 dsu = [[SList([line]).fields(field), line] for line in self]
1079 1085 else:
1080 1086 dsu = [[line, line] for line in self]
1081 1087 if nums:
1082 1088 for i in range(len(dsu)):
1083 1089 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
1084 1090 try:
1085 1091 n = int(numstr)
1086 1092 except ValueError:
1087 1093 n = 0;
1088 1094 dsu[i][0] = n
1089 1095
1090 1096
1091 1097 dsu.sort()
1092 1098 return SList([t[1] for t in dsu])
1093 1099
1094 1100 def print_slist(arg):
1095 1101 """ Prettier (non-repr-like) and more informative printer for SList """
1096 1102 print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
1097 1103 if hasattr(arg, 'hideonce') and arg.hideonce:
1098 1104 arg.hideonce = False
1099 1105 return
1100 1106
1101 1107 nlprint(arg)
1102 1108
1103 1109 print_slist = result_display.when_type(SList)(print_slist)
1104 1110
1105 1111
1106 1112
1107 1113 #----------------------------------------------------------------------------
1108 1114 def esc_quotes(strng):
1109 1115 """Return the input string with single and double quotes escaped out"""
1110 1116
1111 1117 return strng.replace('"','\\"').replace("'","\\'")
1112 1118
1113 1119 #----------------------------------------------------------------------------
1114 1120 def make_quoted_expr(s):
1115 1121 """Return string s in appropriate quotes, using raw string if possible.
1116 1122
1117 1123 XXX - example removed because it caused encoding errors in documentation
1118 1124 generation. We need a new example that doesn't contain invalid chars.
1119 1125
1120 1126 Note the use of raw string and padding at the end to allow trailing
1121 1127 backslash.
1122 1128 """
1123 1129
1124 1130 tail = ''
1125 1131 tailpadding = ''
1126 1132 raw = ''
1127 1133 if "\\" in s:
1128 1134 raw = 'r'
1129 1135 if s.endswith('\\'):
1130 1136 tail = '[:-1]'
1131 1137 tailpadding = '_'
1132 1138 if '"' not in s:
1133 1139 quote = '"'
1134 1140 elif "'" not in s:
1135 1141 quote = "'"
1136 1142 elif '"""' not in s and not s.endswith('"'):
1137 1143 quote = '"""'
1138 1144 elif "'''" not in s and not s.endswith("'"):
1139 1145 quote = "'''"
1140 1146 else:
1141 1147 # give up, backslash-escaped string will do
1142 1148 return '"%s"' % esc_quotes(s)
1143 1149 res = raw + quote + s + tailpadding + quote + tail
1144 1150 return res
1145 1151
1146 1152
1147 1153 #----------------------------------------------------------------------------
1148 1154 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
1149 1155 """Take multiple lines of input.
1150 1156
1151 1157 A list with each line of input as a separate element is returned when a
1152 1158 termination string is entered (defaults to a single '.'). Input can also
1153 1159 terminate via EOF (^D in Unix, ^Z-RET in Windows).
1154 1160
1155 1161 Lines of input which end in \\ are joined into single entries (and a
1156 1162 secondary continuation prompt is issued as long as the user terminates
1157 1163 lines with \\). This allows entering very long strings which are still
1158 1164 meant to be treated as single entities.
1159 1165 """
1160 1166
1161 1167 try:
1162 1168 if header:
1163 1169 header += '\n'
1164 1170 lines = [raw_input(header + ps1)]
1165 1171 except EOFError:
1166 1172 return []
1167 1173 terminate = [terminate_str]
1168 1174 try:
1169 1175 while lines[-1:] != terminate:
1170 1176 new_line = raw_input(ps1)
1171 1177 while new_line.endswith('\\'):
1172 1178 new_line = new_line[:-1] + raw_input(ps2)
1173 1179 lines.append(new_line)
1174 1180
1175 1181 return lines[:-1] # don't return the termination command
1176 1182 except EOFError:
1177 1183 print
1178 1184 return lines
1179 1185
1180 1186 #----------------------------------------------------------------------------
1181 1187 def raw_input_ext(prompt='', ps2='... '):
1182 1188 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
1183 1189
1184 1190 line = raw_input(prompt)
1185 1191 while line.endswith('\\'):
1186 1192 line = line[:-1] + raw_input(ps2)
1187 1193 return line
1188 1194
1189 1195 #----------------------------------------------------------------------------
1190 1196 def ask_yes_no(prompt,default=None):
1191 1197 """Asks a question and returns a boolean (y/n) answer.
1192 1198
1193 1199 If default is given (one of 'y','n'), it is used if the user input is
1194 1200 empty. Otherwise the question is repeated until an answer is given.
1195 1201
1196 1202 An EOF is treated as the default answer. If there is no default, an
1197 1203 exception is raised to prevent infinite loops.
1198 1204
1199 1205 Valid answers are: y/yes/n/no (match is not case sensitive)."""
1200 1206
1201 1207 answers = {'y':True,'n':False,'yes':True,'no':False}
1202 1208 ans = None
1203 1209 while ans not in answers.keys():
1204 1210 try:
1205 1211 ans = raw_input(prompt+' ').lower()
1206 1212 if not ans: # response was an empty string
1207 1213 ans = default
1208 1214 except KeyboardInterrupt:
1209 1215 pass
1210 1216 except EOFError:
1211 1217 if default in answers.keys():
1212 1218 ans = default
1213 1219 print
1214 1220 else:
1215 1221 raise
1216 1222
1217 1223 return answers[ans]
1218 1224
1219 1225 #----------------------------------------------------------------------------
1220 1226 class EvalDict:
1221 1227 """
1222 1228 Emulate a dict which evaluates its contents in the caller's frame.
1223 1229
1224 1230 Usage:
1225 1231 >>> number = 19
1226 1232
1227 1233 >>> text = "python"
1228 1234
1229 1235 >>> print "%(text.capitalize())s %(number/9.0).1f rules!" % EvalDict()
1230 1236 Python 2.1 rules!
1231 1237 """
1232 1238
1233 1239 # This version is due to sismex01@hebmex.com on c.l.py, and is basically a
1234 1240 # modified (shorter) version of:
1235 1241 # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66018 by
1236 1242 # Skip Montanaro (skip@pobox.com).
1237 1243
1238 1244 def __getitem__(self, name):
1239 1245 frame = sys._getframe(1)
1240 1246 return eval(name, frame.f_globals, frame.f_locals)
1241 1247
1242 1248 EvalString = EvalDict # for backwards compatibility
1243 1249 #----------------------------------------------------------------------------
1244 1250 def qw(words,flat=0,sep=None,maxsplit=-1):
1245 1251 """Similar to Perl's qw() operator, but with some more options.
1246 1252
1247 1253 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
1248 1254
1249 1255 words can also be a list itself, and with flat=1, the output will be
1250 1256 recursively flattened.
1251 1257
1252 1258 Examples:
1253 1259
1254 1260 >>> qw('1 2')
1255 1261 ['1', '2']
1256 1262
1257 1263 >>> qw(['a b','1 2',['m n','p q']])
1258 1264 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
1259 1265
1260 1266 >>> qw(['a b','1 2',['m n','p q']],flat=1)
1261 1267 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
1262 1268 """
1263 1269
1264 1270 if type(words) in StringTypes:
1265 1271 return [word.strip() for word in words.split(sep,maxsplit)
1266 1272 if word and not word.isspace() ]
1267 1273 if flat:
1268 1274 return flatten(map(qw,words,[1]*len(words)))
1269 1275 return map(qw,words)
1270 1276
1271 1277 #----------------------------------------------------------------------------
1272 1278 def qwflat(words,sep=None,maxsplit=-1):
1273 1279 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
1274 1280 return qw(words,1,sep,maxsplit)
1275 1281
1276 1282 #----------------------------------------------------------------------------
1277 1283 def qw_lol(indata):
1278 1284 """qw_lol('a b') -> [['a','b']],
1279 1285 otherwise it's just a call to qw().
1280 1286
1281 1287 We need this to make sure the modules_some keys *always* end up as a
1282 1288 list of lists."""
1283 1289
1284 1290 if type(indata) in StringTypes:
1285 1291 return [qw(indata)]
1286 1292 else:
1287 1293 return qw(indata)
1288 1294
1289 1295 #----------------------------------------------------------------------------
1290 1296 def grep(pat,list,case=1):
1291 1297 """Simple minded grep-like function.
1292 1298 grep(pat,list) returns occurrences of pat in list, None on failure.
1293 1299
1294 1300 It only does simple string matching, with no support for regexps. Use the
1295 1301 option case=0 for case-insensitive matching."""
1296 1302
1297 1303 # This is pretty crude. At least it should implement copying only references
1298 1304 # to the original data in case it's big. Now it copies the data for output.
1299 1305 out=[]
1300 1306 if case:
1301 1307 for term in list:
1302 1308 if term.find(pat)>-1: out.append(term)
1303 1309 else:
1304 1310 lpat=pat.lower()
1305 1311 for term in list:
1306 1312 if term.lower().find(lpat)>-1: out.append(term)
1307 1313
1308 1314 if len(out): return out
1309 1315 else: return None
1310 1316
1311 1317 #----------------------------------------------------------------------------
1312 1318 def dgrep(pat,*opts):
1313 1319 """Return grep() on dir()+dir(__builtins__).
1314 1320
1315 1321 A very common use of grep() when working interactively."""
1316 1322
1317 1323 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
1318 1324
1319 1325 #----------------------------------------------------------------------------
1320 1326 def idgrep(pat):
1321 1327 """Case-insensitive dgrep()"""
1322 1328
1323 1329 return dgrep(pat,0)
1324 1330
1325 1331 #----------------------------------------------------------------------------
1326 1332 def igrep(pat,list):
1327 1333 """Synonym for case-insensitive grep."""
1328 1334
1329 1335 return grep(pat,list,case=0)
1330 1336
1331 1337 #----------------------------------------------------------------------------
1332 1338 def indent(str,nspaces=4,ntabs=0):
1333 1339 """Indent a string a given number of spaces or tabstops.
1334 1340
1335 1341 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
1336 1342 """
1337 1343 if str is None:
1338 1344 return
1339 1345 ind = '\t'*ntabs+' '*nspaces
1340 1346 outstr = '%s%s' % (ind,str.replace(os.linesep,os.linesep+ind))
1341 1347 if outstr.endswith(os.linesep+ind):
1342 1348 return outstr[:-len(ind)]
1343 1349 else:
1344 1350 return outstr
1345 1351
1346 1352 #-----------------------------------------------------------------------------
1347 1353 def native_line_ends(filename,backup=1):
1348 1354 """Convert (in-place) a file to line-ends native to the current OS.
1349 1355
1350 1356 If the optional backup argument is given as false, no backup of the
1351 1357 original file is left. """
1352 1358
1353 1359 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
1354 1360
1355 1361 bak_filename = filename + backup_suffixes[os.name]
1356 1362
1357 1363 original = open(filename).read()
1358 1364 shutil.copy2(filename,bak_filename)
1359 1365 try:
1360 1366 new = open(filename,'wb')
1361 1367 new.write(os.linesep.join(original.splitlines()))
1362 1368 new.write(os.linesep) # ALWAYS put an eol at the end of the file
1363 1369 new.close()
1364 1370 except:
1365 1371 os.rename(bak_filename,filename)
1366 1372 if not backup:
1367 1373 try:
1368 1374 os.remove(bak_filename)
1369 1375 except:
1370 1376 pass
1371 1377
1372 1378 #****************************************************************************
1373 1379 # lists, dicts and structures
1374 1380
1375 1381 def belong(candidates,checklist):
1376 1382 """Check whether a list of items appear in a given list of options.
1377 1383
1378 1384 Returns a list of 1 and 0, one for each candidate given."""
1379 1385
1380 1386 return [x in checklist for x in candidates]
1381 1387
1382 1388 #----------------------------------------------------------------------------
1383 1389 def uniq_stable(elems):
1384 1390 """uniq_stable(elems) -> list
1385 1391
1386 1392 Return from an iterable, a list of all the unique elements in the input,
1387 1393 but maintaining the order in which they first appear.
1388 1394
1389 1395 A naive solution to this problem which just makes a dictionary with the
1390 1396 elements as keys fails to respect the stability condition, since
1391 1397 dictionaries are unsorted by nature.
1392 1398
1393 1399 Note: All elements in the input must be valid dictionary keys for this
1394 1400 routine to work, as it internally uses a dictionary for efficiency
1395 1401 reasons."""
1396 1402
1397 1403 unique = []
1398 1404 unique_dict = {}
1399 1405 for nn in elems:
1400 1406 if nn not in unique_dict:
1401 1407 unique.append(nn)
1402 1408 unique_dict[nn] = None
1403 1409 return unique
1404 1410
1405 1411 #----------------------------------------------------------------------------
1406 1412 class NLprinter:
1407 1413 """Print an arbitrarily nested list, indicating index numbers.
1408 1414
1409 1415 An instance of this class called nlprint is available and callable as a
1410 1416 function.
1411 1417
1412 1418 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
1413 1419 and using 'sep' to separate the index from the value. """
1414 1420
1415 1421 def __init__(self):
1416 1422 self.depth = 0
1417 1423
1418 1424 def __call__(self,lst,pos='',**kw):
1419 1425 """Prints the nested list numbering levels."""
1420 1426 kw.setdefault('indent',' ')
1421 1427 kw.setdefault('sep',': ')
1422 1428 kw.setdefault('start',0)
1423 1429 kw.setdefault('stop',len(lst))
1424 1430 # we need to remove start and stop from kw so they don't propagate
1425 1431 # into a recursive call for a nested list.
1426 1432 start = kw['start']; del kw['start']
1427 1433 stop = kw['stop']; del kw['stop']
1428 1434 if self.depth == 0 and 'header' in kw.keys():
1429 1435 print kw['header']
1430 1436
1431 1437 for idx in range(start,stop):
1432 1438 elem = lst[idx]
1433 1439 if type(elem)==type([]):
1434 1440 self.depth += 1
1435 1441 self.__call__(elem,itpl('$pos$idx,'),**kw)
1436 1442 self.depth -= 1
1437 1443 else:
1438 1444 printpl(kw['indent']*self.depth+'$pos$idx$kw["sep"]$elem')
1439 1445
1440 1446 nlprint = NLprinter()
1441 1447 #----------------------------------------------------------------------------
1442 1448 def all_belong(candidates,checklist):
1443 1449 """Check whether a list of items ALL appear in a given list of options.
1444 1450
1445 1451 Returns a single 1 or 0 value."""
1446 1452
1447 1453 return 1-(0 in [x in checklist for x in candidates])
1448 1454
1449 1455 #----------------------------------------------------------------------------
1450 1456 def sort_compare(lst1,lst2,inplace = 1):
1451 1457 """Sort and compare two lists.
1452 1458
1453 1459 By default it does it in place, thus modifying the lists. Use inplace = 0
1454 1460 to avoid that (at the cost of temporary copy creation)."""
1455 1461 if not inplace:
1456 1462 lst1 = lst1[:]
1457 1463 lst2 = lst2[:]
1458 1464 lst1.sort(); lst2.sort()
1459 1465 return lst1 == lst2
1460 1466
1461 1467 #----------------------------------------------------------------------------
1462 1468 def list2dict(lst):
1463 1469 """Takes a list of (key,value) pairs and turns it into a dict."""
1464 1470
1465 1471 dic = {}
1466 1472 for k,v in lst: dic[k] = v
1467 1473 return dic
1468 1474
1469 1475 #----------------------------------------------------------------------------
1470 1476 def list2dict2(lst,default=''):
1471 1477 """Takes a list and turns it into a dict.
1472 1478 Much slower than list2dict, but more versatile. This version can take
1473 1479 lists with sublists of arbitrary length (including sclars)."""
1474 1480
1475 1481 dic = {}
1476 1482 for elem in lst:
1477 1483 if type(elem) in (types.ListType,types.TupleType):
1478 1484 size = len(elem)
1479 1485 if size == 0:
1480 1486 pass
1481 1487 elif size == 1:
1482 1488 dic[elem] = default
1483 1489 else:
1484 1490 k,v = elem[0], elem[1:]
1485 1491 if len(v) == 1: v = v[0]
1486 1492 dic[k] = v
1487 1493 else:
1488 1494 dic[elem] = default
1489 1495 return dic
1490 1496
1491 1497 #----------------------------------------------------------------------------
1492 1498 def flatten(seq):
1493 1499 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
1494 1500
1495 1501 return [x for subseq in seq for x in subseq]
1496 1502
1497 1503 #----------------------------------------------------------------------------
1498 1504 def get_slice(seq,start=0,stop=None,step=1):
1499 1505 """Get a slice of a sequence with variable step. Specify start,stop,step."""
1500 1506 if stop == None:
1501 1507 stop = len(seq)
1502 1508 item = lambda i: seq[i]
1503 1509 return map(item,xrange(start,stop,step))
1504 1510
1505 1511 #----------------------------------------------------------------------------
1506 1512 def chop(seq,size):
1507 1513 """Chop a sequence into chunks of the given size."""
1508 1514 chunk = lambda i: seq[i:i+size]
1509 1515 return map(chunk,xrange(0,len(seq),size))
1510 1516
1511 1517 #----------------------------------------------------------------------------
1512 1518 # with is a keyword as of python 2.5, so this function is renamed to withobj
1513 1519 # from its old 'with' name.
1514 1520 def with_obj(object, **args):
1515 1521 """Set multiple attributes for an object, similar to Pascal's with.
1516 1522
1517 1523 Example:
1518 1524 with_obj(jim,
1519 1525 born = 1960,
1520 1526 haircolour = 'Brown',
1521 1527 eyecolour = 'Green')
1522 1528
1523 1529 Credit: Greg Ewing, in
1524 1530 http://mail.python.org/pipermail/python-list/2001-May/040703.html.
1525 1531
1526 1532 NOTE: up until IPython 0.7.2, this was called simply 'with', but 'with'
1527 1533 has become a keyword for Python 2.5, so we had to rename it."""
1528 1534
1529 1535 object.__dict__.update(args)
1530 1536
1531 1537 #----------------------------------------------------------------------------
1532 1538 def setattr_list(obj,alist,nspace = None):
1533 1539 """Set a list of attributes for an object taken from a namespace.
1534 1540
1535 1541 setattr_list(obj,alist,nspace) -> sets in obj all the attributes listed in
1536 1542 alist with their values taken from nspace, which must be a dict (something
1537 1543 like locals() will often do) If nspace isn't given, locals() of the
1538 1544 *caller* is used, so in most cases you can omit it.
1539 1545
1540 1546 Note that alist can be given as a string, which will be automatically
1541 1547 split into a list on whitespace. If given as a list, it must be a list of
1542 1548 *strings* (the variable names themselves), not of variables."""
1543 1549
1544 1550 # this grabs the local variables from the *previous* call frame -- that is
1545 1551 # the locals from the function that called setattr_list().
1546 1552 # - snipped from weave.inline()
1547 1553 if nspace is None:
1548 1554 call_frame = sys._getframe().f_back
1549 1555 nspace = call_frame.f_locals
1550 1556
1551 1557 if type(alist) in StringTypes:
1552 1558 alist = alist.split()
1553 1559 for attr in alist:
1554 1560 val = eval(attr,nspace)
1555 1561 setattr(obj,attr,val)
1556 1562
1557 1563 #----------------------------------------------------------------------------
1558 1564 def getattr_list(obj,alist,*args):
1559 1565 """getattr_list(obj,alist[, default]) -> attribute list.
1560 1566
1561 1567 Get a list of named attributes for an object. When a default argument is
1562 1568 given, it is returned when the attribute doesn't exist; without it, an
1563 1569 exception is raised in that case.
1564 1570
1565 1571 Note that alist can be given as a string, which will be automatically
1566 1572 split into a list on whitespace. If given as a list, it must be a list of
1567 1573 *strings* (the variable names themselves), not of variables."""
1568 1574
1569 1575 if type(alist) in StringTypes:
1570 1576 alist = alist.split()
1571 1577 if args:
1572 1578 if len(args)==1:
1573 1579 default = args[0]
1574 1580 return map(lambda attr: getattr(obj,attr,default),alist)
1575 1581 else:
1576 1582 raise ValueError,'getattr_list() takes only one optional argument'
1577 1583 else:
1578 1584 return map(lambda attr: getattr(obj,attr),alist)
1579 1585
1580 1586 #----------------------------------------------------------------------------
1581 1587 def map_method(method,object_list,*argseq,**kw):
1582 1588 """map_method(method,object_list,*args,**kw) -> list
1583 1589
1584 1590 Return a list of the results of applying the methods to the items of the
1585 1591 argument sequence(s). If more than one sequence is given, the method is
1586 1592 called with an argument list consisting of the corresponding item of each
1587 1593 sequence. All sequences must be of the same length.
1588 1594
1589 1595 Keyword arguments are passed verbatim to all objects called.
1590 1596
1591 1597 This is Python code, so it's not nearly as fast as the builtin map()."""
1592 1598
1593 1599 out_list = []
1594 1600 idx = 0
1595 1601 for object in object_list:
1596 1602 try:
1597 1603 handler = getattr(object, method)
1598 1604 except AttributeError:
1599 1605 out_list.append(None)
1600 1606 else:
1601 1607 if argseq:
1602 1608 args = map(lambda lst:lst[idx],argseq)
1603 1609 #print 'ob',object,'hand',handler,'ar',args # dbg
1604 1610 out_list.append(handler(args,**kw))
1605 1611 else:
1606 1612 out_list.append(handler(**kw))
1607 1613 idx += 1
1608 1614 return out_list
1609 1615
1610 1616 #----------------------------------------------------------------------------
1611 1617 def get_class_members(cls):
1612 1618 ret = dir(cls)
1613 1619 if hasattr(cls,'__bases__'):
1614 1620 for base in cls.__bases__:
1615 1621 ret.extend(get_class_members(base))
1616 1622 return ret
1617 1623
1618 1624 #----------------------------------------------------------------------------
1619 1625 def dir2(obj):
1620 1626 """dir2(obj) -> list of strings
1621 1627
1622 1628 Extended version of the Python builtin dir(), which does a few extra
1623 1629 checks, and supports common objects with unusual internals that confuse
1624 1630 dir(), such as Traits and PyCrust.
1625 1631
1626 1632 This version is guaranteed to return only a list of true strings, whereas
1627 1633 dir() returns anything that objects inject into themselves, even if they
1628 1634 are later not really valid for attribute access (many extension libraries
1629 1635 have such bugs).
1630 1636 """
1631 1637
1632 1638 # Start building the attribute list via dir(), and then complete it
1633 1639 # with a few extra special-purpose calls.
1634 1640 words = dir(obj)
1635 1641
1636 1642 if hasattr(obj,'__class__'):
1637 1643 words.append('__class__')
1638 1644 words.extend(get_class_members(obj.__class__))
1639 1645 #if '__base__' in words: 1/0
1640 1646
1641 1647 # Some libraries (such as traits) may introduce duplicates, we want to
1642 1648 # track and clean this up if it happens
1643 1649 may_have_dupes = False
1644 1650
1645 1651 # this is the 'dir' function for objects with Enthought's traits
1646 1652 if hasattr(obj, 'trait_names'):
1647 1653 try:
1648 1654 words.extend(obj.trait_names())
1649 1655 may_have_dupes = True
1650 1656 except TypeError:
1651 1657 # This will happen if `obj` is a class and not an instance.
1652 1658 pass
1653 1659
1654 1660 # Support for PyCrust-style _getAttributeNames magic method.
1655 1661 if hasattr(obj, '_getAttributeNames'):
1656 1662 try:
1657 1663 words.extend(obj._getAttributeNames())
1658 1664 may_have_dupes = True
1659 1665 except TypeError:
1660 1666 # `obj` is a class and not an instance. Ignore
1661 1667 # this error.
1662 1668 pass
1663 1669
1664 1670 if may_have_dupes:
1665 1671 # eliminate possible duplicates, as some traits may also
1666 1672 # appear as normal attributes in the dir() call.
1667 1673 words = list(set(words))
1668 1674 words.sort()
1669 1675
1670 1676 # filter out non-string attributes which may be stuffed by dir() calls
1671 1677 # and poor coding in third-party modules
1672 1678 return [w for w in words if isinstance(w, basestring)]
1673 1679
1674 1680 #----------------------------------------------------------------------------
1675 1681 def import_fail_info(mod_name,fns=None):
1676 1682 """Inform load failure for a module."""
1677 1683
1678 1684 if fns == None:
1679 1685 warn("Loading of %s failed.\n" % (mod_name,))
1680 1686 else:
1681 1687 warn("Loading of %s from %s failed.\n" % (fns,mod_name))
1682 1688
1683 1689 #----------------------------------------------------------------------------
1684 1690 # Proposed popitem() extension, written as a method
1685 1691
1686 1692
1687 1693 class NotGiven: pass
1688 1694
1689 1695 def popkey(dct,key,default=NotGiven):
1690 1696 """Return dct[key] and delete dct[key].
1691 1697
1692 1698 If default is given, return it if dct[key] doesn't exist, otherwise raise
1693 1699 KeyError. """
1694 1700
1695 1701 try:
1696 1702 val = dct[key]
1697 1703 except KeyError:
1698 1704 if default is NotGiven:
1699 1705 raise
1700 1706 else:
1701 1707 return default
1702 1708 else:
1703 1709 del dct[key]
1704 1710 return val
1705 1711
1706 1712 def wrap_deprecated(func, suggest = '<nothing>'):
1707 1713 def newFunc(*args, **kwargs):
1708 1714 warnings.warn("Call to deprecated function %s, use %s instead" %
1709 1715 ( func.__name__, suggest),
1710 1716 category=DeprecationWarning,
1711 1717 stacklevel = 2)
1712 1718 return func(*args, **kwargs)
1713 1719 return newFunc
1714 1720
1715 1721
1716 1722 def _num_cpus_unix():
1717 1723 """Return the number of active CPUs on a Unix system."""
1718 1724 return os.sysconf("SC_NPROCESSORS_ONLN")
1719 1725
1720 1726
1721 1727 def _num_cpus_darwin():
1722 1728 """Return the number of active CPUs on a Darwin system."""
1723 1729 p = subprocess.Popen(['sysctl','-n','hw.ncpu'],stdout=subprocess.PIPE)
1724 1730 return p.stdout.read()
1725 1731
1726 1732
1727 1733 def _num_cpus_windows():
1728 1734 """Return the number of active CPUs on a Windows system."""
1729 1735 return os.environ.get("NUMBER_OF_PROCESSORS")
1730 1736
1731 1737
1732 1738 def num_cpus():
1733 1739 """Return the effective number of CPUs in the system as an integer.
1734 1740
1735 1741 This cross-platform function makes an attempt at finding the total number of
1736 1742 available CPUs in the system, as returned by various underlying system and
1737 1743 python calls.
1738 1744
1739 1745 If it can't find a sensible answer, it returns 1 (though an error *may* make
1740 1746 it return a large positive number that's actually incorrect).
1741 1747 """
1742 1748
1743 1749 # Many thanks to the Parallel Python project (http://www.parallelpython.com)
1744 1750 # for the names of the keys we needed to look up for this function. This
1745 1751 # code was inspired by their equivalent function.
1746 1752
1747 1753 ncpufuncs = {'Linux':_num_cpus_unix,
1748 1754 'Darwin':_num_cpus_darwin,
1749 1755 'Windows':_num_cpus_windows,
1750 1756 # On Vista, python < 2.5.2 has a bug and returns 'Microsoft'
1751 1757 # See http://bugs.python.org/issue1082 for details.
1752 1758 'Microsoft':_num_cpus_windows,
1753 1759 }
1754 1760
1755 1761 ncpufunc = ncpufuncs.get(platform.system(),
1756 1762 # default to unix version (Solaris, AIX, etc)
1757 1763 _num_cpus_unix)
1758 1764
1759 1765 try:
1760 1766 ncpus = max(1,int(ncpufunc()))
1761 1767 except:
1762 1768 ncpus = 1
1763 1769 return ncpus
1764 1770
1765 1771 def extract_vars(*names,**kw):
1766 1772 """Extract a set of variables by name from another frame.
1767 1773
1768 1774 :Parameters:
1769 1775 - `*names`: strings
1770 1776 One or more variable names which will be extracted from the caller's
1771 1777 frame.
1772 1778
1773 1779 :Keywords:
1774 1780 - `depth`: integer (0)
1775 1781 How many frames in the stack to walk when looking for your variables.
1776 1782
1777 1783
1778 1784 Examples:
1779 1785
1780 1786 In [2]: def func(x):
1781 1787 ...: y = 1
1782 1788 ...: print extract_vars('x','y')
1783 1789 ...:
1784 1790
1785 1791 In [3]: func('hello')
1786 1792 {'y': 1, 'x': 'hello'}
1787 1793 """
1788 1794
1789 1795 depth = kw.get('depth',0)
1790 1796
1791 1797 callerNS = sys._getframe(depth+1).f_locals
1792 1798 return dict((k,callerNS[k]) for k in names)
1793 1799
1794 1800
1795 1801 def extract_vars_above(*names):
1796 1802 """Extract a set of variables by name from another frame.
1797 1803
1798 1804 Similar to extractVars(), but with a specified depth of 1, so that names
1799 1805 are exctracted exactly from above the caller.
1800 1806
1801 1807 This is simply a convenience function so that the very common case (for us)
1802 1808 of skipping exactly 1 frame doesn't have to construct a special dict for
1803 1809 keyword passing."""
1804 1810
1805 1811 callerNS = sys._getframe(2).f_locals
1806 1812 return dict((k,callerNS[k]) for k in names)
1807 1813
1808 1814 def expand_path(s):
1809 1815 """Expand $VARS and ~names in a string, like a shell
1810 1816
1811 1817 :Examples:
1812 1818
1813 1819 In [2]: os.environ['FOO']='test'
1814 1820
1815 1821 In [3]: expand_path('variable FOO is $FOO')
1816 1822 Out[3]: 'variable FOO is test'
1817 1823 """
1818 1824 # This is a pretty subtle hack. When expand user is given a UNC path
1819 1825 # on Windows (\\server\share$\%username%), os.path.expandvars, removes
1820 1826 # the $ to get (\\server\share\%username%). I think it considered $
1821 1827 # alone an empty var. But, we need the $ to remains there (it indicates
1822 1828 # a hidden share).
1823 1829 if os.name=='nt':
1824 1830 s = s.replace('$\\', 'IPYTHON_TEMP')
1825 1831 s = os.path.expandvars(os.path.expanduser(s))
1826 1832 if os.name=='nt':
1827 1833 s = s.replace('IPYTHON_TEMP', '$\\')
1828 1834 return s
1829 1835
1830 1836 def list_strings(arg):
1831 1837 """Always return a list of strings, given a string or list of strings
1832 1838 as input.
1833 1839
1834 1840 :Examples:
1835 1841
1836 1842 In [7]: list_strings('A single string')
1837 1843 Out[7]: ['A single string']
1838 1844
1839 1845 In [8]: list_strings(['A single string in a list'])
1840 1846 Out[8]: ['A single string in a list']
1841 1847
1842 1848 In [9]: list_strings(['A','list','of','strings'])
1843 1849 Out[9]: ['A', 'list', 'of', 'strings']
1844 1850 """
1845 1851
1846 1852 if isinstance(arg,basestring): return [arg]
1847 1853 else: return arg
1848 1854
1849 1855
1850 1856 #----------------------------------------------------------------------------
1851 1857 def marquee(txt='',width=78,mark='*'):
1852 1858 """Return the input string centered in a 'marquee'.
1853 1859
1854 1860 :Examples:
1855 1861
1856 1862 In [16]: marquee('A test',40)
1857 1863 Out[16]: '**************** A test ****************'
1858 1864
1859 1865 In [17]: marquee('A test',40,'-')
1860 1866 Out[17]: '---------------- A test ----------------'
1861 1867
1862 1868 In [18]: marquee('A test',40,' ')
1863 1869 Out[18]: ' A test '
1864 1870
1865 1871 """
1866 1872 if not txt:
1867 1873 return (mark*width)[:width]
1868 1874 nmark = (width-len(txt)-2)/len(mark)/2
1869 1875 if nmark < 0: nmark =0
1870 1876 marks = mark*nmark
1871 1877 return '%s %s %s' % (marks,txt,marks)
1872 1878
1873 1879 #*************************** end of file <genutils.py> **********************
@@ -1,323 +1,326 b''
1 1 # encoding: utf-8
2 2
3 3 """Tests for genutils.py"""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # stdlib
19 19 import os
20 20 import shutil
21 21 import sys
22 22 import tempfile
23 23 import unittest
24 24
25 25 from cStringIO import StringIO
26 26 from os.path import join, abspath, split
27 27
28 28 # third-party
29 29 import nose.tools as nt
30 30
31 31 from nose import with_setup
32 32 from nose.tools import raises
33 33
34 34 # Our own
35 35 import IPython
36 36 from IPython.utils import genutils
37 37 from IPython.testing import decorators as dec
38 38 from IPython.testing.decorators import skipif, skip_if_not_win32
39 39
40 40 # Platform-dependent imports
41 41 try:
42 42 import _winreg as wreg
43 43 except ImportError:
44 44 #Fake _winreg module on none windows platforms
45 45 import new
46 46 sys.modules["_winreg"] = new.module("_winreg")
47 47 import _winreg as wreg
48 48 #Add entries that needs to be stubbed by the testing code
49 49 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
50 50
51 51 #-----------------------------------------------------------------------------
52 52 # Globals
53 53 #-----------------------------------------------------------------------------
54 54 env = os.environ
55 55 TEST_FILE_PATH = split(abspath(__file__))[0]
56 56 TMP_TEST_DIR = tempfile.mkdtemp()
57 57 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
58 58 IP_TEST_DIR = join(HOME_TEST_DIR,'_ipython')
59 59 #
60 60 # Setup/teardown functions/decorators
61 61 #
62 62
63 63 def setup():
64 64 """Setup testenvironment for the module:
65 65
66 66 - Adds dummy home dir tree
67 67 """
68 68 # Do not mask exceptions here. In particular, catching WindowsError is a
69 69 # problem because that exception is only defined on Windows...
70 70 os.makedirs(IP_TEST_DIR)
71 71
72 72 def teardown():
73 73 """Teardown testenvironment for the module:
74 74
75 75 - Remove dummy home dir tree
76 76 """
77 77 # Note: we remove the parent test dir, which is the root of all test
78 78 # subdirs we may have created. Use shutil instead of os.removedirs, so
79 79 # that non-empty directories are all recursively removed.
80 80 shutil.rmtree(TMP_TEST_DIR)
81 81
82 82
83 83 def setup_environment():
84 84 """Setup testenvironment for some functions that are tested
85 85 in this module. In particular this functions stores attributes
86 86 and other things that we need to stub in some test functions.
87 87 This needs to be done on a function level and not module level because
88 88 each testfunction needs a pristine environment.
89 89 """
90 90 global oldstuff, platformstuff
91 91 oldstuff = (env.copy(), os.name, genutils.get_home_dir, IPython.__file__,)
92 92
93 93 if os.name == 'nt':
94 94 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
95 95
96 if 'IPYTHONDIR' in env:
97 del env['IPYTHONDIR']
96 # Remove both spellings of env variables if present
97 env.pop('IPYTHON_DIR', None)
98 env.pop('IPYTHONDIR', None)
98 99
99 100 def teardown_environment():
100 101 """Restore things that were remebered by the setup_environment function
101 102 """
102 103 (oldenv, os.name, genutils.get_home_dir, IPython.__file__,) = oldstuff
103 104 for key in env.keys():
104 105 if key not in oldenv:
105 106 del env[key]
106 107 env.update(oldenv)
107 108 if hasattr(sys, 'frozen'):
108 109 del sys.frozen
109 110 if os.name == 'nt':
110 111 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
111 112
112 113 # Build decorator that uses the setup_environment/setup_environment
113 114 with_environment = with_setup(setup_environment, teardown_environment)
114 115
115 116
116 117 #
117 118 # Tests for get_home_dir
118 119 #
119 120
120 121 @skip_if_not_win32
121 122 @with_environment
122 123 def test_get_home_dir_1():
123 124 """Testcase for py2exe logic, un-compressed lib
124 125 """
125 126 sys.frozen = True
126 127
127 128 #fake filename for IPython.__init__
128 129 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
129 130
130 131 home_dir = genutils.get_home_dir()
131 132 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
132 133
133 134 @skip_if_not_win32
134 135 @with_environment
135 136 def test_get_home_dir_2():
136 137 """Testcase for py2exe logic, compressed lib
137 138 """
138 139 sys.frozen = True
139 140 #fake filename for IPython.__init__
140 141 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
141 142
142 143 home_dir = genutils.get_home_dir()
143 144 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
144 145
145 146 @with_environment
146 147 def test_get_home_dir_3():
147 148 """Testcase $HOME is set, then use its value as home directory."""
148 149 env["HOME"] = HOME_TEST_DIR
149 150 home_dir = genutils.get_home_dir()
150 151 nt.assert_equal(home_dir, env["HOME"])
151 152
152 153 @with_environment
153 154 def test_get_home_dir_4():
154 155 """Testcase $HOME is not set, os=='poix'.
155 156 This should fail with HomeDirError"""
156 157
157 158 os.name = 'posix'
158 159 if 'HOME' in env: del env['HOME']
159 160 nt.assert_raises(genutils.HomeDirError, genutils.get_home_dir)
160 161
161 162 @skip_if_not_win32
162 163 @with_environment
163 164 def test_get_home_dir_5():
164 165 """Testcase $HOME is not set, os=='nt'
165 166 env['HOMEDRIVE'],env['HOMEPATH'] points to path."""
166 167
167 168 os.name = 'nt'
168 169 if 'HOME' in env: del env['HOME']
169 170 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
170 171
171 172 home_dir = genutils.get_home_dir()
172 173 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
173 174
174 175 @skip_if_not_win32
175 176 @with_environment
176 177 def test_get_home_dir_6():
177 178 """Testcase $HOME is not set, os=='nt'
178 179 env['HOMEDRIVE'],env['HOMEPATH'] do not point to path.
179 180 env['USERPROFILE'] points to path
180 181 """
181 182
182 183 os.name = 'nt'
183 184 if 'HOME' in env: del env['HOME']
184 185 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.abspath(TEST_FILE_PATH), "DOES NOT EXIST"
185 186 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
186 187
187 188 home_dir = genutils.get_home_dir()
188 189 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
189 190
190 191 # Should we stub wreg fully so we can run the test on all platforms?
191 192 @skip_if_not_win32
192 193 @with_environment
193 194 def test_get_home_dir_7():
194 """Testcase $HOME is not set, os=='nt'
195 env['HOMEDRIVE'],env['HOMEPATH'], env['USERPROFILE'] missing
195 """Testcase $HOME is not set, os=='nt'
196
197 env['HOMEDRIVE'],env['HOMEPATH'], env['USERPROFILE'] and others missing
196 198 """
197 199 os.name = 'nt'
198 if 'HOME' in env: del env['HOME']
199 if 'HOMEDRIVE' in env: del env['HOMEDRIVE']
200 # Remove from stub environment all keys that may be set
201 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
202 env.pop(key, None)
200 203
201 204 #Stub windows registry functions
202 205 def OpenKey(x, y):
203 206 class key:
204 207 def Close(self):
205 208 pass
206 209 return key()
207 210 def QueryValueEx(x, y):
208 211 return [abspath(HOME_TEST_DIR)]
209 212
210 213 wreg.OpenKey = OpenKey
211 214 wreg.QueryValueEx = QueryValueEx
212 215
213 216 home_dir = genutils.get_home_dir()
214 217 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
215 218
216 219 #
217 220 # Tests for get_ipython_dir
218 221 #
219 222
220 223 @with_environment
221 224 def test_get_ipython_dir_1():
222 225 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
223 env['IPYTHONDIR'] = "someplace/.ipython"
226 env['IPYTHON_DIR'] = "someplace/.ipython"
224 227 ipdir = genutils.get_ipython_dir()
225 228 nt.assert_equal(ipdir, "someplace/.ipython")
226 229
227 230
228 231 @with_environment
229 232 def test_get_ipython_dir_2():
230 233 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
231 234 genutils.get_home_dir = lambda : "someplace"
232 235 os.name = "posix"
233 236 ipdir = genutils.get_ipython_dir()
234 237 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
235 238
236 239 #
237 240 # Tests for popkey
238 241 #
239 242
240 243 def test_popkey_1():
241 244 """test_popkey_1, Basic usage test of popkey
242 245 """
243 246 dct = dict(a=1, b=2, c=3)
244 247 nt.assert_equal(genutils.popkey(dct, "a"), 1)
245 248 nt.assert_equal(dct, dict(b=2, c=3))
246 249 nt.assert_equal(genutils.popkey(dct, "b"), 2)
247 250 nt.assert_equal(dct, dict(c=3))
248 251 nt.assert_equal(genutils.popkey(dct, "c"), 3)
249 252 nt.assert_equal(dct, dict())
250 253
251 254 def test_popkey_2():
252 255 """test_popkey_2, Test to see that popkey of non occuring keys
253 256 generates a KeyError exception
254 257 """
255 258 dct = dict(a=1, b=2, c=3)
256 259 nt.assert_raises(KeyError, genutils.popkey, dct, "d")
257 260
258 261 def test_popkey_3():
259 262 """test_popkey_3, Tests to see that popkey calls returns the correct value
260 263 and that the key/value was removed from the dict.
261 264 """
262 265 dct = dict(a=1, b=2, c=3)
263 266 nt.assert_equal(genutils.popkey(dct, "A", 13), 13)
264 267 nt.assert_equal(dct, dict(a=1, b=2, c=3))
265 268 nt.assert_equal(genutils.popkey(dct, "B", 14), 14)
266 269 nt.assert_equal(dct, dict(a=1, b=2, c=3))
267 270 nt.assert_equal(genutils.popkey(dct, "C", 15), 15)
268 271 nt.assert_equal(dct, dict(a=1, b=2, c=3))
269 272 nt.assert_equal(genutils.popkey(dct, "a"), 1)
270 273 nt.assert_equal(dct, dict(b=2, c=3))
271 274 nt.assert_equal(genutils.popkey(dct, "b"), 2)
272 275 nt.assert_equal(dct, dict(c=3))
273 276 nt.assert_equal(genutils.popkey(dct, "c"), 3)
274 277 nt.assert_equal(dct, dict())
275 278
276 279
277 280 def test_filefind():
278 281 """Various tests for filefind"""
279 282 f = tempfile.NamedTemporaryFile()
280 283 print 'fname:',f.name
281 284 alt_dirs = genutils.get_ipython_dir()
282 285 t = genutils.filefind(f.name,alt_dirs)
283 286 print 'found:',t
284 287
285 288
286 289 def test_get_ipython_package_dir():
287 290 ipdir = genutils.get_ipython_package_dir()
288 291 nt.assert_true(os.path.isdir(ipdir))
289 292
290 293
291 294 def test_tee_simple():
292 295 "Very simple check with stdout only"
293 296 chan = StringIO()
294 297 text = 'Hello'
295 298 tee = genutils.Tee(chan, channel='stdout')
296 299 print >> chan, text,
297 300 nt.assert_equal(chan.getvalue(), text)
298 301
299 302
300 303 class TeeTestCase(dec.ParametricTestCase):
301 304
302 305 def tchan(self, channel, check='close'):
303 306 trap = StringIO()
304 307 chan = StringIO()
305 308 text = 'Hello'
306 309
307 310 std_ori = getattr(sys, channel)
308 311 setattr(sys, channel, trap)
309 312
310 313 tee = genutils.Tee(chan, channel=channel)
311 314 print >> chan, text,
312 315 setattr(sys, channel, std_ori)
313 316 trap_val = trap.getvalue()
314 317 nt.assert_equals(chan.getvalue(), text)
315 318 if check=='close':
316 319 tee.close()
317 320 else:
318 321 del tee
319 322
320 323 def test(self):
321 324 for chan in ['stdout', 'stderr']:
322 325 for check in ['close', 'del']:
323 326 yield self.tchan(chan, check)
General Comments 0
You need to be logged in to leave comments. Login now