##// END OF EJS Templates
extensions: extract partial application into a bind() function...
Eric Sumner -
r24734:fb6cb1b8 default
parent child Browse files
Show More
@@ -1,431 +1,439 b''
1 1 # extensions.py - extension handling for mercurial
2 2 #
3 3 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 import imp, os
9 9 import util, cmdutil, error
10 10 from i18n import _, gettext
11 11
12 12 _extensions = {}
13 13 _aftercallbacks = {}
14 14 _order = []
15 15 _ignore = ['hbisect', 'bookmarks', 'parentrevspec', 'interhg', 'inotify']
16 16
17 17 def extensions(ui=None):
18 18 if ui:
19 19 def enabled(name):
20 20 for format in ['%s', 'hgext.%s']:
21 21 conf = ui.config('extensions', format % name)
22 22 if conf is not None and not conf.startswith('!'):
23 23 return True
24 24 else:
25 25 enabled = lambda name: True
26 26 for name in _order:
27 27 module = _extensions[name]
28 28 if module and enabled(name):
29 29 yield name, module
30 30
31 31 def find(name):
32 32 '''return module with given extension name'''
33 33 mod = None
34 34 try:
35 35 mod = _extensions[name]
36 36 except KeyError:
37 37 for k, v in _extensions.iteritems():
38 38 if k.endswith('.' + name) or k.endswith('/' + name):
39 39 mod = v
40 40 break
41 41 if not mod:
42 42 raise KeyError(name)
43 43 return mod
44 44
45 45 def loadpath(path, module_name):
46 46 module_name = module_name.replace('.', '_')
47 47 path = util.normpath(util.expandpath(path))
48 48 if os.path.isdir(path):
49 49 # module/__init__.py style
50 50 d, f = os.path.split(path)
51 51 fd, fpath, desc = imp.find_module(f, [d])
52 52 return imp.load_module(module_name, fd, fpath, desc)
53 53 else:
54 54 try:
55 55 return imp.load_source(module_name, path)
56 56 except IOError, exc:
57 57 if not exc.filename:
58 58 exc.filename = path # python does not fill this
59 59 raise
60 60
61 61 def load(ui, name, path):
62 62 if name.startswith('hgext.') or name.startswith('hgext/'):
63 63 shortname = name[6:]
64 64 else:
65 65 shortname = name
66 66 if shortname in _ignore:
67 67 return None
68 68 if shortname in _extensions:
69 69 return _extensions[shortname]
70 70 _extensions[shortname] = None
71 71 if path:
72 72 # the module will be loaded in sys.modules
73 73 # choose an unique name so that it doesn't
74 74 # conflicts with other modules
75 75 mod = loadpath(path, 'hgext.%s' % name)
76 76 else:
77 77 def importh(name):
78 78 mod = __import__(name)
79 79 components = name.split('.')
80 80 for comp in components[1:]:
81 81 mod = getattr(mod, comp)
82 82 return mod
83 83 try:
84 84 mod = importh("hgext.%s" % name)
85 85 except ImportError, err:
86 86 ui.debug('could not import hgext.%s (%s): trying %s\n'
87 87 % (name, err, name))
88 88 mod = importh(name)
89 89 _extensions[shortname] = mod
90 90 _order.append(shortname)
91 91 for fn in _aftercallbacks.get(shortname, []):
92 92 fn(loaded=True)
93 93 return mod
94 94
95 95 def loadall(ui):
96 96 result = ui.configitems("extensions")
97 97 newindex = len(_order)
98 98 for (name, path) in result:
99 99 if path:
100 100 if path[0] == '!':
101 101 continue
102 102 try:
103 103 load(ui, name, path)
104 104 except KeyboardInterrupt:
105 105 raise
106 106 except Exception, inst:
107 107 if path:
108 108 ui.warn(_("*** failed to import extension %s from %s: %s\n")
109 109 % (name, path, inst))
110 110 else:
111 111 ui.warn(_("*** failed to import extension %s: %s\n")
112 112 % (name, inst))
113 113
114 114 for name in _order[newindex:]:
115 115 uisetup = getattr(_extensions[name], 'uisetup', None)
116 116 if uisetup:
117 117 uisetup(ui)
118 118
119 119 for name in _order[newindex:]:
120 120 extsetup = getattr(_extensions[name], 'extsetup', None)
121 121 if extsetup:
122 122 try:
123 123 extsetup(ui)
124 124 except TypeError:
125 125 if extsetup.func_code.co_argcount != 0:
126 126 raise
127 127 extsetup() # old extsetup with no ui argument
128 128
129 129 # Call aftercallbacks that were never met.
130 130 for shortname in _aftercallbacks:
131 131 if shortname in _extensions:
132 132 continue
133 133
134 134 for fn in _aftercallbacks[shortname]:
135 135 fn(loaded=False)
136 136
137 137 def afterloaded(extension, callback):
138 138 '''Run the specified function after a named extension is loaded.
139 139
140 140 If the named extension is already loaded, the callback will be called
141 141 immediately.
142 142
143 143 If the named extension never loads, the callback will be called after
144 144 all extensions have been loaded.
145 145
146 146 The callback receives the named argument ``loaded``, which is a boolean
147 147 indicating whether the dependent extension actually loaded.
148 148 '''
149 149
150 150 if extension in _extensions:
151 151 callback(loaded=True)
152 152 else:
153 153 _aftercallbacks.setdefault(extension, []).append(callback)
154 154
155 def bind(func, *args):
156 '''Partial function application
157
158 Returns a new function that is the partial application of args and kwargs
159 to func. For example,
160
161 f(1, 2, bar=3) === bind(f, 1)(2, bar=3)'''
162 assert callable(func)
163 def closure(*a, **kw):
164 return func(*(args + a), **kw)
165 return closure
166
155 167 def wrapcommand(table, command, wrapper, synopsis=None, docstring=None):
156 168 '''Wrap the command named `command' in table
157 169
158 170 Replace command in the command table with wrapper. The wrapped command will
159 171 be inserted into the command table specified by the table argument.
160 172
161 173 The wrapper will be called like
162 174
163 175 wrapper(orig, *args, **kwargs)
164 176
165 177 where orig is the original (wrapped) function, and *args, **kwargs
166 178 are the arguments passed to it.
167 179
168 180 Optionally append to the command synopsis and docstring, used for help.
169 181 For example, if your extension wraps the ``bookmarks`` command to add the
170 182 flags ``--remote`` and ``--all`` you might call this function like so:
171 183
172 184 synopsis = ' [-a] [--remote]'
173 185 docstring = """
174 186
175 187 The ``remotenames`` extension adds the ``--remote`` and ``--all`` (``-a``)
176 188 flags to the bookmarks command. Either flag will show the remote bookmarks
177 189 known to the repository; ``--remote`` will also supress the output of the
178 190 local bookmarks.
179 191 """
180 192
181 193 extensions.wrapcommand(commands.table, 'bookmarks', exbookmarks,
182 194 synopsis, docstring)
183 195 '''
184 196 assert callable(wrapper)
185 197 aliases, entry = cmdutil.findcmd(command, table)
186 198 for alias, e in table.iteritems():
187 199 if e is entry:
188 200 key = alias
189 201 break
190 202
191 203 origfn = entry[0]
192 def wrap(*args, **kwargs):
193 return util.checksignature(wrapper)(
194 util.checksignature(origfn), *args, **kwargs)
204 wrap = bind(util.checksignature(wrapper), util.checksignature(origfn))
195 205
196 206 wrap.__module__ = getattr(origfn, '__module__')
197 207
198 208 doc = getattr(origfn, '__doc__')
199 209 if docstring is not None:
200 210 doc += docstring
201 211 wrap.__doc__ = doc
202 212
203 213 newentry = list(entry)
204 214 newentry[0] = wrap
205 215 if synopsis is not None:
206 216 newentry[2] += synopsis
207 217 table[key] = tuple(newentry)
208 218 return entry
209 219
210 220 def wrapfunction(container, funcname, wrapper):
211 221 '''Wrap the function named funcname in container
212 222
213 223 Replace the funcname member in the given container with the specified
214 224 wrapper. The container is typically a module, class, or instance.
215 225
216 226 The wrapper will be called like
217 227
218 228 wrapper(orig, *args, **kwargs)
219 229
220 230 where orig is the original (wrapped) function, and *args, **kwargs
221 231 are the arguments passed to it.
222 232
223 233 Wrapping methods of the repository object is not recommended since
224 234 it conflicts with extensions that extend the repository by
225 235 subclassing. All extensions that need to extend methods of
226 236 localrepository should use this subclassing trick: namely,
227 237 reposetup() should look like
228 238
229 239 def reposetup(ui, repo):
230 240 class myrepo(repo.__class__):
231 241 def whatever(self, *args, **kwargs):
232 242 [...extension stuff...]
233 243 super(myrepo, self).whatever(*args, **kwargs)
234 244 [...extension stuff...]
235 245
236 246 repo.__class__ = myrepo
237 247
238 248 In general, combining wrapfunction() with subclassing does not
239 249 work. Since you cannot control what other extensions are loaded by
240 250 your end users, you should play nicely with others by using the
241 251 subclass trick.
242 252 '''
243 253 assert callable(wrapper)
244 def wrap(*args, **kwargs):
245 return wrapper(origfn, *args, **kwargs)
246 254
247 255 origfn = getattr(container, funcname)
248 256 assert callable(origfn)
249 setattr(container, funcname, wrap)
257 setattr(container, funcname, bind(wrapper, origfn))
250 258 return origfn
251 259
252 260 def _disabledpaths(strip_init=False):
253 261 '''find paths of disabled extensions. returns a dict of {name: path}
254 262 removes /__init__.py from packages if strip_init is True'''
255 263 import hgext
256 264 extpath = os.path.dirname(os.path.abspath(hgext.__file__))
257 265 try: # might not be a filesystem path
258 266 files = os.listdir(extpath)
259 267 except OSError:
260 268 return {}
261 269
262 270 exts = {}
263 271 for e in files:
264 272 if e.endswith('.py'):
265 273 name = e.rsplit('.', 1)[0]
266 274 path = os.path.join(extpath, e)
267 275 else:
268 276 name = e
269 277 path = os.path.join(extpath, e, '__init__.py')
270 278 if not os.path.exists(path):
271 279 continue
272 280 if strip_init:
273 281 path = os.path.dirname(path)
274 282 if name in exts or name in _order or name == '__init__':
275 283 continue
276 284 exts[name] = path
277 285 return exts
278 286
279 287 def _moduledoc(file):
280 288 '''return the top-level python documentation for the given file
281 289
282 290 Loosely inspired by pydoc.source_synopsis(), but rewritten to
283 291 handle triple quotes and to return the whole text instead of just
284 292 the synopsis'''
285 293 result = []
286 294
287 295 line = file.readline()
288 296 while line[:1] == '#' or not line.strip():
289 297 line = file.readline()
290 298 if not line:
291 299 break
292 300
293 301 start = line[:3]
294 302 if start == '"""' or start == "'''":
295 303 line = line[3:]
296 304 while line:
297 305 if line.rstrip().endswith(start):
298 306 line = line.split(start)[0]
299 307 if line:
300 308 result.append(line)
301 309 break
302 310 elif not line:
303 311 return None # unmatched delimiter
304 312 result.append(line)
305 313 line = file.readline()
306 314 else:
307 315 return None
308 316
309 317 return ''.join(result)
310 318
311 319 def _disabledhelp(path):
312 320 '''retrieve help synopsis of a disabled extension (without importing)'''
313 321 try:
314 322 file = open(path)
315 323 except IOError:
316 324 return
317 325 else:
318 326 doc = _moduledoc(file)
319 327 file.close()
320 328
321 329 if doc: # extracting localized synopsis
322 330 return gettext(doc).splitlines()[0]
323 331 else:
324 332 return _('(no help text available)')
325 333
326 334 def disabled():
327 335 '''find disabled extensions from hgext. returns a dict of {name: desc}'''
328 336 try:
329 337 from hgext import __index__
330 338 return dict((name, gettext(desc))
331 339 for name, desc in __index__.docs.iteritems()
332 340 if name not in _order)
333 341 except (ImportError, AttributeError):
334 342 pass
335 343
336 344 paths = _disabledpaths()
337 345 if not paths:
338 346 return {}
339 347
340 348 exts = {}
341 349 for name, path in paths.iteritems():
342 350 doc = _disabledhelp(path)
343 351 if doc:
344 352 exts[name] = doc
345 353
346 354 return exts
347 355
348 356 def disabledext(name):
349 357 '''find a specific disabled extension from hgext. returns desc'''
350 358 try:
351 359 from hgext import __index__
352 360 if name in _order: # enabled
353 361 return
354 362 else:
355 363 return gettext(__index__.docs.get(name))
356 364 except (ImportError, AttributeError):
357 365 pass
358 366
359 367 paths = _disabledpaths()
360 368 if name in paths:
361 369 return _disabledhelp(paths[name])
362 370
363 371 def disabledcmd(ui, cmd, strict=False):
364 372 '''import disabled extensions until cmd is found.
365 373 returns (cmdname, extname, module)'''
366 374
367 375 paths = _disabledpaths(strip_init=True)
368 376 if not paths:
369 377 raise error.UnknownCommand(cmd)
370 378
371 379 def findcmd(cmd, name, path):
372 380 try:
373 381 mod = loadpath(path, 'hgext.%s' % name)
374 382 except Exception:
375 383 return
376 384 try:
377 385 aliases, entry = cmdutil.findcmd(cmd,
378 386 getattr(mod, 'cmdtable', {}), strict)
379 387 except (error.AmbiguousCommand, error.UnknownCommand):
380 388 return
381 389 except Exception:
382 390 ui.warn(_('warning: error finding commands in %s\n') % path)
383 391 ui.traceback()
384 392 return
385 393 for c in aliases:
386 394 if c.startswith(cmd):
387 395 cmd = c
388 396 break
389 397 else:
390 398 cmd = aliases[0]
391 399 return (cmd, name, mod)
392 400
393 401 ext = None
394 402 # first, search for an extension with the same name as the command
395 403 path = paths.pop(cmd, None)
396 404 if path:
397 405 ext = findcmd(cmd, cmd, path)
398 406 if not ext:
399 407 # otherwise, interrogate each extension until there's a match
400 408 for name, path in paths.iteritems():
401 409 ext = findcmd(cmd, name, path)
402 410 if ext:
403 411 break
404 412 if ext and 'DEPRECATED' not in ext.__doc__:
405 413 return ext
406 414
407 415 raise error.UnknownCommand(cmd)
408 416
409 417 def enabled(shortname=True):
410 418 '''return a dict of {name: desc} of extensions'''
411 419 exts = {}
412 420 for ename, ext in extensions():
413 421 doc = (gettext(ext.__doc__) or _('(no help text available)'))
414 422 if shortname:
415 423 ename = ename.split('.')[-1]
416 424 exts[ename] = doc.splitlines()[0].strip()
417 425
418 426 return exts
419 427
420 428 def moduleversion(module):
421 429 '''return version information from given module as a string'''
422 430 if (util.safehasattr(module, 'getversion')
423 431 and callable(module.getversion)):
424 432 version = module.getversion()
425 433 elif util.safehasattr(module, '__version__'):
426 434 version = module.__version__
427 435 else:
428 436 version = ''
429 437 if isinstance(version, (list, tuple)):
430 438 version = '.'.join(str(o) for o in version)
431 439 return version
General Comments 0
You need to be logged in to leave comments. Login now