##// END OF EJS Templates
wireproto: don't special case bundlecaps, but sort all scsv arguments...
Joerg Sonnenberger -
r37430:1d459f61 default
parent child Browse files
Show More
@@ -1,1182 +1,1180
1 # wireproto.py - generic wire protocol support functions
1 # wireproto.py - generic wire protocol support functions
2 #
2 #
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import hashlib
10 import hashlib
11 import os
11 import os
12 import tempfile
12 import tempfile
13
13
14 from .i18n import _
14 from .i18n import _
15 from .node import (
15 from .node import (
16 bin,
16 bin,
17 hex,
17 hex,
18 nullid,
18 nullid,
19 )
19 )
20
20
21 from . import (
21 from . import (
22 bundle2,
22 bundle2,
23 changegroup as changegroupmod,
23 changegroup as changegroupmod,
24 discovery,
24 discovery,
25 encoding,
25 encoding,
26 error,
26 error,
27 exchange,
27 exchange,
28 peer,
28 peer,
29 pushkey as pushkeymod,
29 pushkey as pushkeymod,
30 pycompat,
30 pycompat,
31 repository,
31 repository,
32 streamclone,
32 streamclone,
33 util,
33 util,
34 wireprototypes,
34 wireprototypes,
35 )
35 )
36
36
37 from .utils import (
37 from .utils import (
38 procutil,
38 procutil,
39 stringutil,
39 stringutil,
40 )
40 )
41
41
42 urlerr = util.urlerr
42 urlerr = util.urlerr
43 urlreq = util.urlreq
43 urlreq = util.urlreq
44
44
45 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
45 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
46 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
46 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
47 'IncompatibleClient')
47 'IncompatibleClient')
48 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
48 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
49
49
50 class remoteiterbatcher(peer.iterbatcher):
50 class remoteiterbatcher(peer.iterbatcher):
51 def __init__(self, remote):
51 def __init__(self, remote):
52 super(remoteiterbatcher, self).__init__()
52 super(remoteiterbatcher, self).__init__()
53 self._remote = remote
53 self._remote = remote
54
54
55 def __getattr__(self, name):
55 def __getattr__(self, name):
56 # Validate this method is batchable, since submit() only supports
56 # Validate this method is batchable, since submit() only supports
57 # batchable methods.
57 # batchable methods.
58 fn = getattr(self._remote, name)
58 fn = getattr(self._remote, name)
59 if not getattr(fn, 'batchable', None):
59 if not getattr(fn, 'batchable', None):
60 raise error.ProgrammingError('Attempted to batch a non-batchable '
60 raise error.ProgrammingError('Attempted to batch a non-batchable '
61 'call to %r' % name)
61 'call to %r' % name)
62
62
63 return super(remoteiterbatcher, self).__getattr__(name)
63 return super(remoteiterbatcher, self).__getattr__(name)
64
64
65 def submit(self):
65 def submit(self):
66 """Break the batch request into many patch calls and pipeline them.
66 """Break the batch request into many patch calls and pipeline them.
67
67
68 This is mostly valuable over http where request sizes can be
68 This is mostly valuable over http where request sizes can be
69 limited, but can be used in other places as well.
69 limited, but can be used in other places as well.
70 """
70 """
71 # 2-tuple of (command, arguments) that represents what will be
71 # 2-tuple of (command, arguments) that represents what will be
72 # sent over the wire.
72 # sent over the wire.
73 requests = []
73 requests = []
74
74
75 # 4-tuple of (command, final future, @batchable generator, remote
75 # 4-tuple of (command, final future, @batchable generator, remote
76 # future).
76 # future).
77 results = []
77 results = []
78
78
79 for command, args, opts, finalfuture in self.calls:
79 for command, args, opts, finalfuture in self.calls:
80 mtd = getattr(self._remote, command)
80 mtd = getattr(self._remote, command)
81 batchable = mtd.batchable(mtd.__self__, *args, **opts)
81 batchable = mtd.batchable(mtd.__self__, *args, **opts)
82
82
83 commandargs, fremote = next(batchable)
83 commandargs, fremote = next(batchable)
84 assert fremote
84 assert fremote
85 requests.append((command, commandargs))
85 requests.append((command, commandargs))
86 results.append((command, finalfuture, batchable, fremote))
86 results.append((command, finalfuture, batchable, fremote))
87
87
88 if requests:
88 if requests:
89 self._resultiter = self._remote._submitbatch(requests)
89 self._resultiter = self._remote._submitbatch(requests)
90
90
91 self._results = results
91 self._results = results
92
92
93 def results(self):
93 def results(self):
94 for command, finalfuture, batchable, remotefuture in self._results:
94 for command, finalfuture, batchable, remotefuture in self._results:
95 # Get the raw result, set it in the remote future, feed it
95 # Get the raw result, set it in the remote future, feed it
96 # back into the @batchable generator so it can be decoded, and
96 # back into the @batchable generator so it can be decoded, and
97 # set the result on the final future to this value.
97 # set the result on the final future to this value.
98 remoteresult = next(self._resultiter)
98 remoteresult = next(self._resultiter)
99 remotefuture.set(remoteresult)
99 remotefuture.set(remoteresult)
100 finalfuture.set(next(batchable))
100 finalfuture.set(next(batchable))
101
101
102 # Verify our @batchable generators only emit 2 values.
102 # Verify our @batchable generators only emit 2 values.
103 try:
103 try:
104 next(batchable)
104 next(batchable)
105 except StopIteration:
105 except StopIteration:
106 pass
106 pass
107 else:
107 else:
108 raise error.ProgrammingError('%s @batchable generator emitted '
108 raise error.ProgrammingError('%s @batchable generator emitted '
109 'unexpected value count' % command)
109 'unexpected value count' % command)
110
110
111 yield finalfuture.value
111 yield finalfuture.value
112
112
113 # Forward a couple of names from peer to make wireproto interactions
113 # Forward a couple of names from peer to make wireproto interactions
114 # slightly more sensible.
114 # slightly more sensible.
115 batchable = peer.batchable
115 batchable = peer.batchable
116 future = peer.future
116 future = peer.future
117
117
118 # list of nodes encoding / decoding
118 # list of nodes encoding / decoding
119
119
120 def decodelist(l, sep=' '):
120 def decodelist(l, sep=' '):
121 if l:
121 if l:
122 return [bin(v) for v in l.split(sep)]
122 return [bin(v) for v in l.split(sep)]
123 return []
123 return []
124
124
125 def encodelist(l, sep=' '):
125 def encodelist(l, sep=' '):
126 try:
126 try:
127 return sep.join(map(hex, l))
127 return sep.join(map(hex, l))
128 except TypeError:
128 except TypeError:
129 raise
129 raise
130
130
131 # batched call argument encoding
131 # batched call argument encoding
132
132
133 def escapearg(plain):
133 def escapearg(plain):
134 return (plain
134 return (plain
135 .replace(':', ':c')
135 .replace(':', ':c')
136 .replace(',', ':o')
136 .replace(',', ':o')
137 .replace(';', ':s')
137 .replace(';', ':s')
138 .replace('=', ':e'))
138 .replace('=', ':e'))
139
139
140 def unescapearg(escaped):
140 def unescapearg(escaped):
141 return (escaped
141 return (escaped
142 .replace(':e', '=')
142 .replace(':e', '=')
143 .replace(':s', ';')
143 .replace(':s', ';')
144 .replace(':o', ',')
144 .replace(':o', ',')
145 .replace(':c', ':'))
145 .replace(':c', ':'))
146
146
147 def encodebatchcmds(req):
147 def encodebatchcmds(req):
148 """Return a ``cmds`` argument value for the ``batch`` command."""
148 """Return a ``cmds`` argument value for the ``batch`` command."""
149 cmds = []
149 cmds = []
150 for op, argsdict in req:
150 for op, argsdict in req:
151 # Old servers didn't properly unescape argument names. So prevent
151 # Old servers didn't properly unescape argument names. So prevent
152 # the sending of argument names that may not be decoded properly by
152 # the sending of argument names that may not be decoded properly by
153 # servers.
153 # servers.
154 assert all(escapearg(k) == k for k in argsdict)
154 assert all(escapearg(k) == k for k in argsdict)
155
155
156 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
156 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
157 for k, v in argsdict.iteritems())
157 for k, v in argsdict.iteritems())
158 cmds.append('%s %s' % (op, args))
158 cmds.append('%s %s' % (op, args))
159
159
160 return ';'.join(cmds)
160 return ';'.join(cmds)
161
161
162 def clientcompressionsupport(proto):
162 def clientcompressionsupport(proto):
163 """Returns a list of compression methods supported by the client.
163 """Returns a list of compression methods supported by the client.
164
164
165 Returns a list of the compression methods supported by the client
165 Returns a list of the compression methods supported by the client
166 according to the protocol capabilities. If no such capability has
166 according to the protocol capabilities. If no such capability has
167 been announced, fallback to the default of zlib and uncompressed.
167 been announced, fallback to the default of zlib and uncompressed.
168 """
168 """
169 for cap in proto.getprotocaps():
169 for cap in proto.getprotocaps():
170 if cap.startswith('comp='):
170 if cap.startswith('comp='):
171 return cap[5:].split(',')
171 return cap[5:].split(',')
172 return ['zlib', 'none']
172 return ['zlib', 'none']
173
173
174 # mapping of options accepted by getbundle and their types
174 # mapping of options accepted by getbundle and their types
175 #
175 #
176 # Meant to be extended by extensions. It is extensions responsibility to ensure
176 # Meant to be extended by extensions. It is extensions responsibility to ensure
177 # such options are properly processed in exchange.getbundle.
177 # such options are properly processed in exchange.getbundle.
178 #
178 #
179 # supported types are:
179 # supported types are:
180 #
180 #
181 # :nodes: list of binary nodes
181 # :nodes: list of binary nodes
182 # :csv: list of comma-separated values
182 # :csv: list of comma-separated values
183 # :scsv: list of comma-separated values return as set
183 # :scsv: list of comma-separated values return as set
184 # :plain: string with no transformation needed.
184 # :plain: string with no transformation needed.
185 gboptsmap = {'heads': 'nodes',
185 gboptsmap = {'heads': 'nodes',
186 'bookmarks': 'boolean',
186 'bookmarks': 'boolean',
187 'common': 'nodes',
187 'common': 'nodes',
188 'obsmarkers': 'boolean',
188 'obsmarkers': 'boolean',
189 'phases': 'boolean',
189 'phases': 'boolean',
190 'bundlecaps': 'scsv',
190 'bundlecaps': 'scsv',
191 'listkeys': 'csv',
191 'listkeys': 'csv',
192 'cg': 'boolean',
192 'cg': 'boolean',
193 'cbattempted': 'boolean',
193 'cbattempted': 'boolean',
194 'stream': 'boolean',
194 'stream': 'boolean',
195 }
195 }
196
196
197 # client side
197 # client side
198
198
199 class wirepeer(repository.legacypeer):
199 class wirepeer(repository.legacypeer):
200 """Client-side interface for communicating with a peer repository.
200 """Client-side interface for communicating with a peer repository.
201
201
202 Methods commonly call wire protocol commands of the same name.
202 Methods commonly call wire protocol commands of the same name.
203
203
204 See also httppeer.py and sshpeer.py for protocol-specific
204 See also httppeer.py and sshpeer.py for protocol-specific
205 implementations of this interface.
205 implementations of this interface.
206 """
206 """
207 # Begin of ipeercommands interface.
207 # Begin of ipeercommands interface.
208
208
209 def iterbatch(self):
209 def iterbatch(self):
210 return remoteiterbatcher(self)
210 return remoteiterbatcher(self)
211
211
212 @batchable
212 @batchable
213 def lookup(self, key):
213 def lookup(self, key):
214 self.requirecap('lookup', _('look up remote revision'))
214 self.requirecap('lookup', _('look up remote revision'))
215 f = future()
215 f = future()
216 yield {'key': encoding.fromlocal(key)}, f
216 yield {'key': encoding.fromlocal(key)}, f
217 d = f.value
217 d = f.value
218 success, data = d[:-1].split(" ", 1)
218 success, data = d[:-1].split(" ", 1)
219 if int(success):
219 if int(success):
220 yield bin(data)
220 yield bin(data)
221 else:
221 else:
222 self._abort(error.RepoError(data))
222 self._abort(error.RepoError(data))
223
223
224 @batchable
224 @batchable
225 def heads(self):
225 def heads(self):
226 f = future()
226 f = future()
227 yield {}, f
227 yield {}, f
228 d = f.value
228 d = f.value
229 try:
229 try:
230 yield decodelist(d[:-1])
230 yield decodelist(d[:-1])
231 except ValueError:
231 except ValueError:
232 self._abort(error.ResponseError(_("unexpected response:"), d))
232 self._abort(error.ResponseError(_("unexpected response:"), d))
233
233
234 @batchable
234 @batchable
235 def known(self, nodes):
235 def known(self, nodes):
236 f = future()
236 f = future()
237 yield {'nodes': encodelist(nodes)}, f
237 yield {'nodes': encodelist(nodes)}, f
238 d = f.value
238 d = f.value
239 try:
239 try:
240 yield [bool(int(b)) for b in d]
240 yield [bool(int(b)) for b in d]
241 except ValueError:
241 except ValueError:
242 self._abort(error.ResponseError(_("unexpected response:"), d))
242 self._abort(error.ResponseError(_("unexpected response:"), d))
243
243
244 @batchable
244 @batchable
245 def branchmap(self):
245 def branchmap(self):
246 f = future()
246 f = future()
247 yield {}, f
247 yield {}, f
248 d = f.value
248 d = f.value
249 try:
249 try:
250 branchmap = {}
250 branchmap = {}
251 for branchpart in d.splitlines():
251 for branchpart in d.splitlines():
252 branchname, branchheads = branchpart.split(' ', 1)
252 branchname, branchheads = branchpart.split(' ', 1)
253 branchname = encoding.tolocal(urlreq.unquote(branchname))
253 branchname = encoding.tolocal(urlreq.unquote(branchname))
254 branchheads = decodelist(branchheads)
254 branchheads = decodelist(branchheads)
255 branchmap[branchname] = branchheads
255 branchmap[branchname] = branchheads
256 yield branchmap
256 yield branchmap
257 except TypeError:
257 except TypeError:
258 self._abort(error.ResponseError(_("unexpected response:"), d))
258 self._abort(error.ResponseError(_("unexpected response:"), d))
259
259
260 @batchable
260 @batchable
261 def listkeys(self, namespace):
261 def listkeys(self, namespace):
262 if not self.capable('pushkey'):
262 if not self.capable('pushkey'):
263 yield {}, None
263 yield {}, None
264 f = future()
264 f = future()
265 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
265 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
266 yield {'namespace': encoding.fromlocal(namespace)}, f
266 yield {'namespace': encoding.fromlocal(namespace)}, f
267 d = f.value
267 d = f.value
268 self.ui.debug('received listkey for "%s": %i bytes\n'
268 self.ui.debug('received listkey for "%s": %i bytes\n'
269 % (namespace, len(d)))
269 % (namespace, len(d)))
270 yield pushkeymod.decodekeys(d)
270 yield pushkeymod.decodekeys(d)
271
271
272 @batchable
272 @batchable
273 def pushkey(self, namespace, key, old, new):
273 def pushkey(self, namespace, key, old, new):
274 if not self.capable('pushkey'):
274 if not self.capable('pushkey'):
275 yield False, None
275 yield False, None
276 f = future()
276 f = future()
277 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
277 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
278 yield {'namespace': encoding.fromlocal(namespace),
278 yield {'namespace': encoding.fromlocal(namespace),
279 'key': encoding.fromlocal(key),
279 'key': encoding.fromlocal(key),
280 'old': encoding.fromlocal(old),
280 'old': encoding.fromlocal(old),
281 'new': encoding.fromlocal(new)}, f
281 'new': encoding.fromlocal(new)}, f
282 d = f.value
282 d = f.value
283 d, output = d.split('\n', 1)
283 d, output = d.split('\n', 1)
284 try:
284 try:
285 d = bool(int(d))
285 d = bool(int(d))
286 except ValueError:
286 except ValueError:
287 raise error.ResponseError(
287 raise error.ResponseError(
288 _('push failed (unexpected response):'), d)
288 _('push failed (unexpected response):'), d)
289 for l in output.splitlines(True):
289 for l in output.splitlines(True):
290 self.ui.status(_('remote: '), l)
290 self.ui.status(_('remote: '), l)
291 yield d
291 yield d
292
292
293 def stream_out(self):
293 def stream_out(self):
294 return self._callstream('stream_out')
294 return self._callstream('stream_out')
295
295
296 def getbundle(self, source, **kwargs):
296 def getbundle(self, source, **kwargs):
297 kwargs = pycompat.byteskwargs(kwargs)
297 kwargs = pycompat.byteskwargs(kwargs)
298 self.requirecap('getbundle', _('look up remote changes'))
298 self.requirecap('getbundle', _('look up remote changes'))
299 opts = {}
299 opts = {}
300 bundlecaps = kwargs.get('bundlecaps')
300 bundlecaps = kwargs.get('bundlecaps') or set()
301 if bundlecaps is not None:
302 kwargs['bundlecaps'] = sorted(bundlecaps)
303 else:
304 bundlecaps = () # kwargs could have it to None
305 for key, value in kwargs.iteritems():
301 for key, value in kwargs.iteritems():
306 if value is None:
302 if value is None:
307 continue
303 continue
308 keytype = gboptsmap.get(key)
304 keytype = gboptsmap.get(key)
309 if keytype is None:
305 if keytype is None:
310 raise error.ProgrammingError(
306 raise error.ProgrammingError(
311 'Unexpectedly None keytype for key %s' % key)
307 'Unexpectedly None keytype for key %s' % key)
312 elif keytype == 'nodes':
308 elif keytype == 'nodes':
313 value = encodelist(value)
309 value = encodelist(value)
314 elif keytype in ('csv', 'scsv'):
310 elif keytype == 'csv':
315 value = ','.join(value)
311 value = ','.join(value)
312 elif keytype == 'scsv':
313 value = ','.join(sorted(value))
316 elif keytype == 'boolean':
314 elif keytype == 'boolean':
317 value = '%i' % bool(value)
315 value = '%i' % bool(value)
318 elif keytype != 'plain':
316 elif keytype != 'plain':
319 raise KeyError('unknown getbundle option type %s'
317 raise KeyError('unknown getbundle option type %s'
320 % keytype)
318 % keytype)
321 opts[key] = value
319 opts[key] = value
322 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
320 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
323 if any((cap.startswith('HG2') for cap in bundlecaps)):
321 if any((cap.startswith('HG2') for cap in bundlecaps)):
324 return bundle2.getunbundler(self.ui, f)
322 return bundle2.getunbundler(self.ui, f)
325 else:
323 else:
326 return changegroupmod.cg1unpacker(f, 'UN')
324 return changegroupmod.cg1unpacker(f, 'UN')
327
325
328 def unbundle(self, cg, heads, url):
326 def unbundle(self, cg, heads, url):
329 '''Send cg (a readable file-like object representing the
327 '''Send cg (a readable file-like object representing the
330 changegroup to push, typically a chunkbuffer object) to the
328 changegroup to push, typically a chunkbuffer object) to the
331 remote server as a bundle.
329 remote server as a bundle.
332
330
333 When pushing a bundle10 stream, return an integer indicating the
331 When pushing a bundle10 stream, return an integer indicating the
334 result of the push (see changegroup.apply()).
332 result of the push (see changegroup.apply()).
335
333
336 When pushing a bundle20 stream, return a bundle20 stream.
334 When pushing a bundle20 stream, return a bundle20 stream.
337
335
338 `url` is the url the client thinks it's pushing to, which is
336 `url` is the url the client thinks it's pushing to, which is
339 visible to hooks.
337 visible to hooks.
340 '''
338 '''
341
339
342 if heads != ['force'] and self.capable('unbundlehash'):
340 if heads != ['force'] and self.capable('unbundlehash'):
343 heads = encodelist(['hashed',
341 heads = encodelist(['hashed',
344 hashlib.sha1(''.join(sorted(heads))).digest()])
342 hashlib.sha1(''.join(sorted(heads))).digest()])
345 else:
343 else:
346 heads = encodelist(heads)
344 heads = encodelist(heads)
347
345
348 if util.safehasattr(cg, 'deltaheader'):
346 if util.safehasattr(cg, 'deltaheader'):
349 # this a bundle10, do the old style call sequence
347 # this a bundle10, do the old style call sequence
350 ret, output = self._callpush("unbundle", cg, heads=heads)
348 ret, output = self._callpush("unbundle", cg, heads=heads)
351 if ret == "":
349 if ret == "":
352 raise error.ResponseError(
350 raise error.ResponseError(
353 _('push failed:'), output)
351 _('push failed:'), output)
354 try:
352 try:
355 ret = int(ret)
353 ret = int(ret)
356 except ValueError:
354 except ValueError:
357 raise error.ResponseError(
355 raise error.ResponseError(
358 _('push failed (unexpected response):'), ret)
356 _('push failed (unexpected response):'), ret)
359
357
360 for l in output.splitlines(True):
358 for l in output.splitlines(True):
361 self.ui.status(_('remote: '), l)
359 self.ui.status(_('remote: '), l)
362 else:
360 else:
363 # bundle2 push. Send a stream, fetch a stream.
361 # bundle2 push. Send a stream, fetch a stream.
364 stream = self._calltwowaystream('unbundle', cg, heads=heads)
362 stream = self._calltwowaystream('unbundle', cg, heads=heads)
365 ret = bundle2.getunbundler(self.ui, stream)
363 ret = bundle2.getunbundler(self.ui, stream)
366 return ret
364 return ret
367
365
368 # End of ipeercommands interface.
366 # End of ipeercommands interface.
369
367
370 # Begin of ipeerlegacycommands interface.
368 # Begin of ipeerlegacycommands interface.
371
369
372 def branches(self, nodes):
370 def branches(self, nodes):
373 n = encodelist(nodes)
371 n = encodelist(nodes)
374 d = self._call("branches", nodes=n)
372 d = self._call("branches", nodes=n)
375 try:
373 try:
376 br = [tuple(decodelist(b)) for b in d.splitlines()]
374 br = [tuple(decodelist(b)) for b in d.splitlines()]
377 return br
375 return br
378 except ValueError:
376 except ValueError:
379 self._abort(error.ResponseError(_("unexpected response:"), d))
377 self._abort(error.ResponseError(_("unexpected response:"), d))
380
378
381 def between(self, pairs):
379 def between(self, pairs):
382 batch = 8 # avoid giant requests
380 batch = 8 # avoid giant requests
383 r = []
381 r = []
384 for i in xrange(0, len(pairs), batch):
382 for i in xrange(0, len(pairs), batch):
385 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
383 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
386 d = self._call("between", pairs=n)
384 d = self._call("between", pairs=n)
387 try:
385 try:
388 r.extend(l and decodelist(l) or [] for l in d.splitlines())
386 r.extend(l and decodelist(l) or [] for l in d.splitlines())
389 except ValueError:
387 except ValueError:
390 self._abort(error.ResponseError(_("unexpected response:"), d))
388 self._abort(error.ResponseError(_("unexpected response:"), d))
391 return r
389 return r
392
390
393 def changegroup(self, nodes, kind):
391 def changegroup(self, nodes, kind):
394 n = encodelist(nodes)
392 n = encodelist(nodes)
395 f = self._callcompressable("changegroup", roots=n)
393 f = self._callcompressable("changegroup", roots=n)
396 return changegroupmod.cg1unpacker(f, 'UN')
394 return changegroupmod.cg1unpacker(f, 'UN')
397
395
398 def changegroupsubset(self, bases, heads, kind):
396 def changegroupsubset(self, bases, heads, kind):
399 self.requirecap('changegroupsubset', _('look up remote changes'))
397 self.requirecap('changegroupsubset', _('look up remote changes'))
400 bases = encodelist(bases)
398 bases = encodelist(bases)
401 heads = encodelist(heads)
399 heads = encodelist(heads)
402 f = self._callcompressable("changegroupsubset",
400 f = self._callcompressable("changegroupsubset",
403 bases=bases, heads=heads)
401 bases=bases, heads=heads)
404 return changegroupmod.cg1unpacker(f, 'UN')
402 return changegroupmod.cg1unpacker(f, 'UN')
405
403
406 # End of ipeerlegacycommands interface.
404 # End of ipeerlegacycommands interface.
407
405
408 def _submitbatch(self, req):
406 def _submitbatch(self, req):
409 """run batch request <req> on the server
407 """run batch request <req> on the server
410
408
411 Returns an iterator of the raw responses from the server.
409 Returns an iterator of the raw responses from the server.
412 """
410 """
413 ui = self.ui
411 ui = self.ui
414 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
412 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
415 ui.debug('devel-peer-request: batched-content\n')
413 ui.debug('devel-peer-request: batched-content\n')
416 for op, args in req:
414 for op, args in req:
417 msg = 'devel-peer-request: - %s (%d arguments)\n'
415 msg = 'devel-peer-request: - %s (%d arguments)\n'
418 ui.debug(msg % (op, len(args)))
416 ui.debug(msg % (op, len(args)))
419
417
420 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
418 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
421 chunk = rsp.read(1024)
419 chunk = rsp.read(1024)
422 work = [chunk]
420 work = [chunk]
423 while chunk:
421 while chunk:
424 while ';' not in chunk and chunk:
422 while ';' not in chunk and chunk:
425 chunk = rsp.read(1024)
423 chunk = rsp.read(1024)
426 work.append(chunk)
424 work.append(chunk)
427 merged = ''.join(work)
425 merged = ''.join(work)
428 while ';' in merged:
426 while ';' in merged:
429 one, merged = merged.split(';', 1)
427 one, merged = merged.split(';', 1)
430 yield unescapearg(one)
428 yield unescapearg(one)
431 chunk = rsp.read(1024)
429 chunk = rsp.read(1024)
432 work = [merged, chunk]
430 work = [merged, chunk]
433 yield unescapearg(''.join(work))
431 yield unescapearg(''.join(work))
434
432
435 def _submitone(self, op, args):
433 def _submitone(self, op, args):
436 return self._call(op, **pycompat.strkwargs(args))
434 return self._call(op, **pycompat.strkwargs(args))
437
435
438 def debugwireargs(self, one, two, three=None, four=None, five=None):
436 def debugwireargs(self, one, two, three=None, four=None, five=None):
439 # don't pass optional arguments left at their default value
437 # don't pass optional arguments left at their default value
440 opts = {}
438 opts = {}
441 if three is not None:
439 if three is not None:
442 opts[r'three'] = three
440 opts[r'three'] = three
443 if four is not None:
441 if four is not None:
444 opts[r'four'] = four
442 opts[r'four'] = four
445 return self._call('debugwireargs', one=one, two=two, **opts)
443 return self._call('debugwireargs', one=one, two=two, **opts)
446
444
447 def _call(self, cmd, **args):
445 def _call(self, cmd, **args):
448 """execute <cmd> on the server
446 """execute <cmd> on the server
449
447
450 The command is expected to return a simple string.
448 The command is expected to return a simple string.
451
449
452 returns the server reply as a string."""
450 returns the server reply as a string."""
453 raise NotImplementedError()
451 raise NotImplementedError()
454
452
455 def _callstream(self, cmd, **args):
453 def _callstream(self, cmd, **args):
456 """execute <cmd> on the server
454 """execute <cmd> on the server
457
455
458 The command is expected to return a stream. Note that if the
456 The command is expected to return a stream. Note that if the
459 command doesn't return a stream, _callstream behaves
457 command doesn't return a stream, _callstream behaves
460 differently for ssh and http peers.
458 differently for ssh and http peers.
461
459
462 returns the server reply as a file like object.
460 returns the server reply as a file like object.
463 """
461 """
464 raise NotImplementedError()
462 raise NotImplementedError()
465
463
466 def _callcompressable(self, cmd, **args):
464 def _callcompressable(self, cmd, **args):
467 """execute <cmd> on the server
465 """execute <cmd> on the server
468
466
469 The command is expected to return a stream.
467 The command is expected to return a stream.
470
468
471 The stream may have been compressed in some implementations. This
469 The stream may have been compressed in some implementations. This
472 function takes care of the decompression. This is the only difference
470 function takes care of the decompression. This is the only difference
473 with _callstream.
471 with _callstream.
474
472
475 returns the server reply as a file like object.
473 returns the server reply as a file like object.
476 """
474 """
477 raise NotImplementedError()
475 raise NotImplementedError()
478
476
479 def _callpush(self, cmd, fp, **args):
477 def _callpush(self, cmd, fp, **args):
480 """execute a <cmd> on server
478 """execute a <cmd> on server
481
479
482 The command is expected to be related to a push. Push has a special
480 The command is expected to be related to a push. Push has a special
483 return method.
481 return method.
484
482
485 returns the server reply as a (ret, output) tuple. ret is either
483 returns the server reply as a (ret, output) tuple. ret is either
486 empty (error) or a stringified int.
484 empty (error) or a stringified int.
487 """
485 """
488 raise NotImplementedError()
486 raise NotImplementedError()
489
487
490 def _calltwowaystream(self, cmd, fp, **args):
488 def _calltwowaystream(self, cmd, fp, **args):
491 """execute <cmd> on server
489 """execute <cmd> on server
492
490
493 The command will send a stream to the server and get a stream in reply.
491 The command will send a stream to the server and get a stream in reply.
494 """
492 """
495 raise NotImplementedError()
493 raise NotImplementedError()
496
494
497 def _abort(self, exception):
495 def _abort(self, exception):
498 """clearly abort the wire protocol connection and raise the exception
496 """clearly abort the wire protocol connection and raise the exception
499 """
497 """
500 raise NotImplementedError()
498 raise NotImplementedError()
501
499
502 # server side
500 # server side
503
501
504 # wire protocol command can either return a string or one of these classes.
502 # wire protocol command can either return a string or one of these classes.
505
503
506 def getdispatchrepo(repo, proto, command):
504 def getdispatchrepo(repo, proto, command):
507 """Obtain the repo used for processing wire protocol commands.
505 """Obtain the repo used for processing wire protocol commands.
508
506
509 The intent of this function is to serve as a monkeypatch point for
507 The intent of this function is to serve as a monkeypatch point for
510 extensions that need commands to operate on different repo views under
508 extensions that need commands to operate on different repo views under
511 specialized circumstances.
509 specialized circumstances.
512 """
510 """
513 return repo.filtered('served')
511 return repo.filtered('served')
514
512
515 def dispatch(repo, proto, command):
513 def dispatch(repo, proto, command):
516 repo = getdispatchrepo(repo, proto, command)
514 repo = getdispatchrepo(repo, proto, command)
517
515
518 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
516 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
519 commandtable = commandsv2 if transportversion == 2 else commands
517 commandtable = commandsv2 if transportversion == 2 else commands
520 func, spec = commandtable[command]
518 func, spec = commandtable[command]
521
519
522 args = proto.getargs(spec)
520 args = proto.getargs(spec)
523 return func(repo, proto, *args)
521 return func(repo, proto, *args)
524
522
525 def options(cmd, keys, others):
523 def options(cmd, keys, others):
526 opts = {}
524 opts = {}
527 for k in keys:
525 for k in keys:
528 if k in others:
526 if k in others:
529 opts[k] = others[k]
527 opts[k] = others[k]
530 del others[k]
528 del others[k]
531 if others:
529 if others:
532 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
530 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
533 % (cmd, ",".join(others)))
531 % (cmd, ",".join(others)))
534 return opts
532 return opts
535
533
536 def bundle1allowed(repo, action):
534 def bundle1allowed(repo, action):
537 """Whether a bundle1 operation is allowed from the server.
535 """Whether a bundle1 operation is allowed from the server.
538
536
539 Priority is:
537 Priority is:
540
538
541 1. server.bundle1gd.<action> (if generaldelta active)
539 1. server.bundle1gd.<action> (if generaldelta active)
542 2. server.bundle1.<action>
540 2. server.bundle1.<action>
543 3. server.bundle1gd (if generaldelta active)
541 3. server.bundle1gd (if generaldelta active)
544 4. server.bundle1
542 4. server.bundle1
545 """
543 """
546 ui = repo.ui
544 ui = repo.ui
547 gd = 'generaldelta' in repo.requirements
545 gd = 'generaldelta' in repo.requirements
548
546
549 if gd:
547 if gd:
550 v = ui.configbool('server', 'bundle1gd.%s' % action)
548 v = ui.configbool('server', 'bundle1gd.%s' % action)
551 if v is not None:
549 if v is not None:
552 return v
550 return v
553
551
554 v = ui.configbool('server', 'bundle1.%s' % action)
552 v = ui.configbool('server', 'bundle1.%s' % action)
555 if v is not None:
553 if v is not None:
556 return v
554 return v
557
555
558 if gd:
556 if gd:
559 v = ui.configbool('server', 'bundle1gd')
557 v = ui.configbool('server', 'bundle1gd')
560 if v is not None:
558 if v is not None:
561 return v
559 return v
562
560
563 return ui.configbool('server', 'bundle1')
561 return ui.configbool('server', 'bundle1')
564
562
565 def supportedcompengines(ui, role):
563 def supportedcompengines(ui, role):
566 """Obtain the list of supported compression engines for a request."""
564 """Obtain the list of supported compression engines for a request."""
567 assert role in (util.CLIENTROLE, util.SERVERROLE)
565 assert role in (util.CLIENTROLE, util.SERVERROLE)
568
566
569 compengines = util.compengines.supportedwireengines(role)
567 compengines = util.compengines.supportedwireengines(role)
570
568
571 # Allow config to override default list and ordering.
569 # Allow config to override default list and ordering.
572 if role == util.SERVERROLE:
570 if role == util.SERVERROLE:
573 configengines = ui.configlist('server', 'compressionengines')
571 configengines = ui.configlist('server', 'compressionengines')
574 config = 'server.compressionengines'
572 config = 'server.compressionengines'
575 else:
573 else:
576 # This is currently implemented mainly to facilitate testing. In most
574 # This is currently implemented mainly to facilitate testing. In most
577 # cases, the server should be in charge of choosing a compression engine
575 # cases, the server should be in charge of choosing a compression engine
578 # because a server has the most to lose from a sub-optimal choice. (e.g.
576 # because a server has the most to lose from a sub-optimal choice. (e.g.
579 # CPU DoS due to an expensive engine or a network DoS due to poor
577 # CPU DoS due to an expensive engine or a network DoS due to poor
580 # compression ratio).
578 # compression ratio).
581 configengines = ui.configlist('experimental',
579 configengines = ui.configlist('experimental',
582 'clientcompressionengines')
580 'clientcompressionengines')
583 config = 'experimental.clientcompressionengines'
581 config = 'experimental.clientcompressionengines'
584
582
585 # No explicit config. Filter out the ones that aren't supposed to be
583 # No explicit config. Filter out the ones that aren't supposed to be
586 # advertised and return default ordering.
584 # advertised and return default ordering.
587 if not configengines:
585 if not configengines:
588 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
586 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
589 return [e for e in compengines
587 return [e for e in compengines
590 if getattr(e.wireprotosupport(), attr) > 0]
588 if getattr(e.wireprotosupport(), attr) > 0]
591
589
592 # If compression engines are listed in the config, assume there is a good
590 # If compression engines are listed in the config, assume there is a good
593 # reason for it (like server operators wanting to achieve specific
591 # reason for it (like server operators wanting to achieve specific
594 # performance characteristics). So fail fast if the config references
592 # performance characteristics). So fail fast if the config references
595 # unusable compression engines.
593 # unusable compression engines.
596 validnames = set(e.name() for e in compengines)
594 validnames = set(e.name() for e in compengines)
597 invalidnames = set(e for e in configengines if e not in validnames)
595 invalidnames = set(e for e in configengines if e not in validnames)
598 if invalidnames:
596 if invalidnames:
599 raise error.Abort(_('invalid compression engine defined in %s: %s') %
597 raise error.Abort(_('invalid compression engine defined in %s: %s') %
600 (config, ', '.join(sorted(invalidnames))))
598 (config, ', '.join(sorted(invalidnames))))
601
599
602 compengines = [e for e in compengines if e.name() in configengines]
600 compengines = [e for e in compengines if e.name() in configengines]
603 compengines = sorted(compengines,
601 compengines = sorted(compengines,
604 key=lambda e: configengines.index(e.name()))
602 key=lambda e: configengines.index(e.name()))
605
603
606 if not compengines:
604 if not compengines:
607 raise error.Abort(_('%s config option does not specify any known '
605 raise error.Abort(_('%s config option does not specify any known '
608 'compression engines') % config,
606 'compression engines') % config,
609 hint=_('usable compression engines: %s') %
607 hint=_('usable compression engines: %s') %
610 ', '.sorted(validnames))
608 ', '.sorted(validnames))
611
609
612 return compengines
610 return compengines
613
611
614 class commandentry(object):
612 class commandentry(object):
615 """Represents a declared wire protocol command."""
613 """Represents a declared wire protocol command."""
616 def __init__(self, func, args='', transports=None,
614 def __init__(self, func, args='', transports=None,
617 permission='push'):
615 permission='push'):
618 self.func = func
616 self.func = func
619 self.args = args
617 self.args = args
620 self.transports = transports or set()
618 self.transports = transports or set()
621 self.permission = permission
619 self.permission = permission
622
620
623 def _merge(self, func, args):
621 def _merge(self, func, args):
624 """Merge this instance with an incoming 2-tuple.
622 """Merge this instance with an incoming 2-tuple.
625
623
626 This is called when a caller using the old 2-tuple API attempts
624 This is called when a caller using the old 2-tuple API attempts
627 to replace an instance. The incoming values are merged with
625 to replace an instance. The incoming values are merged with
628 data not captured by the 2-tuple and a new instance containing
626 data not captured by the 2-tuple and a new instance containing
629 the union of the two objects is returned.
627 the union of the two objects is returned.
630 """
628 """
631 return commandentry(func, args=args, transports=set(self.transports),
629 return commandentry(func, args=args, transports=set(self.transports),
632 permission=self.permission)
630 permission=self.permission)
633
631
634 # Old code treats instances as 2-tuples. So expose that interface.
632 # Old code treats instances as 2-tuples. So expose that interface.
635 def __iter__(self):
633 def __iter__(self):
636 yield self.func
634 yield self.func
637 yield self.args
635 yield self.args
638
636
639 def __getitem__(self, i):
637 def __getitem__(self, i):
640 if i == 0:
638 if i == 0:
641 return self.func
639 return self.func
642 elif i == 1:
640 elif i == 1:
643 return self.args
641 return self.args
644 else:
642 else:
645 raise IndexError('can only access elements 0 and 1')
643 raise IndexError('can only access elements 0 and 1')
646
644
647 class commanddict(dict):
645 class commanddict(dict):
648 """Container for registered wire protocol commands.
646 """Container for registered wire protocol commands.
649
647
650 It behaves like a dict. But __setitem__ is overwritten to allow silent
648 It behaves like a dict. But __setitem__ is overwritten to allow silent
651 coercion of values from 2-tuples for API compatibility.
649 coercion of values from 2-tuples for API compatibility.
652 """
650 """
653 def __setitem__(self, k, v):
651 def __setitem__(self, k, v):
654 if isinstance(v, commandentry):
652 if isinstance(v, commandentry):
655 pass
653 pass
656 # Cast 2-tuples to commandentry instances.
654 # Cast 2-tuples to commandentry instances.
657 elif isinstance(v, tuple):
655 elif isinstance(v, tuple):
658 if len(v) != 2:
656 if len(v) != 2:
659 raise ValueError('command tuples must have exactly 2 elements')
657 raise ValueError('command tuples must have exactly 2 elements')
660
658
661 # It is common for extensions to wrap wire protocol commands via
659 # It is common for extensions to wrap wire protocol commands via
662 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
660 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
663 # doing this aren't aware of the new API that uses objects to store
661 # doing this aren't aware of the new API that uses objects to store
664 # command entries, we automatically merge old state with new.
662 # command entries, we automatically merge old state with new.
665 if k in self:
663 if k in self:
666 v = self[k]._merge(v[0], v[1])
664 v = self[k]._merge(v[0], v[1])
667 else:
665 else:
668 # Use default values from @wireprotocommand.
666 # Use default values from @wireprotocommand.
669 v = commandentry(v[0], args=v[1],
667 v = commandentry(v[0], args=v[1],
670 transports=set(wireprototypes.TRANSPORTS),
668 transports=set(wireprototypes.TRANSPORTS),
671 permission='push')
669 permission='push')
672 else:
670 else:
673 raise ValueError('command entries must be commandentry instances '
671 raise ValueError('command entries must be commandentry instances '
674 'or 2-tuples')
672 'or 2-tuples')
675
673
676 return super(commanddict, self).__setitem__(k, v)
674 return super(commanddict, self).__setitem__(k, v)
677
675
678 def commandavailable(self, command, proto):
676 def commandavailable(self, command, proto):
679 """Determine if a command is available for the requested protocol."""
677 """Determine if a command is available for the requested protocol."""
680 assert proto.name in wireprototypes.TRANSPORTS
678 assert proto.name in wireprototypes.TRANSPORTS
681
679
682 entry = self.get(command)
680 entry = self.get(command)
683
681
684 if not entry:
682 if not entry:
685 return False
683 return False
686
684
687 if proto.name not in entry.transports:
685 if proto.name not in entry.transports:
688 return False
686 return False
689
687
690 return True
688 return True
691
689
692 # Constants specifying which transports a wire protocol command should be
690 # Constants specifying which transports a wire protocol command should be
693 # available on. For use with @wireprotocommand.
691 # available on. For use with @wireprotocommand.
694 POLICY_ALL = 'all'
692 POLICY_ALL = 'all'
695 POLICY_V1_ONLY = 'v1-only'
693 POLICY_V1_ONLY = 'v1-only'
696 POLICY_V2_ONLY = 'v2-only'
694 POLICY_V2_ONLY = 'v2-only'
697
695
698 # For version 1 transports.
696 # For version 1 transports.
699 commands = commanddict()
697 commands = commanddict()
700
698
701 # For version 2 transports.
699 # For version 2 transports.
702 commandsv2 = commanddict()
700 commandsv2 = commanddict()
703
701
704 def wireprotocommand(name, args='', transportpolicy=POLICY_ALL,
702 def wireprotocommand(name, args='', transportpolicy=POLICY_ALL,
705 permission='push'):
703 permission='push'):
706 """Decorator to declare a wire protocol command.
704 """Decorator to declare a wire protocol command.
707
705
708 ``name`` is the name of the wire protocol command being provided.
706 ``name`` is the name of the wire protocol command being provided.
709
707
710 ``args`` is a space-delimited list of named arguments that the command
708 ``args`` is a space-delimited list of named arguments that the command
711 accepts. ``*`` is a special value that says to accept all arguments.
709 accepts. ``*`` is a special value that says to accept all arguments.
712
710
713 ``transportpolicy`` is a POLICY_* constant denoting which transports
711 ``transportpolicy`` is a POLICY_* constant denoting which transports
714 this wire protocol command should be exposed to. By default, commands
712 this wire protocol command should be exposed to. By default, commands
715 are exposed to all wire protocol transports.
713 are exposed to all wire protocol transports.
716
714
717 ``permission`` defines the permission type needed to run this command.
715 ``permission`` defines the permission type needed to run this command.
718 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
716 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
719 respectively. Default is to assume command requires ``push`` permissions
717 respectively. Default is to assume command requires ``push`` permissions
720 because otherwise commands not declaring their permissions could modify
718 because otherwise commands not declaring their permissions could modify
721 a repository that is supposed to be read-only.
719 a repository that is supposed to be read-only.
722 """
720 """
723 if transportpolicy == POLICY_ALL:
721 if transportpolicy == POLICY_ALL:
724 transports = set(wireprototypes.TRANSPORTS)
722 transports = set(wireprototypes.TRANSPORTS)
725 transportversions = {1, 2}
723 transportversions = {1, 2}
726 elif transportpolicy == POLICY_V1_ONLY:
724 elif transportpolicy == POLICY_V1_ONLY:
727 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
725 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
728 if v['version'] == 1}
726 if v['version'] == 1}
729 transportversions = {1}
727 transportversions = {1}
730 elif transportpolicy == POLICY_V2_ONLY:
728 elif transportpolicy == POLICY_V2_ONLY:
731 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
729 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
732 if v['version'] == 2}
730 if v['version'] == 2}
733 transportversions = {2}
731 transportversions = {2}
734 else:
732 else:
735 raise error.ProgrammingError('invalid transport policy value: %s' %
733 raise error.ProgrammingError('invalid transport policy value: %s' %
736 transportpolicy)
734 transportpolicy)
737
735
738 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
736 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
739 # SSHv2.
737 # SSHv2.
740 # TODO undo this hack when SSH is using the unified frame protocol.
738 # TODO undo this hack when SSH is using the unified frame protocol.
741 if name == b'batch':
739 if name == b'batch':
742 transports.add(wireprototypes.SSHV2)
740 transports.add(wireprototypes.SSHV2)
743
741
744 if permission not in ('push', 'pull'):
742 if permission not in ('push', 'pull'):
745 raise error.ProgrammingError('invalid wire protocol permission; '
743 raise error.ProgrammingError('invalid wire protocol permission; '
746 'got %s; expected "push" or "pull"' %
744 'got %s; expected "push" or "pull"' %
747 permission)
745 permission)
748
746
749 def register(func):
747 def register(func):
750 if 1 in transportversions:
748 if 1 in transportversions:
751 if name in commands:
749 if name in commands:
752 raise error.ProgrammingError('%s command already registered '
750 raise error.ProgrammingError('%s command already registered '
753 'for version 1' % name)
751 'for version 1' % name)
754 commands[name] = commandentry(func, args=args,
752 commands[name] = commandentry(func, args=args,
755 transports=transports,
753 transports=transports,
756 permission=permission)
754 permission=permission)
757 if 2 in transportversions:
755 if 2 in transportversions:
758 if name in commandsv2:
756 if name in commandsv2:
759 raise error.ProgrammingError('%s command already registered '
757 raise error.ProgrammingError('%s command already registered '
760 'for version 2' % name)
758 'for version 2' % name)
761 commandsv2[name] = commandentry(func, args=args,
759 commandsv2[name] = commandentry(func, args=args,
762 transports=transports,
760 transports=transports,
763 permission=permission)
761 permission=permission)
764
762
765 return func
763 return func
766 return register
764 return register
767
765
768 # TODO define a more appropriate permissions type to use for this.
766 # TODO define a more appropriate permissions type to use for this.
769 @wireprotocommand('batch', 'cmds *', permission='pull',
767 @wireprotocommand('batch', 'cmds *', permission='pull',
770 transportpolicy=POLICY_V1_ONLY)
768 transportpolicy=POLICY_V1_ONLY)
771 def batch(repo, proto, cmds, others):
769 def batch(repo, proto, cmds, others):
772 repo = repo.filtered("served")
770 repo = repo.filtered("served")
773 res = []
771 res = []
774 for pair in cmds.split(';'):
772 for pair in cmds.split(';'):
775 op, args = pair.split(' ', 1)
773 op, args = pair.split(' ', 1)
776 vals = {}
774 vals = {}
777 for a in args.split(','):
775 for a in args.split(','):
778 if a:
776 if a:
779 n, v = a.split('=')
777 n, v = a.split('=')
780 vals[unescapearg(n)] = unescapearg(v)
778 vals[unescapearg(n)] = unescapearg(v)
781 func, spec = commands[op]
779 func, spec = commands[op]
782
780
783 # Validate that client has permissions to perform this command.
781 # Validate that client has permissions to perform this command.
784 perm = commands[op].permission
782 perm = commands[op].permission
785 assert perm in ('push', 'pull')
783 assert perm in ('push', 'pull')
786 proto.checkperm(perm)
784 proto.checkperm(perm)
787
785
788 if spec:
786 if spec:
789 keys = spec.split()
787 keys = spec.split()
790 data = {}
788 data = {}
791 for k in keys:
789 for k in keys:
792 if k == '*':
790 if k == '*':
793 star = {}
791 star = {}
794 for key in vals.keys():
792 for key in vals.keys():
795 if key not in keys:
793 if key not in keys:
796 star[key] = vals[key]
794 star[key] = vals[key]
797 data['*'] = star
795 data['*'] = star
798 else:
796 else:
799 data[k] = vals[k]
797 data[k] = vals[k]
800 result = func(repo, proto, *[data[k] for k in keys])
798 result = func(repo, proto, *[data[k] for k in keys])
801 else:
799 else:
802 result = func(repo, proto)
800 result = func(repo, proto)
803 if isinstance(result, wireprototypes.ooberror):
801 if isinstance(result, wireprototypes.ooberror):
804 return result
802 return result
805
803
806 # For now, all batchable commands must return bytesresponse or
804 # For now, all batchable commands must return bytesresponse or
807 # raw bytes (for backwards compatibility).
805 # raw bytes (for backwards compatibility).
808 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
806 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
809 if isinstance(result, wireprototypes.bytesresponse):
807 if isinstance(result, wireprototypes.bytesresponse):
810 result = result.data
808 result = result.data
811 res.append(escapearg(result))
809 res.append(escapearg(result))
812
810
813 return wireprototypes.bytesresponse(';'.join(res))
811 return wireprototypes.bytesresponse(';'.join(res))
814
812
815 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
813 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
816 permission='pull')
814 permission='pull')
817 def between(repo, proto, pairs):
815 def between(repo, proto, pairs):
818 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
816 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
819 r = []
817 r = []
820 for b in repo.between(pairs):
818 for b in repo.between(pairs):
821 r.append(encodelist(b) + "\n")
819 r.append(encodelist(b) + "\n")
822
820
823 return wireprototypes.bytesresponse(''.join(r))
821 return wireprototypes.bytesresponse(''.join(r))
824
822
825 @wireprotocommand('branchmap', permission='pull')
823 @wireprotocommand('branchmap', permission='pull')
826 def branchmap(repo, proto):
824 def branchmap(repo, proto):
827 branchmap = repo.branchmap()
825 branchmap = repo.branchmap()
828 heads = []
826 heads = []
829 for branch, nodes in branchmap.iteritems():
827 for branch, nodes in branchmap.iteritems():
830 branchname = urlreq.quote(encoding.fromlocal(branch))
828 branchname = urlreq.quote(encoding.fromlocal(branch))
831 branchnodes = encodelist(nodes)
829 branchnodes = encodelist(nodes)
832 heads.append('%s %s' % (branchname, branchnodes))
830 heads.append('%s %s' % (branchname, branchnodes))
833
831
834 return wireprototypes.bytesresponse('\n'.join(heads))
832 return wireprototypes.bytesresponse('\n'.join(heads))
835
833
836 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
834 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
837 permission='pull')
835 permission='pull')
838 def branches(repo, proto, nodes):
836 def branches(repo, proto, nodes):
839 nodes = decodelist(nodes)
837 nodes = decodelist(nodes)
840 r = []
838 r = []
841 for b in repo.branches(nodes):
839 for b in repo.branches(nodes):
842 r.append(encodelist(b) + "\n")
840 r.append(encodelist(b) + "\n")
843
841
844 return wireprototypes.bytesresponse(''.join(r))
842 return wireprototypes.bytesresponse(''.join(r))
845
843
846 @wireprotocommand('clonebundles', '', permission='pull')
844 @wireprotocommand('clonebundles', '', permission='pull')
847 def clonebundles(repo, proto):
845 def clonebundles(repo, proto):
848 """Server command for returning info for available bundles to seed clones.
846 """Server command for returning info for available bundles to seed clones.
849
847
850 Clients will parse this response and determine what bundle to fetch.
848 Clients will parse this response and determine what bundle to fetch.
851
849
852 Extensions may wrap this command to filter or dynamically emit data
850 Extensions may wrap this command to filter or dynamically emit data
853 depending on the request. e.g. you could advertise URLs for the closest
851 depending on the request. e.g. you could advertise URLs for the closest
854 data center given the client's IP address.
852 data center given the client's IP address.
855 """
853 """
856 return wireprototypes.bytesresponse(
854 return wireprototypes.bytesresponse(
857 repo.vfs.tryread('clonebundles.manifest'))
855 repo.vfs.tryread('clonebundles.manifest'))
858
856
859 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
857 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
860 'known', 'getbundle', 'unbundlehash']
858 'known', 'getbundle', 'unbundlehash']
861
859
862 def _capabilities(repo, proto):
860 def _capabilities(repo, proto):
863 """return a list of capabilities for a repo
861 """return a list of capabilities for a repo
864
862
865 This function exists to allow extensions to easily wrap capabilities
863 This function exists to allow extensions to easily wrap capabilities
866 computation
864 computation
867
865
868 - returns a lists: easy to alter
866 - returns a lists: easy to alter
869 - change done here will be propagated to both `capabilities` and `hello`
867 - change done here will be propagated to both `capabilities` and `hello`
870 command without any other action needed.
868 command without any other action needed.
871 """
869 """
872 # copy to prevent modification of the global list
870 # copy to prevent modification of the global list
873 caps = list(wireprotocaps)
871 caps = list(wireprotocaps)
874
872
875 # Command of same name as capability isn't exposed to version 1 of
873 # Command of same name as capability isn't exposed to version 1 of
876 # transports. So conditionally add it.
874 # transports. So conditionally add it.
877 if commands.commandavailable('changegroupsubset', proto):
875 if commands.commandavailable('changegroupsubset', proto):
878 caps.append('changegroupsubset')
876 caps.append('changegroupsubset')
879
877
880 if streamclone.allowservergeneration(repo):
878 if streamclone.allowservergeneration(repo):
881 if repo.ui.configbool('server', 'preferuncompressed'):
879 if repo.ui.configbool('server', 'preferuncompressed'):
882 caps.append('stream-preferred')
880 caps.append('stream-preferred')
883 requiredformats = repo.requirements & repo.supportedformats
881 requiredformats = repo.requirements & repo.supportedformats
884 # if our local revlogs are just revlogv1, add 'stream' cap
882 # if our local revlogs are just revlogv1, add 'stream' cap
885 if not requiredformats - {'revlogv1'}:
883 if not requiredformats - {'revlogv1'}:
886 caps.append('stream')
884 caps.append('stream')
887 # otherwise, add 'streamreqs' detailing our local revlog format
885 # otherwise, add 'streamreqs' detailing our local revlog format
888 else:
886 else:
889 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
887 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
890 if repo.ui.configbool('experimental', 'bundle2-advertise'):
888 if repo.ui.configbool('experimental', 'bundle2-advertise'):
891 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
889 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
892 caps.append('bundle2=' + urlreq.quote(capsblob))
890 caps.append('bundle2=' + urlreq.quote(capsblob))
893 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
891 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
894
892
895 return proto.addcapabilities(repo, caps)
893 return proto.addcapabilities(repo, caps)
896
894
897 # If you are writing an extension and consider wrapping this function. Wrap
895 # If you are writing an extension and consider wrapping this function. Wrap
898 # `_capabilities` instead.
896 # `_capabilities` instead.
899 @wireprotocommand('capabilities', permission='pull')
897 @wireprotocommand('capabilities', permission='pull')
900 def capabilities(repo, proto):
898 def capabilities(repo, proto):
901 return wireprototypes.bytesresponse(' '.join(_capabilities(repo, proto)))
899 return wireprototypes.bytesresponse(' '.join(_capabilities(repo, proto)))
902
900
903 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
901 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
904 permission='pull')
902 permission='pull')
905 def changegroup(repo, proto, roots):
903 def changegroup(repo, proto, roots):
906 nodes = decodelist(roots)
904 nodes = decodelist(roots)
907 outgoing = discovery.outgoing(repo, missingroots=nodes,
905 outgoing = discovery.outgoing(repo, missingroots=nodes,
908 missingheads=repo.heads())
906 missingheads=repo.heads())
909 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
907 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
910 gen = iter(lambda: cg.read(32768), '')
908 gen = iter(lambda: cg.read(32768), '')
911 return wireprototypes.streamres(gen=gen)
909 return wireprototypes.streamres(gen=gen)
912
910
913 @wireprotocommand('changegroupsubset', 'bases heads',
911 @wireprotocommand('changegroupsubset', 'bases heads',
914 transportpolicy=POLICY_V1_ONLY,
912 transportpolicy=POLICY_V1_ONLY,
915 permission='pull')
913 permission='pull')
916 def changegroupsubset(repo, proto, bases, heads):
914 def changegroupsubset(repo, proto, bases, heads):
917 bases = decodelist(bases)
915 bases = decodelist(bases)
918 heads = decodelist(heads)
916 heads = decodelist(heads)
919 outgoing = discovery.outgoing(repo, missingroots=bases,
917 outgoing = discovery.outgoing(repo, missingroots=bases,
920 missingheads=heads)
918 missingheads=heads)
921 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
919 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
922 gen = iter(lambda: cg.read(32768), '')
920 gen = iter(lambda: cg.read(32768), '')
923 return wireprototypes.streamres(gen=gen)
921 return wireprototypes.streamres(gen=gen)
924
922
925 @wireprotocommand('debugwireargs', 'one two *',
923 @wireprotocommand('debugwireargs', 'one two *',
926 permission='pull')
924 permission='pull')
927 def debugwireargs(repo, proto, one, two, others):
925 def debugwireargs(repo, proto, one, two, others):
928 # only accept optional args from the known set
926 # only accept optional args from the known set
929 opts = options('debugwireargs', ['three', 'four'], others)
927 opts = options('debugwireargs', ['three', 'four'], others)
930 return wireprototypes.bytesresponse(repo.debugwireargs(
928 return wireprototypes.bytesresponse(repo.debugwireargs(
931 one, two, **pycompat.strkwargs(opts)))
929 one, two, **pycompat.strkwargs(opts)))
932
930
933 @wireprotocommand('getbundle', '*', permission='pull')
931 @wireprotocommand('getbundle', '*', permission='pull')
934 def getbundle(repo, proto, others):
932 def getbundle(repo, proto, others):
935 opts = options('getbundle', gboptsmap.keys(), others)
933 opts = options('getbundle', gboptsmap.keys(), others)
936 for k, v in opts.iteritems():
934 for k, v in opts.iteritems():
937 keytype = gboptsmap[k]
935 keytype = gboptsmap[k]
938 if keytype == 'nodes':
936 if keytype == 'nodes':
939 opts[k] = decodelist(v)
937 opts[k] = decodelist(v)
940 elif keytype == 'csv':
938 elif keytype == 'csv':
941 opts[k] = list(v.split(','))
939 opts[k] = list(v.split(','))
942 elif keytype == 'scsv':
940 elif keytype == 'scsv':
943 opts[k] = set(v.split(','))
941 opts[k] = set(v.split(','))
944 elif keytype == 'boolean':
942 elif keytype == 'boolean':
945 # Client should serialize False as '0', which is a non-empty string
943 # Client should serialize False as '0', which is a non-empty string
946 # so it evaluates as a True bool.
944 # so it evaluates as a True bool.
947 if v == '0':
945 if v == '0':
948 opts[k] = False
946 opts[k] = False
949 else:
947 else:
950 opts[k] = bool(v)
948 opts[k] = bool(v)
951 elif keytype != 'plain':
949 elif keytype != 'plain':
952 raise KeyError('unknown getbundle option type %s'
950 raise KeyError('unknown getbundle option type %s'
953 % keytype)
951 % keytype)
954
952
955 if not bundle1allowed(repo, 'pull'):
953 if not bundle1allowed(repo, 'pull'):
956 if not exchange.bundle2requested(opts.get('bundlecaps')):
954 if not exchange.bundle2requested(opts.get('bundlecaps')):
957 if proto.name == 'http-v1':
955 if proto.name == 'http-v1':
958 return wireprototypes.ooberror(bundle2required)
956 return wireprototypes.ooberror(bundle2required)
959 raise error.Abort(bundle2requiredmain,
957 raise error.Abort(bundle2requiredmain,
960 hint=bundle2requiredhint)
958 hint=bundle2requiredhint)
961
959
962 prefercompressed = True
960 prefercompressed = True
963
961
964 try:
962 try:
965 if repo.ui.configbool('server', 'disablefullbundle'):
963 if repo.ui.configbool('server', 'disablefullbundle'):
966 # Check to see if this is a full clone.
964 # Check to see if this is a full clone.
967 clheads = set(repo.changelog.heads())
965 clheads = set(repo.changelog.heads())
968 changegroup = opts.get('cg', True)
966 changegroup = opts.get('cg', True)
969 heads = set(opts.get('heads', set()))
967 heads = set(opts.get('heads', set()))
970 common = set(opts.get('common', set()))
968 common = set(opts.get('common', set()))
971 common.discard(nullid)
969 common.discard(nullid)
972 if changegroup and not common and clheads == heads:
970 if changegroup and not common and clheads == heads:
973 raise error.Abort(
971 raise error.Abort(
974 _('server has pull-based clones disabled'),
972 _('server has pull-based clones disabled'),
975 hint=_('remove --pull if specified or upgrade Mercurial'))
973 hint=_('remove --pull if specified or upgrade Mercurial'))
976
974
977 info, chunks = exchange.getbundlechunks(repo, 'serve',
975 info, chunks = exchange.getbundlechunks(repo, 'serve',
978 **pycompat.strkwargs(opts))
976 **pycompat.strkwargs(opts))
979 prefercompressed = info.get('prefercompressed', True)
977 prefercompressed = info.get('prefercompressed', True)
980 except error.Abort as exc:
978 except error.Abort as exc:
981 # cleanly forward Abort error to the client
979 # cleanly forward Abort error to the client
982 if not exchange.bundle2requested(opts.get('bundlecaps')):
980 if not exchange.bundle2requested(opts.get('bundlecaps')):
983 if proto.name == 'http-v1':
981 if proto.name == 'http-v1':
984 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
982 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
985 raise # cannot do better for bundle1 + ssh
983 raise # cannot do better for bundle1 + ssh
986 # bundle2 request expect a bundle2 reply
984 # bundle2 request expect a bundle2 reply
987 bundler = bundle2.bundle20(repo.ui)
985 bundler = bundle2.bundle20(repo.ui)
988 manargs = [('message', pycompat.bytestr(exc))]
986 manargs = [('message', pycompat.bytestr(exc))]
989 advargs = []
987 advargs = []
990 if exc.hint is not None:
988 if exc.hint is not None:
991 advargs.append(('hint', exc.hint))
989 advargs.append(('hint', exc.hint))
992 bundler.addpart(bundle2.bundlepart('error:abort',
990 bundler.addpart(bundle2.bundlepart('error:abort',
993 manargs, advargs))
991 manargs, advargs))
994 chunks = bundler.getchunks()
992 chunks = bundler.getchunks()
995 prefercompressed = False
993 prefercompressed = False
996
994
997 return wireprototypes.streamres(
995 return wireprototypes.streamres(
998 gen=chunks, prefer_uncompressed=not prefercompressed)
996 gen=chunks, prefer_uncompressed=not prefercompressed)
999
997
1000 @wireprotocommand('heads', permission='pull')
998 @wireprotocommand('heads', permission='pull')
1001 def heads(repo, proto):
999 def heads(repo, proto):
1002 h = repo.heads()
1000 h = repo.heads()
1003 return wireprototypes.bytesresponse(encodelist(h) + '\n')
1001 return wireprototypes.bytesresponse(encodelist(h) + '\n')
1004
1002
1005 @wireprotocommand('hello', permission='pull')
1003 @wireprotocommand('hello', permission='pull')
1006 def hello(repo, proto):
1004 def hello(repo, proto):
1007 """Called as part of SSH handshake to obtain server info.
1005 """Called as part of SSH handshake to obtain server info.
1008
1006
1009 Returns a list of lines describing interesting things about the
1007 Returns a list of lines describing interesting things about the
1010 server, in an RFC822-like format.
1008 server, in an RFC822-like format.
1011
1009
1012 Currently, the only one defined is ``capabilities``, which consists of a
1010 Currently, the only one defined is ``capabilities``, which consists of a
1013 line of space separated tokens describing server abilities:
1011 line of space separated tokens describing server abilities:
1014
1012
1015 capabilities: <token0> <token1> <token2>
1013 capabilities: <token0> <token1> <token2>
1016 """
1014 """
1017 caps = capabilities(repo, proto).data
1015 caps = capabilities(repo, proto).data
1018 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
1016 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
1019
1017
1020 @wireprotocommand('listkeys', 'namespace', permission='pull')
1018 @wireprotocommand('listkeys', 'namespace', permission='pull')
1021 def listkeys(repo, proto, namespace):
1019 def listkeys(repo, proto, namespace):
1022 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
1020 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
1023 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
1021 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
1024
1022
1025 @wireprotocommand('lookup', 'key', permission='pull')
1023 @wireprotocommand('lookup', 'key', permission='pull')
1026 def lookup(repo, proto, key):
1024 def lookup(repo, proto, key):
1027 try:
1025 try:
1028 k = encoding.tolocal(key)
1026 k = encoding.tolocal(key)
1029 n = repo.lookup(k)
1027 n = repo.lookup(k)
1030 r = hex(n)
1028 r = hex(n)
1031 success = 1
1029 success = 1
1032 except Exception as inst:
1030 except Exception as inst:
1033 r = stringutil.forcebytestr(inst)
1031 r = stringutil.forcebytestr(inst)
1034 success = 0
1032 success = 0
1035 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
1033 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
1036
1034
1037 @wireprotocommand('known', 'nodes *', permission='pull')
1035 @wireprotocommand('known', 'nodes *', permission='pull')
1038 def known(repo, proto, nodes, others):
1036 def known(repo, proto, nodes, others):
1039 v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
1037 v = ''.join(b and '1' or '0' for b in repo.known(decodelist(nodes)))
1040 return wireprototypes.bytesresponse(v)
1038 return wireprototypes.bytesresponse(v)
1041
1039
1042 @wireprotocommand('protocaps', 'caps', permission='pull',
1040 @wireprotocommand('protocaps', 'caps', permission='pull',
1043 transportpolicy=POLICY_V1_ONLY)
1041 transportpolicy=POLICY_V1_ONLY)
1044 def protocaps(repo, proto, caps):
1042 def protocaps(repo, proto, caps):
1045 if proto.name == wireprototypes.SSHV1:
1043 if proto.name == wireprototypes.SSHV1:
1046 proto._protocaps = set(caps.split(' '))
1044 proto._protocaps = set(caps.split(' '))
1047 return wireprototypes.bytesresponse('OK')
1045 return wireprototypes.bytesresponse('OK')
1048
1046
1049 @wireprotocommand('pushkey', 'namespace key old new', permission='push')
1047 @wireprotocommand('pushkey', 'namespace key old new', permission='push')
1050 def pushkey(repo, proto, namespace, key, old, new):
1048 def pushkey(repo, proto, namespace, key, old, new):
1051 # compatibility with pre-1.8 clients which were accidentally
1049 # compatibility with pre-1.8 clients which were accidentally
1052 # sending raw binary nodes rather than utf-8-encoded hex
1050 # sending raw binary nodes rather than utf-8-encoded hex
1053 if len(new) == 20 and stringutil.escapestr(new) != new:
1051 if len(new) == 20 and stringutil.escapestr(new) != new:
1054 # looks like it could be a binary node
1052 # looks like it could be a binary node
1055 try:
1053 try:
1056 new.decode('utf-8')
1054 new.decode('utf-8')
1057 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
1055 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
1058 except UnicodeDecodeError:
1056 except UnicodeDecodeError:
1059 pass # binary, leave unmodified
1057 pass # binary, leave unmodified
1060 else:
1058 else:
1061 new = encoding.tolocal(new) # normal path
1059 new = encoding.tolocal(new) # normal path
1062
1060
1063 with proto.mayberedirectstdio() as output:
1061 with proto.mayberedirectstdio() as output:
1064 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
1062 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
1065 encoding.tolocal(old), new) or False
1063 encoding.tolocal(old), new) or False
1066
1064
1067 output = output.getvalue() if output else ''
1065 output = output.getvalue() if output else ''
1068 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
1066 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
1069
1067
1070 @wireprotocommand('stream_out', permission='pull')
1068 @wireprotocommand('stream_out', permission='pull')
1071 def stream(repo, proto):
1069 def stream(repo, proto):
1072 '''If the server supports streaming clone, it advertises the "stream"
1070 '''If the server supports streaming clone, it advertises the "stream"
1073 capability with a value representing the version and flags of the repo
1071 capability with a value representing the version and flags of the repo
1074 it is serving. Client checks to see if it understands the format.
1072 it is serving. Client checks to see if it understands the format.
1075 '''
1073 '''
1076 return wireprototypes.streamreslegacy(
1074 return wireprototypes.streamreslegacy(
1077 streamclone.generatev1wireproto(repo))
1075 streamclone.generatev1wireproto(repo))
1078
1076
1079 @wireprotocommand('unbundle', 'heads', permission='push')
1077 @wireprotocommand('unbundle', 'heads', permission='push')
1080 def unbundle(repo, proto, heads):
1078 def unbundle(repo, proto, heads):
1081 their_heads = decodelist(heads)
1079 their_heads = decodelist(heads)
1082
1080
1083 with proto.mayberedirectstdio() as output:
1081 with proto.mayberedirectstdio() as output:
1084 try:
1082 try:
1085 exchange.check_heads(repo, their_heads, 'preparing changes')
1083 exchange.check_heads(repo, their_heads, 'preparing changes')
1086
1084
1087 # write bundle data to temporary file because it can be big
1085 # write bundle data to temporary file because it can be big
1088 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1086 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1089 fp = os.fdopen(fd, r'wb+')
1087 fp = os.fdopen(fd, r'wb+')
1090 r = 0
1088 r = 0
1091 try:
1089 try:
1092 proto.forwardpayload(fp)
1090 proto.forwardpayload(fp)
1093 fp.seek(0)
1091 fp.seek(0)
1094 gen = exchange.readbundle(repo.ui, fp, None)
1092 gen = exchange.readbundle(repo.ui, fp, None)
1095 if (isinstance(gen, changegroupmod.cg1unpacker)
1093 if (isinstance(gen, changegroupmod.cg1unpacker)
1096 and not bundle1allowed(repo, 'push')):
1094 and not bundle1allowed(repo, 'push')):
1097 if proto.name == 'http-v1':
1095 if proto.name == 'http-v1':
1098 # need to special case http because stderr do not get to
1096 # need to special case http because stderr do not get to
1099 # the http client on failed push so we need to abuse
1097 # the http client on failed push so we need to abuse
1100 # some other error type to make sure the message get to
1098 # some other error type to make sure the message get to
1101 # the user.
1099 # the user.
1102 return wireprototypes.ooberror(bundle2required)
1100 return wireprototypes.ooberror(bundle2required)
1103 raise error.Abort(bundle2requiredmain,
1101 raise error.Abort(bundle2requiredmain,
1104 hint=bundle2requiredhint)
1102 hint=bundle2requiredhint)
1105
1103
1106 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1104 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1107 proto.client())
1105 proto.client())
1108 if util.safehasattr(r, 'addpart'):
1106 if util.safehasattr(r, 'addpart'):
1109 # The return looks streamable, we are in the bundle2 case
1107 # The return looks streamable, we are in the bundle2 case
1110 # and should return a stream.
1108 # and should return a stream.
1111 return wireprototypes.streamreslegacy(gen=r.getchunks())
1109 return wireprototypes.streamreslegacy(gen=r.getchunks())
1112 return wireprototypes.pushres(
1110 return wireprototypes.pushres(
1113 r, output.getvalue() if output else '')
1111 r, output.getvalue() if output else '')
1114
1112
1115 finally:
1113 finally:
1116 fp.close()
1114 fp.close()
1117 os.unlink(tempname)
1115 os.unlink(tempname)
1118
1116
1119 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1117 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1120 # handle non-bundle2 case first
1118 # handle non-bundle2 case first
1121 if not getattr(exc, 'duringunbundle2', False):
1119 if not getattr(exc, 'duringunbundle2', False):
1122 try:
1120 try:
1123 raise
1121 raise
1124 except error.Abort:
1122 except error.Abort:
1125 # The old code we moved used procutil.stderr directly.
1123 # The old code we moved used procutil.stderr directly.
1126 # We did not change it to minimise code change.
1124 # We did not change it to minimise code change.
1127 # This need to be moved to something proper.
1125 # This need to be moved to something proper.
1128 # Feel free to do it.
1126 # Feel free to do it.
1129 procutil.stderr.write("abort: %s\n" % exc)
1127 procutil.stderr.write("abort: %s\n" % exc)
1130 if exc.hint is not None:
1128 if exc.hint is not None:
1131 procutil.stderr.write("(%s)\n" % exc.hint)
1129 procutil.stderr.write("(%s)\n" % exc.hint)
1132 procutil.stderr.flush()
1130 procutil.stderr.flush()
1133 return wireprototypes.pushres(
1131 return wireprototypes.pushres(
1134 0, output.getvalue() if output else '')
1132 0, output.getvalue() if output else '')
1135 except error.PushRaced:
1133 except error.PushRaced:
1136 return wireprototypes.pusherr(
1134 return wireprototypes.pusherr(
1137 pycompat.bytestr(exc),
1135 pycompat.bytestr(exc),
1138 output.getvalue() if output else '')
1136 output.getvalue() if output else '')
1139
1137
1140 bundler = bundle2.bundle20(repo.ui)
1138 bundler = bundle2.bundle20(repo.ui)
1141 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1139 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1142 bundler.addpart(out)
1140 bundler.addpart(out)
1143 try:
1141 try:
1144 try:
1142 try:
1145 raise
1143 raise
1146 except error.PushkeyFailed as exc:
1144 except error.PushkeyFailed as exc:
1147 # check client caps
1145 # check client caps
1148 remotecaps = getattr(exc, '_replycaps', None)
1146 remotecaps = getattr(exc, '_replycaps', None)
1149 if (remotecaps is not None
1147 if (remotecaps is not None
1150 and 'pushkey' not in remotecaps.get('error', ())):
1148 and 'pushkey' not in remotecaps.get('error', ())):
1151 # no support remote side, fallback to Abort handler.
1149 # no support remote side, fallback to Abort handler.
1152 raise
1150 raise
1153 part = bundler.newpart('error:pushkey')
1151 part = bundler.newpart('error:pushkey')
1154 part.addparam('in-reply-to', exc.partid)
1152 part.addparam('in-reply-to', exc.partid)
1155 if exc.namespace is not None:
1153 if exc.namespace is not None:
1156 part.addparam('namespace', exc.namespace,
1154 part.addparam('namespace', exc.namespace,
1157 mandatory=False)
1155 mandatory=False)
1158 if exc.key is not None:
1156 if exc.key is not None:
1159 part.addparam('key', exc.key, mandatory=False)
1157 part.addparam('key', exc.key, mandatory=False)
1160 if exc.new is not None:
1158 if exc.new is not None:
1161 part.addparam('new', exc.new, mandatory=False)
1159 part.addparam('new', exc.new, mandatory=False)
1162 if exc.old is not None:
1160 if exc.old is not None:
1163 part.addparam('old', exc.old, mandatory=False)
1161 part.addparam('old', exc.old, mandatory=False)
1164 if exc.ret is not None:
1162 if exc.ret is not None:
1165 part.addparam('ret', exc.ret, mandatory=False)
1163 part.addparam('ret', exc.ret, mandatory=False)
1166 except error.BundleValueError as exc:
1164 except error.BundleValueError as exc:
1167 errpart = bundler.newpart('error:unsupportedcontent')
1165 errpart = bundler.newpart('error:unsupportedcontent')
1168 if exc.parttype is not None:
1166 if exc.parttype is not None:
1169 errpart.addparam('parttype', exc.parttype)
1167 errpart.addparam('parttype', exc.parttype)
1170 if exc.params:
1168 if exc.params:
1171 errpart.addparam('params', '\0'.join(exc.params))
1169 errpart.addparam('params', '\0'.join(exc.params))
1172 except error.Abort as exc:
1170 except error.Abort as exc:
1173 manargs = [('message', stringutil.forcebytestr(exc))]
1171 manargs = [('message', stringutil.forcebytestr(exc))]
1174 advargs = []
1172 advargs = []
1175 if exc.hint is not None:
1173 if exc.hint is not None:
1176 advargs.append(('hint', exc.hint))
1174 advargs.append(('hint', exc.hint))
1177 bundler.addpart(bundle2.bundlepart('error:abort',
1175 bundler.addpart(bundle2.bundlepart('error:abort',
1178 manargs, advargs))
1176 manargs, advargs))
1179 except error.PushRaced as exc:
1177 except error.PushRaced as exc:
1180 bundler.newpart('error:pushraced',
1178 bundler.newpart('error:pushraced',
1181 [('message', stringutil.forcebytestr(exc))])
1179 [('message', stringutil.forcebytestr(exc))])
1182 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
1180 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
General Comments 0
You need to be logged in to leave comments. Login now