##// END OF EJS Templates
Added tests test_get_home_dir_3-test_get_home_dir_9...
Jorgen Stenarson -
Show More
@@ -1,2169 +1,2167 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 General purpose utilities.
4 4
5 5 This is a grab-bag of stuff I find useful in most programs I write. Some of
6 6 these things are also convenient when working at the command line.
7 7
8 8 $Id: genutils.py 2998 2008-01-31 10:06:04Z vivainio $"""
9 9
10 10 #*****************************************************************************
11 11 # Copyright (C) 2001-2006 Fernando Perez. <fperez@colorado.edu>
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #*****************************************************************************
16 16
17 17 from IPython import Release
18 18 __author__ = '%s <%s>' % Release.authors['Fernando']
19 19 __license__ = Release.license
20 20
21 21 #****************************************************************************
22 22 # required modules from the Python standard library
23 23 import __main__
24 24 import commands
25 25 try:
26 26 import doctest
27 27 except ImportError:
28 28 pass
29 29 import os
30 30 import platform
31 31 import re
32 32 import shlex
33 33 import shutil
34 34 import subprocess
35 35 import sys
36 36 import tempfile
37 37 import time
38 38 import types
39 39 import warnings
40 40
41 41 # Curses and termios are Unix-only modules
42 42 try:
43 43 import curses
44 44 # We need termios as well, so if its import happens to raise, we bail on
45 45 # using curses altogether.
46 46 import termios
47 47 except ImportError:
48 48 USE_CURSES = False
49 49 else:
50 50 # Curses on Solaris may not be complete, so we can't use it there
51 51 USE_CURSES = hasattr(curses,'initscr')
52 52
53 53 # Other IPython utilities
54 54 import IPython
55 55 from IPython.Itpl import Itpl,itpl,printpl
56 56 from IPython import DPyGetOpt, platutils
57 57 from IPython.generics import result_display
58 58 import IPython.ipapi
59 59 from IPython.external.path import path
60 60 if os.name == "nt":
61 61 from IPython.winconsole import get_console_size
62 62
63 63 try:
64 64 set
65 65 except:
66 66 from sets import Set as set
67 67
68 68
69 69 #****************************************************************************
70 70 # Exceptions
71 71 class Error(Exception):
72 72 """Base class for exceptions in this module."""
73 73 pass
74 74
75 75 #----------------------------------------------------------------------------
76 76 class IOStream:
77 77 def __init__(self,stream,fallback):
78 78 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
79 79 stream = fallback
80 80 self.stream = stream
81 81 self._swrite = stream.write
82 82 self.flush = stream.flush
83 83
84 84 def write(self,data):
85 85 try:
86 86 self._swrite(data)
87 87 except:
88 88 try:
89 89 # print handles some unicode issues which may trip a plain
90 90 # write() call. Attempt to emulate write() by using a
91 91 # trailing comma
92 92 print >> self.stream, data,
93 93 except:
94 94 # if we get here, something is seriously broken.
95 95 print >> sys.stderr, \
96 96 'ERROR - failed to write data to stream:', self.stream
97 97
98 98 def close(self):
99 99 pass
100 100
101 101
102 102 class IOTerm:
103 103 """ Term holds the file or file-like objects for handling I/O operations.
104 104
105 105 These are normally just sys.stdin, sys.stdout and sys.stderr but for
106 106 Windows they can can replaced to allow editing the strings before they are
107 107 displayed."""
108 108
109 109 # In the future, having IPython channel all its I/O operations through
110 110 # this class will make it easier to embed it into other environments which
111 111 # are not a normal terminal (such as a GUI-based shell)
112 112 def __init__(self,cin=None,cout=None,cerr=None):
113 113 self.cin = IOStream(cin,sys.stdin)
114 114 self.cout = IOStream(cout,sys.stdout)
115 115 self.cerr = IOStream(cerr,sys.stderr)
116 116
117 117 # Global variable to be used for all I/O
118 118 Term = IOTerm()
119 119
120 120 import IPython.rlineimpl as readline
121 121 # Remake Term to use the readline i/o facilities
122 122 if sys.platform == 'win32' and readline.have_readline:
123 123
124 124 Term = IOTerm(cout=readline._outputfile,cerr=readline._outputfile)
125 125
126 126
127 127 #****************************************************************************
128 128 # Generic warning/error printer, used by everything else
129 129 def warn(msg,level=2,exit_val=1):
130 130 """Standard warning printer. Gives formatting consistency.
131 131
132 132 Output is sent to Term.cerr (sys.stderr by default).
133 133
134 134 Options:
135 135
136 136 -level(2): allows finer control:
137 137 0 -> Do nothing, dummy function.
138 138 1 -> Print message.
139 139 2 -> Print 'WARNING:' + message. (Default level).
140 140 3 -> Print 'ERROR:' + message.
141 141 4 -> Print 'FATAL ERROR:' + message and trigger a sys.exit(exit_val).
142 142
143 143 -exit_val (1): exit value returned by sys.exit() for a level 4
144 144 warning. Ignored for all other levels."""
145 145
146 146 if level>0:
147 147 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
148 148 print >> Term.cerr, '%s%s' % (header[level],msg)
149 149 if level == 4:
150 150 print >> Term.cerr,'Exiting.\n'
151 151 sys.exit(exit_val)
152 152
153 153 def info(msg):
154 154 """Equivalent to warn(msg,level=1)."""
155 155
156 156 warn(msg,level=1)
157 157
158 158 def error(msg):
159 159 """Equivalent to warn(msg,level=3)."""
160 160
161 161 warn(msg,level=3)
162 162
163 163 def fatal(msg,exit_val=1):
164 164 """Equivalent to warn(msg,exit_val=exit_val,level=4)."""
165 165
166 166 warn(msg,exit_val=exit_val,level=4)
167 167
168 168 #---------------------------------------------------------------------------
169 169 # Debugging routines
170 170 #
171 171 def debugx(expr,pre_msg=''):
172 172 """Print the value of an expression from the caller's frame.
173 173
174 174 Takes an expression, evaluates it in the caller's frame and prints both
175 175 the given expression and the resulting value (as well as a debug mark
176 176 indicating the name of the calling function. The input must be of a form
177 177 suitable for eval().
178 178
179 179 An optional message can be passed, which will be prepended to the printed
180 180 expr->value pair."""
181 181
182 182 cf = sys._getframe(1)
183 183 print '[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
184 184 eval(expr,cf.f_globals,cf.f_locals))
185 185
186 186 # deactivate it by uncommenting the following line, which makes it a no-op
187 187 #def debugx(expr,pre_msg=''): pass
188 188
189 189 #----------------------------------------------------------------------------
190 190 StringTypes = types.StringTypes
191 191
192 192 # Basic timing functionality
193 193
194 194 # If possible (Unix), use the resource module instead of time.clock()
195 195 try:
196 196 import resource
197 197 def clocku():
198 198 """clocku() -> floating point number
199 199
200 200 Return the *USER* CPU time in seconds since the start of the process.
201 201 This is done via a call to resource.getrusage, so it avoids the
202 202 wraparound problems in time.clock()."""
203 203
204 204 return resource.getrusage(resource.RUSAGE_SELF)[0]
205 205
206 206 def clocks():
207 207 """clocks() -> floating point number
208 208
209 209 Return the *SYSTEM* CPU time in seconds since the start of the process.
210 210 This is done via a call to resource.getrusage, so it avoids the
211 211 wraparound problems in time.clock()."""
212 212
213 213 return resource.getrusage(resource.RUSAGE_SELF)[1]
214 214
215 215 def clock():
216 216 """clock() -> floating point number
217 217
218 218 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
219 219 the process. This is done via a call to resource.getrusage, so it
220 220 avoids the wraparound problems in time.clock()."""
221 221
222 222 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
223 223 return u+s
224 224
225 225 def clock2():
226 226 """clock2() -> (t_user,t_system)
227 227
228 228 Similar to clock(), but return a tuple of user/system times."""
229 229 return resource.getrusage(resource.RUSAGE_SELF)[:2]
230 230
231 231 except ImportError:
232 232 # There is no distinction of user/system time under windows, so we just use
233 233 # time.clock() for everything...
234 234 clocku = clocks = clock = time.clock
235 235 def clock2():
236 236 """Under windows, system CPU time can't be measured.
237 237
238 238 This just returns clock() and zero."""
239 239 return time.clock(),0.0
240 240
241 241 def timings_out(reps,func,*args,**kw):
242 242 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
243 243
244 244 Execute a function reps times, return a tuple with the elapsed total
245 245 CPU time in seconds, the time per call and the function's output.
246 246
247 247 Under Unix, the return value is the sum of user+system time consumed by
248 248 the process, computed via the resource module. This prevents problems
249 249 related to the wraparound effect which the time.clock() function has.
250 250
251 251 Under Windows the return value is in wall clock seconds. See the
252 252 documentation for the time module for more details."""
253 253
254 254 reps = int(reps)
255 255 assert reps >=1, 'reps must be >= 1'
256 256 if reps==1:
257 257 start = clock()
258 258 out = func(*args,**kw)
259 259 tot_time = clock()-start
260 260 else:
261 261 rng = xrange(reps-1) # the last time is executed separately to store output
262 262 start = clock()
263 263 for dummy in rng: func(*args,**kw)
264 264 out = func(*args,**kw) # one last time
265 265 tot_time = clock()-start
266 266 av_time = tot_time / reps
267 267 return tot_time,av_time,out
268 268
269 269 def timings(reps,func,*args,**kw):
270 270 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
271 271
272 272 Execute a function reps times, return a tuple with the elapsed total CPU
273 273 time in seconds and the time per call. These are just the first two values
274 274 in timings_out()."""
275 275
276 276 return timings_out(reps,func,*args,**kw)[0:2]
277 277
278 278 def timing(func,*args,**kw):
279 279 """timing(func,*args,**kw) -> t_total
280 280
281 281 Execute a function once, return the elapsed total CPU time in
282 282 seconds. This is just the first value in timings_out()."""
283 283
284 284 return timings_out(1,func,*args,**kw)[0]
285 285
286 286 #****************************************************************************
287 287 # file and system
288 288
289 289 def arg_split(s,posix=False):
290 290 """Split a command line's arguments in a shell-like manner.
291 291
292 292 This is a modified version of the standard library's shlex.split()
293 293 function, but with a default of posix=False for splitting, so that quotes
294 294 in inputs are respected."""
295 295
296 296 # XXX - there may be unicode-related problems here!!! I'm not sure that
297 297 # shlex is truly unicode-safe, so it might be necessary to do
298 298 #
299 299 # s = s.encode(sys.stdin.encoding)
300 300 #
301 301 # first, to ensure that shlex gets a normal string. Input from anyone who
302 302 # knows more about unicode and shlex than I would be good to have here...
303 303 lex = shlex.shlex(s, posix=posix)
304 304 lex.whitespace_split = True
305 305 return list(lex)
306 306
307 307 def system(cmd,verbose=0,debug=0,header=''):
308 308 """Execute a system command, return its exit status.
309 309
310 310 Options:
311 311
312 312 - verbose (0): print the command to be executed.
313 313
314 314 - debug (0): only print, do not actually execute.
315 315
316 316 - header (''): Header to print on screen prior to the executed command (it
317 317 is only prepended to the command, no newlines are added).
318 318
319 319 Note: a stateful version of this function is available through the
320 320 SystemExec class."""
321 321
322 322 stat = 0
323 323 if verbose or debug: print header+cmd
324 324 sys.stdout.flush()
325 325 if not debug: stat = os.system(cmd)
326 326 return stat
327 327
328 328 def abbrev_cwd():
329 329 """ Return abbreviated version of cwd, e.g. d:mydir """
330 330 cwd = os.getcwd().replace('\\','/')
331 331 drivepart = ''
332 332 tail = cwd
333 333 if sys.platform == 'win32':
334 334 if len(cwd) < 4:
335 335 return cwd
336 336 drivepart,tail = os.path.splitdrive(cwd)
337 337
338 338
339 339 parts = tail.split('/')
340 340 if len(parts) > 2:
341 341 tail = '/'.join(parts[-2:])
342 342
343 343 return (drivepart + (
344 344 cwd == '/' and '/' or tail))
345 345
346 346
347 347 # This function is used by ipython in a lot of places to make system calls.
348 348 # We need it to be slightly different under win32, due to the vagaries of
349 349 # 'network shares'. A win32 override is below.
350 350
351 351 def shell(cmd,verbose=0,debug=0,header=''):
352 352 """Execute a command in the system shell, always return None.
353 353
354 354 Options:
355 355
356 356 - verbose (0): print the command to be executed.
357 357
358 358 - debug (0): only print, do not actually execute.
359 359
360 360 - header (''): Header to print on screen prior to the executed command (it
361 361 is only prepended to the command, no newlines are added).
362 362
363 363 Note: this is similar to genutils.system(), but it returns None so it can
364 364 be conveniently used in interactive loops without getting the return value
365 365 (typically 0) printed many times."""
366 366
367 367 stat = 0
368 368 if verbose or debug: print header+cmd
369 369 # flush stdout so we don't mangle python's buffering
370 370 sys.stdout.flush()
371 371
372 372 if not debug:
373 373 platutils.set_term_title("IPy " + cmd)
374 374 os.system(cmd)
375 375 platutils.set_term_title("IPy " + abbrev_cwd())
376 376
377 377 # override shell() for win32 to deal with network shares
378 378 if os.name in ('nt','dos'):
379 379
380 380 shell_ori = shell
381 381
382 382 def shell(cmd,verbose=0,debug=0,header=''):
383 383 if os.getcwd().startswith(r"\\"):
384 384 path = os.getcwd()
385 385 # change to c drive (cannot be on UNC-share when issuing os.system,
386 386 # as cmd.exe cannot handle UNC addresses)
387 387 os.chdir("c:")
388 388 # issue pushd to the UNC-share and then run the command
389 389 try:
390 390 shell_ori('"pushd %s&&"'%path+cmd,verbose,debug,header)
391 391 finally:
392 392 os.chdir(path)
393 393 else:
394 394 shell_ori(cmd,verbose,debug,header)
395 395
396 396 shell.__doc__ = shell_ori.__doc__
397 397
398 398 def getoutput(cmd,verbose=0,debug=0,header='',split=0):
399 399 """Dummy substitute for perl's backquotes.
400 400
401 401 Executes a command and returns the output.
402 402
403 403 Accepts the same arguments as system(), plus:
404 404
405 405 - split(0): if true, the output is returned as a list split on newlines.
406 406
407 407 Note: a stateful version of this function is available through the
408 408 SystemExec class.
409 409
410 410 This is pretty much deprecated and rarely used,
411 411 genutils.getoutputerror may be what you need.
412 412
413 413 """
414 414
415 415 if verbose or debug: print header+cmd
416 416 if not debug:
417 417 output = os.popen(cmd).read()
418 418 # stipping last \n is here for backwards compat.
419 419 if output.endswith('\n'):
420 420 output = output[:-1]
421 421 if split:
422 422 return output.split('\n')
423 423 else:
424 424 return output
425 425
426 426 def getoutputerror(cmd,verbose=0,debug=0,header='',split=0):
427 427 """Return (standard output,standard error) of executing cmd in a shell.
428 428
429 429 Accepts the same arguments as system(), plus:
430 430
431 431 - split(0): if true, each of stdout/err is returned as a list split on
432 432 newlines.
433 433
434 434 Note: a stateful version of this function is available through the
435 435 SystemExec class."""
436 436
437 437 if verbose or debug: print header+cmd
438 438 if not cmd:
439 439 if split:
440 440 return [],[]
441 441 else:
442 442 return '',''
443 443 if not debug:
444 444 pin,pout,perr = os.popen3(cmd)
445 445 tout = pout.read().rstrip()
446 446 terr = perr.read().rstrip()
447 447 pin.close()
448 448 pout.close()
449 449 perr.close()
450 450 if split:
451 451 return tout.split('\n'),terr.split('\n')
452 452 else:
453 453 return tout,terr
454 454
455 455 # for compatibility with older naming conventions
456 456 xsys = system
457 457 bq = getoutput
458 458
459 459 class SystemExec:
460 460 """Access the system and getoutput functions through a stateful interface.
461 461
462 462 Note: here we refer to the system and getoutput functions from this
463 463 library, not the ones from the standard python library.
464 464
465 465 This class offers the system and getoutput functions as methods, but the
466 466 verbose, debug and header parameters can be set for the instance (at
467 467 creation time or later) so that they don't need to be specified on each
468 468 call.
469 469
470 470 For efficiency reasons, there's no way to override the parameters on a
471 471 per-call basis other than by setting instance attributes. If you need
472 472 local overrides, it's best to directly call system() or getoutput().
473 473
474 474 The following names are provided as alternate options:
475 475 - xsys: alias to system
476 476 - bq: alias to getoutput
477 477
478 478 An instance can then be created as:
479 479 >>> sysexec = SystemExec(verbose=1,debug=0,header='Calling: ')
480 480 """
481 481
482 482 def __init__(self,verbose=0,debug=0,header='',split=0):
483 483 """Specify the instance's values for verbose, debug and header."""
484 484 setattr_list(self,'verbose debug header split')
485 485
486 486 def system(self,cmd):
487 487 """Stateful interface to system(), with the same keyword parameters."""
488 488
489 489 system(cmd,self.verbose,self.debug,self.header)
490 490
491 491 def shell(self,cmd):
492 492 """Stateful interface to shell(), with the same keyword parameters."""
493 493
494 494 shell(cmd,self.verbose,self.debug,self.header)
495 495
496 496 xsys = system # alias
497 497
498 498 def getoutput(self,cmd):
499 499 """Stateful interface to getoutput()."""
500 500
501 501 return getoutput(cmd,self.verbose,self.debug,self.header,self.split)
502 502
503 503 def getoutputerror(self,cmd):
504 504 """Stateful interface to getoutputerror()."""
505 505
506 506 return getoutputerror(cmd,self.verbose,self.debug,self.header,self.split)
507 507
508 508 bq = getoutput # alias
509 509
510 510 #-----------------------------------------------------------------------------
511 511 def mutex_opts(dict,ex_op):
512 512 """Check for presence of mutually exclusive keys in a dict.
513 513
514 514 Call: mutex_opts(dict,[[op1a,op1b],[op2a,op2b]...]"""
515 515 for op1,op2 in ex_op:
516 516 if op1 in dict and op2 in dict:
517 517 raise ValueError,'\n*** ERROR in Arguments *** '\
518 518 'Options '+op1+' and '+op2+' are mutually exclusive.'
519 519
520 520 #-----------------------------------------------------------------------------
521 521 def get_py_filename(name):
522 522 """Return a valid python filename in the current directory.
523 523
524 524 If the given name is not a file, it adds '.py' and searches again.
525 525 Raises IOError with an informative message if the file isn't found."""
526 526
527 527 name = os.path.expanduser(name)
528 528 if not os.path.isfile(name) and not name.endswith('.py'):
529 529 name += '.py'
530 530 if os.path.isfile(name):
531 531 return name
532 532 else:
533 533 raise IOError,'File `%s` not found.' % name
534 534
535 535 #-----------------------------------------------------------------------------
536 536 def filefind(fname,alt_dirs = None):
537 537 """Return the given filename either in the current directory, if it
538 538 exists, or in a specified list of directories.
539 539
540 540 ~ expansion is done on all file and directory names.
541 541
542 542 Upon an unsuccessful search, raise an IOError exception."""
543 543
544 544 if alt_dirs is None:
545 545 try:
546 546 alt_dirs = get_home_dir()
547 547 except HomeDirError:
548 548 alt_dirs = os.getcwd()
549 549 search = [fname] + list_strings(alt_dirs)
550 550 search = map(os.path.expanduser,search)
551 551 #print 'search list for',fname,'list:',search # dbg
552 552 fname = search[0]
553 553 if os.path.isfile(fname):
554 554 return fname
555 555 for direc in search[1:]:
556 556 testname = os.path.join(direc,fname)
557 557 #print 'testname',testname # dbg
558 558 if os.path.isfile(testname):
559 559 return testname
560 560 raise IOError,'File' + `fname` + \
561 561 ' not found in current or supplied directories:' + `alt_dirs`
562 562
563 563 #----------------------------------------------------------------------------
564 564 def file_read(filename):
565 565 """Read a file and close it. Returns the file source."""
566 566 fobj = open(filename,'r');
567 567 source = fobj.read();
568 568 fobj.close()
569 569 return source
570 570
571 571 def file_readlines(filename):
572 572 """Read a file and close it. Returns the file source using readlines()."""
573 573 fobj = open(filename,'r');
574 574 lines = fobj.readlines();
575 575 fobj.close()
576 576 return lines
577 577
578 578 #----------------------------------------------------------------------------
579 579 def target_outdated(target,deps):
580 580 """Determine whether a target is out of date.
581 581
582 582 target_outdated(target,deps) -> 1/0
583 583
584 584 deps: list of filenames which MUST exist.
585 585 target: single filename which may or may not exist.
586 586
587 587 If target doesn't exist or is older than any file listed in deps, return
588 588 true, otherwise return false.
589 589 """
590 590 try:
591 591 target_time = os.path.getmtime(target)
592 592 except os.error:
593 593 return 1
594 594 for dep in deps:
595 595 dep_time = os.path.getmtime(dep)
596 596 if dep_time > target_time:
597 597 #print "For target",target,"Dep failed:",dep # dbg
598 598 #print "times (dep,tar):",dep_time,target_time # dbg
599 599 return 1
600 600 return 0
601 601
602 602 #-----------------------------------------------------------------------------
603 603 def target_update(target,deps,cmd):
604 604 """Update a target with a given command given a list of dependencies.
605 605
606 606 target_update(target,deps,cmd) -> runs cmd if target is outdated.
607 607
608 608 This is just a wrapper around target_outdated() which calls the given
609 609 command if target is outdated."""
610 610
611 611 if target_outdated(target,deps):
612 612 xsys(cmd)
613 613
614 614 #----------------------------------------------------------------------------
615 615 def unquote_ends(istr):
616 616 """Remove a single pair of quotes from the endpoints of a string."""
617 617
618 618 if not istr:
619 619 return istr
620 620 if (istr[0]=="'" and istr[-1]=="'") or \
621 621 (istr[0]=='"' and istr[-1]=='"'):
622 622 return istr[1:-1]
623 623 else:
624 624 return istr
625 625
626 626 #----------------------------------------------------------------------------
627 627 def process_cmdline(argv,names=[],defaults={},usage=''):
628 628 """ Process command-line options and arguments.
629 629
630 630 Arguments:
631 631
632 632 - argv: list of arguments, typically sys.argv.
633 633
634 634 - names: list of option names. See DPyGetOpt docs for details on options
635 635 syntax.
636 636
637 637 - defaults: dict of default values.
638 638
639 639 - usage: optional usage notice to print if a wrong argument is passed.
640 640
641 641 Return a dict of options and a list of free arguments."""
642 642
643 643 getopt = DPyGetOpt.DPyGetOpt()
644 644 getopt.setIgnoreCase(0)
645 645 getopt.parseConfiguration(names)
646 646
647 647 try:
648 648 getopt.processArguments(argv)
649 649 except DPyGetOpt.ArgumentError, exc:
650 650 print usage
651 651 warn('"%s"' % exc,level=4)
652 652
653 653 defaults.update(getopt.optionValues)
654 654 args = getopt.freeValues
655 655
656 656 return defaults,args
657 657
658 658 #----------------------------------------------------------------------------
659 659 def optstr2types(ostr):
660 660 """Convert a string of option names to a dict of type mappings.
661 661
662 662 optstr2types(str) -> {None:'string_opts',int:'int_opts',float:'float_opts'}
663 663
664 664 This is used to get the types of all the options in a string formatted
665 665 with the conventions of DPyGetOpt. The 'type' None is used for options
666 666 which are strings (they need no further conversion). This function's main
667 667 use is to get a typemap for use with read_dict().
668 668 """
669 669
670 670 typeconv = {None:'',int:'',float:''}
671 671 typemap = {'s':None,'i':int,'f':float}
672 672 opt_re = re.compile(r'([\w]*)([^:=]*:?=?)([sif]?)')
673 673
674 674 for w in ostr.split():
675 675 oname,alias,otype = opt_re.match(w).groups()
676 676 if otype == '' or alias == '!': # simple switches are integers too
677 677 otype = 'i'
678 678 typeconv[typemap[otype]] += oname + ' '
679 679 return typeconv
680 680
681 681 #----------------------------------------------------------------------------
682 682 def read_dict(filename,type_conv=None,**opt):
683 683 r"""Read a dictionary of key=value pairs from an input file, optionally
684 684 performing conversions on the resulting values.
685 685
686 686 read_dict(filename,type_conv,**opt) -> dict
687 687
688 688 Only one value per line is accepted, the format should be
689 689 # optional comments are ignored
690 690 key value\n
691 691
692 692 Args:
693 693
694 694 - type_conv: A dictionary specifying which keys need to be converted to
695 695 which types. By default all keys are read as strings. This dictionary
696 696 should have as its keys valid conversion functions for strings
697 697 (int,long,float,complex, or your own). The value for each key
698 698 (converter) should be a whitespace separated string containing the names
699 699 of all the entries in the file to be converted using that function. For
700 700 keys to be left alone, use None as the conversion function (only needed
701 701 with purge=1, see below).
702 702
703 703 - opt: dictionary with extra options as below (default in parens)
704 704
705 705 purge(0): if set to 1, all keys *not* listed in type_conv are purged out
706 706 of the dictionary to be returned. If purge is going to be used, the
707 707 set of keys to be left as strings also has to be explicitly specified
708 708 using the (non-existent) conversion function None.
709 709
710 710 fs(None): field separator. This is the key/value separator to be used
711 711 when parsing the file. The None default means any whitespace [behavior
712 712 of string.split()].
713 713
714 714 strip(0): if 1, strip string values of leading/trailinig whitespace.
715 715
716 716 warn(1): warning level if requested keys are not found in file.
717 717 - 0: silently ignore.
718 718 - 1: inform but proceed.
719 719 - 2: raise KeyError exception.
720 720
721 721 no_empty(0): if 1, remove keys with whitespace strings as a value.
722 722
723 723 unique([]): list of keys (or space separated string) which can't be
724 724 repeated. If one such key is found in the file, each new instance
725 725 overwrites the previous one. For keys not listed here, the behavior is
726 726 to make a list of all appearances.
727 727
728 728 Example:
729 729
730 730 If the input file test.ini contains (we put it in a string to keep the test
731 731 self-contained):
732 732
733 733 >>> test_ini = '''\
734 734 ... i 3
735 735 ... x 4.5
736 736 ... y 5.5
737 737 ... s hi ho'''
738 738
739 739 Then we can use it as follows:
740 740 >>> type_conv={int:'i',float:'x',None:'s'}
741 741
742 742 >>> d = read_dict(test_ini)
743 743
744 744 >>> sorted(d.items())
745 745 [('i', '3'), ('s', 'hi ho'), ('x', '4.5'), ('y', '5.5')]
746 746
747 747 >>> d = read_dict(test_ini,type_conv)
748 748
749 749 >>> sorted(d.items())
750 750 [('i', 3), ('s', 'hi ho'), ('x', 4.5), ('y', '5.5')]
751 751
752 752 >>> d = read_dict(test_ini,type_conv,purge=True)
753 753
754 754 >>> sorted(d.items())
755 755 [('i', 3), ('s', 'hi ho'), ('x', 4.5)]
756 756 """
757 757
758 758 # starting config
759 759 opt.setdefault('purge',0)
760 760 opt.setdefault('fs',None) # field sep defaults to any whitespace
761 761 opt.setdefault('strip',0)
762 762 opt.setdefault('warn',1)
763 763 opt.setdefault('no_empty',0)
764 764 opt.setdefault('unique','')
765 765 if type(opt['unique']) in StringTypes:
766 766 unique_keys = qw(opt['unique'])
767 767 elif type(opt['unique']) in (types.TupleType,types.ListType):
768 768 unique_keys = opt['unique']
769 769 else:
770 770 raise ValueError, 'Unique keys must be given as a string, List or Tuple'
771 771
772 772 dict = {}
773 773
774 774 # first read in table of values as strings
775 775 if '\n' in filename:
776 776 lines = filename.splitlines()
777 777 file = None
778 778 else:
779 779 file = open(filename,'r')
780 780 lines = file.readlines()
781 781 for line in lines:
782 782 line = line.strip()
783 783 if len(line) and line[0]=='#': continue
784 784 if len(line)>0:
785 785 lsplit = line.split(opt['fs'],1)
786 786 try:
787 787 key,val = lsplit
788 788 except ValueError:
789 789 key,val = lsplit[0],''
790 790 key = key.strip()
791 791 if opt['strip']: val = val.strip()
792 792 if val == "''" or val == '""': val = ''
793 793 if opt['no_empty'] and (val=='' or val.isspace()):
794 794 continue
795 795 # if a key is found more than once in the file, build a list
796 796 # unless it's in the 'unique' list. In that case, last found in file
797 797 # takes precedence. User beware.
798 798 try:
799 799 if dict[key] and key in unique_keys:
800 800 dict[key] = val
801 801 elif type(dict[key]) is types.ListType:
802 802 dict[key].append(val)
803 803 else:
804 804 dict[key] = [dict[key],val]
805 805 except KeyError:
806 806 dict[key] = val
807 807 # purge if requested
808 808 if opt['purge']:
809 809 accepted_keys = qwflat(type_conv.values())
810 810 for key in dict.keys():
811 811 if key in accepted_keys: continue
812 812 del(dict[key])
813 813 # now convert if requested
814 814 if type_conv==None: return dict
815 815 conversions = type_conv.keys()
816 816 try: conversions.remove(None)
817 817 except: pass
818 818 for convert in conversions:
819 819 for val in qw(type_conv[convert]):
820 820 try:
821 821 dict[val] = convert(dict[val])
822 822 except KeyError,e:
823 823 if opt['warn'] == 0:
824 824 pass
825 825 elif opt['warn'] == 1:
826 826 print >>sys.stderr, 'Warning: key',val,\
827 827 'not found in file',filename
828 828 elif opt['warn'] == 2:
829 829 raise KeyError,e
830 830 else:
831 831 raise ValueError,'Warning level must be 0,1 or 2'
832 832
833 833 return dict
834 834
835 835 #----------------------------------------------------------------------------
836 836 def flag_calls(func):
837 837 """Wrap a function to detect and flag when it gets called.
838 838
839 839 This is a decorator which takes a function and wraps it in a function with
840 840 a 'called' attribute. wrapper.called is initialized to False.
841 841
842 842 The wrapper.called attribute is set to False right before each call to the
843 843 wrapped function, so if the call fails it remains False. After the call
844 844 completes, wrapper.called is set to True and the output is returned.
845 845
846 846 Testing for truth in wrapper.called allows you to determine if a call to
847 847 func() was attempted and succeeded."""
848 848
849 849 def wrapper(*args,**kw):
850 850 wrapper.called = False
851 851 out = func(*args,**kw)
852 852 wrapper.called = True
853 853 return out
854 854
855 855 wrapper.called = False
856 856 wrapper.__doc__ = func.__doc__
857 857 return wrapper
858 858
859 859 #----------------------------------------------------------------------------
860 860 def dhook_wrap(func,*a,**k):
861 861 """Wrap a function call in a sys.displayhook controller.
862 862
863 863 Returns a wrapper around func which calls func, with all its arguments and
864 864 keywords unmodified, using the default sys.displayhook. Since IPython
865 865 modifies sys.displayhook, it breaks the behavior of certain systems that
866 866 rely on the default behavior, notably doctest.
867 867 """
868 868
869 869 def f(*a,**k):
870 870
871 871 dhook_s = sys.displayhook
872 872 sys.displayhook = sys.__displayhook__
873 873 try:
874 874 out = func(*a,**k)
875 875 finally:
876 876 sys.displayhook = dhook_s
877 877
878 878 return out
879 879
880 880 f.__doc__ = func.__doc__
881 881 return f
882 882
883 883 #----------------------------------------------------------------------------
884 884 def doctest_reload():
885 885 """Properly reload doctest to reuse it interactively.
886 886
887 887 This routine:
888 888
889 889 - reloads doctest
890 890
891 891 - resets its global 'master' attribute to None, so that multiple uses of
892 892 the module interactively don't produce cumulative reports.
893 893
894 894 - Monkeypatches its core test runner method to protect it from IPython's
895 895 modified displayhook. Doctest expects the default displayhook behavior
896 896 deep down, so our modification breaks it completely. For this reason, a
897 897 hard monkeypatch seems like a reasonable solution rather than asking
898 898 users to manually use a different doctest runner when under IPython."""
899 899
900 900 import doctest
901 901 reload(doctest)
902 902 doctest.master=None
903 903
904 904 try:
905 905 doctest.DocTestRunner
906 906 except AttributeError:
907 907 # This is only for python 2.3 compatibility, remove once we move to
908 908 # 2.4 only.
909 909 pass
910 910 else:
911 911 doctest.DocTestRunner.run = dhook_wrap(doctest.DocTestRunner.run)
912 912
913 913 #----------------------------------------------------------------------------
914 914 class HomeDirError(Error):
915 915 pass
916 916
917 917 def get_home_dir():
918 918 """Return the closest possible equivalent to a 'home' directory.
919 919
920 920 We first try $HOME. Absent that, on NT it's $HOMEDRIVE\$HOMEPATH.
921 921
922 922 Currently only Posix and NT are implemented, a HomeDirError exception is
923 923 raised for all other OSes. """
924 924
925 925 isdir = os.path.isdir
926 926 env = os.environ
927 927
928 928 # first, check py2exe distribution root directory for _ipython.
929 929 # This overrides all. Normally does not exist.
930 930
931 931 if hasattr(sys, "frozen"): #Is frozen by py2exe
932 932 if '\\library.zip\\' in IPython.__file__.lower():#libraries compressed to zip-file
933 933 root, rest = IPython.__file__.lower().split('library.zip')
934 if isdir(root + '_ipython'):
935 os.environ["IPYKITROOT"] = root.rstrip('\\')
936 return root
937 934 else:
938 root=os.path.abspath(os.path.join(os.path.split(IPython.__file__)[0],"../../"))
935 root=os.path.join(os.path.split(IPython.__file__)[0],"../../")
936 root=os.path.abspath(root).rstrip('\\')
939 937 if isdir(os.path.join(root, '_ipython')):
940 os.environ["IPYKITROOT"] = root.rstrip('\\')
938 os.environ["IPYKITROOT"] = root
941 939 return root
942 940 try:
943 941 homedir = env['HOME']
944 942 if not isdir(homedir):
945 943 # in case a user stuck some string which does NOT resolve to a
946 944 # valid path, it's as good as if we hadn't foud it
947 945 raise KeyError
948 946 return homedir
949 947 except KeyError:
950 948 if os.name == 'posix':
951 949 raise HomeDirError,'undefined $HOME, IPython can not proceed.'
952 950 elif os.name == 'nt':
953 951 # For some strange reason, win9x returns 'nt' for os.name.
954 952 try:
955 953 homedir = os.path.join(env['HOMEDRIVE'],env['HOMEPATH'])
956 954 if not isdir(homedir):
957 955 homedir = os.path.join(env['USERPROFILE'])
958 956 if not isdir(homedir):
959 957 raise HomeDirError
960 958 return homedir
961 except:
959 except KeyError:
962 960 try:
963 961 # Use the registry to get the 'My Documents' folder.
964 962 import _winreg as wreg
965 963 key = wreg.OpenKey(wreg.HKEY_CURRENT_USER,
966 964 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders")
967 965 homedir = wreg.QueryValueEx(key,'Personal')[0]
968 966 key.Close()
969 967 if not isdir(homedir):
970 968 e = ('Invalid "Personal" folder registry key '
971 969 'typically "My Documents".\n'
972 970 'Value: %s\n'
973 971 'This is not a valid directory on your system.' %
974 972 homedir)
975 973 raise HomeDirError(e)
976 974 return homedir
977 975 except HomeDirError:
978 976 raise
979 977 except:
980 978 return 'C:\\'
981 979 elif os.name == 'dos':
982 980 # Desperate, may do absurd things in classic MacOS. May work under DOS.
983 981 return 'C:\\'
984 982 else:
985 983 raise HomeDirError,'support for your operating system not implemented.'
986 984
987 985
988 986 def get_ipython_dir():
989 987 """Get the IPython directory for this platform and user.
990 988
991 989 This uses the logic in `get_home_dir` to find the home directory
992 990 and the adds either .ipython or _ipython to the end of the path.
993 991 """
994 992 if os.name == 'posix':
995 993 ipdir_def = '.ipython'
996 994 else:
997 995 ipdir_def = '_ipython'
998 996 home_dir = get_home_dir()
999 997 ipdir = os.path.abspath(os.environ.get('IPYTHONDIR',
1000 998 os.path.join(home_dir,ipdir_def)))
1001 999 return ipdir
1002 1000
1003 1001 def get_security_dir():
1004 1002 """Get the IPython security directory.
1005 1003
1006 1004 This directory is the default location for all security related files,
1007 1005 including SSL/TLS certificates and FURL files.
1008 1006
1009 1007 If the directory does not exist, it is created with 0700 permissions.
1010 1008 If it exists, permissions are set to 0700.
1011 1009 """
1012 1010 security_dir = os.path.join(get_ipython_dir(), 'security')
1013 1011 if not os.path.isdir(security_dir):
1014 1012 os.mkdir(security_dir, 0700)
1015 1013 else:
1016 1014 os.chmod(security_dir, 0700)
1017 1015 return security_dir
1018 1016
1019 1017 #****************************************************************************
1020 1018 # strings and text
1021 1019
1022 1020 class LSString(str):
1023 1021 """String derivative with a special access attributes.
1024 1022
1025 1023 These are normal strings, but with the special attributes:
1026 1024
1027 1025 .l (or .list) : value as list (split on newlines).
1028 1026 .n (or .nlstr): original value (the string itself).
1029 1027 .s (or .spstr): value as whitespace-separated string.
1030 1028 .p (or .paths): list of path objects
1031 1029
1032 1030 Any values which require transformations are computed only once and
1033 1031 cached.
1034 1032
1035 1033 Such strings are very useful to efficiently interact with the shell, which
1036 1034 typically only understands whitespace-separated options for commands."""
1037 1035
1038 1036 def get_list(self):
1039 1037 try:
1040 1038 return self.__list
1041 1039 except AttributeError:
1042 1040 self.__list = self.split('\n')
1043 1041 return self.__list
1044 1042
1045 1043 l = list = property(get_list)
1046 1044
1047 1045 def get_spstr(self):
1048 1046 try:
1049 1047 return self.__spstr
1050 1048 except AttributeError:
1051 1049 self.__spstr = self.replace('\n',' ')
1052 1050 return self.__spstr
1053 1051
1054 1052 s = spstr = property(get_spstr)
1055 1053
1056 1054 def get_nlstr(self):
1057 1055 return self
1058 1056
1059 1057 n = nlstr = property(get_nlstr)
1060 1058
1061 1059 def get_paths(self):
1062 1060 try:
1063 1061 return self.__paths
1064 1062 except AttributeError:
1065 1063 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
1066 1064 return self.__paths
1067 1065
1068 1066 p = paths = property(get_paths)
1069 1067
1070 1068 def print_lsstring(arg):
1071 1069 """ Prettier (non-repr-like) and more informative printer for LSString """
1072 1070 print "LSString (.p, .n, .l, .s available). Value:"
1073 1071 print arg
1074 1072
1075 1073 print_lsstring = result_display.when_type(LSString)(print_lsstring)
1076 1074
1077 1075 #----------------------------------------------------------------------------
1078 1076 class SList(list):
1079 1077 """List derivative with a special access attributes.
1080 1078
1081 1079 These are normal lists, but with the special attributes:
1082 1080
1083 1081 .l (or .list) : value as list (the list itself).
1084 1082 .n (or .nlstr): value as a string, joined on newlines.
1085 1083 .s (or .spstr): value as a string, joined on spaces.
1086 1084 .p (or .paths): list of path objects
1087 1085
1088 1086 Any values which require transformations are computed only once and
1089 1087 cached."""
1090 1088
1091 1089 def get_list(self):
1092 1090 return self
1093 1091
1094 1092 l = list = property(get_list)
1095 1093
1096 1094 def get_spstr(self):
1097 1095 try:
1098 1096 return self.__spstr
1099 1097 except AttributeError:
1100 1098 self.__spstr = ' '.join(self)
1101 1099 return self.__spstr
1102 1100
1103 1101 s = spstr = property(get_spstr)
1104 1102
1105 1103 def get_nlstr(self):
1106 1104 try:
1107 1105 return self.__nlstr
1108 1106 except AttributeError:
1109 1107 self.__nlstr = '\n'.join(self)
1110 1108 return self.__nlstr
1111 1109
1112 1110 n = nlstr = property(get_nlstr)
1113 1111
1114 1112 def get_paths(self):
1115 1113 try:
1116 1114 return self.__paths
1117 1115 except AttributeError:
1118 1116 self.__paths = [path(p) for p in self if os.path.exists(p)]
1119 1117 return self.__paths
1120 1118
1121 1119 p = paths = property(get_paths)
1122 1120
1123 1121 def grep(self, pattern, prune = False, field = None):
1124 1122 """ Return all strings matching 'pattern' (a regex or callable)
1125 1123
1126 1124 This is case-insensitive. If prune is true, return all items
1127 1125 NOT matching the pattern.
1128 1126
1129 1127 If field is specified, the match must occur in the specified
1130 1128 whitespace-separated field.
1131 1129
1132 1130 Examples::
1133 1131
1134 1132 a.grep( lambda x: x.startswith('C') )
1135 1133 a.grep('Cha.*log', prune=1)
1136 1134 a.grep('chm', field=-1)
1137 1135 """
1138 1136
1139 1137 def match_target(s):
1140 1138 if field is None:
1141 1139 return s
1142 1140 parts = s.split()
1143 1141 try:
1144 1142 tgt = parts[field]
1145 1143 return tgt
1146 1144 except IndexError:
1147 1145 return ""
1148 1146
1149 1147 if isinstance(pattern, basestring):
1150 1148 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
1151 1149 else:
1152 1150 pred = pattern
1153 1151 if not prune:
1154 1152 return SList([el for el in self if pred(match_target(el))])
1155 1153 else:
1156 1154 return SList([el for el in self if not pred(match_target(el))])
1157 1155 def fields(self, *fields):
1158 1156 """ Collect whitespace-separated fields from string list
1159 1157
1160 1158 Allows quick awk-like usage of string lists.
1161 1159
1162 1160 Example data (in var a, created by 'a = !ls -l')::
1163 1161 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
1164 1162 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
1165 1163
1166 1164 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
1167 1165 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
1168 1166 (note the joining by space).
1169 1167 a.fields(-1) is ['ChangeLog', 'IPython']
1170 1168
1171 1169 IndexErrors are ignored.
1172 1170
1173 1171 Without args, fields() just split()'s the strings.
1174 1172 """
1175 1173 if len(fields) == 0:
1176 1174 return [el.split() for el in self]
1177 1175
1178 1176 res = SList()
1179 1177 for el in [f.split() for f in self]:
1180 1178 lineparts = []
1181 1179
1182 1180 for fd in fields:
1183 1181 try:
1184 1182 lineparts.append(el[fd])
1185 1183 except IndexError:
1186 1184 pass
1187 1185 if lineparts:
1188 1186 res.append(" ".join(lineparts))
1189 1187
1190 1188 return res
1191 1189 def sort(self,field= None, nums = False):
1192 1190 """ sort by specified fields (see fields())
1193 1191
1194 1192 Example::
1195 1193 a.sort(1, nums = True)
1196 1194
1197 1195 Sorts a by second field, in numerical order (so that 21 > 3)
1198 1196
1199 1197 """
1200 1198
1201 1199 #decorate, sort, undecorate
1202 1200 if field is not None:
1203 1201 dsu = [[SList([line]).fields(field), line] for line in self]
1204 1202 else:
1205 1203 dsu = [[line, line] for line in self]
1206 1204 if nums:
1207 1205 for i in range(len(dsu)):
1208 1206 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
1209 1207 try:
1210 1208 n = int(numstr)
1211 1209 except ValueError:
1212 1210 n = 0;
1213 1211 dsu[i][0] = n
1214 1212
1215 1213
1216 1214 dsu.sort()
1217 1215 return SList([t[1] for t in dsu])
1218 1216
1219 1217 def print_slist(arg):
1220 1218 """ Prettier (non-repr-like) and more informative printer for SList """
1221 1219 print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
1222 1220 if hasattr(arg, 'hideonce') and arg.hideonce:
1223 1221 arg.hideonce = False
1224 1222 return
1225 1223
1226 1224 nlprint(arg)
1227 1225
1228 1226 print_slist = result_display.when_type(SList)(print_slist)
1229 1227
1230 1228
1231 1229
1232 1230 #----------------------------------------------------------------------------
1233 1231 def esc_quotes(strng):
1234 1232 """Return the input string with single and double quotes escaped out"""
1235 1233
1236 1234 return strng.replace('"','\\"').replace("'","\\'")
1237 1235
1238 1236 #----------------------------------------------------------------------------
1239 1237 def make_quoted_expr(s):
1240 1238 """Return string s in appropriate quotes, using raw string if possible.
1241 1239
1242 1240 Effectively this turns string: cd \ao\ao\
1243 1241 to: r"cd \ao\ao\_"[:-1]
1244 1242
1245 1243 Note the use of raw string and padding at the end to allow trailing backslash.
1246 1244
1247 1245 """
1248 1246
1249 1247 tail = ''
1250 1248 tailpadding = ''
1251 1249 raw = ''
1252 1250 if "\\" in s:
1253 1251 raw = 'r'
1254 1252 if s.endswith('\\'):
1255 1253 tail = '[:-1]'
1256 1254 tailpadding = '_'
1257 1255 if '"' not in s:
1258 1256 quote = '"'
1259 1257 elif "'" not in s:
1260 1258 quote = "'"
1261 1259 elif '"""' not in s and not s.endswith('"'):
1262 1260 quote = '"""'
1263 1261 elif "'''" not in s and not s.endswith("'"):
1264 1262 quote = "'''"
1265 1263 else:
1266 1264 # give up, backslash-escaped string will do
1267 1265 return '"%s"' % esc_quotes(s)
1268 1266 res = raw + quote + s + tailpadding + quote + tail
1269 1267 return res
1270 1268
1271 1269
1272 1270 #----------------------------------------------------------------------------
1273 1271 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
1274 1272 """Take multiple lines of input.
1275 1273
1276 1274 A list with each line of input as a separate element is returned when a
1277 1275 termination string is entered (defaults to a single '.'). Input can also
1278 1276 terminate via EOF (^D in Unix, ^Z-RET in Windows).
1279 1277
1280 1278 Lines of input which end in \\ are joined into single entries (and a
1281 1279 secondary continuation prompt is issued as long as the user terminates
1282 1280 lines with \\). This allows entering very long strings which are still
1283 1281 meant to be treated as single entities.
1284 1282 """
1285 1283
1286 1284 try:
1287 1285 if header:
1288 1286 header += '\n'
1289 1287 lines = [raw_input(header + ps1)]
1290 1288 except EOFError:
1291 1289 return []
1292 1290 terminate = [terminate_str]
1293 1291 try:
1294 1292 while lines[-1:] != terminate:
1295 1293 new_line = raw_input(ps1)
1296 1294 while new_line.endswith('\\'):
1297 1295 new_line = new_line[:-1] + raw_input(ps2)
1298 1296 lines.append(new_line)
1299 1297
1300 1298 return lines[:-1] # don't return the termination command
1301 1299 except EOFError:
1302 1300 print
1303 1301 return lines
1304 1302
1305 1303 #----------------------------------------------------------------------------
1306 1304 def raw_input_ext(prompt='', ps2='... '):
1307 1305 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
1308 1306
1309 1307 line = raw_input(prompt)
1310 1308 while line.endswith('\\'):
1311 1309 line = line[:-1] + raw_input(ps2)
1312 1310 return line
1313 1311
1314 1312 #----------------------------------------------------------------------------
1315 1313 def ask_yes_no(prompt,default=None):
1316 1314 """Asks a question and returns a boolean (y/n) answer.
1317 1315
1318 1316 If default is given (one of 'y','n'), it is used if the user input is
1319 1317 empty. Otherwise the question is repeated until an answer is given.
1320 1318
1321 1319 An EOF is treated as the default answer. If there is no default, an
1322 1320 exception is raised to prevent infinite loops.
1323 1321
1324 1322 Valid answers are: y/yes/n/no (match is not case sensitive)."""
1325 1323
1326 1324 answers = {'y':True,'n':False,'yes':True,'no':False}
1327 1325 ans = None
1328 1326 while ans not in answers.keys():
1329 1327 try:
1330 1328 ans = raw_input(prompt+' ').lower()
1331 1329 if not ans: # response was an empty string
1332 1330 ans = default
1333 1331 except KeyboardInterrupt:
1334 1332 pass
1335 1333 except EOFError:
1336 1334 if default in answers.keys():
1337 1335 ans = default
1338 1336 print
1339 1337 else:
1340 1338 raise
1341 1339
1342 1340 return answers[ans]
1343 1341
1344 1342 #----------------------------------------------------------------------------
1345 1343 def marquee(txt='',width=78,mark='*'):
1346 1344 """Return the input string centered in a 'marquee'."""
1347 1345 if not txt:
1348 1346 return (mark*width)[:width]
1349 1347 nmark = (width-len(txt)-2)/len(mark)/2
1350 1348 if nmark < 0: nmark =0
1351 1349 marks = mark*nmark
1352 1350 return '%s %s %s' % (marks,txt,marks)
1353 1351
1354 1352 #----------------------------------------------------------------------------
1355 1353 class EvalDict:
1356 1354 """
1357 1355 Emulate a dict which evaluates its contents in the caller's frame.
1358 1356
1359 1357 Usage:
1360 1358 >>> number = 19
1361 1359
1362 1360 >>> text = "python"
1363 1361
1364 1362 >>> print "%(text.capitalize())s %(number/9.0).1f rules!" % EvalDict()
1365 1363 Python 2.1 rules!
1366 1364 """
1367 1365
1368 1366 # This version is due to sismex01@hebmex.com on c.l.py, and is basically a
1369 1367 # modified (shorter) version of:
1370 1368 # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66018 by
1371 1369 # Skip Montanaro (skip@pobox.com).
1372 1370
1373 1371 def __getitem__(self, name):
1374 1372 frame = sys._getframe(1)
1375 1373 return eval(name, frame.f_globals, frame.f_locals)
1376 1374
1377 1375 EvalString = EvalDict # for backwards compatibility
1378 1376 #----------------------------------------------------------------------------
1379 1377 def qw(words,flat=0,sep=None,maxsplit=-1):
1380 1378 """Similar to Perl's qw() operator, but with some more options.
1381 1379
1382 1380 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
1383 1381
1384 1382 words can also be a list itself, and with flat=1, the output will be
1385 1383 recursively flattened.
1386 1384
1387 1385 Examples:
1388 1386
1389 1387 >>> qw('1 2')
1390 1388 ['1', '2']
1391 1389
1392 1390 >>> qw(['a b','1 2',['m n','p q']])
1393 1391 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
1394 1392
1395 1393 >>> qw(['a b','1 2',['m n','p q']],flat=1)
1396 1394 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
1397 1395 """
1398 1396
1399 1397 if type(words) in StringTypes:
1400 1398 return [word.strip() for word in words.split(sep,maxsplit)
1401 1399 if word and not word.isspace() ]
1402 1400 if flat:
1403 1401 return flatten(map(qw,words,[1]*len(words)))
1404 1402 return map(qw,words)
1405 1403
1406 1404 #----------------------------------------------------------------------------
1407 1405 def qwflat(words,sep=None,maxsplit=-1):
1408 1406 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
1409 1407 return qw(words,1,sep,maxsplit)
1410 1408
1411 1409 #----------------------------------------------------------------------------
1412 1410 def qw_lol(indata):
1413 1411 """qw_lol('a b') -> [['a','b']],
1414 1412 otherwise it's just a call to qw().
1415 1413
1416 1414 We need this to make sure the modules_some keys *always* end up as a
1417 1415 list of lists."""
1418 1416
1419 1417 if type(indata) in StringTypes:
1420 1418 return [qw(indata)]
1421 1419 else:
1422 1420 return qw(indata)
1423 1421
1424 1422 #-----------------------------------------------------------------------------
1425 1423 def list_strings(arg):
1426 1424 """Always return a list of strings, given a string or list of strings
1427 1425 as input."""
1428 1426
1429 1427 if type(arg) in StringTypes: return [arg]
1430 1428 else: return arg
1431 1429
1432 1430 #----------------------------------------------------------------------------
1433 1431 def grep(pat,list,case=1):
1434 1432 """Simple minded grep-like function.
1435 1433 grep(pat,list) returns occurrences of pat in list, None on failure.
1436 1434
1437 1435 It only does simple string matching, with no support for regexps. Use the
1438 1436 option case=0 for case-insensitive matching."""
1439 1437
1440 1438 # This is pretty crude. At least it should implement copying only references
1441 1439 # to the original data in case it's big. Now it copies the data for output.
1442 1440 out=[]
1443 1441 if case:
1444 1442 for term in list:
1445 1443 if term.find(pat)>-1: out.append(term)
1446 1444 else:
1447 1445 lpat=pat.lower()
1448 1446 for term in list:
1449 1447 if term.lower().find(lpat)>-1: out.append(term)
1450 1448
1451 1449 if len(out): return out
1452 1450 else: return None
1453 1451
1454 1452 #----------------------------------------------------------------------------
1455 1453 def dgrep(pat,*opts):
1456 1454 """Return grep() on dir()+dir(__builtins__).
1457 1455
1458 1456 A very common use of grep() when working interactively."""
1459 1457
1460 1458 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
1461 1459
1462 1460 #----------------------------------------------------------------------------
1463 1461 def idgrep(pat):
1464 1462 """Case-insensitive dgrep()"""
1465 1463
1466 1464 return dgrep(pat,0)
1467 1465
1468 1466 #----------------------------------------------------------------------------
1469 1467 def igrep(pat,list):
1470 1468 """Synonym for case-insensitive grep."""
1471 1469
1472 1470 return grep(pat,list,case=0)
1473 1471
1474 1472 #----------------------------------------------------------------------------
1475 1473 def indent(str,nspaces=4,ntabs=0):
1476 1474 """Indent a string a given number of spaces or tabstops.
1477 1475
1478 1476 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
1479 1477 """
1480 1478 if str is None:
1481 1479 return
1482 1480 ind = '\t'*ntabs+' '*nspaces
1483 1481 outstr = '%s%s' % (ind,str.replace(os.linesep,os.linesep+ind))
1484 1482 if outstr.endswith(os.linesep+ind):
1485 1483 return outstr[:-len(ind)]
1486 1484 else:
1487 1485 return outstr
1488 1486
1489 1487 #-----------------------------------------------------------------------------
1490 1488 def native_line_ends(filename,backup=1):
1491 1489 """Convert (in-place) a file to line-ends native to the current OS.
1492 1490
1493 1491 If the optional backup argument is given as false, no backup of the
1494 1492 original file is left. """
1495 1493
1496 1494 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
1497 1495
1498 1496 bak_filename = filename + backup_suffixes[os.name]
1499 1497
1500 1498 original = open(filename).read()
1501 1499 shutil.copy2(filename,bak_filename)
1502 1500 try:
1503 1501 new = open(filename,'wb')
1504 1502 new.write(os.linesep.join(original.splitlines()))
1505 1503 new.write(os.linesep) # ALWAYS put an eol at the end of the file
1506 1504 new.close()
1507 1505 except:
1508 1506 os.rename(bak_filename,filename)
1509 1507 if not backup:
1510 1508 try:
1511 1509 os.remove(bak_filename)
1512 1510 except:
1513 1511 pass
1514 1512
1515 1513 #----------------------------------------------------------------------------
1516 1514 def get_pager_cmd(pager_cmd = None):
1517 1515 """Return a pager command.
1518 1516
1519 1517 Makes some attempts at finding an OS-correct one."""
1520 1518
1521 1519 if os.name == 'posix':
1522 1520 default_pager_cmd = 'less -r' # -r for color control sequences
1523 1521 elif os.name in ['nt','dos']:
1524 1522 default_pager_cmd = 'type'
1525 1523
1526 1524 if pager_cmd is None:
1527 1525 try:
1528 1526 pager_cmd = os.environ['PAGER']
1529 1527 except:
1530 1528 pager_cmd = default_pager_cmd
1531 1529 return pager_cmd
1532 1530
1533 1531 #-----------------------------------------------------------------------------
1534 1532 def get_pager_start(pager,start):
1535 1533 """Return the string for paging files with an offset.
1536 1534
1537 1535 This is the '+N' argument which less and more (under Unix) accept.
1538 1536 """
1539 1537
1540 1538 if pager in ['less','more']:
1541 1539 if start:
1542 1540 start_string = '+' + str(start)
1543 1541 else:
1544 1542 start_string = ''
1545 1543 else:
1546 1544 start_string = ''
1547 1545 return start_string
1548 1546
1549 1547 #----------------------------------------------------------------------------
1550 1548 # (X)emacs on W32 doesn't like to be bypassed with msvcrt.getch()
1551 1549 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
1552 1550 import msvcrt
1553 1551 def page_more():
1554 1552 """ Smart pausing between pages
1555 1553
1556 1554 @return: True if need print more lines, False if quit
1557 1555 """
1558 1556 Term.cout.write('---Return to continue, q to quit--- ')
1559 1557 ans = msvcrt.getch()
1560 1558 if ans in ("q", "Q"):
1561 1559 result = False
1562 1560 else:
1563 1561 result = True
1564 1562 Term.cout.write("\b"*37 + " "*37 + "\b"*37)
1565 1563 return result
1566 1564 else:
1567 1565 def page_more():
1568 1566 ans = raw_input('---Return to continue, q to quit--- ')
1569 1567 if ans.lower().startswith('q'):
1570 1568 return False
1571 1569 else:
1572 1570 return True
1573 1571
1574 1572 esc_re = re.compile(r"(\x1b[^m]+m)")
1575 1573
1576 1574 def page_dumb(strng,start=0,screen_lines=25):
1577 1575 """Very dumb 'pager' in Python, for when nothing else works.
1578 1576
1579 1577 Only moves forward, same interface as page(), except for pager_cmd and
1580 1578 mode."""
1581 1579
1582 1580 out_ln = strng.splitlines()[start:]
1583 1581 screens = chop(out_ln,screen_lines-1)
1584 1582 if len(screens) == 1:
1585 1583 print >>Term.cout, os.linesep.join(screens[0])
1586 1584 else:
1587 1585 last_escape = ""
1588 1586 for scr in screens[0:-1]:
1589 1587 hunk = os.linesep.join(scr)
1590 1588 print >>Term.cout, last_escape + hunk
1591 1589 if not page_more():
1592 1590 return
1593 1591 esc_list = esc_re.findall(hunk)
1594 1592 if len(esc_list) > 0:
1595 1593 last_escape = esc_list[-1]
1596 1594 print >>Term.cout, last_escape + os.linesep.join(screens[-1])
1597 1595
1598 1596 #----------------------------------------------------------------------------
1599 1597 def page(strng,start=0,screen_lines=0,pager_cmd = None):
1600 1598 """Print a string, piping through a pager after a certain length.
1601 1599
1602 1600 The screen_lines parameter specifies the number of *usable* lines of your
1603 1601 terminal screen (total lines minus lines you need to reserve to show other
1604 1602 information).
1605 1603
1606 1604 If you set screen_lines to a number <=0, page() will try to auto-determine
1607 1605 your screen size and will only use up to (screen_size+screen_lines) for
1608 1606 printing, paging after that. That is, if you want auto-detection but need
1609 1607 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
1610 1608 auto-detection without any lines reserved simply use screen_lines = 0.
1611 1609
1612 1610 If a string won't fit in the allowed lines, it is sent through the
1613 1611 specified pager command. If none given, look for PAGER in the environment,
1614 1612 and ultimately default to less.
1615 1613
1616 1614 If no system pager works, the string is sent through a 'dumb pager'
1617 1615 written in python, very simplistic.
1618 1616 """
1619 1617
1620 1618 # Some routines may auto-compute start offsets incorrectly and pass a
1621 1619 # negative value. Offset to 0 for robustness.
1622 1620 start = max(0,start)
1623 1621
1624 1622 # first, try the hook
1625 1623 ip = IPython.ipapi.get()
1626 1624 if ip:
1627 1625 try:
1628 1626 ip.IP.hooks.show_in_pager(strng)
1629 1627 return
1630 1628 except IPython.ipapi.TryNext:
1631 1629 pass
1632 1630
1633 1631 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
1634 1632 TERM = os.environ.get('TERM','dumb')
1635 1633 if TERM in ['dumb','emacs'] and os.name != 'nt':
1636 1634 print strng
1637 1635 return
1638 1636 # chop off the topmost part of the string we don't want to see
1639 1637 str_lines = strng.split(os.linesep)[start:]
1640 1638 str_toprint = os.linesep.join(str_lines)
1641 1639 num_newlines = len(str_lines)
1642 1640 len_str = len(str_toprint)
1643 1641
1644 1642 # Dumb heuristics to guesstimate number of on-screen lines the string
1645 1643 # takes. Very basic, but good enough for docstrings in reasonable
1646 1644 # terminals. If someone later feels like refining it, it's not hard.
1647 1645 numlines = max(num_newlines,int(len_str/80)+1)
1648 1646
1649 1647 if os.name == "nt":
1650 1648 screen_lines_def = get_console_size(defaulty=25)[1]
1651 1649 else:
1652 1650 screen_lines_def = 25 # default value if we can't auto-determine
1653 1651
1654 1652 # auto-determine screen size
1655 1653 if screen_lines <= 0:
1656 1654 if TERM=='xterm':
1657 1655 use_curses = USE_CURSES
1658 1656 else:
1659 1657 # curses causes problems on many terminals other than xterm.
1660 1658 use_curses = False
1661 1659 if use_curses:
1662 1660 # There is a bug in curses, where *sometimes* it fails to properly
1663 1661 # initialize, and then after the endwin() call is made, the
1664 1662 # terminal is left in an unusable state. Rather than trying to
1665 1663 # check everytime for this (by requesting and comparing termios
1666 1664 # flags each time), we just save the initial terminal state and
1667 1665 # unconditionally reset it every time. It's cheaper than making
1668 1666 # the checks.
1669 1667 term_flags = termios.tcgetattr(sys.stdout)
1670 1668 scr = curses.initscr()
1671 1669 screen_lines_real,screen_cols = scr.getmaxyx()
1672 1670 curses.endwin()
1673 1671 # Restore terminal state in case endwin() didn't.
1674 1672 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
1675 1673 # Now we have what we needed: the screen size in rows/columns
1676 1674 screen_lines += screen_lines_real
1677 1675 #print '***Screen size:',screen_lines_real,'lines x',\
1678 1676 #screen_cols,'columns.' # dbg
1679 1677 else:
1680 1678 screen_lines += screen_lines_def
1681 1679
1682 1680 #print 'numlines',numlines,'screenlines',screen_lines # dbg
1683 1681 if numlines <= screen_lines :
1684 1682 #print '*** normal print' # dbg
1685 1683 print >>Term.cout, str_toprint
1686 1684 else:
1687 1685 # Try to open pager and default to internal one if that fails.
1688 1686 # All failure modes are tagged as 'retval=1', to match the return
1689 1687 # value of a failed system command. If any intermediate attempt
1690 1688 # sets retval to 1, at the end we resort to our own page_dumb() pager.
1691 1689 pager_cmd = get_pager_cmd(pager_cmd)
1692 1690 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
1693 1691 if os.name == 'nt':
1694 1692 if pager_cmd.startswith('type'):
1695 1693 # The default WinXP 'type' command is failing on complex strings.
1696 1694 retval = 1
1697 1695 else:
1698 1696 tmpname = tempfile.mktemp('.txt')
1699 1697 tmpfile = file(tmpname,'wt')
1700 1698 tmpfile.write(strng)
1701 1699 tmpfile.close()
1702 1700 cmd = "%s < %s" % (pager_cmd,tmpname)
1703 1701 if os.system(cmd):
1704 1702 retval = 1
1705 1703 else:
1706 1704 retval = None
1707 1705 os.remove(tmpname)
1708 1706 else:
1709 1707 try:
1710 1708 retval = None
1711 1709 # if I use popen4, things hang. No idea why.
1712 1710 #pager,shell_out = os.popen4(pager_cmd)
1713 1711 pager = os.popen(pager_cmd,'w')
1714 1712 pager.write(strng)
1715 1713 pager.close()
1716 1714 retval = pager.close() # success returns None
1717 1715 except IOError,msg: # broken pipe when user quits
1718 1716 if msg.args == (32,'Broken pipe'):
1719 1717 retval = None
1720 1718 else:
1721 1719 retval = 1
1722 1720 except OSError:
1723 1721 # Other strange problems, sometimes seen in Win2k/cygwin
1724 1722 retval = 1
1725 1723 if retval is not None:
1726 1724 page_dumb(strng,screen_lines=screen_lines)
1727 1725
1728 1726 #----------------------------------------------------------------------------
1729 1727 def page_file(fname,start = 0, pager_cmd = None):
1730 1728 """Page a file, using an optional pager command and starting line.
1731 1729 """
1732 1730
1733 1731 pager_cmd = get_pager_cmd(pager_cmd)
1734 1732 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
1735 1733
1736 1734 try:
1737 1735 if os.environ['TERM'] in ['emacs','dumb']:
1738 1736 raise EnvironmentError
1739 1737 xsys(pager_cmd + ' ' + fname)
1740 1738 except:
1741 1739 try:
1742 1740 if start > 0:
1743 1741 start -= 1
1744 1742 page(open(fname).read(),start)
1745 1743 except:
1746 1744 print 'Unable to show file',`fname`
1747 1745
1748 1746
1749 1747 #----------------------------------------------------------------------------
1750 1748 def snip_print(str,width = 75,print_full = 0,header = ''):
1751 1749 """Print a string snipping the midsection to fit in width.
1752 1750
1753 1751 print_full: mode control:
1754 1752 - 0: only snip long strings
1755 1753 - 1: send to page() directly.
1756 1754 - 2: snip long strings and ask for full length viewing with page()
1757 1755 Return 1 if snipping was necessary, 0 otherwise."""
1758 1756
1759 1757 if print_full == 1:
1760 1758 page(header+str)
1761 1759 return 0
1762 1760
1763 1761 print header,
1764 1762 if len(str) < width:
1765 1763 print str
1766 1764 snip = 0
1767 1765 else:
1768 1766 whalf = int((width -5)/2)
1769 1767 print str[:whalf] + ' <...> ' + str[-whalf:]
1770 1768 snip = 1
1771 1769 if snip and print_full == 2:
1772 1770 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
1773 1771 page(str)
1774 1772 return snip
1775 1773
1776 1774 #****************************************************************************
1777 1775 # lists, dicts and structures
1778 1776
1779 1777 def belong(candidates,checklist):
1780 1778 """Check whether a list of items appear in a given list of options.
1781 1779
1782 1780 Returns a list of 1 and 0, one for each candidate given."""
1783 1781
1784 1782 return [x in checklist for x in candidates]
1785 1783
1786 1784 #----------------------------------------------------------------------------
1787 1785 def uniq_stable(elems):
1788 1786 """uniq_stable(elems) -> list
1789 1787
1790 1788 Return from an iterable, a list of all the unique elements in the input,
1791 1789 but maintaining the order in which they first appear.
1792 1790
1793 1791 A naive solution to this problem which just makes a dictionary with the
1794 1792 elements as keys fails to respect the stability condition, since
1795 1793 dictionaries are unsorted by nature.
1796 1794
1797 1795 Note: All elements in the input must be valid dictionary keys for this
1798 1796 routine to work, as it internally uses a dictionary for efficiency
1799 1797 reasons."""
1800 1798
1801 1799 unique = []
1802 1800 unique_dict = {}
1803 1801 for nn in elems:
1804 1802 if nn not in unique_dict:
1805 1803 unique.append(nn)
1806 1804 unique_dict[nn] = None
1807 1805 return unique
1808 1806
1809 1807 #----------------------------------------------------------------------------
1810 1808 class NLprinter:
1811 1809 """Print an arbitrarily nested list, indicating index numbers.
1812 1810
1813 1811 An instance of this class called nlprint is available and callable as a
1814 1812 function.
1815 1813
1816 1814 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
1817 1815 and using 'sep' to separate the index from the value. """
1818 1816
1819 1817 def __init__(self):
1820 1818 self.depth = 0
1821 1819
1822 1820 def __call__(self,lst,pos='',**kw):
1823 1821 """Prints the nested list numbering levels."""
1824 1822 kw.setdefault('indent',' ')
1825 1823 kw.setdefault('sep',': ')
1826 1824 kw.setdefault('start',0)
1827 1825 kw.setdefault('stop',len(lst))
1828 1826 # we need to remove start and stop from kw so they don't propagate
1829 1827 # into a recursive call for a nested list.
1830 1828 start = kw['start']; del kw['start']
1831 1829 stop = kw['stop']; del kw['stop']
1832 1830 if self.depth == 0 and 'header' in kw.keys():
1833 1831 print kw['header']
1834 1832
1835 1833 for idx in range(start,stop):
1836 1834 elem = lst[idx]
1837 1835 if type(elem)==type([]):
1838 1836 self.depth += 1
1839 1837 self.__call__(elem,itpl('$pos$idx,'),**kw)
1840 1838 self.depth -= 1
1841 1839 else:
1842 1840 printpl(kw['indent']*self.depth+'$pos$idx$kw["sep"]$elem')
1843 1841
1844 1842 nlprint = NLprinter()
1845 1843 #----------------------------------------------------------------------------
1846 1844 def all_belong(candidates,checklist):
1847 1845 """Check whether a list of items ALL appear in a given list of options.
1848 1846
1849 1847 Returns a single 1 or 0 value."""
1850 1848
1851 1849 return 1-(0 in [x in checklist for x in candidates])
1852 1850
1853 1851 #----------------------------------------------------------------------------
1854 1852 def sort_compare(lst1,lst2,inplace = 1):
1855 1853 """Sort and compare two lists.
1856 1854
1857 1855 By default it does it in place, thus modifying the lists. Use inplace = 0
1858 1856 to avoid that (at the cost of temporary copy creation)."""
1859 1857 if not inplace:
1860 1858 lst1 = lst1[:]
1861 1859 lst2 = lst2[:]
1862 1860 lst1.sort(); lst2.sort()
1863 1861 return lst1 == lst2
1864 1862
1865 1863 #----------------------------------------------------------------------------
1866 1864 def list2dict(lst):
1867 1865 """Takes a list of (key,value) pairs and turns it into a dict."""
1868 1866
1869 1867 dic = {}
1870 1868 for k,v in lst: dic[k] = v
1871 1869 return dic
1872 1870
1873 1871 #----------------------------------------------------------------------------
1874 1872 def list2dict2(lst,default=''):
1875 1873 """Takes a list and turns it into a dict.
1876 1874 Much slower than list2dict, but more versatile. This version can take
1877 1875 lists with sublists of arbitrary length (including sclars)."""
1878 1876
1879 1877 dic = {}
1880 1878 for elem in lst:
1881 1879 if type(elem) in (types.ListType,types.TupleType):
1882 1880 size = len(elem)
1883 1881 if size == 0:
1884 1882 pass
1885 1883 elif size == 1:
1886 1884 dic[elem] = default
1887 1885 else:
1888 1886 k,v = elem[0], elem[1:]
1889 1887 if len(v) == 1: v = v[0]
1890 1888 dic[k] = v
1891 1889 else:
1892 1890 dic[elem] = default
1893 1891 return dic
1894 1892
1895 1893 #----------------------------------------------------------------------------
1896 1894 def flatten(seq):
1897 1895 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
1898 1896
1899 1897 return [x for subseq in seq for x in subseq]
1900 1898
1901 1899 #----------------------------------------------------------------------------
1902 1900 def get_slice(seq,start=0,stop=None,step=1):
1903 1901 """Get a slice of a sequence with variable step. Specify start,stop,step."""
1904 1902 if stop == None:
1905 1903 stop = len(seq)
1906 1904 item = lambda i: seq[i]
1907 1905 return map(item,xrange(start,stop,step))
1908 1906
1909 1907 #----------------------------------------------------------------------------
1910 1908 def chop(seq,size):
1911 1909 """Chop a sequence into chunks of the given size."""
1912 1910 chunk = lambda i: seq[i:i+size]
1913 1911 return map(chunk,xrange(0,len(seq),size))
1914 1912
1915 1913 #----------------------------------------------------------------------------
1916 1914 # with is a keyword as of python 2.5, so this function is renamed to withobj
1917 1915 # from its old 'with' name.
1918 1916 def with_obj(object, **args):
1919 1917 """Set multiple attributes for an object, similar to Pascal's with.
1920 1918
1921 1919 Example:
1922 1920 with_obj(jim,
1923 1921 born = 1960,
1924 1922 haircolour = 'Brown',
1925 1923 eyecolour = 'Green')
1926 1924
1927 1925 Credit: Greg Ewing, in
1928 1926 http://mail.python.org/pipermail/python-list/2001-May/040703.html.
1929 1927
1930 1928 NOTE: up until IPython 0.7.2, this was called simply 'with', but 'with'
1931 1929 has become a keyword for Python 2.5, so we had to rename it."""
1932 1930
1933 1931 object.__dict__.update(args)
1934 1932
1935 1933 #----------------------------------------------------------------------------
1936 1934 def setattr_list(obj,alist,nspace = None):
1937 1935 """Set a list of attributes for an object taken from a namespace.
1938 1936
1939 1937 setattr_list(obj,alist,nspace) -> sets in obj all the attributes listed in
1940 1938 alist with their values taken from nspace, which must be a dict (something
1941 1939 like locals() will often do) If nspace isn't given, locals() of the
1942 1940 *caller* is used, so in most cases you can omit it.
1943 1941
1944 1942 Note that alist can be given as a string, which will be automatically
1945 1943 split into a list on whitespace. If given as a list, it must be a list of
1946 1944 *strings* (the variable names themselves), not of variables."""
1947 1945
1948 1946 # this grabs the local variables from the *previous* call frame -- that is
1949 1947 # the locals from the function that called setattr_list().
1950 1948 # - snipped from weave.inline()
1951 1949 if nspace is None:
1952 1950 call_frame = sys._getframe().f_back
1953 1951 nspace = call_frame.f_locals
1954 1952
1955 1953 if type(alist) in StringTypes:
1956 1954 alist = alist.split()
1957 1955 for attr in alist:
1958 1956 val = eval(attr,nspace)
1959 1957 setattr(obj,attr,val)
1960 1958
1961 1959 #----------------------------------------------------------------------------
1962 1960 def getattr_list(obj,alist,*args):
1963 1961 """getattr_list(obj,alist[, default]) -> attribute list.
1964 1962
1965 1963 Get a list of named attributes for an object. When a default argument is
1966 1964 given, it is returned when the attribute doesn't exist; without it, an
1967 1965 exception is raised in that case.
1968 1966
1969 1967 Note that alist can be given as a string, which will be automatically
1970 1968 split into a list on whitespace. If given as a list, it must be a list of
1971 1969 *strings* (the variable names themselves), not of variables."""
1972 1970
1973 1971 if type(alist) in StringTypes:
1974 1972 alist = alist.split()
1975 1973 if args:
1976 1974 if len(args)==1:
1977 1975 default = args[0]
1978 1976 return map(lambda attr: getattr(obj,attr,default),alist)
1979 1977 else:
1980 1978 raise ValueError,'getattr_list() takes only one optional argument'
1981 1979 else:
1982 1980 return map(lambda attr: getattr(obj,attr),alist)
1983 1981
1984 1982 #----------------------------------------------------------------------------
1985 1983 def map_method(method,object_list,*argseq,**kw):
1986 1984 """map_method(method,object_list,*args,**kw) -> list
1987 1985
1988 1986 Return a list of the results of applying the methods to the items of the
1989 1987 argument sequence(s). If more than one sequence is given, the method is
1990 1988 called with an argument list consisting of the corresponding item of each
1991 1989 sequence. All sequences must be of the same length.
1992 1990
1993 1991 Keyword arguments are passed verbatim to all objects called.
1994 1992
1995 1993 This is Python code, so it's not nearly as fast as the builtin map()."""
1996 1994
1997 1995 out_list = []
1998 1996 idx = 0
1999 1997 for object in object_list:
2000 1998 try:
2001 1999 handler = getattr(object, method)
2002 2000 except AttributeError:
2003 2001 out_list.append(None)
2004 2002 else:
2005 2003 if argseq:
2006 2004 args = map(lambda lst:lst[idx],argseq)
2007 2005 #print 'ob',object,'hand',handler,'ar',args # dbg
2008 2006 out_list.append(handler(args,**kw))
2009 2007 else:
2010 2008 out_list.append(handler(**kw))
2011 2009 idx += 1
2012 2010 return out_list
2013 2011
2014 2012 #----------------------------------------------------------------------------
2015 2013 def get_class_members(cls):
2016 2014 ret = dir(cls)
2017 2015 if hasattr(cls,'__bases__'):
2018 2016 for base in cls.__bases__:
2019 2017 ret.extend(get_class_members(base))
2020 2018 return ret
2021 2019
2022 2020 #----------------------------------------------------------------------------
2023 2021 def dir2(obj):
2024 2022 """dir2(obj) -> list of strings
2025 2023
2026 2024 Extended version of the Python builtin dir(), which does a few extra
2027 2025 checks, and supports common objects with unusual internals that confuse
2028 2026 dir(), such as Traits and PyCrust.
2029 2027
2030 2028 This version is guaranteed to return only a list of true strings, whereas
2031 2029 dir() returns anything that objects inject into themselves, even if they
2032 2030 are later not really valid for attribute access (many extension libraries
2033 2031 have such bugs).
2034 2032 """
2035 2033
2036 2034 # Start building the attribute list via dir(), and then complete it
2037 2035 # with a few extra special-purpose calls.
2038 2036 words = dir(obj)
2039 2037
2040 2038 if hasattr(obj,'__class__'):
2041 2039 words.append('__class__')
2042 2040 words.extend(get_class_members(obj.__class__))
2043 2041 #if '__base__' in words: 1/0
2044 2042
2045 2043 # Some libraries (such as traits) may introduce duplicates, we want to
2046 2044 # track and clean this up if it happens
2047 2045 may_have_dupes = False
2048 2046
2049 2047 # this is the 'dir' function for objects with Enthought's traits
2050 2048 if hasattr(obj, 'trait_names'):
2051 2049 try:
2052 2050 words.extend(obj.trait_names())
2053 2051 may_have_dupes = True
2054 2052 except TypeError:
2055 2053 # This will happen if `obj` is a class and not an instance.
2056 2054 pass
2057 2055
2058 2056 # Support for PyCrust-style _getAttributeNames magic method.
2059 2057 if hasattr(obj, '_getAttributeNames'):
2060 2058 try:
2061 2059 words.extend(obj._getAttributeNames())
2062 2060 may_have_dupes = True
2063 2061 except TypeError:
2064 2062 # `obj` is a class and not an instance. Ignore
2065 2063 # this error.
2066 2064 pass
2067 2065
2068 2066 if may_have_dupes:
2069 2067 # eliminate possible duplicates, as some traits may also
2070 2068 # appear as normal attributes in the dir() call.
2071 2069 words = list(set(words))
2072 2070 words.sort()
2073 2071
2074 2072 # filter out non-string attributes which may be stuffed by dir() calls
2075 2073 # and poor coding in third-party modules
2076 2074 return [w for w in words if isinstance(w, basestring)]
2077 2075
2078 2076 #----------------------------------------------------------------------------
2079 2077 def import_fail_info(mod_name,fns=None):
2080 2078 """Inform load failure for a module."""
2081 2079
2082 2080 if fns == None:
2083 2081 warn("Loading of %s failed.\n" % (mod_name,))
2084 2082 else:
2085 2083 warn("Loading of %s from %s failed.\n" % (fns,mod_name))
2086 2084
2087 2085 #----------------------------------------------------------------------------
2088 2086 # Proposed popitem() extension, written as a method
2089 2087
2090 2088
2091 2089 class NotGiven: pass
2092 2090
2093 2091 def popkey(dct,key,default=NotGiven):
2094 2092 """Return dct[key] and delete dct[key].
2095 2093
2096 2094 If default is given, return it if dct[key] doesn't exist, otherwise raise
2097 2095 KeyError. """
2098 2096
2099 2097 try:
2100 2098 val = dct[key]
2101 2099 except KeyError:
2102 2100 if default is NotGiven:
2103 2101 raise
2104 2102 else:
2105 2103 return default
2106 2104 else:
2107 2105 del dct[key]
2108 2106 return val
2109 2107
2110 2108 def wrap_deprecated(func, suggest = '<nothing>'):
2111 2109 def newFunc(*args, **kwargs):
2112 2110 warnings.warn("Call to deprecated function %s, use %s instead" %
2113 2111 ( func.__name__, suggest),
2114 2112 category=DeprecationWarning,
2115 2113 stacklevel = 2)
2116 2114 return func(*args, **kwargs)
2117 2115 return newFunc
2118 2116
2119 2117
2120 2118 def _num_cpus_unix():
2121 2119 """Return the number of active CPUs on a Unix system."""
2122 2120 return os.sysconf("SC_NPROCESSORS_ONLN")
2123 2121
2124 2122
2125 2123 def _num_cpus_darwin():
2126 2124 """Return the number of active CPUs on a Darwin system."""
2127 2125 p = subprocess.Popen(['sysctl','-n','hw.ncpu'],stdout=subprocess.PIPE)
2128 2126 return p.stdout.read()
2129 2127
2130 2128
2131 2129 def _num_cpus_windows():
2132 2130 """Return the number of active CPUs on a Windows system."""
2133 2131 return os.environ.get("NUMBER_OF_PROCESSORS")
2134 2132
2135 2133
2136 2134 def num_cpus():
2137 2135 """Return the effective number of CPUs in the system as an integer.
2138 2136
2139 2137 This cross-platform function makes an attempt at finding the total number of
2140 2138 available CPUs in the system, as returned by various underlying system and
2141 2139 python calls.
2142 2140
2143 2141 If it can't find a sensible answer, it returns 1 (though an error *may* make
2144 2142 it return a large positive number that's actually incorrect).
2145 2143 """
2146 2144
2147 2145 # Many thanks to the Parallel Python project (http://www.parallelpython.com)
2148 2146 # for the names of the keys we needed to look up for this function. This
2149 2147 # code was inspired by their equivalent function.
2150 2148
2151 2149 ncpufuncs = {'Linux':_num_cpus_unix,
2152 2150 'Darwin':_num_cpus_darwin,
2153 2151 'Windows':_num_cpus_windows,
2154 2152 # On Vista, python < 2.5.2 has a bug and returns 'Microsoft'
2155 2153 # See http://bugs.python.org/issue1082 for details.
2156 2154 'Microsoft':_num_cpus_windows,
2157 2155 }
2158 2156
2159 2157 ncpufunc = ncpufuncs.get(platform.system(),
2160 2158 # default to unix version (Solaris, AIX, etc)
2161 2159 _num_cpus_unix)
2162 2160
2163 2161 try:
2164 2162 ncpus = max(1,int(ncpufunc()))
2165 2163 except:
2166 2164 ncpus = 1
2167 2165 return ncpus
2168 2166
2169 2167 #*************************** end of file <genutils.py> **********************
@@ -1,46 +1,157 b''
1 1 # encoding: utf-8
2 2
3 3 """Tests for genutils.py"""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from IPython import genutils
19 import os
19
20 import os, sys, IPython
20 21 env = os.environ
21 22
23 from os.path import join, abspath
22 24
23 25 def test_get_home_dir_1():
24 """Make sure we can get the home directory."""
26 """Testcase to see if we can call get_home_dir without Exceptions."""
25 27 home_dir = genutils.get_home_dir()
26
28
27 29 def test_get_home_dir_2():
28 """Make sure we can get the home directory."""
29 old=env["HOME"]
30 env["HOME"]=os.path.join(".","home_test_dir")
30 """Testcase for py2exe logic, un-compressed lib
31 """
32 sys.frozen=True
33 oldstuff=IPython.__file__
34
35 #fake filename for IPython.__init__
36 IPython.__file__=abspath(join(".", "home_test_dir/Lib/IPython/__init__.py"))
37
38 home_dir = genutils.get_home_dir()
39 assert home_dir==abspath(join(".", "home_test_dir"))
40 IPython.__file__=oldstuff
41 del sys.frozen
42
43 def test_get_home_dir_3():
44 """Testcase for py2exe logic, compressed lib
45 """
46
47 sys.frozen=True
48 oldstuff=IPython.__file__
49
50 #fake filename for IPython.__init__
51 IPython.__file__=abspath(join(".", "home_test_dir/Library.zip/IPython/__init__.py"))
52
53 home_dir = genutils.get_home_dir()
54 assert home_dir==abspath(join(".", "home_test_dir")).lower()
55
56 del sys.frozen
57 IPython.__file__=oldstuff
58
59
60 def test_get_home_dir_4():
61 """Testcase $HOME is set, then use its value as home directory."""
62 oldstuff=env["HOME"]
63
64 env["HOME"]=join(".","home_test_dir")
31 65 home_dir = genutils.get_home_dir()
32 66 assert home_dir==env["HOME"]
33 env["HOME"]=old
34 67
68 env["HOME"]=oldstuff
35 69
70 def test_get_home_dir_5():
71 """Testcase $HOME is not set, os=='posix'.
72 This should fail with HomeDirError"""
73 oldstuff=env["HOME"],os.name
74
75 os.name='posix'
76 del os.environ["HOME"]
77 try:
78 genutils.get_home_dir()
79 assert False
80 except genutils.HomeDirError:
81 pass
82 finally:
83 env["HOME"],os.name=oldstuff
84
85 def test_get_home_dir_6():
86 """Testcase $HOME is not set, os=='nt'
87 env['HOMEDRIVE'],env['HOMEPATH'] points to path."""
88
89 oldstuff=env["HOME"],os.name,env['HOMEDRIVE'],env['HOMEPATH']
90
91 os.name='nt'
92 del os.environ["HOME"]
93 env['HOMEDRIVE'],env['HOMEPATH']=os.path.abspath("."),"home_test_dir"
94
95 home_dir = genutils.get_home_dir()
96 assert home_dir==abspath(join(".", "home_test_dir"))
97
98 env["HOME"],os.name,env['HOMEDRIVE'],env['HOMEPATH']=oldstuff
99
100 def test_get_home_dir_8():
101 """Testcase $HOME is not set, os=='nt'
102 env['HOMEDRIVE'],env['HOMEPATH'] do not point to path.
103 env['USERPROFILE'] points to path
104 """
105 oldstuff=(env["HOME"],os.name,env['HOMEDRIVE'],env['HOMEPATH'])
106
107 os.name='nt'
108 del os.environ["HOME"]
109 env['HOMEDRIVE'],env['HOMEPATH']=os.path.abspath("."),"DOES NOT EXIST"
110 env["USERPROFILE"]=abspath(join(".","home_test_dir"))
36 111
112 home_dir = genutils.get_home_dir()
113 assert home_dir==abspath(join(".", "home_test_dir"))
114
115 (env["HOME"],os.name,env['HOMEDRIVE'],env['HOMEPATH'])=oldstuff
116
117 def test_get_home_dir_9():
118 """Testcase $HOME is not set, os=='nt'
119 env['HOMEDRIVE'],env['HOMEPATH'], env['USERPROFILE'] missing
120 """
121 import _winreg as wreg
122 oldstuff = (env["HOME"],os.name,env['HOMEDRIVE'],
123 env['HOMEPATH'],env["USERPROFILE"],
124 wreg.OpenKey, wreg.QueryValueEx,
125 )
126 os.name='nt'
127 del env["HOME"],env['HOMEDRIVE']
128
129 #Stub windows registry functions
130 def OpenKey(x,y):
131 class key:
132 def Close(self):
133 pass
134 return key()
135 def QueryValueEx(x,y):
136 return [abspath(join(".", "home_test_dir"))]
137
138 wreg.OpenKey=OpenKey
139 wreg.QueryValueEx=QueryValueEx
140
141
142 home_dir = genutils.get_home_dir()
143 assert home_dir==abspath(join(".", "home_test_dir"))
37 144
145 (env["HOME"],os.name,env['HOMEDRIVE'],
146 env['HOMEPATH'],env["USERPROFILE"],
147 wreg.OpenKey, wreg.QueryValueEx,) = oldstuff
148
38 149
39 150 def test_get_ipython_dir():
40 """Make sure we can get the ipython directory."""
151 """Testcase to see if we can call get_ipython_dir without Exceptions."""
41 152 ipdir = genutils.get_ipython_dir()
42 153
43 154 def test_get_security_dir():
44 """Make sure we can get the ipython/security directory."""
155 """Testcase to see if we can call get_security_dir without Exceptions."""
45 156 sdir = genutils.get_security_dir()
46 157 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now