##// END OF EJS Templates
extensions: add unwrapfunction to undo wrapfunction...
Jun Wu -
r29765:19578bb8 default
parent child Browse files
Show More
@@ -0,0 +1,39 b''
1 from __future__ import absolute_import, print_function
2
3 from mercurial import extensions
4
5 def genwrapper(x):
6 def f(orig, *args, **kwds):
7 return [x] + orig(*args, **kwds)
8 f.x = x
9 return f
10
11 def getid(wrapper):
12 return getattr(wrapper, 'x', '-')
13
14 wrappers = [genwrapper(i) for i in range(5)]
15
16 class dummyclass(object):
17 def getstack(self):
18 return ['orig']
19
20 dummy = dummyclass()
21
22 def batchwrap(wrappers):
23 for w in wrappers:
24 extensions.wrapfunction(dummy, 'getstack', w)
25 print('wrap %d: %s' % (getid(w), dummy.getstack()))
26
27 def batchunwrap(wrappers):
28 for w in wrappers:
29 result = None
30 try:
31 result = extensions.unwrapfunction(dummy, 'getstack', w)
32 msg = str(dummy.getstack())
33 except (ValueError, IndexError) as e:
34 msg = e.__class__.__name__
35 print('unwrap %s: %s: %s' % (getid(w), getid(result), msg))
36
37 batchwrap(wrappers + [wrappers[0]])
38 batchunwrap([(wrappers[i] if i >= 0 else None)
39 for i in [3, None, 0, 4, 0, 2, 1, None]])
@@ -0,0 +1,14 b''
1 wrap 0: [0, 'orig']
2 wrap 1: [1, 0, 'orig']
3 wrap 2: [2, 1, 0, 'orig']
4 wrap 3: [3, 2, 1, 0, 'orig']
5 wrap 4: [4, 3, 2, 1, 0, 'orig']
6 wrap 0: [0, 4, 3, 2, 1, 0, 'orig']
7 unwrap 3: 3: [0, 4, 2, 1, 0, 'orig']
8 unwrap -: 0: [4, 2, 1, 0, 'orig']
9 unwrap 0: 0: [4, 2, 1, 'orig']
10 unwrap 4: 4: [2, 1, 'orig']
11 unwrap 0: -: ValueError
12 unwrap 2: 2: [1, 'orig']
13 unwrap 1: 1: ['orig']
14 unwrap -: -: IndexError
@@ -1,515 +1,535 b''
1 1 # extensions.py - extension handling for mercurial
2 2 #
3 3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import imp
11 11 import os
12 12
13 13 from .i18n import (
14 14 _,
15 15 gettext,
16 16 )
17 17
18 18 from . import (
19 19 cmdutil,
20 20 error,
21 21 util,
22 22 )
23 23
24 24 _extensions = {}
25 25 _aftercallbacks = {}
26 26 _order = []
27 27 _builtin = set(['hbisect', 'bookmarks', 'parentrevspec', 'progress', 'interhg',
28 28 'inotify', 'hgcia'])
29 29
30 30 def extensions(ui=None):
31 31 if ui:
32 32 def enabled(name):
33 33 for format in ['%s', 'hgext.%s']:
34 34 conf = ui.config('extensions', format % name)
35 35 if conf is not None and not conf.startswith('!'):
36 36 return True
37 37 else:
38 38 enabled = lambda name: True
39 39 for name in _order:
40 40 module = _extensions[name]
41 41 if module and enabled(name):
42 42 yield name, module
43 43
44 44 def find(name):
45 45 '''return module with given extension name'''
46 46 mod = None
47 47 try:
48 48 mod = _extensions[name]
49 49 except KeyError:
50 50 for k, v in _extensions.iteritems():
51 51 if k.endswith('.' + name) or k.endswith('/' + name):
52 52 mod = v
53 53 break
54 54 if not mod:
55 55 raise KeyError(name)
56 56 return mod
57 57
58 58 def loadpath(path, module_name):
59 59 module_name = module_name.replace('.', '_')
60 60 path = util.normpath(util.expandpath(path))
61 61 if os.path.isdir(path):
62 62 # module/__init__.py style
63 63 d, f = os.path.split(path)
64 64 fd, fpath, desc = imp.find_module(f, [d])
65 65 return imp.load_module(module_name, fd, fpath, desc)
66 66 else:
67 67 try:
68 68 return imp.load_source(module_name, path)
69 69 except IOError as exc:
70 70 if not exc.filename:
71 71 exc.filename = path # python does not fill this
72 72 raise
73 73
74 74 def _importh(name):
75 75 """import and return the <name> module"""
76 76 mod = __import__(name)
77 77 components = name.split('.')
78 78 for comp in components[1:]:
79 79 mod = getattr(mod, comp)
80 80 return mod
81 81
82 82 def _reportimporterror(ui, err, failed, next):
83 83 ui.debug('could not import %s (%s): trying %s\n'
84 84 % (failed, err, next))
85 85 if ui.debugflag:
86 86 ui.traceback()
87 87
88 88 def load(ui, name, path):
89 89 if name.startswith('hgext.') or name.startswith('hgext/'):
90 90 shortname = name[6:]
91 91 else:
92 92 shortname = name
93 93 if shortname in _builtin:
94 94 return None
95 95 if shortname in _extensions:
96 96 return _extensions[shortname]
97 97 _extensions[shortname] = None
98 98 if path:
99 99 # the module will be loaded in sys.modules
100 100 # choose an unique name so that it doesn't
101 101 # conflicts with other modules
102 102 mod = loadpath(path, 'hgext.%s' % name)
103 103 else:
104 104 try:
105 105 mod = _importh("hgext.%s" % name)
106 106 except ImportError as err:
107 107 _reportimporterror(ui, err, "hgext.%s" % name, name)
108 108 try:
109 109 mod = _importh("hgext3rd.%s" % name)
110 110 except ImportError as err:
111 111 _reportimporterror(ui, err, "hgext3rd.%s" % name, name)
112 112 mod = _importh(name)
113 113
114 114 # Before we do anything with the extension, check against minimum stated
115 115 # compatibility. This gives extension authors a mechanism to have their
116 116 # extensions short circuit when loaded with a known incompatible version
117 117 # of Mercurial.
118 118 minver = getattr(mod, 'minimumhgversion', None)
119 119 if minver and util.versiontuple(minver, 2) > util.versiontuple(n=2):
120 120 ui.warn(_('(third party extension %s requires version %s or newer '
121 121 'of Mercurial; disabling)\n') % (shortname, minver))
122 122 return
123 123
124 124 _extensions[shortname] = mod
125 125 _order.append(shortname)
126 126 for fn in _aftercallbacks.get(shortname, []):
127 127 fn(loaded=True)
128 128 return mod
129 129
130 130 def _runuisetup(name, ui):
131 131 uisetup = getattr(_extensions[name], 'uisetup', None)
132 132 if uisetup:
133 133 uisetup(ui)
134 134
135 135 def _runextsetup(name, ui):
136 136 extsetup = getattr(_extensions[name], 'extsetup', None)
137 137 if extsetup:
138 138 try:
139 139 extsetup(ui)
140 140 except TypeError:
141 141 if extsetup.func_code.co_argcount != 0:
142 142 raise
143 143 extsetup() # old extsetup with no ui argument
144 144
145 145 def loadall(ui):
146 146 result = ui.configitems("extensions")
147 147 newindex = len(_order)
148 148 for (name, path) in result:
149 149 if path:
150 150 if path[0] == '!':
151 151 continue
152 152 try:
153 153 load(ui, name, path)
154 154 except KeyboardInterrupt:
155 155 raise
156 156 except Exception as inst:
157 157 if path:
158 158 ui.warn(_("*** failed to import extension %s from %s: %s\n")
159 159 % (name, path, inst))
160 160 else:
161 161 ui.warn(_("*** failed to import extension %s: %s\n")
162 162 % (name, inst))
163 163 ui.traceback()
164 164
165 165 for name in _order[newindex:]:
166 166 _runuisetup(name, ui)
167 167
168 168 for name in _order[newindex:]:
169 169 _runextsetup(name, ui)
170 170
171 171 # Call aftercallbacks that were never met.
172 172 for shortname in _aftercallbacks:
173 173 if shortname in _extensions:
174 174 continue
175 175
176 176 for fn in _aftercallbacks[shortname]:
177 177 fn(loaded=False)
178 178
179 179 # loadall() is called multiple times and lingering _aftercallbacks
180 180 # entries could result in double execution. See issue4646.
181 181 _aftercallbacks.clear()
182 182
183 183 def afterloaded(extension, callback):
184 184 '''Run the specified function after a named extension is loaded.
185 185
186 186 If the named extension is already loaded, the callback will be called
187 187 immediately.
188 188
189 189 If the named extension never loads, the callback will be called after
190 190 all extensions have been loaded.
191 191
192 192 The callback receives the named argument ``loaded``, which is a boolean
193 193 indicating whether the dependent extension actually loaded.
194 194 '''
195 195
196 196 if extension in _extensions:
197 197 callback(loaded=True)
198 198 else:
199 199 _aftercallbacks.setdefault(extension, []).append(callback)
200 200
201 201 def bind(func, *args):
202 202 '''Partial function application
203 203
204 204 Returns a new function that is the partial application of args and kwargs
205 205 to func. For example,
206 206
207 207 f(1, 2, bar=3) === bind(f, 1)(2, bar=3)'''
208 208 assert callable(func)
209 209 def closure(*a, **kw):
210 210 return func(*(args + a), **kw)
211 211 return closure
212 212
213 213 def _updatewrapper(wrap, origfn, unboundwrapper):
214 214 '''Copy and add some useful attributes to wrapper'''
215 215 wrap.__module__ = getattr(origfn, '__module__')
216 216 wrap.__doc__ = getattr(origfn, '__doc__')
217 217 wrap.__dict__.update(getattr(origfn, '__dict__', {}))
218 218 wrap._origfunc = origfn
219 219 wrap._unboundwrapper = unboundwrapper
220 220
221 221 def wrapcommand(table, command, wrapper, synopsis=None, docstring=None):
222 222 '''Wrap the command named `command' in table
223 223
224 224 Replace command in the command table with wrapper. The wrapped command will
225 225 be inserted into the command table specified by the table argument.
226 226
227 227 The wrapper will be called like
228 228
229 229 wrapper(orig, *args, **kwargs)
230 230
231 231 where orig is the original (wrapped) function, and *args, **kwargs
232 232 are the arguments passed to it.
233 233
234 234 Optionally append to the command synopsis and docstring, used for help.
235 235 For example, if your extension wraps the ``bookmarks`` command to add the
236 236 flags ``--remote`` and ``--all`` you might call this function like so:
237 237
238 238 synopsis = ' [-a] [--remote]'
239 239 docstring = """
240 240
241 241 The ``remotenames`` extension adds the ``--remote`` and ``--all`` (``-a``)
242 242 flags to the bookmarks command. Either flag will show the remote bookmarks
243 243 known to the repository; ``--remote`` will also suppress the output of the
244 244 local bookmarks.
245 245 """
246 246
247 247 extensions.wrapcommand(commands.table, 'bookmarks', exbookmarks,
248 248 synopsis, docstring)
249 249 '''
250 250 assert callable(wrapper)
251 251 aliases, entry = cmdutil.findcmd(command, table)
252 252 for alias, e in table.iteritems():
253 253 if e is entry:
254 254 key = alias
255 255 break
256 256
257 257 origfn = entry[0]
258 258 wrap = bind(util.checksignature(wrapper), util.checksignature(origfn))
259 259 _updatewrapper(wrap, origfn, wrapper)
260 260 if docstring is not None:
261 261 wrap.__doc__ += docstring
262 262
263 263 newentry = list(entry)
264 264 newentry[0] = wrap
265 265 if synopsis is not None:
266 266 newentry[2] += synopsis
267 267 table[key] = tuple(newentry)
268 268 return entry
269 269
270 270 def wrapfunction(container, funcname, wrapper):
271 271 '''Wrap the function named funcname in container
272 272
273 273 Replace the funcname member in the given container with the specified
274 274 wrapper. The container is typically a module, class, or instance.
275 275
276 276 The wrapper will be called like
277 277
278 278 wrapper(orig, *args, **kwargs)
279 279
280 280 where orig is the original (wrapped) function, and *args, **kwargs
281 281 are the arguments passed to it.
282 282
283 283 Wrapping methods of the repository object is not recommended since
284 284 it conflicts with extensions that extend the repository by
285 285 subclassing. All extensions that need to extend methods of
286 286 localrepository should use this subclassing trick: namely,
287 287 reposetup() should look like
288 288
289 289 def reposetup(ui, repo):
290 290 class myrepo(repo.__class__):
291 291 def whatever(self, *args, **kwargs):
292 292 [...extension stuff...]
293 293 super(myrepo, self).whatever(*args, **kwargs)
294 294 [...extension stuff...]
295 295
296 296 repo.__class__ = myrepo
297 297
298 298 In general, combining wrapfunction() with subclassing does not
299 299 work. Since you cannot control what other extensions are loaded by
300 300 your end users, you should play nicely with others by using the
301 301 subclass trick.
302 302 '''
303 303 assert callable(wrapper)
304 304
305 305 origfn = getattr(container, funcname)
306 306 assert callable(origfn)
307 307 wrap = bind(wrapper, origfn)
308 308 _updatewrapper(wrap, origfn, wrapper)
309 309 setattr(container, funcname, wrap)
310 310 return origfn
311 311
312 def unwrapfunction(container, funcname, wrapper=None):
313 '''undo wrapfunction
314
315 If wrappers is None, undo the last wrap. Otherwise removes the wrapper
316 from the chain of wrappers.
317
318 Return the removed wrapper.
319 Raise IndexError if wrapper is None and nothing to unwrap; ValueError if
320 wrapper is not None but is not found in the wrapper chain.
321 '''
322 chain = getwrapperchain(container, funcname)
323 origfn = chain.pop()
324 if wrapper is None:
325 wrapper = chain[0]
326 chain.remove(wrapper)
327 setattr(container, funcname, origfn)
328 for w in reversed(chain):
329 wrapfunction(container, funcname, w)
330 return wrapper
331
312 332 def getwrapperchain(container, funcname):
313 333 '''get a chain of wrappers of a function
314 334
315 335 Return a list of functions: [newest wrapper, ..., oldest wrapper, origfunc]
316 336
317 337 The wrapper functions are the ones passed to wrapfunction, whose first
318 338 argument is origfunc.
319 339 '''
320 340 result = []
321 341 fn = getattr(container, funcname)
322 342 while fn:
323 343 assert callable(fn)
324 344 result.append(getattr(fn, '_unboundwrapper', fn))
325 345 fn = getattr(fn, '_origfunc', None)
326 346 return result
327 347
328 348 def _disabledpaths(strip_init=False):
329 349 '''find paths of disabled extensions. returns a dict of {name: path}
330 350 removes /__init__.py from packages if strip_init is True'''
331 351 import hgext
332 352 extpath = os.path.dirname(os.path.abspath(hgext.__file__))
333 353 try: # might not be a filesystem path
334 354 files = os.listdir(extpath)
335 355 except OSError:
336 356 return {}
337 357
338 358 exts = {}
339 359 for e in files:
340 360 if e.endswith('.py'):
341 361 name = e.rsplit('.', 1)[0]
342 362 path = os.path.join(extpath, e)
343 363 else:
344 364 name = e
345 365 path = os.path.join(extpath, e, '__init__.py')
346 366 if not os.path.exists(path):
347 367 continue
348 368 if strip_init:
349 369 path = os.path.dirname(path)
350 370 if name in exts or name in _order or name == '__init__':
351 371 continue
352 372 exts[name] = path
353 373 return exts
354 374
355 375 def _moduledoc(file):
356 376 '''return the top-level python documentation for the given file
357 377
358 378 Loosely inspired by pydoc.source_synopsis(), but rewritten to
359 379 handle triple quotes and to return the whole text instead of just
360 380 the synopsis'''
361 381 result = []
362 382
363 383 line = file.readline()
364 384 while line[:1] == '#' or not line.strip():
365 385 line = file.readline()
366 386 if not line:
367 387 break
368 388
369 389 start = line[:3]
370 390 if start == '"""' or start == "'''":
371 391 line = line[3:]
372 392 while line:
373 393 if line.rstrip().endswith(start):
374 394 line = line.split(start)[0]
375 395 if line:
376 396 result.append(line)
377 397 break
378 398 elif not line:
379 399 return None # unmatched delimiter
380 400 result.append(line)
381 401 line = file.readline()
382 402 else:
383 403 return None
384 404
385 405 return ''.join(result)
386 406
387 407 def _disabledhelp(path):
388 408 '''retrieve help synopsis of a disabled extension (without importing)'''
389 409 try:
390 410 file = open(path)
391 411 except IOError:
392 412 return
393 413 else:
394 414 doc = _moduledoc(file)
395 415 file.close()
396 416
397 417 if doc: # extracting localized synopsis
398 418 return gettext(doc).splitlines()[0]
399 419 else:
400 420 return _('(no help text available)')
401 421
402 422 def disabled():
403 423 '''find disabled extensions from hgext. returns a dict of {name: desc}'''
404 424 try:
405 425 from hgext import __index__
406 426 return dict((name, gettext(desc))
407 427 for name, desc in __index__.docs.iteritems()
408 428 if name not in _order)
409 429 except (ImportError, AttributeError):
410 430 pass
411 431
412 432 paths = _disabledpaths()
413 433 if not paths:
414 434 return {}
415 435
416 436 exts = {}
417 437 for name, path in paths.iteritems():
418 438 doc = _disabledhelp(path)
419 439 if doc:
420 440 exts[name] = doc
421 441
422 442 return exts
423 443
424 444 def disabledext(name):
425 445 '''find a specific disabled extension from hgext. returns desc'''
426 446 try:
427 447 from hgext import __index__
428 448 if name in _order: # enabled
429 449 return
430 450 else:
431 451 return gettext(__index__.docs.get(name))
432 452 except (ImportError, AttributeError):
433 453 pass
434 454
435 455 paths = _disabledpaths()
436 456 if name in paths:
437 457 return _disabledhelp(paths[name])
438 458
439 459 def disabledcmd(ui, cmd, strict=False):
440 460 '''import disabled extensions until cmd is found.
441 461 returns (cmdname, extname, module)'''
442 462
443 463 paths = _disabledpaths(strip_init=True)
444 464 if not paths:
445 465 raise error.UnknownCommand(cmd)
446 466
447 467 def findcmd(cmd, name, path):
448 468 try:
449 469 mod = loadpath(path, 'hgext.%s' % name)
450 470 except Exception:
451 471 return
452 472 try:
453 473 aliases, entry = cmdutil.findcmd(cmd,
454 474 getattr(mod, 'cmdtable', {}), strict)
455 475 except (error.AmbiguousCommand, error.UnknownCommand):
456 476 return
457 477 except Exception:
458 478 ui.warn(_('warning: error finding commands in %s\n') % path)
459 479 ui.traceback()
460 480 return
461 481 for c in aliases:
462 482 if c.startswith(cmd):
463 483 cmd = c
464 484 break
465 485 else:
466 486 cmd = aliases[0]
467 487 return (cmd, name, mod)
468 488
469 489 ext = None
470 490 # first, search for an extension with the same name as the command
471 491 path = paths.pop(cmd, None)
472 492 if path:
473 493 ext = findcmd(cmd, cmd, path)
474 494 if not ext:
475 495 # otherwise, interrogate each extension until there's a match
476 496 for name, path in paths.iteritems():
477 497 ext = findcmd(cmd, name, path)
478 498 if ext:
479 499 break
480 500 if ext and 'DEPRECATED' not in ext.__doc__:
481 501 return ext
482 502
483 503 raise error.UnknownCommand(cmd)
484 504
485 505 def enabled(shortname=True):
486 506 '''return a dict of {name: desc} of extensions'''
487 507 exts = {}
488 508 for ename, ext in extensions():
489 509 doc = (gettext(ext.__doc__) or _('(no help text available)'))
490 510 if shortname:
491 511 ename = ename.split('.')[-1]
492 512 exts[ename] = doc.splitlines()[0].strip()
493 513
494 514 return exts
495 515
496 516 def notloaded():
497 517 '''return short names of extensions that failed to load'''
498 518 return [name for name, mod in _extensions.iteritems() if mod is None]
499 519
500 520 def moduleversion(module):
501 521 '''return version information from given module as a string'''
502 522 if (util.safehasattr(module, 'getversion')
503 523 and callable(module.getversion)):
504 524 version = module.getversion()
505 525 elif util.safehasattr(module, '__version__'):
506 526 version = module.__version__
507 527 else:
508 528 version = ''
509 529 if isinstance(version, (list, tuple)):
510 530 version = '.'.join(str(o) for o in version)
511 531 return version
512 532
513 533 def ismoduleinternal(module):
514 534 exttestedwith = getattr(module, 'testedwith', None)
515 535 return exttestedwith == "internal"
General Comments 0
You need to be logged in to leave comments. Login now