##// END OF EJS Templates
wireprotoserver: add context manager mechanism for redirecting stdio...
Gregory Szorc -
r36083:2ad145fb default
parent child Browse files
Show More
@@ -1,1101 +1,1093 b''
1 # wireproto.py - generic wire protocol support functions
1 # wireproto.py - generic wire protocol support functions
2 #
2 #
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import hashlib
10 import hashlib
11 import os
11 import os
12 import tempfile
12 import tempfile
13
13
14 from .i18n import _
14 from .i18n import _
15 from .node import (
15 from .node import (
16 bin,
16 bin,
17 hex,
17 hex,
18 nullid,
18 nullid,
19 )
19 )
20
20
21 from . import (
21 from . import (
22 bundle2,
22 bundle2,
23 changegroup as changegroupmod,
23 changegroup as changegroupmod,
24 discovery,
24 discovery,
25 encoding,
25 encoding,
26 error,
26 error,
27 exchange,
27 exchange,
28 peer,
28 peer,
29 pushkey as pushkeymod,
29 pushkey as pushkeymod,
30 pycompat,
30 pycompat,
31 repository,
31 repository,
32 streamclone,
32 streamclone,
33 util,
33 util,
34 )
34 )
35
35
36 urlerr = util.urlerr
36 urlerr = util.urlerr
37 urlreq = util.urlreq
37 urlreq = util.urlreq
38
38
39 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
39 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
40 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
40 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
41 'IncompatibleClient')
41 'IncompatibleClient')
42 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
42 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
43
43
44 class remoteiterbatcher(peer.iterbatcher):
44 class remoteiterbatcher(peer.iterbatcher):
45 def __init__(self, remote):
45 def __init__(self, remote):
46 super(remoteiterbatcher, self).__init__()
46 super(remoteiterbatcher, self).__init__()
47 self._remote = remote
47 self._remote = remote
48
48
49 def __getattr__(self, name):
49 def __getattr__(self, name):
50 # Validate this method is batchable, since submit() only supports
50 # Validate this method is batchable, since submit() only supports
51 # batchable methods.
51 # batchable methods.
52 fn = getattr(self._remote, name)
52 fn = getattr(self._remote, name)
53 if not getattr(fn, 'batchable', None):
53 if not getattr(fn, 'batchable', None):
54 raise error.ProgrammingError('Attempted to batch a non-batchable '
54 raise error.ProgrammingError('Attempted to batch a non-batchable '
55 'call to %r' % name)
55 'call to %r' % name)
56
56
57 return super(remoteiterbatcher, self).__getattr__(name)
57 return super(remoteiterbatcher, self).__getattr__(name)
58
58
59 def submit(self):
59 def submit(self):
60 """Break the batch request into many patch calls and pipeline them.
60 """Break the batch request into many patch calls and pipeline them.
61
61
62 This is mostly valuable over http where request sizes can be
62 This is mostly valuable over http where request sizes can be
63 limited, but can be used in other places as well.
63 limited, but can be used in other places as well.
64 """
64 """
65 # 2-tuple of (command, arguments) that represents what will be
65 # 2-tuple of (command, arguments) that represents what will be
66 # sent over the wire.
66 # sent over the wire.
67 requests = []
67 requests = []
68
68
69 # 4-tuple of (command, final future, @batchable generator, remote
69 # 4-tuple of (command, final future, @batchable generator, remote
70 # future).
70 # future).
71 results = []
71 results = []
72
72
73 for command, args, opts, finalfuture in self.calls:
73 for command, args, opts, finalfuture in self.calls:
74 mtd = getattr(self._remote, command)
74 mtd = getattr(self._remote, command)
75 batchable = mtd.batchable(mtd.__self__, *args, **opts)
75 batchable = mtd.batchable(mtd.__self__, *args, **opts)
76
76
77 commandargs, fremote = next(batchable)
77 commandargs, fremote = next(batchable)
78 assert fremote
78 assert fremote
79 requests.append((command, commandargs))
79 requests.append((command, commandargs))
80 results.append((command, finalfuture, batchable, fremote))
80 results.append((command, finalfuture, batchable, fremote))
81
81
82 if requests:
82 if requests:
83 self._resultiter = self._remote._submitbatch(requests)
83 self._resultiter = self._remote._submitbatch(requests)
84
84
85 self._results = results
85 self._results = results
86
86
87 def results(self):
87 def results(self):
88 for command, finalfuture, batchable, remotefuture in self._results:
88 for command, finalfuture, batchable, remotefuture in self._results:
89 # Get the raw result, set it in the remote future, feed it
89 # Get the raw result, set it in the remote future, feed it
90 # back into the @batchable generator so it can be decoded, and
90 # back into the @batchable generator so it can be decoded, and
91 # set the result on the final future to this value.
91 # set the result on the final future to this value.
92 remoteresult = next(self._resultiter)
92 remoteresult = next(self._resultiter)
93 remotefuture.set(remoteresult)
93 remotefuture.set(remoteresult)
94 finalfuture.set(next(batchable))
94 finalfuture.set(next(batchable))
95
95
96 # Verify our @batchable generators only emit 2 values.
96 # Verify our @batchable generators only emit 2 values.
97 try:
97 try:
98 next(batchable)
98 next(batchable)
99 except StopIteration:
99 except StopIteration:
100 pass
100 pass
101 else:
101 else:
102 raise error.ProgrammingError('%s @batchable generator emitted '
102 raise error.ProgrammingError('%s @batchable generator emitted '
103 'unexpected value count' % command)
103 'unexpected value count' % command)
104
104
105 yield finalfuture.value
105 yield finalfuture.value
106
106
107 # Forward a couple of names from peer to make wireproto interactions
107 # Forward a couple of names from peer to make wireproto interactions
108 # slightly more sensible.
108 # slightly more sensible.
109 batchable = peer.batchable
109 batchable = peer.batchable
110 future = peer.future
110 future = peer.future
111
111
112 # list of nodes encoding / decoding
112 # list of nodes encoding / decoding
113
113
114 def decodelist(l, sep=' '):
114 def decodelist(l, sep=' '):
115 if l:
115 if l:
116 return [bin(v) for v in l.split(sep)]
116 return [bin(v) for v in l.split(sep)]
117 return []
117 return []
118
118
119 def encodelist(l, sep=' '):
119 def encodelist(l, sep=' '):
120 try:
120 try:
121 return sep.join(map(hex, l))
121 return sep.join(map(hex, l))
122 except TypeError:
122 except TypeError:
123 raise
123 raise
124
124
125 # batched call argument encoding
125 # batched call argument encoding
126
126
127 def escapearg(plain):
127 def escapearg(plain):
128 return (plain
128 return (plain
129 .replace(':', ':c')
129 .replace(':', ':c')
130 .replace(',', ':o')
130 .replace(',', ':o')
131 .replace(';', ':s')
131 .replace(';', ':s')
132 .replace('=', ':e'))
132 .replace('=', ':e'))
133
133
134 def unescapearg(escaped):
134 def unescapearg(escaped):
135 return (escaped
135 return (escaped
136 .replace(':e', '=')
136 .replace(':e', '=')
137 .replace(':s', ';')
137 .replace(':s', ';')
138 .replace(':o', ',')
138 .replace(':o', ',')
139 .replace(':c', ':'))
139 .replace(':c', ':'))
140
140
141 def encodebatchcmds(req):
141 def encodebatchcmds(req):
142 """Return a ``cmds`` argument value for the ``batch`` command."""
142 """Return a ``cmds`` argument value for the ``batch`` command."""
143 cmds = []
143 cmds = []
144 for op, argsdict in req:
144 for op, argsdict in req:
145 # Old servers didn't properly unescape argument names. So prevent
145 # Old servers didn't properly unescape argument names. So prevent
146 # the sending of argument names that may not be decoded properly by
146 # the sending of argument names that may not be decoded properly by
147 # servers.
147 # servers.
148 assert all(escapearg(k) == k for k in argsdict)
148 assert all(escapearg(k) == k for k in argsdict)
149
149
150 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
150 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
151 for k, v in argsdict.iteritems())
151 for k, v in argsdict.iteritems())
152 cmds.append('%s %s' % (op, args))
152 cmds.append('%s %s' % (op, args))
153
153
154 return ';'.join(cmds)
154 return ';'.join(cmds)
155
155
156 # mapping of options accepted by getbundle and their types
156 # mapping of options accepted by getbundle and their types
157 #
157 #
158 # Meant to be extended by extensions. It is extensions responsibility to ensure
158 # Meant to be extended by extensions. It is extensions responsibility to ensure
159 # such options are properly processed in exchange.getbundle.
159 # such options are properly processed in exchange.getbundle.
160 #
160 #
161 # supported types are:
161 # supported types are:
162 #
162 #
163 # :nodes: list of binary nodes
163 # :nodes: list of binary nodes
164 # :csv: list of comma-separated values
164 # :csv: list of comma-separated values
165 # :scsv: list of comma-separated values return as set
165 # :scsv: list of comma-separated values return as set
166 # :plain: string with no transformation needed.
166 # :plain: string with no transformation needed.
167 gboptsmap = {'heads': 'nodes',
167 gboptsmap = {'heads': 'nodes',
168 'bookmarks': 'boolean',
168 'bookmarks': 'boolean',
169 'common': 'nodes',
169 'common': 'nodes',
170 'obsmarkers': 'boolean',
170 'obsmarkers': 'boolean',
171 'phases': 'boolean',
171 'phases': 'boolean',
172 'bundlecaps': 'scsv',
172 'bundlecaps': 'scsv',
173 'listkeys': 'csv',
173 'listkeys': 'csv',
174 'cg': 'boolean',
174 'cg': 'boolean',
175 'cbattempted': 'boolean',
175 'cbattempted': 'boolean',
176 'stream': 'boolean',
176 'stream': 'boolean',
177 }
177 }
178
178
179 # client side
179 # client side
180
180
181 class wirepeer(repository.legacypeer):
181 class wirepeer(repository.legacypeer):
182 """Client-side interface for communicating with a peer repository.
182 """Client-side interface for communicating with a peer repository.
183
183
184 Methods commonly call wire protocol commands of the same name.
184 Methods commonly call wire protocol commands of the same name.
185
185
186 See also httppeer.py and sshpeer.py for protocol-specific
186 See also httppeer.py and sshpeer.py for protocol-specific
187 implementations of this interface.
187 implementations of this interface.
188 """
188 """
189 # Begin of basewirepeer interface.
189 # Begin of basewirepeer interface.
190
190
191 def iterbatch(self):
191 def iterbatch(self):
192 return remoteiterbatcher(self)
192 return remoteiterbatcher(self)
193
193
194 @batchable
194 @batchable
195 def lookup(self, key):
195 def lookup(self, key):
196 self.requirecap('lookup', _('look up remote revision'))
196 self.requirecap('lookup', _('look up remote revision'))
197 f = future()
197 f = future()
198 yield {'key': encoding.fromlocal(key)}, f
198 yield {'key': encoding.fromlocal(key)}, f
199 d = f.value
199 d = f.value
200 success, data = d[:-1].split(" ", 1)
200 success, data = d[:-1].split(" ", 1)
201 if int(success):
201 if int(success):
202 yield bin(data)
202 yield bin(data)
203 else:
203 else:
204 self._abort(error.RepoError(data))
204 self._abort(error.RepoError(data))
205
205
206 @batchable
206 @batchable
207 def heads(self):
207 def heads(self):
208 f = future()
208 f = future()
209 yield {}, f
209 yield {}, f
210 d = f.value
210 d = f.value
211 try:
211 try:
212 yield decodelist(d[:-1])
212 yield decodelist(d[:-1])
213 except ValueError:
213 except ValueError:
214 self._abort(error.ResponseError(_("unexpected response:"), d))
214 self._abort(error.ResponseError(_("unexpected response:"), d))
215
215
216 @batchable
216 @batchable
217 def known(self, nodes):
217 def known(self, nodes):
218 f = future()
218 f = future()
219 yield {'nodes': encodelist(nodes)}, f
219 yield {'nodes': encodelist(nodes)}, f
220 d = f.value
220 d = f.value
221 try:
221 try:
222 yield [bool(int(b)) for b in d]
222 yield [bool(int(b)) for b in d]
223 except ValueError:
223 except ValueError:
224 self._abort(error.ResponseError(_("unexpected response:"), d))
224 self._abort(error.ResponseError(_("unexpected response:"), d))
225
225
226 @batchable
226 @batchable
227 def branchmap(self):
227 def branchmap(self):
228 f = future()
228 f = future()
229 yield {}, f
229 yield {}, f
230 d = f.value
230 d = f.value
231 try:
231 try:
232 branchmap = {}
232 branchmap = {}
233 for branchpart in d.splitlines():
233 for branchpart in d.splitlines():
234 branchname, branchheads = branchpart.split(' ', 1)
234 branchname, branchheads = branchpart.split(' ', 1)
235 branchname = encoding.tolocal(urlreq.unquote(branchname))
235 branchname = encoding.tolocal(urlreq.unquote(branchname))
236 branchheads = decodelist(branchheads)
236 branchheads = decodelist(branchheads)
237 branchmap[branchname] = branchheads
237 branchmap[branchname] = branchheads
238 yield branchmap
238 yield branchmap
239 except TypeError:
239 except TypeError:
240 self._abort(error.ResponseError(_("unexpected response:"), d))
240 self._abort(error.ResponseError(_("unexpected response:"), d))
241
241
242 @batchable
242 @batchable
243 def listkeys(self, namespace):
243 def listkeys(self, namespace):
244 if not self.capable('pushkey'):
244 if not self.capable('pushkey'):
245 yield {}, None
245 yield {}, None
246 f = future()
246 f = future()
247 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
247 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
248 yield {'namespace': encoding.fromlocal(namespace)}, f
248 yield {'namespace': encoding.fromlocal(namespace)}, f
249 d = f.value
249 d = f.value
250 self.ui.debug('received listkey for "%s": %i bytes\n'
250 self.ui.debug('received listkey for "%s": %i bytes\n'
251 % (namespace, len(d)))
251 % (namespace, len(d)))
252 yield pushkeymod.decodekeys(d)
252 yield pushkeymod.decodekeys(d)
253
253
254 @batchable
254 @batchable
255 def pushkey(self, namespace, key, old, new):
255 def pushkey(self, namespace, key, old, new):
256 if not self.capable('pushkey'):
256 if not self.capable('pushkey'):
257 yield False, None
257 yield False, None
258 f = future()
258 f = future()
259 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
259 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
260 yield {'namespace': encoding.fromlocal(namespace),
260 yield {'namespace': encoding.fromlocal(namespace),
261 'key': encoding.fromlocal(key),
261 'key': encoding.fromlocal(key),
262 'old': encoding.fromlocal(old),
262 'old': encoding.fromlocal(old),
263 'new': encoding.fromlocal(new)}, f
263 'new': encoding.fromlocal(new)}, f
264 d = f.value
264 d = f.value
265 d, output = d.split('\n', 1)
265 d, output = d.split('\n', 1)
266 try:
266 try:
267 d = bool(int(d))
267 d = bool(int(d))
268 except ValueError:
268 except ValueError:
269 raise error.ResponseError(
269 raise error.ResponseError(
270 _('push failed (unexpected response):'), d)
270 _('push failed (unexpected response):'), d)
271 for l in output.splitlines(True):
271 for l in output.splitlines(True):
272 self.ui.status(_('remote: '), l)
272 self.ui.status(_('remote: '), l)
273 yield d
273 yield d
274
274
275 def stream_out(self):
275 def stream_out(self):
276 return self._callstream('stream_out')
276 return self._callstream('stream_out')
277
277
278 def getbundle(self, source, **kwargs):
278 def getbundle(self, source, **kwargs):
279 kwargs = pycompat.byteskwargs(kwargs)
279 kwargs = pycompat.byteskwargs(kwargs)
280 self.requirecap('getbundle', _('look up remote changes'))
280 self.requirecap('getbundle', _('look up remote changes'))
281 opts = {}
281 opts = {}
282 bundlecaps = kwargs.get('bundlecaps')
282 bundlecaps = kwargs.get('bundlecaps')
283 if bundlecaps is not None:
283 if bundlecaps is not None:
284 kwargs['bundlecaps'] = sorted(bundlecaps)
284 kwargs['bundlecaps'] = sorted(bundlecaps)
285 else:
285 else:
286 bundlecaps = () # kwargs could have it to None
286 bundlecaps = () # kwargs could have it to None
287 for key, value in kwargs.iteritems():
287 for key, value in kwargs.iteritems():
288 if value is None:
288 if value is None:
289 continue
289 continue
290 keytype = gboptsmap.get(key)
290 keytype = gboptsmap.get(key)
291 if keytype is None:
291 if keytype is None:
292 raise error.ProgrammingError(
292 raise error.ProgrammingError(
293 'Unexpectedly None keytype for key %s' % key)
293 'Unexpectedly None keytype for key %s' % key)
294 elif keytype == 'nodes':
294 elif keytype == 'nodes':
295 value = encodelist(value)
295 value = encodelist(value)
296 elif keytype in ('csv', 'scsv'):
296 elif keytype in ('csv', 'scsv'):
297 value = ','.join(value)
297 value = ','.join(value)
298 elif keytype == 'boolean':
298 elif keytype == 'boolean':
299 value = '%i' % bool(value)
299 value = '%i' % bool(value)
300 elif keytype != 'plain':
300 elif keytype != 'plain':
301 raise KeyError('unknown getbundle option type %s'
301 raise KeyError('unknown getbundle option type %s'
302 % keytype)
302 % keytype)
303 opts[key] = value
303 opts[key] = value
304 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
304 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
305 if any((cap.startswith('HG2') for cap in bundlecaps)):
305 if any((cap.startswith('HG2') for cap in bundlecaps)):
306 return bundle2.getunbundler(self.ui, f)
306 return bundle2.getunbundler(self.ui, f)
307 else:
307 else:
308 return changegroupmod.cg1unpacker(f, 'UN')
308 return changegroupmod.cg1unpacker(f, 'UN')
309
309
310 def unbundle(self, cg, heads, url):
310 def unbundle(self, cg, heads, url):
311 '''Send cg (a readable file-like object representing the
311 '''Send cg (a readable file-like object representing the
312 changegroup to push, typically a chunkbuffer object) to the
312 changegroup to push, typically a chunkbuffer object) to the
313 remote server as a bundle.
313 remote server as a bundle.
314
314
315 When pushing a bundle10 stream, return an integer indicating the
315 When pushing a bundle10 stream, return an integer indicating the
316 result of the push (see changegroup.apply()).
316 result of the push (see changegroup.apply()).
317
317
318 When pushing a bundle20 stream, return a bundle20 stream.
318 When pushing a bundle20 stream, return a bundle20 stream.
319
319
320 `url` is the url the client thinks it's pushing to, which is
320 `url` is the url the client thinks it's pushing to, which is
321 visible to hooks.
321 visible to hooks.
322 '''
322 '''
323
323
324 if heads != ['force'] and self.capable('unbundlehash'):
324 if heads != ['force'] and self.capable('unbundlehash'):
325 heads = encodelist(['hashed',
325 heads = encodelist(['hashed',
326 hashlib.sha1(''.join(sorted(heads))).digest()])
326 hashlib.sha1(''.join(sorted(heads))).digest()])
327 else:
327 else:
328 heads = encodelist(heads)
328 heads = encodelist(heads)
329
329
330 if util.safehasattr(cg, 'deltaheader'):
330 if util.safehasattr(cg, 'deltaheader'):
331 # this a bundle10, do the old style call sequence
331 # this a bundle10, do the old style call sequence
332 ret, output = self._callpush("unbundle", cg, heads=heads)
332 ret, output = self._callpush("unbundle", cg, heads=heads)
333 if ret == "":
333 if ret == "":
334 raise error.ResponseError(
334 raise error.ResponseError(
335 _('push failed:'), output)
335 _('push failed:'), output)
336 try:
336 try:
337 ret = int(ret)
337 ret = int(ret)
338 except ValueError:
338 except ValueError:
339 raise error.ResponseError(
339 raise error.ResponseError(
340 _('push failed (unexpected response):'), ret)
340 _('push failed (unexpected response):'), ret)
341
341
342 for l in output.splitlines(True):
342 for l in output.splitlines(True):
343 self.ui.status(_('remote: '), l)
343 self.ui.status(_('remote: '), l)
344 else:
344 else:
345 # bundle2 push. Send a stream, fetch a stream.
345 # bundle2 push. Send a stream, fetch a stream.
346 stream = self._calltwowaystream('unbundle', cg, heads=heads)
346 stream = self._calltwowaystream('unbundle', cg, heads=heads)
347 ret = bundle2.getunbundler(self.ui, stream)
347 ret = bundle2.getunbundler(self.ui, stream)
348 return ret
348 return ret
349
349
350 # End of basewirepeer interface.
350 # End of basewirepeer interface.
351
351
352 # Begin of baselegacywirepeer interface.
352 # Begin of baselegacywirepeer interface.
353
353
354 def branches(self, nodes):
354 def branches(self, nodes):
355 n = encodelist(nodes)
355 n = encodelist(nodes)
356 d = self._call("branches", nodes=n)
356 d = self._call("branches", nodes=n)
357 try:
357 try:
358 br = [tuple(decodelist(b)) for b in d.splitlines()]
358 br = [tuple(decodelist(b)) for b in d.splitlines()]
359 return br
359 return br
360 except ValueError:
360 except ValueError:
361 self._abort(error.ResponseError(_("unexpected response:"), d))
361 self._abort(error.ResponseError(_("unexpected response:"), d))
362
362
363 def between(self, pairs):
363 def between(self, pairs):
364 batch = 8 # avoid giant requests
364 batch = 8 # avoid giant requests
365 r = []
365 r = []
366 for i in xrange(0, len(pairs), batch):
366 for i in xrange(0, len(pairs), batch):
367 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
367 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
368 d = self._call("between", pairs=n)
368 d = self._call("between", pairs=n)
369 try:
369 try:
370 r.extend(l and decodelist(l) or [] for l in d.splitlines())
370 r.extend(l and decodelist(l) or [] for l in d.splitlines())
371 except ValueError:
371 except ValueError:
372 self._abort(error.ResponseError(_("unexpected response:"), d))
372 self._abort(error.ResponseError(_("unexpected response:"), d))
373 return r
373 return r
374
374
375 def changegroup(self, nodes, kind):
375 def changegroup(self, nodes, kind):
376 n = encodelist(nodes)
376 n = encodelist(nodes)
377 f = self._callcompressable("changegroup", roots=n)
377 f = self._callcompressable("changegroup", roots=n)
378 return changegroupmod.cg1unpacker(f, 'UN')
378 return changegroupmod.cg1unpacker(f, 'UN')
379
379
380 def changegroupsubset(self, bases, heads, kind):
380 def changegroupsubset(self, bases, heads, kind):
381 self.requirecap('changegroupsubset', _('look up remote changes'))
381 self.requirecap('changegroupsubset', _('look up remote changes'))
382 bases = encodelist(bases)
382 bases = encodelist(bases)
383 heads = encodelist(heads)
383 heads = encodelist(heads)
384 f = self._callcompressable("changegroupsubset",
384 f = self._callcompressable("changegroupsubset",
385 bases=bases, heads=heads)
385 bases=bases, heads=heads)
386 return changegroupmod.cg1unpacker(f, 'UN')
386 return changegroupmod.cg1unpacker(f, 'UN')
387
387
388 # End of baselegacywirepeer interface.
388 # End of baselegacywirepeer interface.
389
389
390 def _submitbatch(self, req):
390 def _submitbatch(self, req):
391 """run batch request <req> on the server
391 """run batch request <req> on the server
392
392
393 Returns an iterator of the raw responses from the server.
393 Returns an iterator of the raw responses from the server.
394 """
394 """
395 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
395 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
396 chunk = rsp.read(1024)
396 chunk = rsp.read(1024)
397 work = [chunk]
397 work = [chunk]
398 while chunk:
398 while chunk:
399 while ';' not in chunk and chunk:
399 while ';' not in chunk and chunk:
400 chunk = rsp.read(1024)
400 chunk = rsp.read(1024)
401 work.append(chunk)
401 work.append(chunk)
402 merged = ''.join(work)
402 merged = ''.join(work)
403 while ';' in merged:
403 while ';' in merged:
404 one, merged = merged.split(';', 1)
404 one, merged = merged.split(';', 1)
405 yield unescapearg(one)
405 yield unescapearg(one)
406 chunk = rsp.read(1024)
406 chunk = rsp.read(1024)
407 work = [merged, chunk]
407 work = [merged, chunk]
408 yield unescapearg(''.join(work))
408 yield unescapearg(''.join(work))
409
409
410 def _submitone(self, op, args):
410 def _submitone(self, op, args):
411 return self._call(op, **pycompat.strkwargs(args))
411 return self._call(op, **pycompat.strkwargs(args))
412
412
413 def debugwireargs(self, one, two, three=None, four=None, five=None):
413 def debugwireargs(self, one, two, three=None, four=None, five=None):
414 # don't pass optional arguments left at their default value
414 # don't pass optional arguments left at their default value
415 opts = {}
415 opts = {}
416 if three is not None:
416 if three is not None:
417 opts[r'three'] = three
417 opts[r'three'] = three
418 if four is not None:
418 if four is not None:
419 opts[r'four'] = four
419 opts[r'four'] = four
420 return self._call('debugwireargs', one=one, two=two, **opts)
420 return self._call('debugwireargs', one=one, two=two, **opts)
421
421
422 def _call(self, cmd, **args):
422 def _call(self, cmd, **args):
423 """execute <cmd> on the server
423 """execute <cmd> on the server
424
424
425 The command is expected to return a simple string.
425 The command is expected to return a simple string.
426
426
427 returns the server reply as a string."""
427 returns the server reply as a string."""
428 raise NotImplementedError()
428 raise NotImplementedError()
429
429
430 def _callstream(self, cmd, **args):
430 def _callstream(self, cmd, **args):
431 """execute <cmd> on the server
431 """execute <cmd> on the server
432
432
433 The command is expected to return a stream. Note that if the
433 The command is expected to return a stream. Note that if the
434 command doesn't return a stream, _callstream behaves
434 command doesn't return a stream, _callstream behaves
435 differently for ssh and http peers.
435 differently for ssh and http peers.
436
436
437 returns the server reply as a file like object.
437 returns the server reply as a file like object.
438 """
438 """
439 raise NotImplementedError()
439 raise NotImplementedError()
440
440
441 def _callcompressable(self, cmd, **args):
441 def _callcompressable(self, cmd, **args):
442 """execute <cmd> on the server
442 """execute <cmd> on the server
443
443
444 The command is expected to return a stream.
444 The command is expected to return a stream.
445
445
446 The stream may have been compressed in some implementations. This
446 The stream may have been compressed in some implementations. This
447 function takes care of the decompression. This is the only difference
447 function takes care of the decompression. This is the only difference
448 with _callstream.
448 with _callstream.
449
449
450 returns the server reply as a file like object.
450 returns the server reply as a file like object.
451 """
451 """
452 raise NotImplementedError()
452 raise NotImplementedError()
453
453
454 def _callpush(self, cmd, fp, **args):
454 def _callpush(self, cmd, fp, **args):
455 """execute a <cmd> on server
455 """execute a <cmd> on server
456
456
457 The command is expected to be related to a push. Push has a special
457 The command is expected to be related to a push. Push has a special
458 return method.
458 return method.
459
459
460 returns the server reply as a (ret, output) tuple. ret is either
460 returns the server reply as a (ret, output) tuple. ret is either
461 empty (error) or a stringified int.
461 empty (error) or a stringified int.
462 """
462 """
463 raise NotImplementedError()
463 raise NotImplementedError()
464
464
465 def _calltwowaystream(self, cmd, fp, **args):
465 def _calltwowaystream(self, cmd, fp, **args):
466 """execute <cmd> on server
466 """execute <cmd> on server
467
467
468 The command will send a stream to the server and get a stream in reply.
468 The command will send a stream to the server and get a stream in reply.
469 """
469 """
470 raise NotImplementedError()
470 raise NotImplementedError()
471
471
472 def _abort(self, exception):
472 def _abort(self, exception):
473 """clearly abort the wire protocol connection and raise the exception
473 """clearly abort the wire protocol connection and raise the exception
474 """
474 """
475 raise NotImplementedError()
475 raise NotImplementedError()
476
476
477 # server side
477 # server side
478
478
479 # wire protocol command can either return a string or one of these classes.
479 # wire protocol command can either return a string or one of these classes.
480 class streamres(object):
480 class streamres(object):
481 """wireproto reply: binary stream
481 """wireproto reply: binary stream
482
482
483 The call was successful and the result is a stream.
483 The call was successful and the result is a stream.
484
484
485 Accepts a generator containing chunks of data to be sent to the client.
485 Accepts a generator containing chunks of data to be sent to the client.
486
486
487 ``prefer_uncompressed`` indicates that the data is expected to be
487 ``prefer_uncompressed`` indicates that the data is expected to be
488 uncompressable and that the stream should therefore use the ``none``
488 uncompressable and that the stream should therefore use the ``none``
489 engine.
489 engine.
490 """
490 """
491 def __init__(self, gen=None, prefer_uncompressed=False):
491 def __init__(self, gen=None, prefer_uncompressed=False):
492 self.gen = gen
492 self.gen = gen
493 self.prefer_uncompressed = prefer_uncompressed
493 self.prefer_uncompressed = prefer_uncompressed
494
494
495 class streamres_legacy(object):
495 class streamres_legacy(object):
496 """wireproto reply: uncompressed binary stream
496 """wireproto reply: uncompressed binary stream
497
497
498 The call was successful and the result is a stream.
498 The call was successful and the result is a stream.
499
499
500 Accepts a generator containing chunks of data to be sent to the client.
500 Accepts a generator containing chunks of data to be sent to the client.
501
501
502 Like ``streamres``, but sends an uncompressed data for "version 1" clients
502 Like ``streamres``, but sends an uncompressed data for "version 1" clients
503 using the application/mercurial-0.1 media type.
503 using the application/mercurial-0.1 media type.
504 """
504 """
505 def __init__(self, gen=None):
505 def __init__(self, gen=None):
506 self.gen = gen
506 self.gen = gen
507
507
508 class pushres(object):
508 class pushres(object):
509 """wireproto reply: success with simple integer return
509 """wireproto reply: success with simple integer return
510
510
511 The call was successful and returned an integer contained in `self.res`.
511 The call was successful and returned an integer contained in `self.res`.
512 """
512 """
513 def __init__(self, res):
513 def __init__(self, res):
514 self.res = res
514 self.res = res
515
515
516 class pusherr(object):
516 class pusherr(object):
517 """wireproto reply: failure
517 """wireproto reply: failure
518
518
519 The call failed. The `self.res` attribute contains the error message.
519 The call failed. The `self.res` attribute contains the error message.
520 """
520 """
521 def __init__(self, res):
521 def __init__(self, res):
522 self.res = res
522 self.res = res
523
523
524 class ooberror(object):
524 class ooberror(object):
525 """wireproto reply: failure of a batch of operation
525 """wireproto reply: failure of a batch of operation
526
526
527 Something failed during a batch call. The error message is stored in
527 Something failed during a batch call. The error message is stored in
528 `self.message`.
528 `self.message`.
529 """
529 """
530 def __init__(self, message):
530 def __init__(self, message):
531 self.message = message
531 self.message = message
532
532
533 def getdispatchrepo(repo, proto, command):
533 def getdispatchrepo(repo, proto, command):
534 """Obtain the repo used for processing wire protocol commands.
534 """Obtain the repo used for processing wire protocol commands.
535
535
536 The intent of this function is to serve as a monkeypatch point for
536 The intent of this function is to serve as a monkeypatch point for
537 extensions that need commands to operate on different repo views under
537 extensions that need commands to operate on different repo views under
538 specialized circumstances.
538 specialized circumstances.
539 """
539 """
540 return repo.filtered('served')
540 return repo.filtered('served')
541
541
542 def dispatch(repo, proto, command):
542 def dispatch(repo, proto, command):
543 repo = getdispatchrepo(repo, proto, command)
543 repo = getdispatchrepo(repo, proto, command)
544 func, spec = commands[command]
544 func, spec = commands[command]
545 args = proto.getargs(spec)
545 args = proto.getargs(spec)
546 return func(repo, proto, *args)
546 return func(repo, proto, *args)
547
547
548 def options(cmd, keys, others):
548 def options(cmd, keys, others):
549 opts = {}
549 opts = {}
550 for k in keys:
550 for k in keys:
551 if k in others:
551 if k in others:
552 opts[k] = others[k]
552 opts[k] = others[k]
553 del others[k]
553 del others[k]
554 if others:
554 if others:
555 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
555 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
556 % (cmd, ",".join(others)))
556 % (cmd, ",".join(others)))
557 return opts
557 return opts
558
558
559 def bundle1allowed(repo, action):
559 def bundle1allowed(repo, action):
560 """Whether a bundle1 operation is allowed from the server.
560 """Whether a bundle1 operation is allowed from the server.
561
561
562 Priority is:
562 Priority is:
563
563
564 1. server.bundle1gd.<action> (if generaldelta active)
564 1. server.bundle1gd.<action> (if generaldelta active)
565 2. server.bundle1.<action>
565 2. server.bundle1.<action>
566 3. server.bundle1gd (if generaldelta active)
566 3. server.bundle1gd (if generaldelta active)
567 4. server.bundle1
567 4. server.bundle1
568 """
568 """
569 ui = repo.ui
569 ui = repo.ui
570 gd = 'generaldelta' in repo.requirements
570 gd = 'generaldelta' in repo.requirements
571
571
572 if gd:
572 if gd:
573 v = ui.configbool('server', 'bundle1gd.%s' % action)
573 v = ui.configbool('server', 'bundle1gd.%s' % action)
574 if v is not None:
574 if v is not None:
575 return v
575 return v
576
576
577 v = ui.configbool('server', 'bundle1.%s' % action)
577 v = ui.configbool('server', 'bundle1.%s' % action)
578 if v is not None:
578 if v is not None:
579 return v
579 return v
580
580
581 if gd:
581 if gd:
582 v = ui.configbool('server', 'bundle1gd')
582 v = ui.configbool('server', 'bundle1gd')
583 if v is not None:
583 if v is not None:
584 return v
584 return v
585
585
586 return ui.configbool('server', 'bundle1')
586 return ui.configbool('server', 'bundle1')
587
587
588 def supportedcompengines(ui, proto, role):
588 def supportedcompengines(ui, proto, role):
589 """Obtain the list of supported compression engines for a request."""
589 """Obtain the list of supported compression engines for a request."""
590 assert role in (util.CLIENTROLE, util.SERVERROLE)
590 assert role in (util.CLIENTROLE, util.SERVERROLE)
591
591
592 compengines = util.compengines.supportedwireengines(role)
592 compengines = util.compengines.supportedwireengines(role)
593
593
594 # Allow config to override default list and ordering.
594 # Allow config to override default list and ordering.
595 if role == util.SERVERROLE:
595 if role == util.SERVERROLE:
596 configengines = ui.configlist('server', 'compressionengines')
596 configengines = ui.configlist('server', 'compressionengines')
597 config = 'server.compressionengines'
597 config = 'server.compressionengines'
598 else:
598 else:
599 # This is currently implemented mainly to facilitate testing. In most
599 # This is currently implemented mainly to facilitate testing. In most
600 # cases, the server should be in charge of choosing a compression engine
600 # cases, the server should be in charge of choosing a compression engine
601 # because a server has the most to lose from a sub-optimal choice. (e.g.
601 # because a server has the most to lose from a sub-optimal choice. (e.g.
602 # CPU DoS due to an expensive engine or a network DoS due to poor
602 # CPU DoS due to an expensive engine or a network DoS due to poor
603 # compression ratio).
603 # compression ratio).
604 configengines = ui.configlist('experimental',
604 configengines = ui.configlist('experimental',
605 'clientcompressionengines')
605 'clientcompressionengines')
606 config = 'experimental.clientcompressionengines'
606 config = 'experimental.clientcompressionengines'
607
607
608 # No explicit config. Filter out the ones that aren't supposed to be
608 # No explicit config. Filter out the ones that aren't supposed to be
609 # advertised and return default ordering.
609 # advertised and return default ordering.
610 if not configengines:
610 if not configengines:
611 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
611 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
612 return [e for e in compengines
612 return [e for e in compengines
613 if getattr(e.wireprotosupport(), attr) > 0]
613 if getattr(e.wireprotosupport(), attr) > 0]
614
614
615 # If compression engines are listed in the config, assume there is a good
615 # If compression engines are listed in the config, assume there is a good
616 # reason for it (like server operators wanting to achieve specific
616 # reason for it (like server operators wanting to achieve specific
617 # performance characteristics). So fail fast if the config references
617 # performance characteristics). So fail fast if the config references
618 # unusable compression engines.
618 # unusable compression engines.
619 validnames = set(e.name() for e in compengines)
619 validnames = set(e.name() for e in compengines)
620 invalidnames = set(e for e in configengines if e not in validnames)
620 invalidnames = set(e for e in configengines if e not in validnames)
621 if invalidnames:
621 if invalidnames:
622 raise error.Abort(_('invalid compression engine defined in %s: %s') %
622 raise error.Abort(_('invalid compression engine defined in %s: %s') %
623 (config, ', '.join(sorted(invalidnames))))
623 (config, ', '.join(sorted(invalidnames))))
624
624
625 compengines = [e for e in compengines if e.name() in configengines]
625 compengines = [e for e in compengines if e.name() in configengines]
626 compengines = sorted(compengines,
626 compengines = sorted(compengines,
627 key=lambda e: configengines.index(e.name()))
627 key=lambda e: configengines.index(e.name()))
628
628
629 if not compengines:
629 if not compengines:
630 raise error.Abort(_('%s config option does not specify any known '
630 raise error.Abort(_('%s config option does not specify any known '
631 'compression engines') % config,
631 'compression engines') % config,
632 hint=_('usable compression engines: %s') %
632 hint=_('usable compression engines: %s') %
633 ', '.sorted(validnames))
633 ', '.sorted(validnames))
634
634
635 return compengines
635 return compengines
636
636
637 class commandentry(object):
637 class commandentry(object):
638 """Represents a declared wire protocol command."""
638 """Represents a declared wire protocol command."""
639 def __init__(self, func, args=''):
639 def __init__(self, func, args=''):
640 self.func = func
640 self.func = func
641 self.args = args
641 self.args = args
642
642
643 def _merge(self, func, args):
643 def _merge(self, func, args):
644 """Merge this instance with an incoming 2-tuple.
644 """Merge this instance with an incoming 2-tuple.
645
645
646 This is called when a caller using the old 2-tuple API attempts
646 This is called when a caller using the old 2-tuple API attempts
647 to replace an instance. The incoming values are merged with
647 to replace an instance. The incoming values are merged with
648 data not captured by the 2-tuple and a new instance containing
648 data not captured by the 2-tuple and a new instance containing
649 the union of the two objects is returned.
649 the union of the two objects is returned.
650 """
650 """
651 return commandentry(func, args)
651 return commandentry(func, args)
652
652
653 # Old code treats instances as 2-tuples. So expose that interface.
653 # Old code treats instances as 2-tuples. So expose that interface.
654 def __iter__(self):
654 def __iter__(self):
655 yield self.func
655 yield self.func
656 yield self.args
656 yield self.args
657
657
658 def __getitem__(self, i):
658 def __getitem__(self, i):
659 if i == 0:
659 if i == 0:
660 return self.func
660 return self.func
661 elif i == 1:
661 elif i == 1:
662 return self.args
662 return self.args
663 else:
663 else:
664 raise IndexError('can only access elements 0 and 1')
664 raise IndexError('can only access elements 0 and 1')
665
665
666 class commanddict(dict):
666 class commanddict(dict):
667 """Container for registered wire protocol commands.
667 """Container for registered wire protocol commands.
668
668
669 It behaves like a dict. But __setitem__ is overwritten to allow silent
669 It behaves like a dict. But __setitem__ is overwritten to allow silent
670 coercion of values from 2-tuples for API compatibility.
670 coercion of values from 2-tuples for API compatibility.
671 """
671 """
672 def __setitem__(self, k, v):
672 def __setitem__(self, k, v):
673 if isinstance(v, commandentry):
673 if isinstance(v, commandentry):
674 pass
674 pass
675 # Cast 2-tuples to commandentry instances.
675 # Cast 2-tuples to commandentry instances.
676 elif isinstance(v, tuple):
676 elif isinstance(v, tuple):
677 if len(v) != 2:
677 if len(v) != 2:
678 raise ValueError('command tuples must have exactly 2 elements')
678 raise ValueError('command tuples must have exactly 2 elements')
679
679
680 # It is common for extensions to wrap wire protocol commands via
680 # It is common for extensions to wrap wire protocol commands via
681 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
681 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
682 # doing this aren't aware of the new API that uses objects to store
682 # doing this aren't aware of the new API that uses objects to store
683 # command entries, we automatically merge old state with new.
683 # command entries, we automatically merge old state with new.
684 if k in self:
684 if k in self:
685 v = self[k]._merge(v[0], v[1])
685 v = self[k]._merge(v[0], v[1])
686 else:
686 else:
687 v = commandentry(v[0], v[1])
687 v = commandentry(v[0], v[1])
688 else:
688 else:
689 raise ValueError('command entries must be commandentry instances '
689 raise ValueError('command entries must be commandentry instances '
690 'or 2-tuples')
690 'or 2-tuples')
691
691
692 return super(commanddict, self).__setitem__(k, v)
692 return super(commanddict, self).__setitem__(k, v)
693
693
694 def commandavailable(self, command, proto):
694 def commandavailable(self, command, proto):
695 """Determine if a command is available for the requested protocol."""
695 """Determine if a command is available for the requested protocol."""
696 # For now, commands are available for all protocols. So do a simple
696 # For now, commands are available for all protocols. So do a simple
697 # membership test.
697 # membership test.
698 return command in self
698 return command in self
699
699
700 commands = commanddict()
700 commands = commanddict()
701
701
702 def wireprotocommand(name, args=''):
702 def wireprotocommand(name, args=''):
703 """Decorator to declare a wire protocol command.
703 """Decorator to declare a wire protocol command.
704
704
705 ``name`` is the name of the wire protocol command being provided.
705 ``name`` is the name of the wire protocol command being provided.
706
706
707 ``args`` is a space-delimited list of named arguments that the command
707 ``args`` is a space-delimited list of named arguments that the command
708 accepts. ``*`` is a special value that says to accept all arguments.
708 accepts. ``*`` is a special value that says to accept all arguments.
709 """
709 """
710 def register(func):
710 def register(func):
711 commands[name] = commandentry(func, args)
711 commands[name] = commandentry(func, args)
712 return func
712 return func
713 return register
713 return register
714
714
715 @wireprotocommand('batch', 'cmds *')
715 @wireprotocommand('batch', 'cmds *')
716 def batch(repo, proto, cmds, others):
716 def batch(repo, proto, cmds, others):
717 repo = repo.filtered("served")
717 repo = repo.filtered("served")
718 res = []
718 res = []
719 for pair in cmds.split(';'):
719 for pair in cmds.split(';'):
720 op, args = pair.split(' ', 1)
720 op, args = pair.split(' ', 1)
721 vals = {}
721 vals = {}
722 for a in args.split(','):
722 for a in args.split(','):
723 if a:
723 if a:
724 n, v = a.split('=')
724 n, v = a.split('=')
725 vals[unescapearg(n)] = unescapearg(v)
725 vals[unescapearg(n)] = unescapearg(v)
726 func, spec = commands[op]
726 func, spec = commands[op]
727 if spec:
727 if spec:
728 keys = spec.split()
728 keys = spec.split()
729 data = {}
729 data = {}
730 for k in keys:
730 for k in keys:
731 if k == '*':
731 if k == '*':
732 star = {}
732 star = {}
733 for key in vals.keys():
733 for key in vals.keys():
734 if key not in keys:
734 if key not in keys:
735 star[key] = vals[key]
735 star[key] = vals[key]
736 data['*'] = star
736 data['*'] = star
737 else:
737 else:
738 data[k] = vals[k]
738 data[k] = vals[k]
739 result = func(repo, proto, *[data[k] for k in keys])
739 result = func(repo, proto, *[data[k] for k in keys])
740 else:
740 else:
741 result = func(repo, proto)
741 result = func(repo, proto)
742 if isinstance(result, ooberror):
742 if isinstance(result, ooberror):
743 return result
743 return result
744 res.append(escapearg(result))
744 res.append(escapearg(result))
745 return ';'.join(res)
745 return ';'.join(res)
746
746
747 @wireprotocommand('between', 'pairs')
747 @wireprotocommand('between', 'pairs')
748 def between(repo, proto, pairs):
748 def between(repo, proto, pairs):
749 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
749 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
750 r = []
750 r = []
751 for b in repo.between(pairs):
751 for b in repo.between(pairs):
752 r.append(encodelist(b) + "\n")
752 r.append(encodelist(b) + "\n")
753 return "".join(r)
753 return "".join(r)
754
754
755 @wireprotocommand('branchmap')
755 @wireprotocommand('branchmap')
756 def branchmap(repo, proto):
756 def branchmap(repo, proto):
757 branchmap = repo.branchmap()
757 branchmap = repo.branchmap()
758 heads = []
758 heads = []
759 for branch, nodes in branchmap.iteritems():
759 for branch, nodes in branchmap.iteritems():
760 branchname = urlreq.quote(encoding.fromlocal(branch))
760 branchname = urlreq.quote(encoding.fromlocal(branch))
761 branchnodes = encodelist(nodes)
761 branchnodes = encodelist(nodes)
762 heads.append('%s %s' % (branchname, branchnodes))
762 heads.append('%s %s' % (branchname, branchnodes))
763 return '\n'.join(heads)
763 return '\n'.join(heads)
764
764
765 @wireprotocommand('branches', 'nodes')
765 @wireprotocommand('branches', 'nodes')
766 def branches(repo, proto, nodes):
766 def branches(repo, proto, nodes):
767 nodes = decodelist(nodes)
767 nodes = decodelist(nodes)
768 r = []
768 r = []
769 for b in repo.branches(nodes):
769 for b in repo.branches(nodes):
770 r.append(encodelist(b) + "\n")
770 r.append(encodelist(b) + "\n")
771 return "".join(r)
771 return "".join(r)
772
772
773 @wireprotocommand('clonebundles', '')
773 @wireprotocommand('clonebundles', '')
774 def clonebundles(repo, proto):
774 def clonebundles(repo, proto):
775 """Server command for returning info for available bundles to seed clones.
775 """Server command for returning info for available bundles to seed clones.
776
776
777 Clients will parse this response and determine what bundle to fetch.
777 Clients will parse this response and determine what bundle to fetch.
778
778
779 Extensions may wrap this command to filter or dynamically emit data
779 Extensions may wrap this command to filter or dynamically emit data
780 depending on the request. e.g. you could advertise URLs for the closest
780 depending on the request. e.g. you could advertise URLs for the closest
781 data center given the client's IP address.
781 data center given the client's IP address.
782 """
782 """
783 return repo.vfs.tryread('clonebundles.manifest')
783 return repo.vfs.tryread('clonebundles.manifest')
784
784
785 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
785 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
786 'known', 'getbundle', 'unbundlehash', 'batch']
786 'known', 'getbundle', 'unbundlehash', 'batch']
787
787
788 def _capabilities(repo, proto):
788 def _capabilities(repo, proto):
789 """return a list of capabilities for a repo
789 """return a list of capabilities for a repo
790
790
791 This function exists to allow extensions to easily wrap capabilities
791 This function exists to allow extensions to easily wrap capabilities
792 computation
792 computation
793
793
794 - returns a lists: easy to alter
794 - returns a lists: easy to alter
795 - change done here will be propagated to both `capabilities` and `hello`
795 - change done here will be propagated to both `capabilities` and `hello`
796 command without any other action needed.
796 command without any other action needed.
797 """
797 """
798 # copy to prevent modification of the global list
798 # copy to prevent modification of the global list
799 caps = list(wireprotocaps)
799 caps = list(wireprotocaps)
800 if streamclone.allowservergeneration(repo):
800 if streamclone.allowservergeneration(repo):
801 if repo.ui.configbool('server', 'preferuncompressed'):
801 if repo.ui.configbool('server', 'preferuncompressed'):
802 caps.append('stream-preferred')
802 caps.append('stream-preferred')
803 requiredformats = repo.requirements & repo.supportedformats
803 requiredformats = repo.requirements & repo.supportedformats
804 # if our local revlogs are just revlogv1, add 'stream' cap
804 # if our local revlogs are just revlogv1, add 'stream' cap
805 if not requiredformats - {'revlogv1'}:
805 if not requiredformats - {'revlogv1'}:
806 caps.append('stream')
806 caps.append('stream')
807 # otherwise, add 'streamreqs' detailing our local revlog format
807 # otherwise, add 'streamreqs' detailing our local revlog format
808 else:
808 else:
809 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
809 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
810 if repo.ui.configbool('experimental', 'bundle2-advertise'):
810 if repo.ui.configbool('experimental', 'bundle2-advertise'):
811 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
811 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
812 caps.append('bundle2=' + urlreq.quote(capsblob))
812 caps.append('bundle2=' + urlreq.quote(capsblob))
813 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
813 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
814
814
815 if proto.name == 'http':
815 if proto.name == 'http':
816 caps.append('httpheader=%d' %
816 caps.append('httpheader=%d' %
817 repo.ui.configint('server', 'maxhttpheaderlen'))
817 repo.ui.configint('server', 'maxhttpheaderlen'))
818 if repo.ui.configbool('experimental', 'httppostargs'):
818 if repo.ui.configbool('experimental', 'httppostargs'):
819 caps.append('httppostargs')
819 caps.append('httppostargs')
820
820
821 # FUTURE advertise 0.2rx once support is implemented
821 # FUTURE advertise 0.2rx once support is implemented
822 # FUTURE advertise minrx and mintx after consulting config option
822 # FUTURE advertise minrx and mintx after consulting config option
823 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
823 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
824
824
825 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
825 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
826 if compengines:
826 if compengines:
827 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
827 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
828 for e in compengines)
828 for e in compengines)
829 caps.append('compression=%s' % comptypes)
829 caps.append('compression=%s' % comptypes)
830
830
831 return caps
831 return caps
832
832
833 # If you are writing an extension and consider wrapping this function. Wrap
833 # If you are writing an extension and consider wrapping this function. Wrap
834 # `_capabilities` instead.
834 # `_capabilities` instead.
835 @wireprotocommand('capabilities')
835 @wireprotocommand('capabilities')
836 def capabilities(repo, proto):
836 def capabilities(repo, proto):
837 return ' '.join(_capabilities(repo, proto))
837 return ' '.join(_capabilities(repo, proto))
838
838
839 @wireprotocommand('changegroup', 'roots')
839 @wireprotocommand('changegroup', 'roots')
840 def changegroup(repo, proto, roots):
840 def changegroup(repo, proto, roots):
841 nodes = decodelist(roots)
841 nodes = decodelist(roots)
842 outgoing = discovery.outgoing(repo, missingroots=nodes,
842 outgoing = discovery.outgoing(repo, missingroots=nodes,
843 missingheads=repo.heads())
843 missingheads=repo.heads())
844 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
844 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
845 gen = iter(lambda: cg.read(32768), '')
845 gen = iter(lambda: cg.read(32768), '')
846 return streamres(gen=gen)
846 return streamres(gen=gen)
847
847
848 @wireprotocommand('changegroupsubset', 'bases heads')
848 @wireprotocommand('changegroupsubset', 'bases heads')
849 def changegroupsubset(repo, proto, bases, heads):
849 def changegroupsubset(repo, proto, bases, heads):
850 bases = decodelist(bases)
850 bases = decodelist(bases)
851 heads = decodelist(heads)
851 heads = decodelist(heads)
852 outgoing = discovery.outgoing(repo, missingroots=bases,
852 outgoing = discovery.outgoing(repo, missingroots=bases,
853 missingheads=heads)
853 missingheads=heads)
854 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
854 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
855 gen = iter(lambda: cg.read(32768), '')
855 gen = iter(lambda: cg.read(32768), '')
856 return streamres(gen=gen)
856 return streamres(gen=gen)
857
857
858 @wireprotocommand('debugwireargs', 'one two *')
858 @wireprotocommand('debugwireargs', 'one two *')
859 def debugwireargs(repo, proto, one, two, others):
859 def debugwireargs(repo, proto, one, two, others):
860 # only accept optional args from the known set
860 # only accept optional args from the known set
861 opts = options('debugwireargs', ['three', 'four'], others)
861 opts = options('debugwireargs', ['three', 'four'], others)
862 return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
862 return repo.debugwireargs(one, two, **pycompat.strkwargs(opts))
863
863
864 @wireprotocommand('getbundle', '*')
864 @wireprotocommand('getbundle', '*')
865 def getbundle(repo, proto, others):
865 def getbundle(repo, proto, others):
866 opts = options('getbundle', gboptsmap.keys(), others)
866 opts = options('getbundle', gboptsmap.keys(), others)
867 for k, v in opts.iteritems():
867 for k, v in opts.iteritems():
868 keytype = gboptsmap[k]
868 keytype = gboptsmap[k]
869 if keytype == 'nodes':
869 if keytype == 'nodes':
870 opts[k] = decodelist(v)
870 opts[k] = decodelist(v)
871 elif keytype == 'csv':
871 elif keytype == 'csv':
872 opts[k] = list(v.split(','))
872 opts[k] = list(v.split(','))
873 elif keytype == 'scsv':
873 elif keytype == 'scsv':
874 opts[k] = set(v.split(','))
874 opts[k] = set(v.split(','))
875 elif keytype == 'boolean':
875 elif keytype == 'boolean':
876 # Client should serialize False as '0', which is a non-empty string
876 # Client should serialize False as '0', which is a non-empty string
877 # so it evaluates as a True bool.
877 # so it evaluates as a True bool.
878 if v == '0':
878 if v == '0':
879 opts[k] = False
879 opts[k] = False
880 else:
880 else:
881 opts[k] = bool(v)
881 opts[k] = bool(v)
882 elif keytype != 'plain':
882 elif keytype != 'plain':
883 raise KeyError('unknown getbundle option type %s'
883 raise KeyError('unknown getbundle option type %s'
884 % keytype)
884 % keytype)
885
885
886 if not bundle1allowed(repo, 'pull'):
886 if not bundle1allowed(repo, 'pull'):
887 if not exchange.bundle2requested(opts.get('bundlecaps')):
887 if not exchange.bundle2requested(opts.get('bundlecaps')):
888 if proto.name == 'http':
888 if proto.name == 'http':
889 return ooberror(bundle2required)
889 return ooberror(bundle2required)
890 raise error.Abort(bundle2requiredmain,
890 raise error.Abort(bundle2requiredmain,
891 hint=bundle2requiredhint)
891 hint=bundle2requiredhint)
892
892
893 prefercompressed = True
893 prefercompressed = True
894
894
895 try:
895 try:
896 if repo.ui.configbool('server', 'disablefullbundle'):
896 if repo.ui.configbool('server', 'disablefullbundle'):
897 # Check to see if this is a full clone.
897 # Check to see if this is a full clone.
898 clheads = set(repo.changelog.heads())
898 clheads = set(repo.changelog.heads())
899 changegroup = opts.get('cg', True)
899 changegroup = opts.get('cg', True)
900 heads = set(opts.get('heads', set()))
900 heads = set(opts.get('heads', set()))
901 common = set(opts.get('common', set()))
901 common = set(opts.get('common', set()))
902 common.discard(nullid)
902 common.discard(nullid)
903 if changegroup and not common and clheads == heads:
903 if changegroup and not common and clheads == heads:
904 raise error.Abort(
904 raise error.Abort(
905 _('server has pull-based clones disabled'),
905 _('server has pull-based clones disabled'),
906 hint=_('remove --pull if specified or upgrade Mercurial'))
906 hint=_('remove --pull if specified or upgrade Mercurial'))
907
907
908 info, chunks = exchange.getbundlechunks(repo, 'serve',
908 info, chunks = exchange.getbundlechunks(repo, 'serve',
909 **pycompat.strkwargs(opts))
909 **pycompat.strkwargs(opts))
910 prefercompressed = info.get('prefercompressed', True)
910 prefercompressed = info.get('prefercompressed', True)
911 except error.Abort as exc:
911 except error.Abort as exc:
912 # cleanly forward Abort error to the client
912 # cleanly forward Abort error to the client
913 if not exchange.bundle2requested(opts.get('bundlecaps')):
913 if not exchange.bundle2requested(opts.get('bundlecaps')):
914 if proto.name == 'http':
914 if proto.name == 'http':
915 return ooberror(str(exc) + '\n')
915 return ooberror(str(exc) + '\n')
916 raise # cannot do better for bundle1 + ssh
916 raise # cannot do better for bundle1 + ssh
917 # bundle2 request expect a bundle2 reply
917 # bundle2 request expect a bundle2 reply
918 bundler = bundle2.bundle20(repo.ui)
918 bundler = bundle2.bundle20(repo.ui)
919 manargs = [('message', str(exc))]
919 manargs = [('message', str(exc))]
920 advargs = []
920 advargs = []
921 if exc.hint is not None:
921 if exc.hint is not None:
922 advargs.append(('hint', exc.hint))
922 advargs.append(('hint', exc.hint))
923 bundler.addpart(bundle2.bundlepart('error:abort',
923 bundler.addpart(bundle2.bundlepart('error:abort',
924 manargs, advargs))
924 manargs, advargs))
925 chunks = bundler.getchunks()
925 chunks = bundler.getchunks()
926 prefercompressed = False
926 prefercompressed = False
927
927
928 return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
928 return streamres(gen=chunks, prefer_uncompressed=not prefercompressed)
929
929
930 @wireprotocommand('heads')
930 @wireprotocommand('heads')
931 def heads(repo, proto):
931 def heads(repo, proto):
932 h = repo.heads()
932 h = repo.heads()
933 return encodelist(h) + "\n"
933 return encodelist(h) + "\n"
934
934
935 @wireprotocommand('hello')
935 @wireprotocommand('hello')
936 def hello(repo, proto):
936 def hello(repo, proto):
937 '''the hello command returns a set of lines describing various
937 '''the hello command returns a set of lines describing various
938 interesting things about the server, in an RFC822-like format.
938 interesting things about the server, in an RFC822-like format.
939 Currently the only one defined is "capabilities", which
939 Currently the only one defined is "capabilities", which
940 consists of a line in the form:
940 consists of a line in the form:
941
941
942 capabilities: space separated list of tokens
942 capabilities: space separated list of tokens
943 '''
943 '''
944 return "capabilities: %s\n" % (capabilities(repo, proto))
944 return "capabilities: %s\n" % (capabilities(repo, proto))
945
945
946 @wireprotocommand('listkeys', 'namespace')
946 @wireprotocommand('listkeys', 'namespace')
947 def listkeys(repo, proto, namespace):
947 def listkeys(repo, proto, namespace):
948 d = repo.listkeys(encoding.tolocal(namespace)).items()
948 d = repo.listkeys(encoding.tolocal(namespace)).items()
949 return pushkeymod.encodekeys(d)
949 return pushkeymod.encodekeys(d)
950
950
951 @wireprotocommand('lookup', 'key')
951 @wireprotocommand('lookup', 'key')
952 def lookup(repo, proto, key):
952 def lookup(repo, proto, key):
953 try:
953 try:
954 k = encoding.tolocal(key)
954 k = encoding.tolocal(key)
955 c = repo[k]
955 c = repo[k]
956 r = c.hex()
956 r = c.hex()
957 success = 1
957 success = 1
958 except Exception as inst:
958 except Exception as inst:
959 r = str(inst)
959 r = str(inst)
960 success = 0
960 success = 0
961 return "%d %s\n" % (success, r)
961 return "%d %s\n" % (success, r)
962
962
963 @wireprotocommand('known', 'nodes *')
963 @wireprotocommand('known', 'nodes *')
964 def known(repo, proto, nodes, others):
964 def known(repo, proto, nodes, others):
965 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
965 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
966
966
967 @wireprotocommand('pushkey', 'namespace key old new')
967 @wireprotocommand('pushkey', 'namespace key old new')
968 def pushkey(repo, proto, namespace, key, old, new):
968 def pushkey(repo, proto, namespace, key, old, new):
969 # compatibility with pre-1.8 clients which were accidentally
969 # compatibility with pre-1.8 clients which were accidentally
970 # sending raw binary nodes rather than utf-8-encoded hex
970 # sending raw binary nodes rather than utf-8-encoded hex
971 if len(new) == 20 and util.escapestr(new) != new:
971 if len(new) == 20 and util.escapestr(new) != new:
972 # looks like it could be a binary node
972 # looks like it could be a binary node
973 try:
973 try:
974 new.decode('utf-8')
974 new.decode('utf-8')
975 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
975 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
976 except UnicodeDecodeError:
976 except UnicodeDecodeError:
977 pass # binary, leave unmodified
977 pass # binary, leave unmodified
978 else:
978 else:
979 new = encoding.tolocal(new) # normal path
979 new = encoding.tolocal(new) # normal path
980
980
981 if util.safehasattr(proto, 'restore'):
981 with proto.mayberedirectstdio() as output:
982
983 proto.redirect()
984
985 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
982 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
986 encoding.tolocal(old), new) or False
983 encoding.tolocal(old), new) or False
987
984
988 output = proto.restore()
985 output = output.getvalue() if output else ''
989
986 return '%s\n%s' % (int(r), output)
990 return '%s\n%s' % (int(r), output)
991
992 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
993 encoding.tolocal(old), new)
994 return '%s\n' % int(r)
995
987
996 @wireprotocommand('stream_out')
988 @wireprotocommand('stream_out')
997 def stream(repo, proto):
989 def stream(repo, proto):
998 '''If the server supports streaming clone, it advertises the "stream"
990 '''If the server supports streaming clone, it advertises the "stream"
999 capability with a value representing the version and flags of the repo
991 capability with a value representing the version and flags of the repo
1000 it is serving. Client checks to see if it understands the format.
992 it is serving. Client checks to see if it understands the format.
1001 '''
993 '''
1002 return streamres_legacy(streamclone.generatev1wireproto(repo))
994 return streamres_legacy(streamclone.generatev1wireproto(repo))
1003
995
1004 @wireprotocommand('unbundle', 'heads')
996 @wireprotocommand('unbundle', 'heads')
1005 def unbundle(repo, proto, heads):
997 def unbundle(repo, proto, heads):
1006 their_heads = decodelist(heads)
998 their_heads = decodelist(heads)
1007
999
1008 try:
1000 try:
1009 proto.redirect()
1001 proto.redirect()
1010
1002
1011 exchange.check_heads(repo, their_heads, 'preparing changes')
1003 exchange.check_heads(repo, their_heads, 'preparing changes')
1012
1004
1013 # write bundle data to temporary file because it can be big
1005 # write bundle data to temporary file because it can be big
1014 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1006 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1015 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1007 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1016 r = 0
1008 r = 0
1017 try:
1009 try:
1018 proto.getfile(fp)
1010 proto.getfile(fp)
1019 fp.seek(0)
1011 fp.seek(0)
1020 gen = exchange.readbundle(repo.ui, fp, None)
1012 gen = exchange.readbundle(repo.ui, fp, None)
1021 if (isinstance(gen, changegroupmod.cg1unpacker)
1013 if (isinstance(gen, changegroupmod.cg1unpacker)
1022 and not bundle1allowed(repo, 'push')):
1014 and not bundle1allowed(repo, 'push')):
1023 if proto.name == 'http':
1015 if proto.name == 'http':
1024 # need to special case http because stderr do not get to
1016 # need to special case http because stderr do not get to
1025 # the http client on failed push so we need to abuse some
1017 # the http client on failed push so we need to abuse some
1026 # other error type to make sure the message get to the
1018 # other error type to make sure the message get to the
1027 # user.
1019 # user.
1028 return ooberror(bundle2required)
1020 return ooberror(bundle2required)
1029 raise error.Abort(bundle2requiredmain,
1021 raise error.Abort(bundle2requiredmain,
1030 hint=bundle2requiredhint)
1022 hint=bundle2requiredhint)
1031
1023
1032 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1024 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1033 proto._client())
1025 proto._client())
1034 if util.safehasattr(r, 'addpart'):
1026 if util.safehasattr(r, 'addpart'):
1035 # The return looks streamable, we are in the bundle2 case and
1027 # The return looks streamable, we are in the bundle2 case and
1036 # should return a stream.
1028 # should return a stream.
1037 return streamres_legacy(gen=r.getchunks())
1029 return streamres_legacy(gen=r.getchunks())
1038 return pushres(r)
1030 return pushres(r)
1039
1031
1040 finally:
1032 finally:
1041 fp.close()
1033 fp.close()
1042 os.unlink(tempname)
1034 os.unlink(tempname)
1043
1035
1044 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1036 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1045 # handle non-bundle2 case first
1037 # handle non-bundle2 case first
1046 if not getattr(exc, 'duringunbundle2', False):
1038 if not getattr(exc, 'duringunbundle2', False):
1047 try:
1039 try:
1048 raise
1040 raise
1049 except error.Abort:
1041 except error.Abort:
1050 # The old code we moved used util.stderr directly.
1042 # The old code we moved used util.stderr directly.
1051 # We did not change it to minimise code change.
1043 # We did not change it to minimise code change.
1052 # This need to be moved to something proper.
1044 # This need to be moved to something proper.
1053 # Feel free to do it.
1045 # Feel free to do it.
1054 util.stderr.write("abort: %s\n" % exc)
1046 util.stderr.write("abort: %s\n" % exc)
1055 if exc.hint is not None:
1047 if exc.hint is not None:
1056 util.stderr.write("(%s)\n" % exc.hint)
1048 util.stderr.write("(%s)\n" % exc.hint)
1057 return pushres(0)
1049 return pushres(0)
1058 except error.PushRaced:
1050 except error.PushRaced:
1059 return pusherr(str(exc))
1051 return pusherr(str(exc))
1060
1052
1061 bundler = bundle2.bundle20(repo.ui)
1053 bundler = bundle2.bundle20(repo.ui)
1062 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1054 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1063 bundler.addpart(out)
1055 bundler.addpart(out)
1064 try:
1056 try:
1065 try:
1057 try:
1066 raise
1058 raise
1067 except error.PushkeyFailed as exc:
1059 except error.PushkeyFailed as exc:
1068 # check client caps
1060 # check client caps
1069 remotecaps = getattr(exc, '_replycaps', None)
1061 remotecaps = getattr(exc, '_replycaps', None)
1070 if (remotecaps is not None
1062 if (remotecaps is not None
1071 and 'pushkey' not in remotecaps.get('error', ())):
1063 and 'pushkey' not in remotecaps.get('error', ())):
1072 # no support remote side, fallback to Abort handler.
1064 # no support remote side, fallback to Abort handler.
1073 raise
1065 raise
1074 part = bundler.newpart('error:pushkey')
1066 part = bundler.newpart('error:pushkey')
1075 part.addparam('in-reply-to', exc.partid)
1067 part.addparam('in-reply-to', exc.partid)
1076 if exc.namespace is not None:
1068 if exc.namespace is not None:
1077 part.addparam('namespace', exc.namespace, mandatory=False)
1069 part.addparam('namespace', exc.namespace, mandatory=False)
1078 if exc.key is not None:
1070 if exc.key is not None:
1079 part.addparam('key', exc.key, mandatory=False)
1071 part.addparam('key', exc.key, mandatory=False)
1080 if exc.new is not None:
1072 if exc.new is not None:
1081 part.addparam('new', exc.new, mandatory=False)
1073 part.addparam('new', exc.new, mandatory=False)
1082 if exc.old is not None:
1074 if exc.old is not None:
1083 part.addparam('old', exc.old, mandatory=False)
1075 part.addparam('old', exc.old, mandatory=False)
1084 if exc.ret is not None:
1076 if exc.ret is not None:
1085 part.addparam('ret', exc.ret, mandatory=False)
1077 part.addparam('ret', exc.ret, mandatory=False)
1086 except error.BundleValueError as exc:
1078 except error.BundleValueError as exc:
1087 errpart = bundler.newpart('error:unsupportedcontent')
1079 errpart = bundler.newpart('error:unsupportedcontent')
1088 if exc.parttype is not None:
1080 if exc.parttype is not None:
1089 errpart.addparam('parttype', exc.parttype)
1081 errpart.addparam('parttype', exc.parttype)
1090 if exc.params:
1082 if exc.params:
1091 errpart.addparam('params', '\0'.join(exc.params))
1083 errpart.addparam('params', '\0'.join(exc.params))
1092 except error.Abort as exc:
1084 except error.Abort as exc:
1093 manargs = [('message', str(exc))]
1085 manargs = [('message', str(exc))]
1094 advargs = []
1086 advargs = []
1095 if exc.hint is not None:
1087 if exc.hint is not None:
1096 advargs.append(('hint', exc.hint))
1088 advargs.append(('hint', exc.hint))
1097 bundler.addpart(bundle2.bundlepart('error:abort',
1089 bundler.addpart(bundle2.bundlepart('error:abort',
1098 manargs, advargs))
1090 manargs, advargs))
1099 except error.PushRaced as exc:
1091 except error.PushRaced as exc:
1100 bundler.newpart('error:pushraced', [('message', str(exc))])
1092 bundler.newpart('error:pushraced', [('message', str(exc))])
1101 return streamres_legacy(gen=bundler.getchunks())
1093 return streamres_legacy(gen=bundler.getchunks())
@@ -1,447 +1,481 b''
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
1 # Copyright 21 May 2005 - (c) 2005 Jake Edge <jake@edge2.net>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2 # Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
3 #
3 #
4 # This software may be used and distributed according to the terms of the
4 # This software may be used and distributed according to the terms of the
5 # GNU General Public License version 2 or any later version.
5 # GNU General Public License version 2 or any later version.
6
6
7 from __future__ import absolute_import
7 from __future__ import absolute_import
8
8
9 import abc
9 import abc
10 import cgi
10 import cgi
11 import contextlib
11 import struct
12 import struct
12 import sys
13 import sys
13
14
14 from .i18n import _
15 from .i18n import _
15 from . import (
16 from . import (
16 encoding,
17 encoding,
17 error,
18 error,
18 hook,
19 hook,
19 pycompat,
20 pycompat,
20 util,
21 util,
21 wireproto,
22 wireproto,
22 )
23 )
23
24
24 stringio = util.stringio
25 stringio = util.stringio
25
26
26 urlerr = util.urlerr
27 urlerr = util.urlerr
27 urlreq = util.urlreq
28 urlreq = util.urlreq
28
29
29 HTTP_OK = 200
30 HTTP_OK = 200
30
31
31 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE = 'application/mercurial-0.1'
32 HGTYPE2 = 'application/mercurial-0.2'
33 HGTYPE2 = 'application/mercurial-0.2'
33 HGERRTYPE = 'application/hg-error'
34 HGERRTYPE = 'application/hg-error'
34
35
35 # Names of the SSH protocol implementations.
36 # Names of the SSH protocol implementations.
36 SSHV1 = 'ssh-v1'
37 SSHV1 = 'ssh-v1'
37 # This is advertised over the wire. Incremental the counter at the end
38 # This is advertised over the wire. Incremental the counter at the end
38 # to reflect BC breakages.
39 # to reflect BC breakages.
39 SSHV2 = 'exp-ssh-v2-0001'
40 SSHV2 = 'exp-ssh-v2-0001'
40
41
41 class baseprotocolhandler(object):
42 class baseprotocolhandler(object):
42 """Abstract base class for wire protocol handlers.
43 """Abstract base class for wire protocol handlers.
43
44
44 A wire protocol handler serves as an interface between protocol command
45 A wire protocol handler serves as an interface between protocol command
45 handlers and the wire protocol transport layer. Protocol handlers provide
46 handlers and the wire protocol transport layer. Protocol handlers provide
46 methods to read command arguments, redirect stdio for the duration of
47 methods to read command arguments, redirect stdio for the duration of
47 the request, handle response types, etc.
48 the request, handle response types, etc.
48 """
49 """
49
50
50 __metaclass__ = abc.ABCMeta
51 __metaclass__ = abc.ABCMeta
51
52
52 @abc.abstractproperty
53 @abc.abstractproperty
53 def name(self):
54 def name(self):
54 """The name of the protocol implementation.
55 """The name of the protocol implementation.
55
56
56 Used for uniquely identifying the transport type.
57 Used for uniquely identifying the transport type.
57 """
58 """
58
59
59 @abc.abstractmethod
60 @abc.abstractmethod
60 def getargs(self, args):
61 def getargs(self, args):
61 """return the value for arguments in <args>
62 """return the value for arguments in <args>
62
63
63 returns a list of values (same order as <args>)"""
64 returns a list of values (same order as <args>)"""
64
65
65 @abc.abstractmethod
66 @abc.abstractmethod
66 def getfile(self, fp):
67 def getfile(self, fp):
67 """write the whole content of a file into a file like object
68 """write the whole content of a file into a file like object
68
69
69 The file is in the form::
70 The file is in the form::
70
71
71 (<chunk-size>\n<chunk>)+0\n
72 (<chunk-size>\n<chunk>)+0\n
72
73
73 chunk size is the ascii version of the int.
74 chunk size is the ascii version of the int.
74 """
75 """
75
76
76 @abc.abstractmethod
77 @abc.abstractmethod
78 def mayberedirectstdio(self):
79 """Context manager to possibly redirect stdio.
80
81 The context manager yields a file-object like object that receives
82 stdout and stderr output when the context manager is active. Or it
83 yields ``None`` if no I/O redirection occurs.
84
85 The intent of this context manager is to capture stdio output
86 so it may be sent in the response. Some transports support streaming
87 stdio to the client in real time. For these transports, stdio output
88 won't be captured.
89 """
90
91 @abc.abstractmethod
77 def redirect(self):
92 def redirect(self):
78 """may setup interception for stdout and stderr
93 """may setup interception for stdout and stderr
79
94
80 See also the `restore` method."""
95 See also the `restore` method."""
81
96
82 # If the `redirect` function does install interception, the `restore`
97 # If the `redirect` function does install interception, the `restore`
83 # function MUST be defined. If interception is not used, this function
98 # function MUST be defined. If interception is not used, this function
84 # MUST NOT be defined.
99 # MUST NOT be defined.
85 #
100 #
86 # left commented here on purpose
101 # left commented here on purpose
87 #
102 #
88 #def restore(self):
103 #def restore(self):
89 # """reinstall previous stdout and stderr and return intercepted stdout
104 # """reinstall previous stdout and stderr and return intercepted stdout
90 # """
105 # """
91 # raise NotImplementedError()
106 # raise NotImplementedError()
92
107
93 def decodevaluefromheaders(req, headerprefix):
108 def decodevaluefromheaders(req, headerprefix):
94 """Decode a long value from multiple HTTP request headers.
109 """Decode a long value from multiple HTTP request headers.
95
110
96 Returns the value as a bytes, not a str.
111 Returns the value as a bytes, not a str.
97 """
112 """
98 chunks = []
113 chunks = []
99 i = 1
114 i = 1
100 prefix = headerprefix.upper().replace(r'-', r'_')
115 prefix = headerprefix.upper().replace(r'-', r'_')
101 while True:
116 while True:
102 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
117 v = req.env.get(r'HTTP_%s_%d' % (prefix, i))
103 if v is None:
118 if v is None:
104 break
119 break
105 chunks.append(pycompat.bytesurl(v))
120 chunks.append(pycompat.bytesurl(v))
106 i += 1
121 i += 1
107
122
108 return ''.join(chunks)
123 return ''.join(chunks)
109
124
110 class webproto(baseprotocolhandler):
125 class webproto(baseprotocolhandler):
111 def __init__(self, req, ui):
126 def __init__(self, req, ui):
112 self._req = req
127 self._req = req
113 self._ui = ui
128 self._ui = ui
114
129
115 @property
130 @property
116 def name(self):
131 def name(self):
117 return 'http'
132 return 'http'
118
133
119 def getargs(self, args):
134 def getargs(self, args):
120 knownargs = self._args()
135 knownargs = self._args()
121 data = {}
136 data = {}
122 keys = args.split()
137 keys = args.split()
123 for k in keys:
138 for k in keys:
124 if k == '*':
139 if k == '*':
125 star = {}
140 star = {}
126 for key in knownargs.keys():
141 for key in knownargs.keys():
127 if key != 'cmd' and key not in keys:
142 if key != 'cmd' and key not in keys:
128 star[key] = knownargs[key][0]
143 star[key] = knownargs[key][0]
129 data['*'] = star
144 data['*'] = star
130 else:
145 else:
131 data[k] = knownargs[k][0]
146 data[k] = knownargs[k][0]
132 return [data[k] for k in keys]
147 return [data[k] for k in keys]
133
148
134 def _args(self):
149 def _args(self):
135 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
150 args = util.rapply(pycompat.bytesurl, self._req.form.copy())
136 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
151 postlen = int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
137 if postlen:
152 if postlen:
138 args.update(cgi.parse_qs(
153 args.update(cgi.parse_qs(
139 self._req.read(postlen), keep_blank_values=True))
154 self._req.read(postlen), keep_blank_values=True))
140 return args
155 return args
141
156
142 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
157 argvalue = decodevaluefromheaders(self._req, r'X-HgArg')
143 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
158 args.update(cgi.parse_qs(argvalue, keep_blank_values=True))
144 return args
159 return args
145
160
146 def getfile(self, fp):
161 def getfile(self, fp):
147 length = int(self._req.env[r'CONTENT_LENGTH'])
162 length = int(self._req.env[r'CONTENT_LENGTH'])
148 # If httppostargs is used, we need to read Content-Length
163 # If httppostargs is used, we need to read Content-Length
149 # minus the amount that was consumed by args.
164 # minus the amount that was consumed by args.
150 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
165 length -= int(self._req.env.get(r'HTTP_X_HGARGS_POST', 0))
151 for s in util.filechunkiter(self._req, limit=length):
166 for s in util.filechunkiter(self._req, limit=length):
152 fp.write(s)
167 fp.write(s)
153
168
169 @contextlib.contextmanager
170 def mayberedirectstdio(self):
171 oldout = self._ui.fout
172 olderr = self._ui.ferr
173
174 out = util.stringio()
175
176 try:
177 self._ui.fout = out
178 self._ui.ferr = out
179 yield out
180 finally:
181 self._ui.fout = oldout
182 self._ui.ferr = olderr
183
154 def redirect(self):
184 def redirect(self):
155 self._oldio = self._ui.fout, self._ui.ferr
185 self._oldio = self._ui.fout, self._ui.ferr
156 self._ui.ferr = self._ui.fout = stringio()
186 self._ui.ferr = self._ui.fout = stringio()
157
187
158 def restore(self):
188 def restore(self):
159 val = self._ui.fout.getvalue()
189 val = self._ui.fout.getvalue()
160 self._ui.ferr, self._ui.fout = self._oldio
190 self._ui.ferr, self._ui.fout = self._oldio
161 return val
191 return val
162
192
163 def _client(self):
193 def _client(self):
164 return 'remote:%s:%s:%s' % (
194 return 'remote:%s:%s:%s' % (
165 self._req.env.get('wsgi.url_scheme') or 'http',
195 self._req.env.get('wsgi.url_scheme') or 'http',
166 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
196 urlreq.quote(self._req.env.get('REMOTE_HOST', '')),
167 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
197 urlreq.quote(self._req.env.get('REMOTE_USER', '')))
168
198
169 def responsetype(self, prefer_uncompressed):
199 def responsetype(self, prefer_uncompressed):
170 """Determine the appropriate response type and compression settings.
200 """Determine the appropriate response type and compression settings.
171
201
172 Returns a tuple of (mediatype, compengine, engineopts).
202 Returns a tuple of (mediatype, compengine, engineopts).
173 """
203 """
174 # Determine the response media type and compression engine based
204 # Determine the response media type and compression engine based
175 # on the request parameters.
205 # on the request parameters.
176 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
206 protocaps = decodevaluefromheaders(self._req, r'X-HgProto').split(' ')
177
207
178 if '0.2' in protocaps:
208 if '0.2' in protocaps:
179 # All clients are expected to support uncompressed data.
209 # All clients are expected to support uncompressed data.
180 if prefer_uncompressed:
210 if prefer_uncompressed:
181 return HGTYPE2, util._noopengine(), {}
211 return HGTYPE2, util._noopengine(), {}
182
212
183 # Default as defined by wire protocol spec.
213 # Default as defined by wire protocol spec.
184 compformats = ['zlib', 'none']
214 compformats = ['zlib', 'none']
185 for cap in protocaps:
215 for cap in protocaps:
186 if cap.startswith('comp='):
216 if cap.startswith('comp='):
187 compformats = cap[5:].split(',')
217 compformats = cap[5:].split(',')
188 break
218 break
189
219
190 # Now find an agreed upon compression format.
220 # Now find an agreed upon compression format.
191 for engine in wireproto.supportedcompengines(self._ui, self,
221 for engine in wireproto.supportedcompengines(self._ui, self,
192 util.SERVERROLE):
222 util.SERVERROLE):
193 if engine.wireprotosupport().name in compformats:
223 if engine.wireprotosupport().name in compformats:
194 opts = {}
224 opts = {}
195 level = self._ui.configint('server',
225 level = self._ui.configint('server',
196 '%slevel' % engine.name())
226 '%slevel' % engine.name())
197 if level is not None:
227 if level is not None:
198 opts['level'] = level
228 opts['level'] = level
199
229
200 return HGTYPE2, engine, opts
230 return HGTYPE2, engine, opts
201
231
202 # No mutually supported compression format. Fall back to the
232 # No mutually supported compression format. Fall back to the
203 # legacy protocol.
233 # legacy protocol.
204
234
205 # Don't allow untrusted settings because disabling compression or
235 # Don't allow untrusted settings because disabling compression or
206 # setting a very high compression level could lead to flooding
236 # setting a very high compression level could lead to flooding
207 # the server's network or CPU.
237 # the server's network or CPU.
208 opts = {'level': self._ui.configint('server', 'zliblevel')}
238 opts = {'level': self._ui.configint('server', 'zliblevel')}
209 return HGTYPE, util.compengines['zlib'], opts
239 return HGTYPE, util.compengines['zlib'], opts
210
240
211 def iscmd(cmd):
241 def iscmd(cmd):
212 return cmd in wireproto.commands
242 return cmd in wireproto.commands
213
243
214 def parsehttprequest(repo, req, query):
244 def parsehttprequest(repo, req, query):
215 """Parse the HTTP request for a wire protocol request.
245 """Parse the HTTP request for a wire protocol request.
216
246
217 If the current request appears to be a wire protocol request, this
247 If the current request appears to be a wire protocol request, this
218 function returns a dict with details about that request, including
248 function returns a dict with details about that request, including
219 an ``abstractprotocolserver`` instance suitable for handling the
249 an ``abstractprotocolserver`` instance suitable for handling the
220 request. Otherwise, ``None`` is returned.
250 request. Otherwise, ``None`` is returned.
221
251
222 ``req`` is a ``wsgirequest`` instance.
252 ``req`` is a ``wsgirequest`` instance.
223 """
253 """
224 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
254 # HTTP version 1 wire protocol requests are denoted by a "cmd" query
225 # string parameter. If it isn't present, this isn't a wire protocol
255 # string parameter. If it isn't present, this isn't a wire protocol
226 # request.
256 # request.
227 if r'cmd' not in req.form:
257 if r'cmd' not in req.form:
228 return None
258 return None
229
259
230 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
260 cmd = pycompat.sysbytes(req.form[r'cmd'][0])
231
261
232 # The "cmd" request parameter is used by both the wire protocol and hgweb.
262 # The "cmd" request parameter is used by both the wire protocol and hgweb.
233 # While not all wire protocol commands are available for all transports,
263 # While not all wire protocol commands are available for all transports,
234 # if we see a "cmd" value that resembles a known wire protocol command, we
264 # if we see a "cmd" value that resembles a known wire protocol command, we
235 # route it to a protocol handler. This is better than routing possible
265 # route it to a protocol handler. This is better than routing possible
236 # wire protocol requests to hgweb because it prevents hgweb from using
266 # wire protocol requests to hgweb because it prevents hgweb from using
237 # known wire protocol commands and it is less confusing for machine
267 # known wire protocol commands and it is less confusing for machine
238 # clients.
268 # clients.
239 if cmd not in wireproto.commands:
269 if cmd not in wireproto.commands:
240 return None
270 return None
241
271
242 proto = webproto(req, repo.ui)
272 proto = webproto(req, repo.ui)
243
273
244 return {
274 return {
245 'cmd': cmd,
275 'cmd': cmd,
246 'proto': proto,
276 'proto': proto,
247 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
277 'dispatch': lambda: _callhttp(repo, req, proto, cmd),
248 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
278 'handleerror': lambda ex: _handlehttperror(ex, req, cmd),
249 }
279 }
250
280
251 def _callhttp(repo, req, proto, cmd):
281 def _callhttp(repo, req, proto, cmd):
252 def genversion2(gen, engine, engineopts):
282 def genversion2(gen, engine, engineopts):
253 # application/mercurial-0.2 always sends a payload header
283 # application/mercurial-0.2 always sends a payload header
254 # identifying the compression engine.
284 # identifying the compression engine.
255 name = engine.wireprotosupport().name
285 name = engine.wireprotosupport().name
256 assert 0 < len(name) < 256
286 assert 0 < len(name) < 256
257 yield struct.pack('B', len(name))
287 yield struct.pack('B', len(name))
258 yield name
288 yield name
259
289
260 for chunk in gen:
290 for chunk in gen:
261 yield chunk
291 yield chunk
262
292
263 rsp = wireproto.dispatch(repo, proto, cmd)
293 rsp = wireproto.dispatch(repo, proto, cmd)
264
294
265 if not wireproto.commands.commandavailable(cmd, proto):
295 if not wireproto.commands.commandavailable(cmd, proto):
266 req.respond(HTTP_OK, HGERRTYPE,
296 req.respond(HTTP_OK, HGERRTYPE,
267 body=_('requested wire protocol command is not available '
297 body=_('requested wire protocol command is not available '
268 'over HTTP'))
298 'over HTTP'))
269 return []
299 return []
270
300
271 if isinstance(rsp, bytes):
301 if isinstance(rsp, bytes):
272 req.respond(HTTP_OK, HGTYPE, body=rsp)
302 req.respond(HTTP_OK, HGTYPE, body=rsp)
273 return []
303 return []
274 elif isinstance(rsp, wireproto.streamres_legacy):
304 elif isinstance(rsp, wireproto.streamres_legacy):
275 gen = rsp.gen
305 gen = rsp.gen
276 req.respond(HTTP_OK, HGTYPE)
306 req.respond(HTTP_OK, HGTYPE)
277 return gen
307 return gen
278 elif isinstance(rsp, wireproto.streamres):
308 elif isinstance(rsp, wireproto.streamres):
279 gen = rsp.gen
309 gen = rsp.gen
280
310
281 # This code for compression should not be streamres specific. It
311 # This code for compression should not be streamres specific. It
282 # is here because we only compress streamres at the moment.
312 # is here because we only compress streamres at the moment.
283 mediatype, engine, engineopts = proto.responsetype(
313 mediatype, engine, engineopts = proto.responsetype(
284 rsp.prefer_uncompressed)
314 rsp.prefer_uncompressed)
285 gen = engine.compressstream(gen, engineopts)
315 gen = engine.compressstream(gen, engineopts)
286
316
287 if mediatype == HGTYPE2:
317 if mediatype == HGTYPE2:
288 gen = genversion2(gen, engine, engineopts)
318 gen = genversion2(gen, engine, engineopts)
289
319
290 req.respond(HTTP_OK, mediatype)
320 req.respond(HTTP_OK, mediatype)
291 return gen
321 return gen
292 elif isinstance(rsp, wireproto.pushres):
322 elif isinstance(rsp, wireproto.pushres):
293 val = proto.restore()
323 val = proto.restore()
294 rsp = '%d\n%s' % (rsp.res, val)
324 rsp = '%d\n%s' % (rsp.res, val)
295 req.respond(HTTP_OK, HGTYPE, body=rsp)
325 req.respond(HTTP_OK, HGTYPE, body=rsp)
296 return []
326 return []
297 elif isinstance(rsp, wireproto.pusherr):
327 elif isinstance(rsp, wireproto.pusherr):
298 # This is the httplib workaround documented in _handlehttperror().
328 # This is the httplib workaround documented in _handlehttperror().
299 req.drain()
329 req.drain()
300
330
301 proto.restore()
331 proto.restore()
302 rsp = '0\n%s\n' % rsp.res
332 rsp = '0\n%s\n' % rsp.res
303 req.respond(HTTP_OK, HGTYPE, body=rsp)
333 req.respond(HTTP_OK, HGTYPE, body=rsp)
304 return []
334 return []
305 elif isinstance(rsp, wireproto.ooberror):
335 elif isinstance(rsp, wireproto.ooberror):
306 rsp = rsp.message
336 rsp = rsp.message
307 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
337 req.respond(HTTP_OK, HGERRTYPE, body=rsp)
308 return []
338 return []
309 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
339 raise error.ProgrammingError('hgweb.protocol internal failure', rsp)
310
340
311 def _handlehttperror(e, req, cmd):
341 def _handlehttperror(e, req, cmd):
312 """Called when an ErrorResponse is raised during HTTP request processing."""
342 """Called when an ErrorResponse is raised during HTTP request processing."""
313
343
314 # Clients using Python's httplib are stateful: the HTTP client
344 # Clients using Python's httplib are stateful: the HTTP client
315 # won't process an HTTP response until all request data is
345 # won't process an HTTP response until all request data is
316 # sent to the server. The intent of this code is to ensure
346 # sent to the server. The intent of this code is to ensure
317 # we always read HTTP request data from the client, thus
347 # we always read HTTP request data from the client, thus
318 # ensuring httplib transitions to a state that allows it to read
348 # ensuring httplib transitions to a state that allows it to read
319 # the HTTP response. In other words, it helps prevent deadlocks
349 # the HTTP response. In other words, it helps prevent deadlocks
320 # on clients using httplib.
350 # on clients using httplib.
321
351
322 if (req.env[r'REQUEST_METHOD'] == r'POST' and
352 if (req.env[r'REQUEST_METHOD'] == r'POST' and
323 # But not if Expect: 100-continue is being used.
353 # But not if Expect: 100-continue is being used.
324 (req.env.get('HTTP_EXPECT',
354 (req.env.get('HTTP_EXPECT',
325 '').lower() != '100-continue') or
355 '').lower() != '100-continue') or
326 # Or the non-httplib HTTP library is being advertised by
356 # Or the non-httplib HTTP library is being advertised by
327 # the client.
357 # the client.
328 req.env.get('X-HgHttp2', '')):
358 req.env.get('X-HgHttp2', '')):
329 req.drain()
359 req.drain()
330 else:
360 else:
331 req.headers.append((r'Connection', r'Close'))
361 req.headers.append((r'Connection', r'Close'))
332
362
333 # TODO This response body assumes the failed command was
363 # TODO This response body assumes the failed command was
334 # "unbundle." That assumption is not always valid.
364 # "unbundle." That assumption is not always valid.
335 req.respond(e, HGTYPE, body='0\n%s\n' % e)
365 req.respond(e, HGTYPE, body='0\n%s\n' % e)
336
366
337 return ''
367 return ''
338
368
339 def _sshv1respondbytes(fout, value):
369 def _sshv1respondbytes(fout, value):
340 """Send a bytes response for protocol version 1."""
370 """Send a bytes response for protocol version 1."""
341 fout.write('%d\n' % len(value))
371 fout.write('%d\n' % len(value))
342 fout.write(value)
372 fout.write(value)
343 fout.flush()
373 fout.flush()
344
374
345 def _sshv1respondstream(fout, source):
375 def _sshv1respondstream(fout, source):
346 write = fout.write
376 write = fout.write
347 for chunk in source.gen:
377 for chunk in source.gen:
348 write(chunk)
378 write(chunk)
349 fout.flush()
379 fout.flush()
350
380
351 def _sshv1respondooberror(fout, ferr, rsp):
381 def _sshv1respondooberror(fout, ferr, rsp):
352 ferr.write(b'%s\n-\n' % rsp)
382 ferr.write(b'%s\n-\n' % rsp)
353 ferr.flush()
383 ferr.flush()
354 fout.write(b'\n')
384 fout.write(b'\n')
355 fout.flush()
385 fout.flush()
356
386
357 class sshv1protocolhandler(baseprotocolhandler):
387 class sshv1protocolhandler(baseprotocolhandler):
358 """Handler for requests services via version 1 of SSH protocol."""
388 """Handler for requests services via version 1 of SSH protocol."""
359 def __init__(self, ui, fin, fout):
389 def __init__(self, ui, fin, fout):
360 self._ui = ui
390 self._ui = ui
361 self._fin = fin
391 self._fin = fin
362 self._fout = fout
392 self._fout = fout
363
393
364 @property
394 @property
365 def name(self):
395 def name(self):
366 return 'ssh'
396 return 'ssh'
367
397
368 def getargs(self, args):
398 def getargs(self, args):
369 data = {}
399 data = {}
370 keys = args.split()
400 keys = args.split()
371 for n in xrange(len(keys)):
401 for n in xrange(len(keys)):
372 argline = self._fin.readline()[:-1]
402 argline = self._fin.readline()[:-1]
373 arg, l = argline.split()
403 arg, l = argline.split()
374 if arg not in keys:
404 if arg not in keys:
375 raise error.Abort(_("unexpected parameter %r") % arg)
405 raise error.Abort(_("unexpected parameter %r") % arg)
376 if arg == '*':
406 if arg == '*':
377 star = {}
407 star = {}
378 for k in xrange(int(l)):
408 for k in xrange(int(l)):
379 argline = self._fin.readline()[:-1]
409 argline = self._fin.readline()[:-1]
380 arg, l = argline.split()
410 arg, l = argline.split()
381 val = self._fin.read(int(l))
411 val = self._fin.read(int(l))
382 star[arg] = val
412 star[arg] = val
383 data['*'] = star
413 data['*'] = star
384 else:
414 else:
385 val = self._fin.read(int(l))
415 val = self._fin.read(int(l))
386 data[arg] = val
416 data[arg] = val
387 return [data[k] for k in keys]
417 return [data[k] for k in keys]
388
418
389 def getfile(self, fpout):
419 def getfile(self, fpout):
390 _sshv1respondbytes(self._fout, b'')
420 _sshv1respondbytes(self._fout, b'')
391 count = int(self._fin.readline())
421 count = int(self._fin.readline())
392 while count:
422 while count:
393 fpout.write(self._fin.read(count))
423 fpout.write(self._fin.read(count))
394 count = int(self._fin.readline())
424 count = int(self._fin.readline())
395
425
426 @contextlib.contextmanager
427 def mayberedirectstdio(self):
428 yield None
429
396 def redirect(self):
430 def redirect(self):
397 pass
431 pass
398
432
399 def _client(self):
433 def _client(self):
400 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
434 client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
401 return 'remote:ssh:' + client
435 return 'remote:ssh:' + client
402
436
403 class sshserver(object):
437 class sshserver(object):
404 def __init__(self, ui, repo):
438 def __init__(self, ui, repo):
405 self._ui = ui
439 self._ui = ui
406 self._repo = repo
440 self._repo = repo
407 self._fin = ui.fin
441 self._fin = ui.fin
408 self._fout = ui.fout
442 self._fout = ui.fout
409
443
410 hook.redirect(True)
444 hook.redirect(True)
411 ui.fout = repo.ui.fout = ui.ferr
445 ui.fout = repo.ui.fout = ui.ferr
412
446
413 # Prevent insertion/deletion of CRs
447 # Prevent insertion/deletion of CRs
414 util.setbinary(self._fin)
448 util.setbinary(self._fin)
415 util.setbinary(self._fout)
449 util.setbinary(self._fout)
416
450
417 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
451 self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
418
452
419 def serve_forever(self):
453 def serve_forever(self):
420 while self.serve_one():
454 while self.serve_one():
421 pass
455 pass
422 sys.exit(0)
456 sys.exit(0)
423
457
424 def serve_one(self):
458 def serve_one(self):
425 cmd = self._fin.readline()[:-1]
459 cmd = self._fin.readline()[:-1]
426 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
460 if cmd and wireproto.commands.commandavailable(cmd, self._proto):
427 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
461 rsp = wireproto.dispatch(self._repo, self._proto, cmd)
428
462
429 if isinstance(rsp, bytes):
463 if isinstance(rsp, bytes):
430 _sshv1respondbytes(self._fout, rsp)
464 _sshv1respondbytes(self._fout, rsp)
431 elif isinstance(rsp, wireproto.streamres):
465 elif isinstance(rsp, wireproto.streamres):
432 _sshv1respondstream(self._fout, rsp)
466 _sshv1respondstream(self._fout, rsp)
433 elif isinstance(rsp, wireproto.streamres_legacy):
467 elif isinstance(rsp, wireproto.streamres_legacy):
434 _sshv1respondstream(self._fout, rsp)
468 _sshv1respondstream(self._fout, rsp)
435 elif isinstance(rsp, wireproto.pushres):
469 elif isinstance(rsp, wireproto.pushres):
436 _sshv1respondbytes(self._fout, b'')
470 _sshv1respondbytes(self._fout, b'')
437 _sshv1respondbytes(self._fout, bytes(rsp.res))
471 _sshv1respondbytes(self._fout, bytes(rsp.res))
438 elif isinstance(rsp, wireproto.pusherr):
472 elif isinstance(rsp, wireproto.pusherr):
439 _sshv1respondbytes(self._fout, rsp.res)
473 _sshv1respondbytes(self._fout, rsp.res)
440 elif isinstance(rsp, wireproto.ooberror):
474 elif isinstance(rsp, wireproto.ooberror):
441 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
475 _sshv1respondooberror(self._fout, self._ui.ferr, rsp.message)
442 else:
476 else:
443 raise error.ProgrammingError('unhandled response type from '
477 raise error.ProgrammingError('unhandled response type from '
444 'wire protocol command: %s' % rsp)
478 'wire protocol command: %s' % rsp)
445 elif cmd:
479 elif cmd:
446 _sshv1respondbytes(self._fout, b'')
480 _sshv1respondbytes(self._fout, b'')
447 return cmd != ''
481 return cmd != ''
General Comments 0
You need to be logged in to leave comments. Login now