##// END OF EJS Templates
python3: replace im_{self,func} with __{self,func}__ globally...
Augie Fackler -
r34727:daf12f69 default
parent child Browse files
Show More
@@ -1,1066 +1,1066 b''
1 # wireproto.py - generic wire protocol support functions
1 # wireproto.py - generic wire protocol support functions
2 #
2 #
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import hashlib
10 import hashlib
11 import os
11 import os
12 import tempfile
12 import tempfile
13
13
14 from .i18n import _
14 from .i18n import _
15 from .node import (
15 from .node import (
16 bin,
16 bin,
17 hex,
17 hex,
18 nullid,
18 nullid,
19 )
19 )
20
20
21 from . import (
21 from . import (
22 bundle2,
22 bundle2,
23 changegroup as changegroupmod,
23 changegroup as changegroupmod,
24 discovery,
24 discovery,
25 encoding,
25 encoding,
26 error,
26 error,
27 exchange,
27 exchange,
28 peer,
28 peer,
29 pushkey as pushkeymod,
29 pushkey as pushkeymod,
30 pycompat,
30 pycompat,
31 repository,
31 repository,
32 streamclone,
32 streamclone,
33 util,
33 util,
34 )
34 )
35
35
36 urlerr = util.urlerr
36 urlerr = util.urlerr
37 urlreq = util.urlreq
37 urlreq = util.urlreq
38
38
39 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
39 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
40 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
40 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
41 'IncompatibleClient')
41 'IncompatibleClient')
42 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
42 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
43
43
44 class abstractserverproto(object):
44 class abstractserverproto(object):
45 """abstract class that summarizes the protocol API
45 """abstract class that summarizes the protocol API
46
46
47 Used as reference and documentation.
47 Used as reference and documentation.
48 """
48 """
49
49
50 def getargs(self, args):
50 def getargs(self, args):
51 """return the value for arguments in <args>
51 """return the value for arguments in <args>
52
52
53 returns a list of values (same order as <args>)"""
53 returns a list of values (same order as <args>)"""
54 raise NotImplementedError()
54 raise NotImplementedError()
55
55
56 def getfile(self, fp):
56 def getfile(self, fp):
57 """write the whole content of a file into a file like object
57 """write the whole content of a file into a file like object
58
58
59 The file is in the form::
59 The file is in the form::
60
60
61 (<chunk-size>\n<chunk>)+0\n
61 (<chunk-size>\n<chunk>)+0\n
62
62
63 chunk size is the ascii version of the int.
63 chunk size is the ascii version of the int.
64 """
64 """
65 raise NotImplementedError()
65 raise NotImplementedError()
66
66
67 def redirect(self):
67 def redirect(self):
68 """may setup interception for stdout and stderr
68 """may setup interception for stdout and stderr
69
69
70 See also the `restore` method."""
70 See also the `restore` method."""
71 raise NotImplementedError()
71 raise NotImplementedError()
72
72
73 # If the `redirect` function does install interception, the `restore`
73 # If the `redirect` function does install interception, the `restore`
74 # function MUST be defined. If interception is not used, this function
74 # function MUST be defined. If interception is not used, this function
75 # MUST NOT be defined.
75 # MUST NOT be defined.
76 #
76 #
77 # left commented here on purpose
77 # left commented here on purpose
78 #
78 #
79 #def restore(self):
79 #def restore(self):
80 # """reinstall previous stdout and stderr and return intercepted stdout
80 # """reinstall previous stdout and stderr and return intercepted stdout
81 # """
81 # """
82 # raise NotImplementedError()
82 # raise NotImplementedError()
83
83
84 class remoteiterbatcher(peer.iterbatcher):
84 class remoteiterbatcher(peer.iterbatcher):
85 def __init__(self, remote):
85 def __init__(self, remote):
86 super(remoteiterbatcher, self).__init__()
86 super(remoteiterbatcher, self).__init__()
87 self._remote = remote
87 self._remote = remote
88
88
89 def __getattr__(self, name):
89 def __getattr__(self, name):
90 # Validate this method is batchable, since submit() only supports
90 # Validate this method is batchable, since submit() only supports
91 # batchable methods.
91 # batchable methods.
92 fn = getattr(self._remote, name)
92 fn = getattr(self._remote, name)
93 if not getattr(fn, 'batchable', None):
93 if not getattr(fn, 'batchable', None):
94 raise error.ProgrammingError('Attempted to batch a non-batchable '
94 raise error.ProgrammingError('Attempted to batch a non-batchable '
95 'call to %r' % name)
95 'call to %r' % name)
96
96
97 return super(remoteiterbatcher, self).__getattr__(name)
97 return super(remoteiterbatcher, self).__getattr__(name)
98
98
99 def submit(self):
99 def submit(self):
100 """Break the batch request into many patch calls and pipeline them.
100 """Break the batch request into many patch calls and pipeline them.
101
101
102 This is mostly valuable over http where request sizes can be
102 This is mostly valuable over http where request sizes can be
103 limited, but can be used in other places as well.
103 limited, but can be used in other places as well.
104 """
104 """
105 # 2-tuple of (command, arguments) that represents what will be
105 # 2-tuple of (command, arguments) that represents what will be
106 # sent over the wire.
106 # sent over the wire.
107 requests = []
107 requests = []
108
108
109 # 4-tuple of (command, final future, @batchable generator, remote
109 # 4-tuple of (command, final future, @batchable generator, remote
110 # future).
110 # future).
111 results = []
111 results = []
112
112
113 for command, args, opts, finalfuture in self.calls:
113 for command, args, opts, finalfuture in self.calls:
114 mtd = getattr(self._remote, command)
114 mtd = getattr(self._remote, command)
115 batchable = mtd.batchable(mtd.im_self, *args, **opts)
115 batchable = mtd.batchable(mtd.__self__, *args, **opts)
116
116
117 commandargs, fremote = next(batchable)
117 commandargs, fremote = next(batchable)
118 assert fremote
118 assert fremote
119 requests.append((command, commandargs))
119 requests.append((command, commandargs))
120 results.append((command, finalfuture, batchable, fremote))
120 results.append((command, finalfuture, batchable, fremote))
121
121
122 if requests:
122 if requests:
123 self._resultiter = self._remote._submitbatch(requests)
123 self._resultiter = self._remote._submitbatch(requests)
124
124
125 self._results = results
125 self._results = results
126
126
127 def results(self):
127 def results(self):
128 for command, finalfuture, batchable, remotefuture in self._results:
128 for command, finalfuture, batchable, remotefuture in self._results:
129 # Get the raw result, set it in the remote future, feed it
129 # Get the raw result, set it in the remote future, feed it
130 # back into the @batchable generator so it can be decoded, and
130 # back into the @batchable generator so it can be decoded, and
131 # set the result on the final future to this value.
131 # set the result on the final future to this value.
132 remoteresult = next(self._resultiter)
132 remoteresult = next(self._resultiter)
133 remotefuture.set(remoteresult)
133 remotefuture.set(remoteresult)
134 finalfuture.set(next(batchable))
134 finalfuture.set(next(batchable))
135
135
136 # Verify our @batchable generators only emit 2 values.
136 # Verify our @batchable generators only emit 2 values.
137 try:
137 try:
138 next(batchable)
138 next(batchable)
139 except StopIteration:
139 except StopIteration:
140 pass
140 pass
141 else:
141 else:
142 raise error.ProgrammingError('%s @batchable generator emitted '
142 raise error.ProgrammingError('%s @batchable generator emitted '
143 'unexpected value count' % command)
143 'unexpected value count' % command)
144
144
145 yield finalfuture.value
145 yield finalfuture.value
146
146
147 # Forward a couple of names from peer to make wireproto interactions
147 # Forward a couple of names from peer to make wireproto interactions
148 # slightly more sensible.
148 # slightly more sensible.
149 batchable = peer.batchable
149 batchable = peer.batchable
150 future = peer.future
150 future = peer.future
151
151
152 # list of nodes encoding / decoding
152 # list of nodes encoding / decoding
153
153
154 def decodelist(l, sep=' '):
154 def decodelist(l, sep=' '):
155 if l:
155 if l:
156 return map(bin, l.split(sep))
156 return map(bin, l.split(sep))
157 return []
157 return []
158
158
159 def encodelist(l, sep=' '):
159 def encodelist(l, sep=' '):
160 try:
160 try:
161 return sep.join(map(hex, l))
161 return sep.join(map(hex, l))
162 except TypeError:
162 except TypeError:
163 raise
163 raise
164
164
165 # batched call argument encoding
165 # batched call argument encoding
166
166
167 def escapearg(plain):
167 def escapearg(plain):
168 return (plain
168 return (plain
169 .replace(':', ':c')
169 .replace(':', ':c')
170 .replace(',', ':o')
170 .replace(',', ':o')
171 .replace(';', ':s')
171 .replace(';', ':s')
172 .replace('=', ':e'))
172 .replace('=', ':e'))
173
173
174 def unescapearg(escaped):
174 def unescapearg(escaped):
175 return (escaped
175 return (escaped
176 .replace(':e', '=')
176 .replace(':e', '=')
177 .replace(':s', ';')
177 .replace(':s', ';')
178 .replace(':o', ',')
178 .replace(':o', ',')
179 .replace(':c', ':'))
179 .replace(':c', ':'))
180
180
181 def encodebatchcmds(req):
181 def encodebatchcmds(req):
182 """Return a ``cmds`` argument value for the ``batch`` command."""
182 """Return a ``cmds`` argument value for the ``batch`` command."""
183 cmds = []
183 cmds = []
184 for op, argsdict in req:
184 for op, argsdict in req:
185 # Old servers didn't properly unescape argument names. So prevent
185 # Old servers didn't properly unescape argument names. So prevent
186 # the sending of argument names that may not be decoded properly by
186 # the sending of argument names that may not be decoded properly by
187 # servers.
187 # servers.
188 assert all(escapearg(k) == k for k in argsdict)
188 assert all(escapearg(k) == k for k in argsdict)
189
189
190 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
190 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
191 for k, v in argsdict.iteritems())
191 for k, v in argsdict.iteritems())
192 cmds.append('%s %s' % (op, args))
192 cmds.append('%s %s' % (op, args))
193
193
194 return ';'.join(cmds)
194 return ';'.join(cmds)
195
195
196 # mapping of options accepted by getbundle and their types
196 # mapping of options accepted by getbundle and their types
197 #
197 #
198 # Meant to be extended by extensions. It is extensions responsibility to ensure
198 # Meant to be extended by extensions. It is extensions responsibility to ensure
199 # such options are properly processed in exchange.getbundle.
199 # such options are properly processed in exchange.getbundle.
200 #
200 #
201 # supported types are:
201 # supported types are:
202 #
202 #
203 # :nodes: list of binary nodes
203 # :nodes: list of binary nodes
204 # :csv: list of comma-separated values
204 # :csv: list of comma-separated values
205 # :scsv: list of comma-separated values return as set
205 # :scsv: list of comma-separated values return as set
206 # :plain: string with no transformation needed.
206 # :plain: string with no transformation needed.
207 gboptsmap = {'heads': 'nodes',
207 gboptsmap = {'heads': 'nodes',
208 'common': 'nodes',
208 'common': 'nodes',
209 'obsmarkers': 'boolean',
209 'obsmarkers': 'boolean',
210 'phases': 'boolean',
210 'phases': 'boolean',
211 'bundlecaps': 'scsv',
211 'bundlecaps': 'scsv',
212 'listkeys': 'csv',
212 'listkeys': 'csv',
213 'cg': 'boolean',
213 'cg': 'boolean',
214 'cbattempted': 'boolean'}
214 'cbattempted': 'boolean'}
215
215
216 # client side
216 # client side
217
217
218 class wirepeer(repository.legacypeer):
218 class wirepeer(repository.legacypeer):
219 """Client-side interface for communicating with a peer repository.
219 """Client-side interface for communicating with a peer repository.
220
220
221 Methods commonly call wire protocol commands of the same name.
221 Methods commonly call wire protocol commands of the same name.
222
222
223 See also httppeer.py and sshpeer.py for protocol-specific
223 See also httppeer.py and sshpeer.py for protocol-specific
224 implementations of this interface.
224 implementations of this interface.
225 """
225 """
226 # Begin of basewirepeer interface.
226 # Begin of basewirepeer interface.
227
227
228 def iterbatch(self):
228 def iterbatch(self):
229 return remoteiterbatcher(self)
229 return remoteiterbatcher(self)
230
230
231 @batchable
231 @batchable
232 def lookup(self, key):
232 def lookup(self, key):
233 self.requirecap('lookup', _('look up remote revision'))
233 self.requirecap('lookup', _('look up remote revision'))
234 f = future()
234 f = future()
235 yield {'key': encoding.fromlocal(key)}, f
235 yield {'key': encoding.fromlocal(key)}, f
236 d = f.value
236 d = f.value
237 success, data = d[:-1].split(" ", 1)
237 success, data = d[:-1].split(" ", 1)
238 if int(success):
238 if int(success):
239 yield bin(data)
239 yield bin(data)
240 else:
240 else:
241 self._abort(error.RepoError(data))
241 self._abort(error.RepoError(data))
242
242
243 @batchable
243 @batchable
244 def heads(self):
244 def heads(self):
245 f = future()
245 f = future()
246 yield {}, f
246 yield {}, f
247 d = f.value
247 d = f.value
248 try:
248 try:
249 yield decodelist(d[:-1])
249 yield decodelist(d[:-1])
250 except ValueError:
250 except ValueError:
251 self._abort(error.ResponseError(_("unexpected response:"), d))
251 self._abort(error.ResponseError(_("unexpected response:"), d))
252
252
253 @batchable
253 @batchable
254 def known(self, nodes):
254 def known(self, nodes):
255 f = future()
255 f = future()
256 yield {'nodes': encodelist(nodes)}, f
256 yield {'nodes': encodelist(nodes)}, f
257 d = f.value
257 d = f.value
258 try:
258 try:
259 yield [bool(int(b)) for b in d]
259 yield [bool(int(b)) for b in d]
260 except ValueError:
260 except ValueError:
261 self._abort(error.ResponseError(_("unexpected response:"), d))
261 self._abort(error.ResponseError(_("unexpected response:"), d))
262
262
263 @batchable
263 @batchable
264 def branchmap(self):
264 def branchmap(self):
265 f = future()
265 f = future()
266 yield {}, f
266 yield {}, f
267 d = f.value
267 d = f.value
268 try:
268 try:
269 branchmap = {}
269 branchmap = {}
270 for branchpart in d.splitlines():
270 for branchpart in d.splitlines():
271 branchname, branchheads = branchpart.split(' ', 1)
271 branchname, branchheads = branchpart.split(' ', 1)
272 branchname = encoding.tolocal(urlreq.unquote(branchname))
272 branchname = encoding.tolocal(urlreq.unquote(branchname))
273 branchheads = decodelist(branchheads)
273 branchheads = decodelist(branchheads)
274 branchmap[branchname] = branchheads
274 branchmap[branchname] = branchheads
275 yield branchmap
275 yield branchmap
276 except TypeError:
276 except TypeError:
277 self._abort(error.ResponseError(_("unexpected response:"), d))
277 self._abort(error.ResponseError(_("unexpected response:"), d))
278
278
279 @batchable
279 @batchable
280 def listkeys(self, namespace):
280 def listkeys(self, namespace):
281 if not self.capable('pushkey'):
281 if not self.capable('pushkey'):
282 yield {}, None
282 yield {}, None
283 f = future()
283 f = future()
284 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
284 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
285 yield {'namespace': encoding.fromlocal(namespace)}, f
285 yield {'namespace': encoding.fromlocal(namespace)}, f
286 d = f.value
286 d = f.value
287 self.ui.debug('received listkey for "%s": %i bytes\n'
287 self.ui.debug('received listkey for "%s": %i bytes\n'
288 % (namespace, len(d)))
288 % (namespace, len(d)))
289 yield pushkeymod.decodekeys(d)
289 yield pushkeymod.decodekeys(d)
290
290
291 @batchable
291 @batchable
292 def pushkey(self, namespace, key, old, new):
292 def pushkey(self, namespace, key, old, new):
293 if not self.capable('pushkey'):
293 if not self.capable('pushkey'):
294 yield False, None
294 yield False, None
295 f = future()
295 f = future()
296 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
296 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
297 yield {'namespace': encoding.fromlocal(namespace),
297 yield {'namespace': encoding.fromlocal(namespace),
298 'key': encoding.fromlocal(key),
298 'key': encoding.fromlocal(key),
299 'old': encoding.fromlocal(old),
299 'old': encoding.fromlocal(old),
300 'new': encoding.fromlocal(new)}, f
300 'new': encoding.fromlocal(new)}, f
301 d = f.value
301 d = f.value
302 d, output = d.split('\n', 1)
302 d, output = d.split('\n', 1)
303 try:
303 try:
304 d = bool(int(d))
304 d = bool(int(d))
305 except ValueError:
305 except ValueError:
306 raise error.ResponseError(
306 raise error.ResponseError(
307 _('push failed (unexpected response):'), d)
307 _('push failed (unexpected response):'), d)
308 for l in output.splitlines(True):
308 for l in output.splitlines(True):
309 self.ui.status(_('remote: '), l)
309 self.ui.status(_('remote: '), l)
310 yield d
310 yield d
311
311
312 def stream_out(self):
312 def stream_out(self):
313 return self._callstream('stream_out')
313 return self._callstream('stream_out')
314
314
315 def getbundle(self, source, **kwargs):
315 def getbundle(self, source, **kwargs):
316 self.requirecap('getbundle', _('look up remote changes'))
316 self.requirecap('getbundle', _('look up remote changes'))
317 opts = {}
317 opts = {}
318 bundlecaps = kwargs.get('bundlecaps')
318 bundlecaps = kwargs.get('bundlecaps')
319 if bundlecaps is not None:
319 if bundlecaps is not None:
320 kwargs['bundlecaps'] = sorted(bundlecaps)
320 kwargs['bundlecaps'] = sorted(bundlecaps)
321 else:
321 else:
322 bundlecaps = () # kwargs could have it to None
322 bundlecaps = () # kwargs could have it to None
323 for key, value in kwargs.iteritems():
323 for key, value in kwargs.iteritems():
324 if value is None:
324 if value is None:
325 continue
325 continue
326 keytype = gboptsmap.get(key)
326 keytype = gboptsmap.get(key)
327 if keytype is None:
327 if keytype is None:
328 assert False, 'unexpected'
328 assert False, 'unexpected'
329 elif keytype == 'nodes':
329 elif keytype == 'nodes':
330 value = encodelist(value)
330 value = encodelist(value)
331 elif keytype in ('csv', 'scsv'):
331 elif keytype in ('csv', 'scsv'):
332 value = ','.join(value)
332 value = ','.join(value)
333 elif keytype == 'boolean':
333 elif keytype == 'boolean':
334 value = '%i' % bool(value)
334 value = '%i' % bool(value)
335 elif keytype != 'plain':
335 elif keytype != 'plain':
336 raise KeyError('unknown getbundle option type %s'
336 raise KeyError('unknown getbundle option type %s'
337 % keytype)
337 % keytype)
338 opts[key] = value
338 opts[key] = value
339 f = self._callcompressable("getbundle", **opts)
339 f = self._callcompressable("getbundle", **opts)
340 if any((cap.startswith('HG2') for cap in bundlecaps)):
340 if any((cap.startswith('HG2') for cap in bundlecaps)):
341 return bundle2.getunbundler(self.ui, f)
341 return bundle2.getunbundler(self.ui, f)
342 else:
342 else:
343 return changegroupmod.cg1unpacker(f, 'UN')
343 return changegroupmod.cg1unpacker(f, 'UN')
344
344
345 def unbundle(self, cg, heads, url):
345 def unbundle(self, cg, heads, url):
346 '''Send cg (a readable file-like object representing the
346 '''Send cg (a readable file-like object representing the
347 changegroup to push, typically a chunkbuffer object) to the
347 changegroup to push, typically a chunkbuffer object) to the
348 remote server as a bundle.
348 remote server as a bundle.
349
349
350 When pushing a bundle10 stream, return an integer indicating the
350 When pushing a bundle10 stream, return an integer indicating the
351 result of the push (see changegroup.apply()).
351 result of the push (see changegroup.apply()).
352
352
353 When pushing a bundle20 stream, return a bundle20 stream.
353 When pushing a bundle20 stream, return a bundle20 stream.
354
354
355 `url` is the url the client thinks it's pushing to, which is
355 `url` is the url the client thinks it's pushing to, which is
356 visible to hooks.
356 visible to hooks.
357 '''
357 '''
358
358
359 if heads != ['force'] and self.capable('unbundlehash'):
359 if heads != ['force'] and self.capable('unbundlehash'):
360 heads = encodelist(['hashed',
360 heads = encodelist(['hashed',
361 hashlib.sha1(''.join(sorted(heads))).digest()])
361 hashlib.sha1(''.join(sorted(heads))).digest()])
362 else:
362 else:
363 heads = encodelist(heads)
363 heads = encodelist(heads)
364
364
365 if util.safehasattr(cg, 'deltaheader'):
365 if util.safehasattr(cg, 'deltaheader'):
366 # this a bundle10, do the old style call sequence
366 # this a bundle10, do the old style call sequence
367 ret, output = self._callpush("unbundle", cg, heads=heads)
367 ret, output = self._callpush("unbundle", cg, heads=heads)
368 if ret == "":
368 if ret == "":
369 raise error.ResponseError(
369 raise error.ResponseError(
370 _('push failed:'), output)
370 _('push failed:'), output)
371 try:
371 try:
372 ret = int(ret)
372 ret = int(ret)
373 except ValueError:
373 except ValueError:
374 raise error.ResponseError(
374 raise error.ResponseError(
375 _('push failed (unexpected response):'), ret)
375 _('push failed (unexpected response):'), ret)
376
376
377 for l in output.splitlines(True):
377 for l in output.splitlines(True):
378 self.ui.status(_('remote: '), l)
378 self.ui.status(_('remote: '), l)
379 else:
379 else:
380 # bundle2 push. Send a stream, fetch a stream.
380 # bundle2 push. Send a stream, fetch a stream.
381 stream = self._calltwowaystream('unbundle', cg, heads=heads)
381 stream = self._calltwowaystream('unbundle', cg, heads=heads)
382 ret = bundle2.getunbundler(self.ui, stream)
382 ret = bundle2.getunbundler(self.ui, stream)
383 return ret
383 return ret
384
384
385 # End of basewirepeer interface.
385 # End of basewirepeer interface.
386
386
387 # Begin of baselegacywirepeer interface.
387 # Begin of baselegacywirepeer interface.
388
388
389 def branches(self, nodes):
389 def branches(self, nodes):
390 n = encodelist(nodes)
390 n = encodelist(nodes)
391 d = self._call("branches", nodes=n)
391 d = self._call("branches", nodes=n)
392 try:
392 try:
393 br = [tuple(decodelist(b)) for b in d.splitlines()]
393 br = [tuple(decodelist(b)) for b in d.splitlines()]
394 return br
394 return br
395 except ValueError:
395 except ValueError:
396 self._abort(error.ResponseError(_("unexpected response:"), d))
396 self._abort(error.ResponseError(_("unexpected response:"), d))
397
397
398 def between(self, pairs):
398 def between(self, pairs):
399 batch = 8 # avoid giant requests
399 batch = 8 # avoid giant requests
400 r = []
400 r = []
401 for i in xrange(0, len(pairs), batch):
401 for i in xrange(0, len(pairs), batch):
402 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
402 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
403 d = self._call("between", pairs=n)
403 d = self._call("between", pairs=n)
404 try:
404 try:
405 r.extend(l and decodelist(l) or [] for l in d.splitlines())
405 r.extend(l and decodelist(l) or [] for l in d.splitlines())
406 except ValueError:
406 except ValueError:
407 self._abort(error.ResponseError(_("unexpected response:"), d))
407 self._abort(error.ResponseError(_("unexpected response:"), d))
408 return r
408 return r
409
409
410 def changegroup(self, nodes, kind):
410 def changegroup(self, nodes, kind):
411 n = encodelist(nodes)
411 n = encodelist(nodes)
412 f = self._callcompressable("changegroup", roots=n)
412 f = self._callcompressable("changegroup", roots=n)
413 return changegroupmod.cg1unpacker(f, 'UN')
413 return changegroupmod.cg1unpacker(f, 'UN')
414
414
415 def changegroupsubset(self, bases, heads, kind):
415 def changegroupsubset(self, bases, heads, kind):
416 self.requirecap('changegroupsubset', _('look up remote changes'))
416 self.requirecap('changegroupsubset', _('look up remote changes'))
417 bases = encodelist(bases)
417 bases = encodelist(bases)
418 heads = encodelist(heads)
418 heads = encodelist(heads)
419 f = self._callcompressable("changegroupsubset",
419 f = self._callcompressable("changegroupsubset",
420 bases=bases, heads=heads)
420 bases=bases, heads=heads)
421 return changegroupmod.cg1unpacker(f, 'UN')
421 return changegroupmod.cg1unpacker(f, 'UN')
422
422
423 # End of baselegacywirepeer interface.
423 # End of baselegacywirepeer interface.
424
424
425 def _submitbatch(self, req):
425 def _submitbatch(self, req):
426 """run batch request <req> on the server
426 """run batch request <req> on the server
427
427
428 Returns an iterator of the raw responses from the server.
428 Returns an iterator of the raw responses from the server.
429 """
429 """
430 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
430 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
431 chunk = rsp.read(1024)
431 chunk = rsp.read(1024)
432 work = [chunk]
432 work = [chunk]
433 while chunk:
433 while chunk:
434 while ';' not in chunk and chunk:
434 while ';' not in chunk and chunk:
435 chunk = rsp.read(1024)
435 chunk = rsp.read(1024)
436 work.append(chunk)
436 work.append(chunk)
437 merged = ''.join(work)
437 merged = ''.join(work)
438 while ';' in merged:
438 while ';' in merged:
439 one, merged = merged.split(';', 1)
439 one, merged = merged.split(';', 1)
440 yield unescapearg(one)
440 yield unescapearg(one)
441 chunk = rsp.read(1024)
441 chunk = rsp.read(1024)
442 work = [merged, chunk]
442 work = [merged, chunk]
443 yield unescapearg(''.join(work))
443 yield unescapearg(''.join(work))
444
444
445 def _submitone(self, op, args):
445 def _submitone(self, op, args):
446 return self._call(op, **args)
446 return self._call(op, **args)
447
447
448 def debugwireargs(self, one, two, three=None, four=None, five=None):
448 def debugwireargs(self, one, two, three=None, four=None, five=None):
449 # don't pass optional arguments left at their default value
449 # don't pass optional arguments left at their default value
450 opts = {}
450 opts = {}
451 if three is not None:
451 if three is not None:
452 opts['three'] = three
452 opts['three'] = three
453 if four is not None:
453 if four is not None:
454 opts['four'] = four
454 opts['four'] = four
455 return self._call('debugwireargs', one=one, two=two, **opts)
455 return self._call('debugwireargs', one=one, two=two, **opts)
456
456
457 def _call(self, cmd, **args):
457 def _call(self, cmd, **args):
458 """execute <cmd> on the server
458 """execute <cmd> on the server
459
459
460 The command is expected to return a simple string.
460 The command is expected to return a simple string.
461
461
462 returns the server reply as a string."""
462 returns the server reply as a string."""
463 raise NotImplementedError()
463 raise NotImplementedError()
464
464
465 def _callstream(self, cmd, **args):
465 def _callstream(self, cmd, **args):
466 """execute <cmd> on the server
466 """execute <cmd> on the server
467
467
468 The command is expected to return a stream. Note that if the
468 The command is expected to return a stream. Note that if the
469 command doesn't return a stream, _callstream behaves
469 command doesn't return a stream, _callstream behaves
470 differently for ssh and http peers.
470 differently for ssh and http peers.
471
471
472 returns the server reply as a file like object.
472 returns the server reply as a file like object.
473 """
473 """
474 raise NotImplementedError()
474 raise NotImplementedError()
475
475
476 def _callcompressable(self, cmd, **args):
476 def _callcompressable(self, cmd, **args):
477 """execute <cmd> on the server
477 """execute <cmd> on the server
478
478
479 The command is expected to return a stream.
479 The command is expected to return a stream.
480
480
481 The stream may have been compressed in some implementations. This
481 The stream may have been compressed in some implementations. This
482 function takes care of the decompression. This is the only difference
482 function takes care of the decompression. This is the only difference
483 with _callstream.
483 with _callstream.
484
484
485 returns the server reply as a file like object.
485 returns the server reply as a file like object.
486 """
486 """
487 raise NotImplementedError()
487 raise NotImplementedError()
488
488
489 def _callpush(self, cmd, fp, **args):
489 def _callpush(self, cmd, fp, **args):
490 """execute a <cmd> on server
490 """execute a <cmd> on server
491
491
492 The command is expected to be related to a push. Push has a special
492 The command is expected to be related to a push. Push has a special
493 return method.
493 return method.
494
494
495 returns the server reply as a (ret, output) tuple. ret is either
495 returns the server reply as a (ret, output) tuple. ret is either
496 empty (error) or a stringified int.
496 empty (error) or a stringified int.
497 """
497 """
498 raise NotImplementedError()
498 raise NotImplementedError()
499
499
500 def _calltwowaystream(self, cmd, fp, **args):
500 def _calltwowaystream(self, cmd, fp, **args):
501 """execute <cmd> on server
501 """execute <cmd> on server
502
502
503 The command will send a stream to the server and get a stream in reply.
503 The command will send a stream to the server and get a stream in reply.
504 """
504 """
505 raise NotImplementedError()
505 raise NotImplementedError()
506
506
507 def _abort(self, exception):
507 def _abort(self, exception):
508 """clearly abort the wire protocol connection and raise the exception
508 """clearly abort the wire protocol connection and raise the exception
509 """
509 """
510 raise NotImplementedError()
510 raise NotImplementedError()
511
511
512 # server side
512 # server side
513
513
514 # wire protocol command can either return a string or one of these classes.
514 # wire protocol command can either return a string or one of these classes.
515 class streamres(object):
515 class streamres(object):
516 """wireproto reply: binary stream
516 """wireproto reply: binary stream
517
517
518 The call was successful and the result is a stream.
518 The call was successful and the result is a stream.
519
519
520 Accepts either a generator or an object with a ``read(size)`` method.
520 Accepts either a generator or an object with a ``read(size)`` method.
521
521
522 ``v1compressible`` indicates whether this data can be compressed to
522 ``v1compressible`` indicates whether this data can be compressed to
523 "version 1" clients (technically: HTTP peers using
523 "version 1" clients (technically: HTTP peers using
524 application/mercurial-0.1 media type). This flag should NOT be used on
524 application/mercurial-0.1 media type). This flag should NOT be used on
525 new commands because new clients should support a more modern compression
525 new commands because new clients should support a more modern compression
526 mechanism.
526 mechanism.
527 """
527 """
528 def __init__(self, gen=None, reader=None, v1compressible=False):
528 def __init__(self, gen=None, reader=None, v1compressible=False):
529 self.gen = gen
529 self.gen = gen
530 self.reader = reader
530 self.reader = reader
531 self.v1compressible = v1compressible
531 self.v1compressible = v1compressible
532
532
533 class pushres(object):
533 class pushres(object):
534 """wireproto reply: success with simple integer return
534 """wireproto reply: success with simple integer return
535
535
536 The call was successful and returned an integer contained in `self.res`.
536 The call was successful and returned an integer contained in `self.res`.
537 """
537 """
538 def __init__(self, res):
538 def __init__(self, res):
539 self.res = res
539 self.res = res
540
540
541 class pusherr(object):
541 class pusherr(object):
542 """wireproto reply: failure
542 """wireproto reply: failure
543
543
544 The call failed. The `self.res` attribute contains the error message.
544 The call failed. The `self.res` attribute contains the error message.
545 """
545 """
546 def __init__(self, res):
546 def __init__(self, res):
547 self.res = res
547 self.res = res
548
548
549 class ooberror(object):
549 class ooberror(object):
550 """wireproto reply: failure of a batch of operation
550 """wireproto reply: failure of a batch of operation
551
551
552 Something failed during a batch call. The error message is stored in
552 Something failed during a batch call. The error message is stored in
553 `self.message`.
553 `self.message`.
554 """
554 """
555 def __init__(self, message):
555 def __init__(self, message):
556 self.message = message
556 self.message = message
557
557
558 def getdispatchrepo(repo, proto, command):
558 def getdispatchrepo(repo, proto, command):
559 """Obtain the repo used for processing wire protocol commands.
559 """Obtain the repo used for processing wire protocol commands.
560
560
561 The intent of this function is to serve as a monkeypatch point for
561 The intent of this function is to serve as a monkeypatch point for
562 extensions that need commands to operate on different repo views under
562 extensions that need commands to operate on different repo views under
563 specialized circumstances.
563 specialized circumstances.
564 """
564 """
565 return repo.filtered('served')
565 return repo.filtered('served')
566
566
567 def dispatch(repo, proto, command):
567 def dispatch(repo, proto, command):
568 repo = getdispatchrepo(repo, proto, command)
568 repo = getdispatchrepo(repo, proto, command)
569 func, spec = commands[command]
569 func, spec = commands[command]
570 args = proto.getargs(spec)
570 args = proto.getargs(spec)
571 return func(repo, proto, *args)
571 return func(repo, proto, *args)
572
572
573 def options(cmd, keys, others):
573 def options(cmd, keys, others):
574 opts = {}
574 opts = {}
575 for k in keys:
575 for k in keys:
576 if k in others:
576 if k in others:
577 opts[k] = others[k]
577 opts[k] = others[k]
578 del others[k]
578 del others[k]
579 if others:
579 if others:
580 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
580 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
581 % (cmd, ",".join(others)))
581 % (cmd, ",".join(others)))
582 return opts
582 return opts
583
583
584 def bundle1allowed(repo, action):
584 def bundle1allowed(repo, action):
585 """Whether a bundle1 operation is allowed from the server.
585 """Whether a bundle1 operation is allowed from the server.
586
586
587 Priority is:
587 Priority is:
588
588
589 1. server.bundle1gd.<action> (if generaldelta active)
589 1. server.bundle1gd.<action> (if generaldelta active)
590 2. server.bundle1.<action>
590 2. server.bundle1.<action>
591 3. server.bundle1gd (if generaldelta active)
591 3. server.bundle1gd (if generaldelta active)
592 4. server.bundle1
592 4. server.bundle1
593 """
593 """
594 ui = repo.ui
594 ui = repo.ui
595 gd = 'generaldelta' in repo.requirements
595 gd = 'generaldelta' in repo.requirements
596
596
597 if gd:
597 if gd:
598 v = ui.configbool('server', 'bundle1gd.%s' % action)
598 v = ui.configbool('server', 'bundle1gd.%s' % action)
599 if v is not None:
599 if v is not None:
600 return v
600 return v
601
601
602 v = ui.configbool('server', 'bundle1.%s' % action)
602 v = ui.configbool('server', 'bundle1.%s' % action)
603 if v is not None:
603 if v is not None:
604 return v
604 return v
605
605
606 if gd:
606 if gd:
607 v = ui.configbool('server', 'bundle1gd')
607 v = ui.configbool('server', 'bundle1gd')
608 if v is not None:
608 if v is not None:
609 return v
609 return v
610
610
611 return ui.configbool('server', 'bundle1')
611 return ui.configbool('server', 'bundle1')
612
612
613 def supportedcompengines(ui, proto, role):
613 def supportedcompengines(ui, proto, role):
614 """Obtain the list of supported compression engines for a request."""
614 """Obtain the list of supported compression engines for a request."""
615 assert role in (util.CLIENTROLE, util.SERVERROLE)
615 assert role in (util.CLIENTROLE, util.SERVERROLE)
616
616
617 compengines = util.compengines.supportedwireengines(role)
617 compengines = util.compengines.supportedwireengines(role)
618
618
619 # Allow config to override default list and ordering.
619 # Allow config to override default list and ordering.
620 if role == util.SERVERROLE:
620 if role == util.SERVERROLE:
621 configengines = ui.configlist('server', 'compressionengines')
621 configengines = ui.configlist('server', 'compressionengines')
622 config = 'server.compressionengines'
622 config = 'server.compressionengines'
623 else:
623 else:
624 # This is currently implemented mainly to facilitate testing. In most
624 # This is currently implemented mainly to facilitate testing. In most
625 # cases, the server should be in charge of choosing a compression engine
625 # cases, the server should be in charge of choosing a compression engine
626 # because a server has the most to lose from a sub-optimal choice. (e.g.
626 # because a server has the most to lose from a sub-optimal choice. (e.g.
627 # CPU DoS due to an expensive engine or a network DoS due to poor
627 # CPU DoS due to an expensive engine or a network DoS due to poor
628 # compression ratio).
628 # compression ratio).
629 configengines = ui.configlist('experimental',
629 configengines = ui.configlist('experimental',
630 'clientcompressionengines')
630 'clientcompressionengines')
631 config = 'experimental.clientcompressionengines'
631 config = 'experimental.clientcompressionengines'
632
632
633 # No explicit config. Filter out the ones that aren't supposed to be
633 # No explicit config. Filter out the ones that aren't supposed to be
634 # advertised and return default ordering.
634 # advertised and return default ordering.
635 if not configengines:
635 if not configengines:
636 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
636 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
637 return [e for e in compengines
637 return [e for e in compengines
638 if getattr(e.wireprotosupport(), attr) > 0]
638 if getattr(e.wireprotosupport(), attr) > 0]
639
639
640 # If compression engines are listed in the config, assume there is a good
640 # If compression engines are listed in the config, assume there is a good
641 # reason for it (like server operators wanting to achieve specific
641 # reason for it (like server operators wanting to achieve specific
642 # performance characteristics). So fail fast if the config references
642 # performance characteristics). So fail fast if the config references
643 # unusable compression engines.
643 # unusable compression engines.
644 validnames = set(e.name() for e in compengines)
644 validnames = set(e.name() for e in compengines)
645 invalidnames = set(e for e in configengines if e not in validnames)
645 invalidnames = set(e for e in configengines if e not in validnames)
646 if invalidnames:
646 if invalidnames:
647 raise error.Abort(_('invalid compression engine defined in %s: %s') %
647 raise error.Abort(_('invalid compression engine defined in %s: %s') %
648 (config, ', '.join(sorted(invalidnames))))
648 (config, ', '.join(sorted(invalidnames))))
649
649
650 compengines = [e for e in compengines if e.name() in configengines]
650 compengines = [e for e in compengines if e.name() in configengines]
651 compengines = sorted(compengines,
651 compengines = sorted(compengines,
652 key=lambda e: configengines.index(e.name()))
652 key=lambda e: configengines.index(e.name()))
653
653
654 if not compengines:
654 if not compengines:
655 raise error.Abort(_('%s config option does not specify any known '
655 raise error.Abort(_('%s config option does not specify any known '
656 'compression engines') % config,
656 'compression engines') % config,
657 hint=_('usable compression engines: %s') %
657 hint=_('usable compression engines: %s') %
658 ', '.sorted(validnames))
658 ', '.sorted(validnames))
659
659
660 return compengines
660 return compengines
661
661
662 # list of commands
662 # list of commands
663 commands = {}
663 commands = {}
664
664
665 def wireprotocommand(name, args=''):
665 def wireprotocommand(name, args=''):
666 """decorator for wire protocol command"""
666 """decorator for wire protocol command"""
667 def register(func):
667 def register(func):
668 commands[name] = (func, args)
668 commands[name] = (func, args)
669 return func
669 return func
670 return register
670 return register
671
671
672 @wireprotocommand('batch', 'cmds *')
672 @wireprotocommand('batch', 'cmds *')
673 def batch(repo, proto, cmds, others):
673 def batch(repo, proto, cmds, others):
674 repo = repo.filtered("served")
674 repo = repo.filtered("served")
675 res = []
675 res = []
676 for pair in cmds.split(';'):
676 for pair in cmds.split(';'):
677 op, args = pair.split(' ', 1)
677 op, args = pair.split(' ', 1)
678 vals = {}
678 vals = {}
679 for a in args.split(','):
679 for a in args.split(','):
680 if a:
680 if a:
681 n, v = a.split('=')
681 n, v = a.split('=')
682 vals[unescapearg(n)] = unescapearg(v)
682 vals[unescapearg(n)] = unescapearg(v)
683 func, spec = commands[op]
683 func, spec = commands[op]
684 if spec:
684 if spec:
685 keys = spec.split()
685 keys = spec.split()
686 data = {}
686 data = {}
687 for k in keys:
687 for k in keys:
688 if k == '*':
688 if k == '*':
689 star = {}
689 star = {}
690 for key in vals.keys():
690 for key in vals.keys():
691 if key not in keys:
691 if key not in keys:
692 star[key] = vals[key]
692 star[key] = vals[key]
693 data['*'] = star
693 data['*'] = star
694 else:
694 else:
695 data[k] = vals[k]
695 data[k] = vals[k]
696 result = func(repo, proto, *[data[k] for k in keys])
696 result = func(repo, proto, *[data[k] for k in keys])
697 else:
697 else:
698 result = func(repo, proto)
698 result = func(repo, proto)
699 if isinstance(result, ooberror):
699 if isinstance(result, ooberror):
700 return result
700 return result
701 res.append(escapearg(result))
701 res.append(escapearg(result))
702 return ';'.join(res)
702 return ';'.join(res)
703
703
704 @wireprotocommand('between', 'pairs')
704 @wireprotocommand('between', 'pairs')
705 def between(repo, proto, pairs):
705 def between(repo, proto, pairs):
706 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
706 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
707 r = []
707 r = []
708 for b in repo.between(pairs):
708 for b in repo.between(pairs):
709 r.append(encodelist(b) + "\n")
709 r.append(encodelist(b) + "\n")
710 return "".join(r)
710 return "".join(r)
711
711
712 @wireprotocommand('branchmap')
712 @wireprotocommand('branchmap')
713 def branchmap(repo, proto):
713 def branchmap(repo, proto):
714 branchmap = repo.branchmap()
714 branchmap = repo.branchmap()
715 heads = []
715 heads = []
716 for branch, nodes in branchmap.iteritems():
716 for branch, nodes in branchmap.iteritems():
717 branchname = urlreq.quote(encoding.fromlocal(branch))
717 branchname = urlreq.quote(encoding.fromlocal(branch))
718 branchnodes = encodelist(nodes)
718 branchnodes = encodelist(nodes)
719 heads.append('%s %s' % (branchname, branchnodes))
719 heads.append('%s %s' % (branchname, branchnodes))
720 return '\n'.join(heads)
720 return '\n'.join(heads)
721
721
722 @wireprotocommand('branches', 'nodes')
722 @wireprotocommand('branches', 'nodes')
723 def branches(repo, proto, nodes):
723 def branches(repo, proto, nodes):
724 nodes = decodelist(nodes)
724 nodes = decodelist(nodes)
725 r = []
725 r = []
726 for b in repo.branches(nodes):
726 for b in repo.branches(nodes):
727 r.append(encodelist(b) + "\n")
727 r.append(encodelist(b) + "\n")
728 return "".join(r)
728 return "".join(r)
729
729
730 @wireprotocommand('clonebundles', '')
730 @wireprotocommand('clonebundles', '')
731 def clonebundles(repo, proto):
731 def clonebundles(repo, proto):
732 """Server command for returning info for available bundles to seed clones.
732 """Server command for returning info for available bundles to seed clones.
733
733
734 Clients will parse this response and determine what bundle to fetch.
734 Clients will parse this response and determine what bundle to fetch.
735
735
736 Extensions may wrap this command to filter or dynamically emit data
736 Extensions may wrap this command to filter or dynamically emit data
737 depending on the request. e.g. you could advertise URLs for the closest
737 depending on the request. e.g. you could advertise URLs for the closest
738 data center given the client's IP address.
738 data center given the client's IP address.
739 """
739 """
740 return repo.vfs.tryread('clonebundles.manifest')
740 return repo.vfs.tryread('clonebundles.manifest')
741
741
742 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
742 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
743 'known', 'getbundle', 'unbundlehash', 'batch']
743 'known', 'getbundle', 'unbundlehash', 'batch']
744
744
745 def _capabilities(repo, proto):
745 def _capabilities(repo, proto):
746 """return a list of capabilities for a repo
746 """return a list of capabilities for a repo
747
747
748 This function exists to allow extensions to easily wrap capabilities
748 This function exists to allow extensions to easily wrap capabilities
749 computation
749 computation
750
750
751 - returns a lists: easy to alter
751 - returns a lists: easy to alter
752 - change done here will be propagated to both `capabilities` and `hello`
752 - change done here will be propagated to both `capabilities` and `hello`
753 command without any other action needed.
753 command without any other action needed.
754 """
754 """
755 # copy to prevent modification of the global list
755 # copy to prevent modification of the global list
756 caps = list(wireprotocaps)
756 caps = list(wireprotocaps)
757 if streamclone.allowservergeneration(repo):
757 if streamclone.allowservergeneration(repo):
758 if repo.ui.configbool('server', 'preferuncompressed'):
758 if repo.ui.configbool('server', 'preferuncompressed'):
759 caps.append('stream-preferred')
759 caps.append('stream-preferred')
760 requiredformats = repo.requirements & repo.supportedformats
760 requiredformats = repo.requirements & repo.supportedformats
761 # if our local revlogs are just revlogv1, add 'stream' cap
761 # if our local revlogs are just revlogv1, add 'stream' cap
762 if not requiredformats - {'revlogv1'}:
762 if not requiredformats - {'revlogv1'}:
763 caps.append('stream')
763 caps.append('stream')
764 # otherwise, add 'streamreqs' detailing our local revlog format
764 # otherwise, add 'streamreqs' detailing our local revlog format
765 else:
765 else:
766 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
766 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
767 if repo.ui.configbool('experimental', 'bundle2-advertise'):
767 if repo.ui.configbool('experimental', 'bundle2-advertise'):
768 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo))
768 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo))
769 caps.append('bundle2=' + urlreq.quote(capsblob))
769 caps.append('bundle2=' + urlreq.quote(capsblob))
770 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
770 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
771
771
772 if proto.name == 'http':
772 if proto.name == 'http':
773 caps.append('httpheader=%d' %
773 caps.append('httpheader=%d' %
774 repo.ui.configint('server', 'maxhttpheaderlen'))
774 repo.ui.configint('server', 'maxhttpheaderlen'))
775 if repo.ui.configbool('experimental', 'httppostargs'):
775 if repo.ui.configbool('experimental', 'httppostargs'):
776 caps.append('httppostargs')
776 caps.append('httppostargs')
777
777
778 # FUTURE advertise 0.2rx once support is implemented
778 # FUTURE advertise 0.2rx once support is implemented
779 # FUTURE advertise minrx and mintx after consulting config option
779 # FUTURE advertise minrx and mintx after consulting config option
780 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
780 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
781
781
782 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
782 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
783 if compengines:
783 if compengines:
784 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
784 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
785 for e in compengines)
785 for e in compengines)
786 caps.append('compression=%s' % comptypes)
786 caps.append('compression=%s' % comptypes)
787
787
788 return caps
788 return caps
789
789
790 # If you are writing an extension and consider wrapping this function. Wrap
790 # If you are writing an extension and consider wrapping this function. Wrap
791 # `_capabilities` instead.
791 # `_capabilities` instead.
792 @wireprotocommand('capabilities')
792 @wireprotocommand('capabilities')
793 def capabilities(repo, proto):
793 def capabilities(repo, proto):
794 return ' '.join(_capabilities(repo, proto))
794 return ' '.join(_capabilities(repo, proto))
795
795
796 @wireprotocommand('changegroup', 'roots')
796 @wireprotocommand('changegroup', 'roots')
797 def changegroup(repo, proto, roots):
797 def changegroup(repo, proto, roots):
798 nodes = decodelist(roots)
798 nodes = decodelist(roots)
799 outgoing = discovery.outgoing(repo, missingroots=nodes,
799 outgoing = discovery.outgoing(repo, missingroots=nodes,
800 missingheads=repo.heads())
800 missingheads=repo.heads())
801 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
801 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
802 return streamres(reader=cg, v1compressible=True)
802 return streamres(reader=cg, v1compressible=True)
803
803
804 @wireprotocommand('changegroupsubset', 'bases heads')
804 @wireprotocommand('changegroupsubset', 'bases heads')
805 def changegroupsubset(repo, proto, bases, heads):
805 def changegroupsubset(repo, proto, bases, heads):
806 bases = decodelist(bases)
806 bases = decodelist(bases)
807 heads = decodelist(heads)
807 heads = decodelist(heads)
808 outgoing = discovery.outgoing(repo, missingroots=bases,
808 outgoing = discovery.outgoing(repo, missingroots=bases,
809 missingheads=heads)
809 missingheads=heads)
810 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
810 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
811 return streamres(reader=cg, v1compressible=True)
811 return streamres(reader=cg, v1compressible=True)
812
812
813 @wireprotocommand('debugwireargs', 'one two *')
813 @wireprotocommand('debugwireargs', 'one two *')
814 def debugwireargs(repo, proto, one, two, others):
814 def debugwireargs(repo, proto, one, two, others):
815 # only accept optional args from the known set
815 # only accept optional args from the known set
816 opts = options('debugwireargs', ['three', 'four'], others)
816 opts = options('debugwireargs', ['three', 'four'], others)
817 return repo.debugwireargs(one, two, **opts)
817 return repo.debugwireargs(one, two, **opts)
818
818
819 @wireprotocommand('getbundle', '*')
819 @wireprotocommand('getbundle', '*')
820 def getbundle(repo, proto, others):
820 def getbundle(repo, proto, others):
821 opts = options('getbundle', gboptsmap.keys(), others)
821 opts = options('getbundle', gboptsmap.keys(), others)
822 for k, v in opts.iteritems():
822 for k, v in opts.iteritems():
823 keytype = gboptsmap[k]
823 keytype = gboptsmap[k]
824 if keytype == 'nodes':
824 if keytype == 'nodes':
825 opts[k] = decodelist(v)
825 opts[k] = decodelist(v)
826 elif keytype == 'csv':
826 elif keytype == 'csv':
827 opts[k] = list(v.split(','))
827 opts[k] = list(v.split(','))
828 elif keytype == 'scsv':
828 elif keytype == 'scsv':
829 opts[k] = set(v.split(','))
829 opts[k] = set(v.split(','))
830 elif keytype == 'boolean':
830 elif keytype == 'boolean':
831 # Client should serialize False as '0', which is a non-empty string
831 # Client should serialize False as '0', which is a non-empty string
832 # so it evaluates as a True bool.
832 # so it evaluates as a True bool.
833 if v == '0':
833 if v == '0':
834 opts[k] = False
834 opts[k] = False
835 else:
835 else:
836 opts[k] = bool(v)
836 opts[k] = bool(v)
837 elif keytype != 'plain':
837 elif keytype != 'plain':
838 raise KeyError('unknown getbundle option type %s'
838 raise KeyError('unknown getbundle option type %s'
839 % keytype)
839 % keytype)
840
840
841 if not bundle1allowed(repo, 'pull'):
841 if not bundle1allowed(repo, 'pull'):
842 if not exchange.bundle2requested(opts.get('bundlecaps')):
842 if not exchange.bundle2requested(opts.get('bundlecaps')):
843 if proto.name == 'http':
843 if proto.name == 'http':
844 return ooberror(bundle2required)
844 return ooberror(bundle2required)
845 raise error.Abort(bundle2requiredmain,
845 raise error.Abort(bundle2requiredmain,
846 hint=bundle2requiredhint)
846 hint=bundle2requiredhint)
847
847
848 try:
848 try:
849 if repo.ui.configbool('server', 'disablefullbundle'):
849 if repo.ui.configbool('server', 'disablefullbundle'):
850 # Check to see if this is a full clone.
850 # Check to see if this is a full clone.
851 clheads = set(repo.changelog.heads())
851 clheads = set(repo.changelog.heads())
852 heads = set(opts.get('heads', set()))
852 heads = set(opts.get('heads', set()))
853 common = set(opts.get('common', set()))
853 common = set(opts.get('common', set()))
854 common.discard(nullid)
854 common.discard(nullid)
855 if not common and clheads == heads:
855 if not common and clheads == heads:
856 raise error.Abort(
856 raise error.Abort(
857 _('server has pull-based clones disabled'),
857 _('server has pull-based clones disabled'),
858 hint=_('remove --pull if specified or upgrade Mercurial'))
858 hint=_('remove --pull if specified or upgrade Mercurial'))
859
859
860 chunks = exchange.getbundlechunks(repo, 'serve', **opts)
860 chunks = exchange.getbundlechunks(repo, 'serve', **opts)
861 except error.Abort as exc:
861 except error.Abort as exc:
862 # cleanly forward Abort error to the client
862 # cleanly forward Abort error to the client
863 if not exchange.bundle2requested(opts.get('bundlecaps')):
863 if not exchange.bundle2requested(opts.get('bundlecaps')):
864 if proto.name == 'http':
864 if proto.name == 'http':
865 return ooberror(str(exc) + '\n')
865 return ooberror(str(exc) + '\n')
866 raise # cannot do better for bundle1 + ssh
866 raise # cannot do better for bundle1 + ssh
867 # bundle2 request expect a bundle2 reply
867 # bundle2 request expect a bundle2 reply
868 bundler = bundle2.bundle20(repo.ui)
868 bundler = bundle2.bundle20(repo.ui)
869 manargs = [('message', str(exc))]
869 manargs = [('message', str(exc))]
870 advargs = []
870 advargs = []
871 if exc.hint is not None:
871 if exc.hint is not None:
872 advargs.append(('hint', exc.hint))
872 advargs.append(('hint', exc.hint))
873 bundler.addpart(bundle2.bundlepart('error:abort',
873 bundler.addpart(bundle2.bundlepart('error:abort',
874 manargs, advargs))
874 manargs, advargs))
875 return streamres(gen=bundler.getchunks(), v1compressible=True)
875 return streamres(gen=bundler.getchunks(), v1compressible=True)
876 return streamres(gen=chunks, v1compressible=True)
876 return streamres(gen=chunks, v1compressible=True)
877
877
878 @wireprotocommand('heads')
878 @wireprotocommand('heads')
879 def heads(repo, proto):
879 def heads(repo, proto):
880 h = repo.heads()
880 h = repo.heads()
881 return encodelist(h) + "\n"
881 return encodelist(h) + "\n"
882
882
883 @wireprotocommand('hello')
883 @wireprotocommand('hello')
884 def hello(repo, proto):
884 def hello(repo, proto):
885 '''the hello command returns a set of lines describing various
885 '''the hello command returns a set of lines describing various
886 interesting things about the server, in an RFC822-like format.
886 interesting things about the server, in an RFC822-like format.
887 Currently the only one defined is "capabilities", which
887 Currently the only one defined is "capabilities", which
888 consists of a line in the form:
888 consists of a line in the form:
889
889
890 capabilities: space separated list of tokens
890 capabilities: space separated list of tokens
891 '''
891 '''
892 return "capabilities: %s\n" % (capabilities(repo, proto))
892 return "capabilities: %s\n" % (capabilities(repo, proto))
893
893
894 @wireprotocommand('listkeys', 'namespace')
894 @wireprotocommand('listkeys', 'namespace')
895 def listkeys(repo, proto, namespace):
895 def listkeys(repo, proto, namespace):
896 d = repo.listkeys(encoding.tolocal(namespace)).items()
896 d = repo.listkeys(encoding.tolocal(namespace)).items()
897 return pushkeymod.encodekeys(d)
897 return pushkeymod.encodekeys(d)
898
898
899 @wireprotocommand('lookup', 'key')
899 @wireprotocommand('lookup', 'key')
900 def lookup(repo, proto, key):
900 def lookup(repo, proto, key):
901 try:
901 try:
902 k = encoding.tolocal(key)
902 k = encoding.tolocal(key)
903 c = repo[k]
903 c = repo[k]
904 r = c.hex()
904 r = c.hex()
905 success = 1
905 success = 1
906 except Exception as inst:
906 except Exception as inst:
907 r = str(inst)
907 r = str(inst)
908 success = 0
908 success = 0
909 return "%s %s\n" % (success, r)
909 return "%s %s\n" % (success, r)
910
910
911 @wireprotocommand('known', 'nodes *')
911 @wireprotocommand('known', 'nodes *')
912 def known(repo, proto, nodes, others):
912 def known(repo, proto, nodes, others):
913 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
913 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
914
914
915 @wireprotocommand('pushkey', 'namespace key old new')
915 @wireprotocommand('pushkey', 'namespace key old new')
916 def pushkey(repo, proto, namespace, key, old, new):
916 def pushkey(repo, proto, namespace, key, old, new):
917 # compatibility with pre-1.8 clients which were accidentally
917 # compatibility with pre-1.8 clients which were accidentally
918 # sending raw binary nodes rather than utf-8-encoded hex
918 # sending raw binary nodes rather than utf-8-encoded hex
919 if len(new) == 20 and util.escapestr(new) != new:
919 if len(new) == 20 and util.escapestr(new) != new:
920 # looks like it could be a binary node
920 # looks like it could be a binary node
921 try:
921 try:
922 new.decode('utf-8')
922 new.decode('utf-8')
923 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
923 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
924 except UnicodeDecodeError:
924 except UnicodeDecodeError:
925 pass # binary, leave unmodified
925 pass # binary, leave unmodified
926 else:
926 else:
927 new = encoding.tolocal(new) # normal path
927 new = encoding.tolocal(new) # normal path
928
928
929 if util.safehasattr(proto, 'restore'):
929 if util.safehasattr(proto, 'restore'):
930
930
931 proto.redirect()
931 proto.redirect()
932
932
933 try:
933 try:
934 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
934 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
935 encoding.tolocal(old), new) or False
935 encoding.tolocal(old), new) or False
936 except error.Abort:
936 except error.Abort:
937 r = False
937 r = False
938
938
939 output = proto.restore()
939 output = proto.restore()
940
940
941 return '%s\n%s' % (int(r), output)
941 return '%s\n%s' % (int(r), output)
942
942
943 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
943 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
944 encoding.tolocal(old), new)
944 encoding.tolocal(old), new)
945 return '%s\n' % int(r)
945 return '%s\n' % int(r)
946
946
947 @wireprotocommand('stream_out')
947 @wireprotocommand('stream_out')
948 def stream(repo, proto):
948 def stream(repo, proto):
949 '''If the server supports streaming clone, it advertises the "stream"
949 '''If the server supports streaming clone, it advertises the "stream"
950 capability with a value representing the version and flags of the repo
950 capability with a value representing the version and flags of the repo
951 it is serving. Client checks to see if it understands the format.
951 it is serving. Client checks to see if it understands the format.
952 '''
952 '''
953 if not streamclone.allowservergeneration(repo):
953 if not streamclone.allowservergeneration(repo):
954 return '1\n'
954 return '1\n'
955
955
956 def getstream(it):
956 def getstream(it):
957 yield '0\n'
957 yield '0\n'
958 for chunk in it:
958 for chunk in it:
959 yield chunk
959 yield chunk
960
960
961 try:
961 try:
962 # LockError may be raised before the first result is yielded. Don't
962 # LockError may be raised before the first result is yielded. Don't
963 # emit output until we're sure we got the lock successfully.
963 # emit output until we're sure we got the lock successfully.
964 it = streamclone.generatev1wireproto(repo)
964 it = streamclone.generatev1wireproto(repo)
965 return streamres(gen=getstream(it))
965 return streamres(gen=getstream(it))
966 except error.LockError:
966 except error.LockError:
967 return '2\n'
967 return '2\n'
968
968
969 @wireprotocommand('unbundle', 'heads')
969 @wireprotocommand('unbundle', 'heads')
970 def unbundle(repo, proto, heads):
970 def unbundle(repo, proto, heads):
971 their_heads = decodelist(heads)
971 their_heads = decodelist(heads)
972
972
973 try:
973 try:
974 proto.redirect()
974 proto.redirect()
975
975
976 exchange.check_heads(repo, their_heads, 'preparing changes')
976 exchange.check_heads(repo, their_heads, 'preparing changes')
977
977
978 # write bundle data to temporary file because it can be big
978 # write bundle data to temporary file because it can be big
979 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
979 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
980 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
980 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
981 r = 0
981 r = 0
982 try:
982 try:
983 proto.getfile(fp)
983 proto.getfile(fp)
984 fp.seek(0)
984 fp.seek(0)
985 gen = exchange.readbundle(repo.ui, fp, None)
985 gen = exchange.readbundle(repo.ui, fp, None)
986 if (isinstance(gen, changegroupmod.cg1unpacker)
986 if (isinstance(gen, changegroupmod.cg1unpacker)
987 and not bundle1allowed(repo, 'push')):
987 and not bundle1allowed(repo, 'push')):
988 if proto.name == 'http':
988 if proto.name == 'http':
989 # need to special case http because stderr do not get to
989 # need to special case http because stderr do not get to
990 # the http client on failed push so we need to abuse some
990 # the http client on failed push so we need to abuse some
991 # other error type to make sure the message get to the
991 # other error type to make sure the message get to the
992 # user.
992 # user.
993 return ooberror(bundle2required)
993 return ooberror(bundle2required)
994 raise error.Abort(bundle2requiredmain,
994 raise error.Abort(bundle2requiredmain,
995 hint=bundle2requiredhint)
995 hint=bundle2requiredhint)
996
996
997 r = exchange.unbundle(repo, gen, their_heads, 'serve',
997 r = exchange.unbundle(repo, gen, their_heads, 'serve',
998 proto._client())
998 proto._client())
999 if util.safehasattr(r, 'addpart'):
999 if util.safehasattr(r, 'addpart'):
1000 # The return looks streamable, we are in the bundle2 case and
1000 # The return looks streamable, we are in the bundle2 case and
1001 # should return a stream.
1001 # should return a stream.
1002 return streamres(gen=r.getchunks())
1002 return streamres(gen=r.getchunks())
1003 return pushres(r)
1003 return pushres(r)
1004
1004
1005 finally:
1005 finally:
1006 fp.close()
1006 fp.close()
1007 os.unlink(tempname)
1007 os.unlink(tempname)
1008
1008
1009 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1009 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1010 # handle non-bundle2 case first
1010 # handle non-bundle2 case first
1011 if not getattr(exc, 'duringunbundle2', False):
1011 if not getattr(exc, 'duringunbundle2', False):
1012 try:
1012 try:
1013 raise
1013 raise
1014 except error.Abort:
1014 except error.Abort:
1015 # The old code we moved used util.stderr directly.
1015 # The old code we moved used util.stderr directly.
1016 # We did not change it to minimise code change.
1016 # We did not change it to minimise code change.
1017 # This need to be moved to something proper.
1017 # This need to be moved to something proper.
1018 # Feel free to do it.
1018 # Feel free to do it.
1019 util.stderr.write("abort: %s\n" % exc)
1019 util.stderr.write("abort: %s\n" % exc)
1020 if exc.hint is not None:
1020 if exc.hint is not None:
1021 util.stderr.write("(%s)\n" % exc.hint)
1021 util.stderr.write("(%s)\n" % exc.hint)
1022 return pushres(0)
1022 return pushres(0)
1023 except error.PushRaced:
1023 except error.PushRaced:
1024 return pusherr(str(exc))
1024 return pusherr(str(exc))
1025
1025
1026 bundler = bundle2.bundle20(repo.ui)
1026 bundler = bundle2.bundle20(repo.ui)
1027 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1027 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1028 bundler.addpart(out)
1028 bundler.addpart(out)
1029 try:
1029 try:
1030 try:
1030 try:
1031 raise
1031 raise
1032 except error.PushkeyFailed as exc:
1032 except error.PushkeyFailed as exc:
1033 # check client caps
1033 # check client caps
1034 remotecaps = getattr(exc, '_replycaps', None)
1034 remotecaps = getattr(exc, '_replycaps', None)
1035 if (remotecaps is not None
1035 if (remotecaps is not None
1036 and 'pushkey' not in remotecaps.get('error', ())):
1036 and 'pushkey' not in remotecaps.get('error', ())):
1037 # no support remote side, fallback to Abort handler.
1037 # no support remote side, fallback to Abort handler.
1038 raise
1038 raise
1039 part = bundler.newpart('error:pushkey')
1039 part = bundler.newpart('error:pushkey')
1040 part.addparam('in-reply-to', exc.partid)
1040 part.addparam('in-reply-to', exc.partid)
1041 if exc.namespace is not None:
1041 if exc.namespace is not None:
1042 part.addparam('namespace', exc.namespace, mandatory=False)
1042 part.addparam('namespace', exc.namespace, mandatory=False)
1043 if exc.key is not None:
1043 if exc.key is not None:
1044 part.addparam('key', exc.key, mandatory=False)
1044 part.addparam('key', exc.key, mandatory=False)
1045 if exc.new is not None:
1045 if exc.new is not None:
1046 part.addparam('new', exc.new, mandatory=False)
1046 part.addparam('new', exc.new, mandatory=False)
1047 if exc.old is not None:
1047 if exc.old is not None:
1048 part.addparam('old', exc.old, mandatory=False)
1048 part.addparam('old', exc.old, mandatory=False)
1049 if exc.ret is not None:
1049 if exc.ret is not None:
1050 part.addparam('ret', exc.ret, mandatory=False)
1050 part.addparam('ret', exc.ret, mandatory=False)
1051 except error.BundleValueError as exc:
1051 except error.BundleValueError as exc:
1052 errpart = bundler.newpart('error:unsupportedcontent')
1052 errpart = bundler.newpart('error:unsupportedcontent')
1053 if exc.parttype is not None:
1053 if exc.parttype is not None:
1054 errpart.addparam('parttype', exc.parttype)
1054 errpart.addparam('parttype', exc.parttype)
1055 if exc.params:
1055 if exc.params:
1056 errpart.addparam('params', '\0'.join(exc.params))
1056 errpart.addparam('params', '\0'.join(exc.params))
1057 except error.Abort as exc:
1057 except error.Abort as exc:
1058 manargs = [('message', str(exc))]
1058 manargs = [('message', str(exc))]
1059 advargs = []
1059 advargs = []
1060 if exc.hint is not None:
1060 if exc.hint is not None:
1061 advargs.append(('hint', exc.hint))
1061 advargs.append(('hint', exc.hint))
1062 bundler.addpart(bundle2.bundlepart('error:abort',
1062 bundler.addpart(bundle2.bundlepart('error:abort',
1063 manargs, advargs))
1063 manargs, advargs))
1064 except error.PushRaced as exc:
1064 except error.PushRaced as exc:
1065 bundler.newpart('error:pushraced', [('message', str(exc))])
1065 bundler.newpart('error:pushraced', [('message', str(exc))])
1066 return streamres(gen=bundler.getchunks())
1066 return streamres(gen=bundler.getchunks())
@@ -1,296 +1,296 b''
1 from __future__ import absolute_import
1 from __future__ import absolute_import
2
2
3 import copy
3 import copy
4 import errno
4 import errno
5 import os
5 import os
6 import silenttestrunner
6 import silenttestrunner
7 import tempfile
7 import tempfile
8 import types
8 import types
9 import unittest
9 import unittest
10
10
11 from mercurial import (
11 from mercurial import (
12 error,
12 error,
13 lock,
13 lock,
14 vfs as vfsmod,
14 vfs as vfsmod,
15 )
15 )
16
16
17 testlockname = 'testlock'
17 testlockname = 'testlock'
18
18
19 # work around http://bugs.python.org/issue1515
19 # work around http://bugs.python.org/issue1515
20 if types.MethodType not in copy._deepcopy_dispatch:
20 if types.MethodType not in copy._deepcopy_dispatch:
21 def _deepcopy_method(x, memo):
21 def _deepcopy_method(x, memo):
22 return type(x)(x.im_func, copy.deepcopy(x.im_self, memo), x.im_class)
22 return type(x)(x.__func__, copy.deepcopy(x.__self__, memo), x.im_class)
23 copy._deepcopy_dispatch[types.MethodType] = _deepcopy_method
23 copy._deepcopy_dispatch[types.MethodType] = _deepcopy_method
24
24
25 class lockwrapper(lock.lock):
25 class lockwrapper(lock.lock):
26 def __init__(self, pidoffset, *args, **kwargs):
26 def __init__(self, pidoffset, *args, **kwargs):
27 # lock.lock.__init__() calls lock(), so the pidoffset assignment needs
27 # lock.lock.__init__() calls lock(), so the pidoffset assignment needs
28 # to be earlier
28 # to be earlier
29 self._pidoffset = pidoffset
29 self._pidoffset = pidoffset
30 super(lockwrapper, self).__init__(*args, **kwargs)
30 super(lockwrapper, self).__init__(*args, **kwargs)
31 def _getpid(self):
31 def _getpid(self):
32 return super(lockwrapper, self)._getpid() + self._pidoffset
32 return super(lockwrapper, self)._getpid() + self._pidoffset
33
33
34 class teststate(object):
34 class teststate(object):
35 def __init__(self, testcase, dir, pidoffset=0):
35 def __init__(self, testcase, dir, pidoffset=0):
36 self._testcase = testcase
36 self._testcase = testcase
37 self._acquirecalled = False
37 self._acquirecalled = False
38 self._releasecalled = False
38 self._releasecalled = False
39 self._postreleasecalled = False
39 self._postreleasecalled = False
40 self.vfs = vfsmod.vfs(dir, audit=False)
40 self.vfs = vfsmod.vfs(dir, audit=False)
41 self._pidoffset = pidoffset
41 self._pidoffset = pidoffset
42
42
43 def makelock(self, *args, **kwargs):
43 def makelock(self, *args, **kwargs):
44 l = lockwrapper(self._pidoffset, self.vfs, testlockname,
44 l = lockwrapper(self._pidoffset, self.vfs, testlockname,
45 releasefn=self.releasefn, acquirefn=self.acquirefn,
45 releasefn=self.releasefn, acquirefn=self.acquirefn,
46 *args, **kwargs)
46 *args, **kwargs)
47 l.postrelease.append(self.postreleasefn)
47 l.postrelease.append(self.postreleasefn)
48 return l
48 return l
49
49
50 def acquirefn(self):
50 def acquirefn(self):
51 self._acquirecalled = True
51 self._acquirecalled = True
52
52
53 def releasefn(self):
53 def releasefn(self):
54 self._releasecalled = True
54 self._releasecalled = True
55
55
56 def postreleasefn(self):
56 def postreleasefn(self):
57 self._postreleasecalled = True
57 self._postreleasecalled = True
58
58
59 def assertacquirecalled(self, called):
59 def assertacquirecalled(self, called):
60 self._testcase.assertEqual(
60 self._testcase.assertEqual(
61 self._acquirecalled, called,
61 self._acquirecalled, called,
62 'expected acquire to be %s but was actually %s' % (
62 'expected acquire to be %s but was actually %s' % (
63 self._tocalled(called),
63 self._tocalled(called),
64 self._tocalled(self._acquirecalled),
64 self._tocalled(self._acquirecalled),
65 ))
65 ))
66
66
67 def resetacquirefn(self):
67 def resetacquirefn(self):
68 self._acquirecalled = False
68 self._acquirecalled = False
69
69
70 def assertreleasecalled(self, called):
70 def assertreleasecalled(self, called):
71 self._testcase.assertEqual(
71 self._testcase.assertEqual(
72 self._releasecalled, called,
72 self._releasecalled, called,
73 'expected release to be %s but was actually %s' % (
73 'expected release to be %s but was actually %s' % (
74 self._tocalled(called),
74 self._tocalled(called),
75 self._tocalled(self._releasecalled),
75 self._tocalled(self._releasecalled),
76 ))
76 ))
77
77
78 def assertpostreleasecalled(self, called):
78 def assertpostreleasecalled(self, called):
79 self._testcase.assertEqual(
79 self._testcase.assertEqual(
80 self._postreleasecalled, called,
80 self._postreleasecalled, called,
81 'expected postrelease to be %s but was actually %s' % (
81 'expected postrelease to be %s but was actually %s' % (
82 self._tocalled(called),
82 self._tocalled(called),
83 self._tocalled(self._postreleasecalled),
83 self._tocalled(self._postreleasecalled),
84 ))
84 ))
85
85
86 def assertlockexists(self, exists):
86 def assertlockexists(self, exists):
87 actual = self.vfs.lexists(testlockname)
87 actual = self.vfs.lexists(testlockname)
88 self._testcase.assertEqual(
88 self._testcase.assertEqual(
89 actual, exists,
89 actual, exists,
90 'expected lock to %s but actually did %s' % (
90 'expected lock to %s but actually did %s' % (
91 self._toexists(exists),
91 self._toexists(exists),
92 self._toexists(actual),
92 self._toexists(actual),
93 ))
93 ))
94
94
95 def _tocalled(self, called):
95 def _tocalled(self, called):
96 if called:
96 if called:
97 return 'called'
97 return 'called'
98 else:
98 else:
99 return 'not called'
99 return 'not called'
100
100
101 def _toexists(self, exists):
101 def _toexists(self, exists):
102 if exists:
102 if exists:
103 return 'exist'
103 return 'exist'
104 else:
104 else:
105 return 'not exist'
105 return 'not exist'
106
106
107 class testlock(unittest.TestCase):
107 class testlock(unittest.TestCase):
108 def testlock(self):
108 def testlock(self):
109 state = teststate(self, tempfile.mkdtemp(dir=os.getcwd()))
109 state = teststate(self, tempfile.mkdtemp(dir=os.getcwd()))
110 lock = state.makelock()
110 lock = state.makelock()
111 state.assertacquirecalled(True)
111 state.assertacquirecalled(True)
112 lock.release()
112 lock.release()
113 state.assertreleasecalled(True)
113 state.assertreleasecalled(True)
114 state.assertpostreleasecalled(True)
114 state.assertpostreleasecalled(True)
115 state.assertlockexists(False)
115 state.assertlockexists(False)
116
116
117 def testrecursivelock(self):
117 def testrecursivelock(self):
118 state = teststate(self, tempfile.mkdtemp(dir=os.getcwd()))
118 state = teststate(self, tempfile.mkdtemp(dir=os.getcwd()))
119 lock = state.makelock()
119 lock = state.makelock()
120 state.assertacquirecalled(True)
120 state.assertacquirecalled(True)
121
121
122 state.resetacquirefn()
122 state.resetacquirefn()
123 lock.lock()
123 lock.lock()
124 # recursive lock should not call acquirefn again
124 # recursive lock should not call acquirefn again
125 state.assertacquirecalled(False)
125 state.assertacquirecalled(False)
126
126
127 lock.release() # brings lock refcount down from 2 to 1
127 lock.release() # brings lock refcount down from 2 to 1
128 state.assertreleasecalled(False)
128 state.assertreleasecalled(False)
129 state.assertpostreleasecalled(False)
129 state.assertpostreleasecalled(False)
130 state.assertlockexists(True)
130 state.assertlockexists(True)
131
131
132 lock.release() # releases the lock
132 lock.release() # releases the lock
133 state.assertreleasecalled(True)
133 state.assertreleasecalled(True)
134 state.assertpostreleasecalled(True)
134 state.assertpostreleasecalled(True)
135 state.assertlockexists(False)
135 state.assertlockexists(False)
136
136
137 def testlockfork(self):
137 def testlockfork(self):
138 state = teststate(self, tempfile.mkdtemp(dir=os.getcwd()))
138 state = teststate(self, tempfile.mkdtemp(dir=os.getcwd()))
139 lock = state.makelock()
139 lock = state.makelock()
140 state.assertacquirecalled(True)
140 state.assertacquirecalled(True)
141
141
142 # fake a fork
142 # fake a fork
143 forklock = copy.deepcopy(lock)
143 forklock = copy.deepcopy(lock)
144 forklock._pidoffset = 1
144 forklock._pidoffset = 1
145 forklock.release()
145 forklock.release()
146 state.assertreleasecalled(False)
146 state.assertreleasecalled(False)
147 state.assertpostreleasecalled(False)
147 state.assertpostreleasecalled(False)
148 state.assertlockexists(True)
148 state.assertlockexists(True)
149
149
150 # release the actual lock
150 # release the actual lock
151 lock.release()
151 lock.release()
152 state.assertreleasecalled(True)
152 state.assertreleasecalled(True)
153 state.assertpostreleasecalled(True)
153 state.assertpostreleasecalled(True)
154 state.assertlockexists(False)
154 state.assertlockexists(False)
155
155
156 def testinheritlock(self):
156 def testinheritlock(self):
157 d = tempfile.mkdtemp(dir=os.getcwd())
157 d = tempfile.mkdtemp(dir=os.getcwd())
158 parentstate = teststate(self, d)
158 parentstate = teststate(self, d)
159 parentlock = parentstate.makelock()
159 parentlock = parentstate.makelock()
160 parentstate.assertacquirecalled(True)
160 parentstate.assertacquirecalled(True)
161
161
162 # set up lock inheritance
162 # set up lock inheritance
163 with parentlock.inherit() as lockname:
163 with parentlock.inherit() as lockname:
164 parentstate.assertreleasecalled(True)
164 parentstate.assertreleasecalled(True)
165 parentstate.assertpostreleasecalled(False)
165 parentstate.assertpostreleasecalled(False)
166 parentstate.assertlockexists(True)
166 parentstate.assertlockexists(True)
167
167
168 childstate = teststate(self, d, pidoffset=1)
168 childstate = teststate(self, d, pidoffset=1)
169 childlock = childstate.makelock(parentlock=lockname)
169 childlock = childstate.makelock(parentlock=lockname)
170 childstate.assertacquirecalled(True)
170 childstate.assertacquirecalled(True)
171
171
172 childlock.release()
172 childlock.release()
173 childstate.assertreleasecalled(True)
173 childstate.assertreleasecalled(True)
174 childstate.assertpostreleasecalled(False)
174 childstate.assertpostreleasecalled(False)
175 childstate.assertlockexists(True)
175 childstate.assertlockexists(True)
176
176
177 parentstate.resetacquirefn()
177 parentstate.resetacquirefn()
178
178
179 parentstate.assertacquirecalled(True)
179 parentstate.assertacquirecalled(True)
180
180
181 parentlock.release()
181 parentlock.release()
182 parentstate.assertreleasecalled(True)
182 parentstate.assertreleasecalled(True)
183 parentstate.assertpostreleasecalled(True)
183 parentstate.assertpostreleasecalled(True)
184 parentstate.assertlockexists(False)
184 parentstate.assertlockexists(False)
185
185
186 def testmultilock(self):
186 def testmultilock(self):
187 d = tempfile.mkdtemp(dir=os.getcwd())
187 d = tempfile.mkdtemp(dir=os.getcwd())
188 state0 = teststate(self, d)
188 state0 = teststate(self, d)
189 lock0 = state0.makelock()
189 lock0 = state0.makelock()
190 state0.assertacquirecalled(True)
190 state0.assertacquirecalled(True)
191
191
192 with lock0.inherit() as lock0name:
192 with lock0.inherit() as lock0name:
193 state0.assertreleasecalled(True)
193 state0.assertreleasecalled(True)
194 state0.assertpostreleasecalled(False)
194 state0.assertpostreleasecalled(False)
195 state0.assertlockexists(True)
195 state0.assertlockexists(True)
196
196
197 state1 = teststate(self, d, pidoffset=1)
197 state1 = teststate(self, d, pidoffset=1)
198 lock1 = state1.makelock(parentlock=lock0name)
198 lock1 = state1.makelock(parentlock=lock0name)
199 state1.assertacquirecalled(True)
199 state1.assertacquirecalled(True)
200
200
201 # from within lock1, acquire another lock
201 # from within lock1, acquire another lock
202 with lock1.inherit() as lock1name:
202 with lock1.inherit() as lock1name:
203 # since the file on disk is lock0's this should have the same
203 # since the file on disk is lock0's this should have the same
204 # name
204 # name
205 self.assertEqual(lock0name, lock1name)
205 self.assertEqual(lock0name, lock1name)
206
206
207 state2 = teststate(self, d, pidoffset=2)
207 state2 = teststate(self, d, pidoffset=2)
208 lock2 = state2.makelock(parentlock=lock1name)
208 lock2 = state2.makelock(parentlock=lock1name)
209 state2.assertacquirecalled(True)
209 state2.assertacquirecalled(True)
210
210
211 lock2.release()
211 lock2.release()
212 state2.assertreleasecalled(True)
212 state2.assertreleasecalled(True)
213 state2.assertpostreleasecalled(False)
213 state2.assertpostreleasecalled(False)
214 state2.assertlockexists(True)
214 state2.assertlockexists(True)
215
215
216 state1.resetacquirefn()
216 state1.resetacquirefn()
217
217
218 state1.assertacquirecalled(True)
218 state1.assertacquirecalled(True)
219
219
220 lock1.release()
220 lock1.release()
221 state1.assertreleasecalled(True)
221 state1.assertreleasecalled(True)
222 state1.assertpostreleasecalled(False)
222 state1.assertpostreleasecalled(False)
223 state1.assertlockexists(True)
223 state1.assertlockexists(True)
224
224
225 lock0.release()
225 lock0.release()
226
226
227 def testinheritlockfork(self):
227 def testinheritlockfork(self):
228 d = tempfile.mkdtemp(dir=os.getcwd())
228 d = tempfile.mkdtemp(dir=os.getcwd())
229 parentstate = teststate(self, d)
229 parentstate = teststate(self, d)
230 parentlock = parentstate.makelock()
230 parentlock = parentstate.makelock()
231 parentstate.assertacquirecalled(True)
231 parentstate.assertacquirecalled(True)
232
232
233 # set up lock inheritance
233 # set up lock inheritance
234 with parentlock.inherit() as lockname:
234 with parentlock.inherit() as lockname:
235 childstate = teststate(self, d, pidoffset=1)
235 childstate = teststate(self, d, pidoffset=1)
236 childlock = childstate.makelock(parentlock=lockname)
236 childlock = childstate.makelock(parentlock=lockname)
237 childstate.assertacquirecalled(True)
237 childstate.assertacquirecalled(True)
238
238
239 # fork the child lock
239 # fork the child lock
240 forkchildlock = copy.deepcopy(childlock)
240 forkchildlock = copy.deepcopy(childlock)
241 forkchildlock._pidoffset += 1
241 forkchildlock._pidoffset += 1
242 forkchildlock.release()
242 forkchildlock.release()
243 childstate.assertreleasecalled(False)
243 childstate.assertreleasecalled(False)
244 childstate.assertpostreleasecalled(False)
244 childstate.assertpostreleasecalled(False)
245 childstate.assertlockexists(True)
245 childstate.assertlockexists(True)
246
246
247 # release the child lock
247 # release the child lock
248 childlock.release()
248 childlock.release()
249 childstate.assertreleasecalled(True)
249 childstate.assertreleasecalled(True)
250 childstate.assertpostreleasecalled(False)
250 childstate.assertpostreleasecalled(False)
251 childstate.assertlockexists(True)
251 childstate.assertlockexists(True)
252
252
253 parentlock.release()
253 parentlock.release()
254
254
255 def testinheritcheck(self):
255 def testinheritcheck(self):
256 d = tempfile.mkdtemp(dir=os.getcwd())
256 d = tempfile.mkdtemp(dir=os.getcwd())
257 state = teststate(self, d)
257 state = teststate(self, d)
258 def check():
258 def check():
259 raise error.LockInheritanceContractViolation('check failed')
259 raise error.LockInheritanceContractViolation('check failed')
260 lock = state.makelock(inheritchecker=check)
260 lock = state.makelock(inheritchecker=check)
261 state.assertacquirecalled(True)
261 state.assertacquirecalled(True)
262
262
263 with self.assertRaises(error.LockInheritanceContractViolation):
263 with self.assertRaises(error.LockInheritanceContractViolation):
264 with lock.inherit():
264 with lock.inherit():
265 pass
265 pass
266
266
267 lock.release()
267 lock.release()
268
268
269 def testfrequentlockunlock(self):
269 def testfrequentlockunlock(self):
270 """This tests whether lock acquisition fails as expected, even if
270 """This tests whether lock acquisition fails as expected, even if
271 (1) lock can't be acquired (makelock fails by EEXIST), and
271 (1) lock can't be acquired (makelock fails by EEXIST), and
272 (2) locker info can't be read in (readlock fails by ENOENT) while
272 (2) locker info can't be read in (readlock fails by ENOENT) while
273 retrying 5 times.
273 retrying 5 times.
274 """
274 """
275
275
276 d = tempfile.mkdtemp(dir=os.getcwd())
276 d = tempfile.mkdtemp(dir=os.getcwd())
277 state = teststate(self, d)
277 state = teststate(self, d)
278
278
279 def emulatefrequentlock(*args):
279 def emulatefrequentlock(*args):
280 raise OSError(errno.EEXIST, "File exists")
280 raise OSError(errno.EEXIST, "File exists")
281 def emulatefrequentunlock(*args):
281 def emulatefrequentunlock(*args):
282 raise OSError(errno.ENOENT, "No such file or directory")
282 raise OSError(errno.ENOENT, "No such file or directory")
283
283
284 state.vfs.makelock = emulatefrequentlock
284 state.vfs.makelock = emulatefrequentlock
285 state.vfs.readlock = emulatefrequentunlock
285 state.vfs.readlock = emulatefrequentunlock
286
286
287 try:
287 try:
288 state.makelock(timeout=0)
288 state.makelock(timeout=0)
289 self.fail("unexpected lock acquisition")
289 self.fail("unexpected lock acquisition")
290 except error.LockHeld as why:
290 except error.LockHeld as why:
291 self.assertTrue(why.errno == errno.ETIMEDOUT)
291 self.assertTrue(why.errno == errno.ETIMEDOUT)
292 self.assertTrue(why.locker == "")
292 self.assertTrue(why.locker == "")
293 state.assertlockexists(False)
293 state.assertlockexists(False)
294
294
295 if __name__ == '__main__':
295 if __name__ == '__main__':
296 silenttestrunner.main(__name__)
296 silenttestrunner.main(__name__)
General Comments 0
You need to be logged in to leave comments. Login now