##// END OF EJS Templates
wireproto: move version 1 peer functionality to standalone module (API)...
Gregory Szorc -
r37632:a81d02ea default
parent child Browse files
Show More
@@ -1,1187 +1,1188 b''
1 # Infinite push
1 # Infinite push
2 #
2 #
3 # Copyright 2016 Facebook, Inc.
3 # Copyright 2016 Facebook, Inc.
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7 """ store some pushes in a remote blob store on the server (EXPERIMENTAL)
7 """ store some pushes in a remote blob store on the server (EXPERIMENTAL)
8
8
9 [infinitepush]
9 [infinitepush]
10 # Server-side and client-side option. Pattern of the infinitepush bookmark
10 # Server-side and client-side option. Pattern of the infinitepush bookmark
11 branchpattern = PATTERN
11 branchpattern = PATTERN
12
12
13 # Server or client
13 # Server or client
14 server = False
14 server = False
15
15
16 # Server-side option. Possible values: 'disk' or 'sql'. Fails if not set
16 # Server-side option. Possible values: 'disk' or 'sql'. Fails if not set
17 indextype = disk
17 indextype = disk
18
18
19 # Server-side option. Used only if indextype=sql.
19 # Server-side option. Used only if indextype=sql.
20 # Format: 'IP:PORT:DB_NAME:USER:PASSWORD'
20 # Format: 'IP:PORT:DB_NAME:USER:PASSWORD'
21 sqlhost = IP:PORT:DB_NAME:USER:PASSWORD
21 sqlhost = IP:PORT:DB_NAME:USER:PASSWORD
22
22
23 # Server-side option. Used only if indextype=disk.
23 # Server-side option. Used only if indextype=disk.
24 # Filesystem path to the index store
24 # Filesystem path to the index store
25 indexpath = PATH
25 indexpath = PATH
26
26
27 # Server-side option. Possible values: 'disk' or 'external'
27 # Server-side option. Possible values: 'disk' or 'external'
28 # Fails if not set
28 # Fails if not set
29 storetype = disk
29 storetype = disk
30
30
31 # Server-side option.
31 # Server-side option.
32 # Path to the binary that will save bundle to the bundlestore
32 # Path to the binary that will save bundle to the bundlestore
33 # Formatted cmd line will be passed to it (see `put_args`)
33 # Formatted cmd line will be passed to it (see `put_args`)
34 put_binary = put
34 put_binary = put
35
35
36 # Serser-side option. Used only if storetype=external.
36 # Serser-side option. Used only if storetype=external.
37 # Format cmd-line string for put binary. Placeholder: {filename}
37 # Format cmd-line string for put binary. Placeholder: {filename}
38 put_args = {filename}
38 put_args = {filename}
39
39
40 # Server-side option.
40 # Server-side option.
41 # Path to the binary that get bundle from the bundlestore.
41 # Path to the binary that get bundle from the bundlestore.
42 # Formatted cmd line will be passed to it (see `get_args`)
42 # Formatted cmd line will be passed to it (see `get_args`)
43 get_binary = get
43 get_binary = get
44
44
45 # Serser-side option. Used only if storetype=external.
45 # Serser-side option. Used only if storetype=external.
46 # Format cmd-line string for get binary. Placeholders: {filename} {handle}
46 # Format cmd-line string for get binary. Placeholders: {filename} {handle}
47 get_args = {filename} {handle}
47 get_args = {filename} {handle}
48
48
49 # Server-side option
49 # Server-side option
50 logfile = FIlE
50 logfile = FIlE
51
51
52 # Server-side option
52 # Server-side option
53 loglevel = DEBUG
53 loglevel = DEBUG
54
54
55 # Server-side option. Used only if indextype=sql.
55 # Server-side option. Used only if indextype=sql.
56 # Sets mysql wait_timeout option.
56 # Sets mysql wait_timeout option.
57 waittimeout = 300
57 waittimeout = 300
58
58
59 # Server-side option. Used only if indextype=sql.
59 # Server-side option. Used only if indextype=sql.
60 # Sets mysql innodb_lock_wait_timeout option.
60 # Sets mysql innodb_lock_wait_timeout option.
61 locktimeout = 120
61 locktimeout = 120
62
62
63 # Server-side option. Used only if indextype=sql.
63 # Server-side option. Used only if indextype=sql.
64 # Name of the repository
64 # Name of the repository
65 reponame = ''
65 reponame = ''
66
66
67 # Client-side option. Used by --list-remote option. List of remote scratch
67 # Client-side option. Used by --list-remote option. List of remote scratch
68 # patterns to list if no patterns are specified.
68 # patterns to list if no patterns are specified.
69 defaultremotepatterns = ['*']
69 defaultremotepatterns = ['*']
70
70
71 # Instructs infinitepush to forward all received bundle2 parts to the
71 # Instructs infinitepush to forward all received bundle2 parts to the
72 # bundle for storage. Defaults to False.
72 # bundle for storage. Defaults to False.
73 storeallparts = True
73 storeallparts = True
74
74
75 # routes each incoming push to the bundlestore. defaults to False
75 # routes each incoming push to the bundlestore. defaults to False
76 pushtobundlestore = True
76 pushtobundlestore = True
77
77
78 [remotenames]
78 [remotenames]
79 # Client-side option
79 # Client-side option
80 # This option should be set only if remotenames extension is enabled.
80 # This option should be set only if remotenames extension is enabled.
81 # Whether remote bookmarks are tracked by remotenames extension.
81 # Whether remote bookmarks are tracked by remotenames extension.
82 bookmarks = True
82 bookmarks = True
83 """
83 """
84
84
85 from __future__ import absolute_import
85 from __future__ import absolute_import
86
86
87 import collections
87 import collections
88 import contextlib
88 import contextlib
89 import errno
89 import errno
90 import functools
90 import functools
91 import logging
91 import logging
92 import os
92 import os
93 import random
93 import random
94 import re
94 import re
95 import socket
95 import socket
96 import subprocess
96 import subprocess
97 import tempfile
97 import tempfile
98 import time
98 import time
99
99
100 from mercurial.node import (
100 from mercurial.node import (
101 bin,
101 bin,
102 hex,
102 hex,
103 )
103 )
104
104
105 from mercurial.i18n import _
105 from mercurial.i18n import _
106
106
107 from mercurial.utils import (
107 from mercurial.utils import (
108 procutil,
108 procutil,
109 stringutil,
109 stringutil,
110 )
110 )
111
111
112 from mercurial import (
112 from mercurial import (
113 bundle2,
113 bundle2,
114 changegroup,
114 changegroup,
115 commands,
115 commands,
116 discovery,
116 discovery,
117 encoding,
117 encoding,
118 error,
118 error,
119 exchange,
119 exchange,
120 extensions,
120 extensions,
121 hg,
121 hg,
122 localrepo,
122 localrepo,
123 peer,
123 peer,
124 phases,
124 phases,
125 pushkey,
125 pushkey,
126 pycompat,
126 pycompat,
127 registrar,
127 registrar,
128 util,
128 util,
129 wireproto,
129 wireproto,
130 wireprototypes,
130 wireprototypes,
131 wireprotov1peer,
131 )
132 )
132
133
133 from . import (
134 from . import (
134 bundleparts,
135 bundleparts,
135 common,
136 common,
136 )
137 )
137
138
138 # Note for extension authors: ONLY specify testedwith = 'ships-with-hg-core' for
139 # Note for extension authors: ONLY specify testedwith = 'ships-with-hg-core' for
139 # extensions which SHIP WITH MERCURIAL. Non-mainline extensions should
140 # extensions which SHIP WITH MERCURIAL. Non-mainline extensions should
140 # be specifying the version(s) of Mercurial they are tested with, or
141 # be specifying the version(s) of Mercurial they are tested with, or
141 # leave the attribute unspecified.
142 # leave the attribute unspecified.
142 testedwith = 'ships-with-hg-core'
143 testedwith = 'ships-with-hg-core'
143
144
144 configtable = {}
145 configtable = {}
145 configitem = registrar.configitem(configtable)
146 configitem = registrar.configitem(configtable)
146
147
147 configitem('infinitepush', 'server',
148 configitem('infinitepush', 'server',
148 default=False,
149 default=False,
149 )
150 )
150 configitem('infinitepush', 'storetype',
151 configitem('infinitepush', 'storetype',
151 default='',
152 default='',
152 )
153 )
153 configitem('infinitepush', 'indextype',
154 configitem('infinitepush', 'indextype',
154 default='',
155 default='',
155 )
156 )
156 configitem('infinitepush', 'indexpath',
157 configitem('infinitepush', 'indexpath',
157 default='',
158 default='',
158 )
159 )
159 configitem('infinitepush', 'storeallparts',
160 configitem('infinitepush', 'storeallparts',
160 default=False,
161 default=False,
161 )
162 )
162 configitem('infinitepush', 'reponame',
163 configitem('infinitepush', 'reponame',
163 default='',
164 default='',
164 )
165 )
165 configitem('scratchbranch', 'storepath',
166 configitem('scratchbranch', 'storepath',
166 default='',
167 default='',
167 )
168 )
168 configitem('infinitepush', 'branchpattern',
169 configitem('infinitepush', 'branchpattern',
169 default='',
170 default='',
170 )
171 )
171 configitem('infinitepush', 'pushtobundlestore',
172 configitem('infinitepush', 'pushtobundlestore',
172 default=False,
173 default=False,
173 )
174 )
174 configitem('experimental', 'server-bundlestore-bookmark',
175 configitem('experimental', 'server-bundlestore-bookmark',
175 default='',
176 default='',
176 )
177 )
177 configitem('experimental', 'infinitepush-scratchpush',
178 configitem('experimental', 'infinitepush-scratchpush',
178 default=False,
179 default=False,
179 )
180 )
180
181
181 experimental = 'experimental'
182 experimental = 'experimental'
182 configbookmark = 'server-bundlestore-bookmark'
183 configbookmark = 'server-bundlestore-bookmark'
183 configscratchpush = 'infinitepush-scratchpush'
184 configscratchpush = 'infinitepush-scratchpush'
184
185
185 scratchbranchparttype = bundleparts.scratchbranchparttype
186 scratchbranchparttype = bundleparts.scratchbranchparttype
186 revsetpredicate = registrar.revsetpredicate()
187 revsetpredicate = registrar.revsetpredicate()
187 templatekeyword = registrar.templatekeyword()
188 templatekeyword = registrar.templatekeyword()
188 _scratchbranchmatcher = lambda x: False
189 _scratchbranchmatcher = lambda x: False
189 _maybehash = re.compile(r'^[a-f0-9]+$').search
190 _maybehash = re.compile(r'^[a-f0-9]+$').search
190
191
191 def _buildexternalbundlestore(ui):
192 def _buildexternalbundlestore(ui):
192 put_args = ui.configlist('infinitepush', 'put_args', [])
193 put_args = ui.configlist('infinitepush', 'put_args', [])
193 put_binary = ui.config('infinitepush', 'put_binary')
194 put_binary = ui.config('infinitepush', 'put_binary')
194 if not put_binary:
195 if not put_binary:
195 raise error.Abort('put binary is not specified')
196 raise error.Abort('put binary is not specified')
196 get_args = ui.configlist('infinitepush', 'get_args', [])
197 get_args = ui.configlist('infinitepush', 'get_args', [])
197 get_binary = ui.config('infinitepush', 'get_binary')
198 get_binary = ui.config('infinitepush', 'get_binary')
198 if not get_binary:
199 if not get_binary:
199 raise error.Abort('get binary is not specified')
200 raise error.Abort('get binary is not specified')
200 from . import store
201 from . import store
201 return store.externalbundlestore(put_binary, put_args, get_binary, get_args)
202 return store.externalbundlestore(put_binary, put_args, get_binary, get_args)
202
203
203 def _buildsqlindex(ui):
204 def _buildsqlindex(ui):
204 sqlhost = ui.config('infinitepush', 'sqlhost')
205 sqlhost = ui.config('infinitepush', 'sqlhost')
205 if not sqlhost:
206 if not sqlhost:
206 raise error.Abort(_('please set infinitepush.sqlhost'))
207 raise error.Abort(_('please set infinitepush.sqlhost'))
207 host, port, db, user, password = sqlhost.split(':')
208 host, port, db, user, password = sqlhost.split(':')
208 reponame = ui.config('infinitepush', 'reponame')
209 reponame = ui.config('infinitepush', 'reponame')
209 if not reponame:
210 if not reponame:
210 raise error.Abort(_('please set infinitepush.reponame'))
211 raise error.Abort(_('please set infinitepush.reponame'))
211
212
212 logfile = ui.config('infinitepush', 'logfile', '')
213 logfile = ui.config('infinitepush', 'logfile', '')
213 waittimeout = ui.configint('infinitepush', 'waittimeout', 300)
214 waittimeout = ui.configint('infinitepush', 'waittimeout', 300)
214 locktimeout = ui.configint('infinitepush', 'locktimeout', 120)
215 locktimeout = ui.configint('infinitepush', 'locktimeout', 120)
215 from . import sqlindexapi
216 from . import sqlindexapi
216 return sqlindexapi.sqlindexapi(
217 return sqlindexapi.sqlindexapi(
217 reponame, host, port, db, user, password,
218 reponame, host, port, db, user, password,
218 logfile, _getloglevel(ui), waittimeout=waittimeout,
219 logfile, _getloglevel(ui), waittimeout=waittimeout,
219 locktimeout=locktimeout)
220 locktimeout=locktimeout)
220
221
221 def _getloglevel(ui):
222 def _getloglevel(ui):
222 loglevel = ui.config('infinitepush', 'loglevel', 'DEBUG')
223 loglevel = ui.config('infinitepush', 'loglevel', 'DEBUG')
223 numeric_loglevel = getattr(logging, loglevel.upper(), None)
224 numeric_loglevel = getattr(logging, loglevel.upper(), None)
224 if not isinstance(numeric_loglevel, int):
225 if not isinstance(numeric_loglevel, int):
225 raise error.Abort(_('invalid log level %s') % loglevel)
226 raise error.Abort(_('invalid log level %s') % loglevel)
226 return numeric_loglevel
227 return numeric_loglevel
227
228
228 def _tryhoist(ui, remotebookmark):
229 def _tryhoist(ui, remotebookmark):
229 '''returns a bookmarks with hoisted part removed
230 '''returns a bookmarks with hoisted part removed
230
231
231 Remotenames extension has a 'hoist' config that allows to use remote
232 Remotenames extension has a 'hoist' config that allows to use remote
232 bookmarks without specifying remote path. For example, 'hg update master'
233 bookmarks without specifying remote path. For example, 'hg update master'
233 works as well as 'hg update remote/master'. We want to allow the same in
234 works as well as 'hg update remote/master'. We want to allow the same in
234 infinitepush.
235 infinitepush.
235 '''
236 '''
236
237
237 if common.isremotebooksenabled(ui):
238 if common.isremotebooksenabled(ui):
238 hoist = ui.config('remotenames', 'hoistedpeer') + '/'
239 hoist = ui.config('remotenames', 'hoistedpeer') + '/'
239 if remotebookmark.startswith(hoist):
240 if remotebookmark.startswith(hoist):
240 return remotebookmark[len(hoist):]
241 return remotebookmark[len(hoist):]
241 return remotebookmark
242 return remotebookmark
242
243
243 class bundlestore(object):
244 class bundlestore(object):
244 def __init__(self, repo):
245 def __init__(self, repo):
245 self._repo = repo
246 self._repo = repo
246 storetype = self._repo.ui.config('infinitepush', 'storetype')
247 storetype = self._repo.ui.config('infinitepush', 'storetype')
247 if storetype == 'disk':
248 if storetype == 'disk':
248 from . import store
249 from . import store
249 self.store = store.filebundlestore(self._repo.ui, self._repo)
250 self.store = store.filebundlestore(self._repo.ui, self._repo)
250 elif storetype == 'external':
251 elif storetype == 'external':
251 self.store = _buildexternalbundlestore(self._repo.ui)
252 self.store = _buildexternalbundlestore(self._repo.ui)
252 else:
253 else:
253 raise error.Abort(
254 raise error.Abort(
254 _('unknown infinitepush store type specified %s') % storetype)
255 _('unknown infinitepush store type specified %s') % storetype)
255
256
256 indextype = self._repo.ui.config('infinitepush', 'indextype')
257 indextype = self._repo.ui.config('infinitepush', 'indextype')
257 if indextype == 'disk':
258 if indextype == 'disk':
258 from . import fileindexapi
259 from . import fileindexapi
259 self.index = fileindexapi.fileindexapi(self._repo)
260 self.index = fileindexapi.fileindexapi(self._repo)
260 elif indextype == 'sql':
261 elif indextype == 'sql':
261 self.index = _buildsqlindex(self._repo.ui)
262 self.index = _buildsqlindex(self._repo.ui)
262 else:
263 else:
263 raise error.Abort(
264 raise error.Abort(
264 _('unknown infinitepush index type specified %s') % indextype)
265 _('unknown infinitepush index type specified %s') % indextype)
265
266
266 def _isserver(ui):
267 def _isserver(ui):
267 return ui.configbool('infinitepush', 'server')
268 return ui.configbool('infinitepush', 'server')
268
269
269 def reposetup(ui, repo):
270 def reposetup(ui, repo):
270 if _isserver(ui) and repo.local():
271 if _isserver(ui) and repo.local():
271 repo.bundlestore = bundlestore(repo)
272 repo.bundlestore = bundlestore(repo)
272
273
273 def extsetup(ui):
274 def extsetup(ui):
274 commonsetup(ui)
275 commonsetup(ui)
275 if _isserver(ui):
276 if _isserver(ui):
276 serverextsetup(ui)
277 serverextsetup(ui)
277 else:
278 else:
278 clientextsetup(ui)
279 clientextsetup(ui)
279
280
280 def commonsetup(ui):
281 def commonsetup(ui):
281 wireproto.commands['listkeyspatterns'] = (
282 wireproto.commands['listkeyspatterns'] = (
282 wireprotolistkeyspatterns, 'namespace patterns')
283 wireprotolistkeyspatterns, 'namespace patterns')
283 scratchbranchpat = ui.config('infinitepush', 'branchpattern')
284 scratchbranchpat = ui.config('infinitepush', 'branchpattern')
284 if scratchbranchpat:
285 if scratchbranchpat:
285 global _scratchbranchmatcher
286 global _scratchbranchmatcher
286 kind, pat, _scratchbranchmatcher = \
287 kind, pat, _scratchbranchmatcher = \
287 stringutil.stringmatcher(scratchbranchpat)
288 stringutil.stringmatcher(scratchbranchpat)
288
289
289 def serverextsetup(ui):
290 def serverextsetup(ui):
290 origpushkeyhandler = bundle2.parthandlermapping['pushkey']
291 origpushkeyhandler = bundle2.parthandlermapping['pushkey']
291
292
292 def newpushkeyhandler(*args, **kwargs):
293 def newpushkeyhandler(*args, **kwargs):
293 bundle2pushkey(origpushkeyhandler, *args, **kwargs)
294 bundle2pushkey(origpushkeyhandler, *args, **kwargs)
294 newpushkeyhandler.params = origpushkeyhandler.params
295 newpushkeyhandler.params = origpushkeyhandler.params
295 bundle2.parthandlermapping['pushkey'] = newpushkeyhandler
296 bundle2.parthandlermapping['pushkey'] = newpushkeyhandler
296
297
297 orighandlephasehandler = bundle2.parthandlermapping['phase-heads']
298 orighandlephasehandler = bundle2.parthandlermapping['phase-heads']
298 newphaseheadshandler = lambda *args, **kwargs: \
299 newphaseheadshandler = lambda *args, **kwargs: \
299 bundle2handlephases(orighandlephasehandler, *args, **kwargs)
300 bundle2handlephases(orighandlephasehandler, *args, **kwargs)
300 newphaseheadshandler.params = orighandlephasehandler.params
301 newphaseheadshandler.params = orighandlephasehandler.params
301 bundle2.parthandlermapping['phase-heads'] = newphaseheadshandler
302 bundle2.parthandlermapping['phase-heads'] = newphaseheadshandler
302
303
303 extensions.wrapfunction(localrepo.localrepository, 'listkeys',
304 extensions.wrapfunction(localrepo.localrepository, 'listkeys',
304 localrepolistkeys)
305 localrepolistkeys)
305 wireproto.commands['lookup'] = (
306 wireproto.commands['lookup'] = (
306 _lookupwrap(wireproto.commands['lookup'][0]), 'key')
307 _lookupwrap(wireproto.commands['lookup'][0]), 'key')
307 extensions.wrapfunction(exchange, 'getbundlechunks', getbundlechunks)
308 extensions.wrapfunction(exchange, 'getbundlechunks', getbundlechunks)
308
309
309 extensions.wrapfunction(bundle2, 'processparts', processparts)
310 extensions.wrapfunction(bundle2, 'processparts', processparts)
310
311
311 def clientextsetup(ui):
312 def clientextsetup(ui):
312 entry = extensions.wrapcommand(commands.table, 'push', _push)
313 entry = extensions.wrapcommand(commands.table, 'push', _push)
313
314
314 entry[1].append(
315 entry[1].append(
315 ('', 'bundle-store', None,
316 ('', 'bundle-store', None,
316 _('force push to go to bundle store (EXPERIMENTAL)')))
317 _('force push to go to bundle store (EXPERIMENTAL)')))
317
318
318 extensions.wrapcommand(commands.table, 'pull', _pull)
319 extensions.wrapcommand(commands.table, 'pull', _pull)
319
320
320 extensions.wrapfunction(discovery, 'checkheads', _checkheads)
321 extensions.wrapfunction(discovery, 'checkheads', _checkheads)
321
322
322 wireproto.wirepeer.listkeyspatterns = listkeyspatterns
323 wireprotov1peer.wirepeer.listkeyspatterns = listkeyspatterns
323
324
324 partorder = exchange.b2partsgenorder
325 partorder = exchange.b2partsgenorder
325 index = partorder.index('changeset')
326 index = partorder.index('changeset')
326 partorder.insert(
327 partorder.insert(
327 index, partorder.pop(partorder.index(scratchbranchparttype)))
328 index, partorder.pop(partorder.index(scratchbranchparttype)))
328
329
329 def _checkheads(orig, pushop):
330 def _checkheads(orig, pushop):
330 if pushop.ui.configbool(experimental, configscratchpush, False):
331 if pushop.ui.configbool(experimental, configscratchpush, False):
331 return
332 return
332 return orig(pushop)
333 return orig(pushop)
333
334
334 def wireprotolistkeyspatterns(repo, proto, namespace, patterns):
335 def wireprotolistkeyspatterns(repo, proto, namespace, patterns):
335 patterns = wireprototypes.decodelist(patterns)
336 patterns = wireprototypes.decodelist(patterns)
336 d = repo.listkeys(encoding.tolocal(namespace), patterns).iteritems()
337 d = repo.listkeys(encoding.tolocal(namespace), patterns).iteritems()
337 return pushkey.encodekeys(d)
338 return pushkey.encodekeys(d)
338
339
339 def localrepolistkeys(orig, self, namespace, patterns=None):
340 def localrepolistkeys(orig, self, namespace, patterns=None):
340 if namespace == 'bookmarks' and patterns:
341 if namespace == 'bookmarks' and patterns:
341 index = self.bundlestore.index
342 index = self.bundlestore.index
342 results = {}
343 results = {}
343 bookmarks = orig(self, namespace)
344 bookmarks = orig(self, namespace)
344 for pattern in patterns:
345 for pattern in patterns:
345 results.update(index.getbookmarks(pattern))
346 results.update(index.getbookmarks(pattern))
346 if pattern.endswith('*'):
347 if pattern.endswith('*'):
347 pattern = 're:^' + pattern[:-1] + '.*'
348 pattern = 're:^' + pattern[:-1] + '.*'
348 kind, pat, matcher = stringutil.stringmatcher(pattern)
349 kind, pat, matcher = stringutil.stringmatcher(pattern)
349 for bookmark, node in bookmarks.iteritems():
350 for bookmark, node in bookmarks.iteritems():
350 if matcher(bookmark):
351 if matcher(bookmark):
351 results[bookmark] = node
352 results[bookmark] = node
352 return results
353 return results
353 else:
354 else:
354 return orig(self, namespace)
355 return orig(self, namespace)
355
356
356 @peer.batchable
357 @peer.batchable
357 def listkeyspatterns(self, namespace, patterns):
358 def listkeyspatterns(self, namespace, patterns):
358 if not self.capable('pushkey'):
359 if not self.capable('pushkey'):
359 yield {}, None
360 yield {}, None
360 f = peer.future()
361 f = peer.future()
361 self.ui.debug('preparing listkeys for "%s" with pattern "%s"\n' %
362 self.ui.debug('preparing listkeys for "%s" with pattern "%s"\n' %
362 (namespace, patterns))
363 (namespace, patterns))
363 yield {
364 yield {
364 'namespace': encoding.fromlocal(namespace),
365 'namespace': encoding.fromlocal(namespace),
365 'patterns': wireprototypes.encodelist(patterns)
366 'patterns': wireprototypes.encodelist(patterns)
366 }, f
367 }, f
367 d = f.value
368 d = f.value
368 self.ui.debug('received listkey for "%s": %i bytes\n'
369 self.ui.debug('received listkey for "%s": %i bytes\n'
369 % (namespace, len(d)))
370 % (namespace, len(d)))
370 yield pushkey.decodekeys(d)
371 yield pushkey.decodekeys(d)
371
372
372 def _readbundlerevs(bundlerepo):
373 def _readbundlerevs(bundlerepo):
373 return list(bundlerepo.revs('bundle()'))
374 return list(bundlerepo.revs('bundle()'))
374
375
375 def _includefilelogstobundle(bundlecaps, bundlerepo, bundlerevs, ui):
376 def _includefilelogstobundle(bundlecaps, bundlerepo, bundlerevs, ui):
376 '''Tells remotefilelog to include all changed files to the changegroup
377 '''Tells remotefilelog to include all changed files to the changegroup
377
378
378 By default remotefilelog doesn't include file content to the changegroup.
379 By default remotefilelog doesn't include file content to the changegroup.
379 But we need to include it if we are fetching from bundlestore.
380 But we need to include it if we are fetching from bundlestore.
380 '''
381 '''
381 changedfiles = set()
382 changedfiles = set()
382 cl = bundlerepo.changelog
383 cl = bundlerepo.changelog
383 for r in bundlerevs:
384 for r in bundlerevs:
384 # [3] means changed files
385 # [3] means changed files
385 changedfiles.update(cl.read(r)[3])
386 changedfiles.update(cl.read(r)[3])
386 if not changedfiles:
387 if not changedfiles:
387 return bundlecaps
388 return bundlecaps
388
389
389 changedfiles = '\0'.join(changedfiles)
390 changedfiles = '\0'.join(changedfiles)
390 newcaps = []
391 newcaps = []
391 appended = False
392 appended = False
392 for cap in (bundlecaps or []):
393 for cap in (bundlecaps or []):
393 if cap.startswith('excludepattern='):
394 if cap.startswith('excludepattern='):
394 newcaps.append('\0'.join((cap, changedfiles)))
395 newcaps.append('\0'.join((cap, changedfiles)))
395 appended = True
396 appended = True
396 else:
397 else:
397 newcaps.append(cap)
398 newcaps.append(cap)
398 if not appended:
399 if not appended:
399 # Not found excludepattern cap. Just append it
400 # Not found excludepattern cap. Just append it
400 newcaps.append('excludepattern=' + changedfiles)
401 newcaps.append('excludepattern=' + changedfiles)
401
402
402 return newcaps
403 return newcaps
403
404
404 def _rebundle(bundlerepo, bundleroots, unknownhead):
405 def _rebundle(bundlerepo, bundleroots, unknownhead):
405 '''
406 '''
406 Bundle may include more revision then user requested. For example,
407 Bundle may include more revision then user requested. For example,
407 if user asks for revision but bundle also consists its descendants.
408 if user asks for revision but bundle also consists its descendants.
408 This function will filter out all revision that user is not requested.
409 This function will filter out all revision that user is not requested.
409 '''
410 '''
410 parts = []
411 parts = []
411
412
412 version = '02'
413 version = '02'
413 outgoing = discovery.outgoing(bundlerepo, commonheads=bundleroots,
414 outgoing = discovery.outgoing(bundlerepo, commonheads=bundleroots,
414 missingheads=[unknownhead])
415 missingheads=[unknownhead])
415 cgstream = changegroup.makestream(bundlerepo, outgoing, version, 'pull')
416 cgstream = changegroup.makestream(bundlerepo, outgoing, version, 'pull')
416 cgstream = util.chunkbuffer(cgstream).read()
417 cgstream = util.chunkbuffer(cgstream).read()
417 cgpart = bundle2.bundlepart('changegroup', data=cgstream)
418 cgpart = bundle2.bundlepart('changegroup', data=cgstream)
418 cgpart.addparam('version', version)
419 cgpart.addparam('version', version)
419 parts.append(cgpart)
420 parts.append(cgpart)
420
421
421 return parts
422 return parts
422
423
423 def _getbundleroots(oldrepo, bundlerepo, bundlerevs):
424 def _getbundleroots(oldrepo, bundlerepo, bundlerevs):
424 cl = bundlerepo.changelog
425 cl = bundlerepo.changelog
425 bundleroots = []
426 bundleroots = []
426 for rev in bundlerevs:
427 for rev in bundlerevs:
427 node = cl.node(rev)
428 node = cl.node(rev)
428 parents = cl.parents(node)
429 parents = cl.parents(node)
429 for parent in parents:
430 for parent in parents:
430 # include all revs that exist in the main repo
431 # include all revs that exist in the main repo
431 # to make sure that bundle may apply client-side
432 # to make sure that bundle may apply client-side
432 if parent in oldrepo:
433 if parent in oldrepo:
433 bundleroots.append(parent)
434 bundleroots.append(parent)
434 return bundleroots
435 return bundleroots
435
436
436 def _needsrebundling(head, bundlerepo):
437 def _needsrebundling(head, bundlerepo):
437 bundleheads = list(bundlerepo.revs('heads(bundle())'))
438 bundleheads = list(bundlerepo.revs('heads(bundle())'))
438 return not (len(bundleheads) == 1 and
439 return not (len(bundleheads) == 1 and
439 bundlerepo[bundleheads[0]].node() == head)
440 bundlerepo[bundleheads[0]].node() == head)
440
441
441 def _generateoutputparts(head, bundlerepo, bundleroots, bundlefile):
442 def _generateoutputparts(head, bundlerepo, bundleroots, bundlefile):
442 '''generates bundle that will be send to the user
443 '''generates bundle that will be send to the user
443
444
444 returns tuple with raw bundle string and bundle type
445 returns tuple with raw bundle string and bundle type
445 '''
446 '''
446 parts = []
447 parts = []
447 if not _needsrebundling(head, bundlerepo):
448 if not _needsrebundling(head, bundlerepo):
448 with util.posixfile(bundlefile, "rb") as f:
449 with util.posixfile(bundlefile, "rb") as f:
449 unbundler = exchange.readbundle(bundlerepo.ui, f, bundlefile)
450 unbundler = exchange.readbundle(bundlerepo.ui, f, bundlefile)
450 if isinstance(unbundler, changegroup.cg1unpacker):
451 if isinstance(unbundler, changegroup.cg1unpacker):
451 part = bundle2.bundlepart('changegroup',
452 part = bundle2.bundlepart('changegroup',
452 data=unbundler._stream.read())
453 data=unbundler._stream.read())
453 part.addparam('version', '01')
454 part.addparam('version', '01')
454 parts.append(part)
455 parts.append(part)
455 elif isinstance(unbundler, bundle2.unbundle20):
456 elif isinstance(unbundler, bundle2.unbundle20):
456 haschangegroup = False
457 haschangegroup = False
457 for part in unbundler.iterparts():
458 for part in unbundler.iterparts():
458 if part.type == 'changegroup':
459 if part.type == 'changegroup':
459 haschangegroup = True
460 haschangegroup = True
460 newpart = bundle2.bundlepart(part.type, data=part.read())
461 newpart = bundle2.bundlepart(part.type, data=part.read())
461 for key, value in part.params.iteritems():
462 for key, value in part.params.iteritems():
462 newpart.addparam(key, value)
463 newpart.addparam(key, value)
463 parts.append(newpart)
464 parts.append(newpart)
464
465
465 if not haschangegroup:
466 if not haschangegroup:
466 raise error.Abort(
467 raise error.Abort(
467 'unexpected bundle without changegroup part, ' +
468 'unexpected bundle without changegroup part, ' +
468 'head: %s' % hex(head),
469 'head: %s' % hex(head),
469 hint='report to administrator')
470 hint='report to administrator')
470 else:
471 else:
471 raise error.Abort('unknown bundle type')
472 raise error.Abort('unknown bundle type')
472 else:
473 else:
473 parts = _rebundle(bundlerepo, bundleroots, head)
474 parts = _rebundle(bundlerepo, bundleroots, head)
474
475
475 return parts
476 return parts
476
477
477 def getbundlechunks(orig, repo, source, heads=None, bundlecaps=None, **kwargs):
478 def getbundlechunks(orig, repo, source, heads=None, bundlecaps=None, **kwargs):
478 heads = heads or []
479 heads = heads or []
479 # newheads are parents of roots of scratch bundles that were requested
480 # newheads are parents of roots of scratch bundles that were requested
480 newphases = {}
481 newphases = {}
481 scratchbundles = []
482 scratchbundles = []
482 newheads = []
483 newheads = []
483 scratchheads = []
484 scratchheads = []
484 nodestobundle = {}
485 nodestobundle = {}
485 allbundlestocleanup = []
486 allbundlestocleanup = []
486 try:
487 try:
487 for head in heads:
488 for head in heads:
488 if head not in repo.changelog.nodemap:
489 if head not in repo.changelog.nodemap:
489 if head not in nodestobundle:
490 if head not in nodestobundle:
490 newbundlefile = common.downloadbundle(repo, head)
491 newbundlefile = common.downloadbundle(repo, head)
491 bundlepath = "bundle:%s+%s" % (repo.root, newbundlefile)
492 bundlepath = "bundle:%s+%s" % (repo.root, newbundlefile)
492 bundlerepo = hg.repository(repo.ui, bundlepath)
493 bundlerepo = hg.repository(repo.ui, bundlepath)
493
494
494 allbundlestocleanup.append((bundlerepo, newbundlefile))
495 allbundlestocleanup.append((bundlerepo, newbundlefile))
495 bundlerevs = set(_readbundlerevs(bundlerepo))
496 bundlerevs = set(_readbundlerevs(bundlerepo))
496 bundlecaps = _includefilelogstobundle(
497 bundlecaps = _includefilelogstobundle(
497 bundlecaps, bundlerepo, bundlerevs, repo.ui)
498 bundlecaps, bundlerepo, bundlerevs, repo.ui)
498 cl = bundlerepo.changelog
499 cl = bundlerepo.changelog
499 bundleroots = _getbundleroots(repo, bundlerepo, bundlerevs)
500 bundleroots = _getbundleroots(repo, bundlerepo, bundlerevs)
500 for rev in bundlerevs:
501 for rev in bundlerevs:
501 node = cl.node(rev)
502 node = cl.node(rev)
502 newphases[hex(node)] = str(phases.draft)
503 newphases[hex(node)] = str(phases.draft)
503 nodestobundle[node] = (bundlerepo, bundleroots,
504 nodestobundle[node] = (bundlerepo, bundleroots,
504 newbundlefile)
505 newbundlefile)
505
506
506 scratchbundles.append(
507 scratchbundles.append(
507 _generateoutputparts(head, *nodestobundle[head]))
508 _generateoutputparts(head, *nodestobundle[head]))
508 newheads.extend(bundleroots)
509 newheads.extend(bundleroots)
509 scratchheads.append(head)
510 scratchheads.append(head)
510 finally:
511 finally:
511 for bundlerepo, bundlefile in allbundlestocleanup:
512 for bundlerepo, bundlefile in allbundlestocleanup:
512 bundlerepo.close()
513 bundlerepo.close()
513 try:
514 try:
514 os.unlink(bundlefile)
515 os.unlink(bundlefile)
515 except (IOError, OSError):
516 except (IOError, OSError):
516 # if we can't cleanup the file then just ignore the error,
517 # if we can't cleanup the file then just ignore the error,
517 # no need to fail
518 # no need to fail
518 pass
519 pass
519
520
520 pullfrombundlestore = bool(scratchbundles)
521 pullfrombundlestore = bool(scratchbundles)
521 wrappedchangegrouppart = False
522 wrappedchangegrouppart = False
522 wrappedlistkeys = False
523 wrappedlistkeys = False
523 oldchangegrouppart = exchange.getbundle2partsmapping['changegroup']
524 oldchangegrouppart = exchange.getbundle2partsmapping['changegroup']
524 try:
525 try:
525 def _changegrouppart(bundler, *args, **kwargs):
526 def _changegrouppart(bundler, *args, **kwargs):
526 # Order is important here. First add non-scratch part
527 # Order is important here. First add non-scratch part
527 # and only then add parts with scratch bundles because
528 # and only then add parts with scratch bundles because
528 # non-scratch part contains parents of roots of scratch bundles.
529 # non-scratch part contains parents of roots of scratch bundles.
529 result = oldchangegrouppart(bundler, *args, **kwargs)
530 result = oldchangegrouppart(bundler, *args, **kwargs)
530 for bundle in scratchbundles:
531 for bundle in scratchbundles:
531 for part in bundle:
532 for part in bundle:
532 bundler.addpart(part)
533 bundler.addpart(part)
533 return result
534 return result
534
535
535 exchange.getbundle2partsmapping['changegroup'] = _changegrouppart
536 exchange.getbundle2partsmapping['changegroup'] = _changegrouppart
536 wrappedchangegrouppart = True
537 wrappedchangegrouppart = True
537
538
538 def _listkeys(orig, self, namespace):
539 def _listkeys(orig, self, namespace):
539 origvalues = orig(self, namespace)
540 origvalues = orig(self, namespace)
540 if namespace == 'phases' and pullfrombundlestore:
541 if namespace == 'phases' and pullfrombundlestore:
541 if origvalues.get('publishing') == 'True':
542 if origvalues.get('publishing') == 'True':
542 # Make repo non-publishing to preserve draft phase
543 # Make repo non-publishing to preserve draft phase
543 del origvalues['publishing']
544 del origvalues['publishing']
544 origvalues.update(newphases)
545 origvalues.update(newphases)
545 return origvalues
546 return origvalues
546
547
547 extensions.wrapfunction(localrepo.localrepository, 'listkeys',
548 extensions.wrapfunction(localrepo.localrepository, 'listkeys',
548 _listkeys)
549 _listkeys)
549 wrappedlistkeys = True
550 wrappedlistkeys = True
550 heads = list((set(newheads) | set(heads)) - set(scratchheads))
551 heads = list((set(newheads) | set(heads)) - set(scratchheads))
551 result = orig(repo, source, heads=heads,
552 result = orig(repo, source, heads=heads,
552 bundlecaps=bundlecaps, **kwargs)
553 bundlecaps=bundlecaps, **kwargs)
553 finally:
554 finally:
554 if wrappedchangegrouppart:
555 if wrappedchangegrouppart:
555 exchange.getbundle2partsmapping['changegroup'] = oldchangegrouppart
556 exchange.getbundle2partsmapping['changegroup'] = oldchangegrouppart
556 if wrappedlistkeys:
557 if wrappedlistkeys:
557 extensions.unwrapfunction(localrepo.localrepository, 'listkeys',
558 extensions.unwrapfunction(localrepo.localrepository, 'listkeys',
558 _listkeys)
559 _listkeys)
559 return result
560 return result
560
561
561 def _lookupwrap(orig):
562 def _lookupwrap(orig):
562 def _lookup(repo, proto, key):
563 def _lookup(repo, proto, key):
563 localkey = encoding.tolocal(key)
564 localkey = encoding.tolocal(key)
564
565
565 if isinstance(localkey, str) and _scratchbranchmatcher(localkey):
566 if isinstance(localkey, str) and _scratchbranchmatcher(localkey):
566 scratchnode = repo.bundlestore.index.getnode(localkey)
567 scratchnode = repo.bundlestore.index.getnode(localkey)
567 if scratchnode:
568 if scratchnode:
568 return "%s %s\n" % (1, scratchnode)
569 return "%s %s\n" % (1, scratchnode)
569 else:
570 else:
570 return "%s %s\n" % (0, 'scratch branch %s not found' % localkey)
571 return "%s %s\n" % (0, 'scratch branch %s not found' % localkey)
571 else:
572 else:
572 try:
573 try:
573 r = hex(repo.lookup(localkey))
574 r = hex(repo.lookup(localkey))
574 return "%s %s\n" % (1, r)
575 return "%s %s\n" % (1, r)
575 except Exception as inst:
576 except Exception as inst:
576 if repo.bundlestore.index.getbundle(localkey):
577 if repo.bundlestore.index.getbundle(localkey):
577 return "%s %s\n" % (1, localkey)
578 return "%s %s\n" % (1, localkey)
578 else:
579 else:
579 r = str(inst)
580 r = str(inst)
580 return "%s %s\n" % (0, r)
581 return "%s %s\n" % (0, r)
581 return _lookup
582 return _lookup
582
583
583 def _pull(orig, ui, repo, source="default", **opts):
584 def _pull(orig, ui, repo, source="default", **opts):
584 opts = pycompat.byteskwargs(opts)
585 opts = pycompat.byteskwargs(opts)
585 # Copy paste from `pull` command
586 # Copy paste from `pull` command
586 source, branches = hg.parseurl(ui.expandpath(source), opts.get('branch'))
587 source, branches = hg.parseurl(ui.expandpath(source), opts.get('branch'))
587
588
588 scratchbookmarks = {}
589 scratchbookmarks = {}
589 unfi = repo.unfiltered()
590 unfi = repo.unfiltered()
590 unknownnodes = []
591 unknownnodes = []
591 for rev in opts.get('rev', []):
592 for rev in opts.get('rev', []):
592 if rev not in unfi:
593 if rev not in unfi:
593 unknownnodes.append(rev)
594 unknownnodes.append(rev)
594 if opts.get('bookmark'):
595 if opts.get('bookmark'):
595 bookmarks = []
596 bookmarks = []
596 revs = opts.get('rev') or []
597 revs = opts.get('rev') or []
597 for bookmark in opts.get('bookmark'):
598 for bookmark in opts.get('bookmark'):
598 if _scratchbranchmatcher(bookmark):
599 if _scratchbranchmatcher(bookmark):
599 # rev is not known yet
600 # rev is not known yet
600 # it will be fetched with listkeyspatterns next
601 # it will be fetched with listkeyspatterns next
601 scratchbookmarks[bookmark] = 'REVTOFETCH'
602 scratchbookmarks[bookmark] = 'REVTOFETCH'
602 else:
603 else:
603 bookmarks.append(bookmark)
604 bookmarks.append(bookmark)
604
605
605 if scratchbookmarks:
606 if scratchbookmarks:
606 other = hg.peer(repo, opts, source)
607 other = hg.peer(repo, opts, source)
607 fetchedbookmarks = other.listkeyspatterns(
608 fetchedbookmarks = other.listkeyspatterns(
608 'bookmarks', patterns=scratchbookmarks)
609 'bookmarks', patterns=scratchbookmarks)
609 for bookmark in scratchbookmarks:
610 for bookmark in scratchbookmarks:
610 if bookmark not in fetchedbookmarks:
611 if bookmark not in fetchedbookmarks:
611 raise error.Abort('remote bookmark %s not found!' %
612 raise error.Abort('remote bookmark %s not found!' %
612 bookmark)
613 bookmark)
613 scratchbookmarks[bookmark] = fetchedbookmarks[bookmark]
614 scratchbookmarks[bookmark] = fetchedbookmarks[bookmark]
614 revs.append(fetchedbookmarks[bookmark])
615 revs.append(fetchedbookmarks[bookmark])
615 opts['bookmark'] = bookmarks
616 opts['bookmark'] = bookmarks
616 opts['rev'] = revs
617 opts['rev'] = revs
617
618
618 if scratchbookmarks or unknownnodes:
619 if scratchbookmarks or unknownnodes:
619 # Set anyincoming to True
620 # Set anyincoming to True
620 extensions.wrapfunction(discovery, 'findcommonincoming',
621 extensions.wrapfunction(discovery, 'findcommonincoming',
621 _findcommonincoming)
622 _findcommonincoming)
622 try:
623 try:
623 # Remote scratch bookmarks will be deleted because remotenames doesn't
624 # Remote scratch bookmarks will be deleted because remotenames doesn't
624 # know about them. Let's save it before pull and restore after
625 # know about them. Let's save it before pull and restore after
625 remotescratchbookmarks = _readscratchremotebookmarks(ui, repo, source)
626 remotescratchbookmarks = _readscratchremotebookmarks(ui, repo, source)
626 result = orig(ui, repo, source, **pycompat.strkwargs(opts))
627 result = orig(ui, repo, source, **pycompat.strkwargs(opts))
627 # TODO(stash): race condition is possible
628 # TODO(stash): race condition is possible
628 # if scratch bookmarks was updated right after orig.
629 # if scratch bookmarks was updated right after orig.
629 # But that's unlikely and shouldn't be harmful.
630 # But that's unlikely and shouldn't be harmful.
630 if common.isremotebooksenabled(ui):
631 if common.isremotebooksenabled(ui):
631 remotescratchbookmarks.update(scratchbookmarks)
632 remotescratchbookmarks.update(scratchbookmarks)
632 _saveremotebookmarks(repo, remotescratchbookmarks, source)
633 _saveremotebookmarks(repo, remotescratchbookmarks, source)
633 else:
634 else:
634 _savelocalbookmarks(repo, scratchbookmarks)
635 _savelocalbookmarks(repo, scratchbookmarks)
635 return result
636 return result
636 finally:
637 finally:
637 if scratchbookmarks:
638 if scratchbookmarks:
638 extensions.unwrapfunction(discovery, 'findcommonincoming')
639 extensions.unwrapfunction(discovery, 'findcommonincoming')
639
640
640 def _readscratchremotebookmarks(ui, repo, other):
641 def _readscratchremotebookmarks(ui, repo, other):
641 if common.isremotebooksenabled(ui):
642 if common.isremotebooksenabled(ui):
642 remotenamesext = extensions.find('remotenames')
643 remotenamesext = extensions.find('remotenames')
643 remotepath = remotenamesext.activepath(repo.ui, other)
644 remotepath = remotenamesext.activepath(repo.ui, other)
644 result = {}
645 result = {}
645 # Let's refresh remotenames to make sure we have it up to date
646 # Let's refresh remotenames to make sure we have it up to date
646 # Seems that `repo.names['remotebookmarks']` may return stale bookmarks
647 # Seems that `repo.names['remotebookmarks']` may return stale bookmarks
647 # and it results in deleting scratch bookmarks. Our best guess how to
648 # and it results in deleting scratch bookmarks. Our best guess how to
648 # fix it is to use `clearnames()`
649 # fix it is to use `clearnames()`
649 repo._remotenames.clearnames()
650 repo._remotenames.clearnames()
650 for remotebookmark in repo.names['remotebookmarks'].listnames(repo):
651 for remotebookmark in repo.names['remotebookmarks'].listnames(repo):
651 path, bookname = remotenamesext.splitremotename(remotebookmark)
652 path, bookname = remotenamesext.splitremotename(remotebookmark)
652 if path == remotepath and _scratchbranchmatcher(bookname):
653 if path == remotepath and _scratchbranchmatcher(bookname):
653 nodes = repo.names['remotebookmarks'].nodes(repo,
654 nodes = repo.names['remotebookmarks'].nodes(repo,
654 remotebookmark)
655 remotebookmark)
655 if nodes:
656 if nodes:
656 result[bookname] = hex(nodes[0])
657 result[bookname] = hex(nodes[0])
657 return result
658 return result
658 else:
659 else:
659 return {}
660 return {}
660
661
661 def _saveremotebookmarks(repo, newbookmarks, remote):
662 def _saveremotebookmarks(repo, newbookmarks, remote):
662 remotenamesext = extensions.find('remotenames')
663 remotenamesext = extensions.find('remotenames')
663 remotepath = remotenamesext.activepath(repo.ui, remote)
664 remotepath = remotenamesext.activepath(repo.ui, remote)
664 branches = collections.defaultdict(list)
665 branches = collections.defaultdict(list)
665 bookmarks = {}
666 bookmarks = {}
666 remotenames = remotenamesext.readremotenames(repo)
667 remotenames = remotenamesext.readremotenames(repo)
667 for hexnode, nametype, remote, rname in remotenames:
668 for hexnode, nametype, remote, rname in remotenames:
668 if remote != remotepath:
669 if remote != remotepath:
669 continue
670 continue
670 if nametype == 'bookmarks':
671 if nametype == 'bookmarks':
671 if rname in newbookmarks:
672 if rname in newbookmarks:
672 # It's possible if we have a normal bookmark that matches
673 # It's possible if we have a normal bookmark that matches
673 # scratch branch pattern. In this case just use the current
674 # scratch branch pattern. In this case just use the current
674 # bookmark node
675 # bookmark node
675 del newbookmarks[rname]
676 del newbookmarks[rname]
676 bookmarks[rname] = hexnode
677 bookmarks[rname] = hexnode
677 elif nametype == 'branches':
678 elif nametype == 'branches':
678 # saveremotenames expects 20 byte binary nodes for branches
679 # saveremotenames expects 20 byte binary nodes for branches
679 branches[rname].append(bin(hexnode))
680 branches[rname].append(bin(hexnode))
680
681
681 for bookmark, hexnode in newbookmarks.iteritems():
682 for bookmark, hexnode in newbookmarks.iteritems():
682 bookmarks[bookmark] = hexnode
683 bookmarks[bookmark] = hexnode
683 remotenamesext.saveremotenames(repo, remotepath, branches, bookmarks)
684 remotenamesext.saveremotenames(repo, remotepath, branches, bookmarks)
684
685
685 def _savelocalbookmarks(repo, bookmarks):
686 def _savelocalbookmarks(repo, bookmarks):
686 if not bookmarks:
687 if not bookmarks:
687 return
688 return
688 with repo.wlock(), repo.lock(), repo.transaction('bookmark') as tr:
689 with repo.wlock(), repo.lock(), repo.transaction('bookmark') as tr:
689 changes = []
690 changes = []
690 for scratchbook, node in bookmarks.iteritems():
691 for scratchbook, node in bookmarks.iteritems():
691 changectx = repo[node]
692 changectx = repo[node]
692 changes.append((scratchbook, changectx.node()))
693 changes.append((scratchbook, changectx.node()))
693 repo._bookmarks.applychanges(repo, tr, changes)
694 repo._bookmarks.applychanges(repo, tr, changes)
694
695
695 def _findcommonincoming(orig, *args, **kwargs):
696 def _findcommonincoming(orig, *args, **kwargs):
696 common, inc, remoteheads = orig(*args, **kwargs)
697 common, inc, remoteheads = orig(*args, **kwargs)
697 return common, True, remoteheads
698 return common, True, remoteheads
698
699
699 def _push(orig, ui, repo, dest=None, *args, **opts):
700 def _push(orig, ui, repo, dest=None, *args, **opts):
700
701
701 bookmark = opts.get(r'bookmark')
702 bookmark = opts.get(r'bookmark')
702 # we only support pushing one infinitepush bookmark at once
703 # we only support pushing one infinitepush bookmark at once
703 if len(bookmark) == 1:
704 if len(bookmark) == 1:
704 bookmark = bookmark[0]
705 bookmark = bookmark[0]
705 else:
706 else:
706 bookmark = ''
707 bookmark = ''
707
708
708 oldphasemove = None
709 oldphasemove = None
709 overrides = {(experimental, configbookmark): bookmark}
710 overrides = {(experimental, configbookmark): bookmark}
710
711
711 with ui.configoverride(overrides, 'infinitepush'):
712 with ui.configoverride(overrides, 'infinitepush'):
712 scratchpush = opts.get('bundle_store')
713 scratchpush = opts.get('bundle_store')
713 if _scratchbranchmatcher(bookmark):
714 if _scratchbranchmatcher(bookmark):
714 scratchpush = True
715 scratchpush = True
715 # bundle2 can be sent back after push (for example, bundle2
716 # bundle2 can be sent back after push (for example, bundle2
716 # containing `pushkey` part to update bookmarks)
717 # containing `pushkey` part to update bookmarks)
717 ui.setconfig(experimental, 'bundle2.pushback', True)
718 ui.setconfig(experimental, 'bundle2.pushback', True)
718
719
719 if scratchpush:
720 if scratchpush:
720 # this is an infinitepush, we don't want the bookmark to be applied
721 # this is an infinitepush, we don't want the bookmark to be applied
721 # rather that should be stored in the bundlestore
722 # rather that should be stored in the bundlestore
722 opts[r'bookmark'] = []
723 opts[r'bookmark'] = []
723 ui.setconfig(experimental, configscratchpush, True)
724 ui.setconfig(experimental, configscratchpush, True)
724 oldphasemove = extensions.wrapfunction(exchange,
725 oldphasemove = extensions.wrapfunction(exchange,
725 '_localphasemove',
726 '_localphasemove',
726 _phasemove)
727 _phasemove)
727 # Copy-paste from `push` command
728 # Copy-paste from `push` command
728 path = ui.paths.getpath(dest, default=('default-push', 'default'))
729 path = ui.paths.getpath(dest, default=('default-push', 'default'))
729 if not path:
730 if not path:
730 raise error.Abort(_('default repository not configured!'),
731 raise error.Abort(_('default repository not configured!'),
731 hint=_("see 'hg help config.paths'"))
732 hint=_("see 'hg help config.paths'"))
732 destpath = path.pushloc or path.loc
733 destpath = path.pushloc or path.loc
733 # Remote scratch bookmarks will be deleted because remotenames doesn't
734 # Remote scratch bookmarks will be deleted because remotenames doesn't
734 # know about them. Let's save it before push and restore after
735 # know about them. Let's save it before push and restore after
735 remotescratchbookmarks = _readscratchremotebookmarks(ui, repo, destpath)
736 remotescratchbookmarks = _readscratchremotebookmarks(ui, repo, destpath)
736 result = orig(ui, repo, dest, *args, **opts)
737 result = orig(ui, repo, dest, *args, **opts)
737 if common.isremotebooksenabled(ui):
738 if common.isremotebooksenabled(ui):
738 if bookmark and scratchpush:
739 if bookmark and scratchpush:
739 other = hg.peer(repo, opts, destpath)
740 other = hg.peer(repo, opts, destpath)
740 fetchedbookmarks = other.listkeyspatterns('bookmarks',
741 fetchedbookmarks = other.listkeyspatterns('bookmarks',
741 patterns=[bookmark])
742 patterns=[bookmark])
742 remotescratchbookmarks.update(fetchedbookmarks)
743 remotescratchbookmarks.update(fetchedbookmarks)
743 _saveremotebookmarks(repo, remotescratchbookmarks, destpath)
744 _saveremotebookmarks(repo, remotescratchbookmarks, destpath)
744 if oldphasemove:
745 if oldphasemove:
745 exchange._localphasemove = oldphasemove
746 exchange._localphasemove = oldphasemove
746 return result
747 return result
747
748
748 def _deleteinfinitepushbookmarks(ui, repo, path, names):
749 def _deleteinfinitepushbookmarks(ui, repo, path, names):
749 """Prune remote names by removing the bookmarks we don't want anymore,
750 """Prune remote names by removing the bookmarks we don't want anymore,
750 then writing the result back to disk
751 then writing the result back to disk
751 """
752 """
752 remotenamesext = extensions.find('remotenames')
753 remotenamesext = extensions.find('remotenames')
753
754
754 # remotename format is:
755 # remotename format is:
755 # (node, nametype ("branches" or "bookmarks"), remote, name)
756 # (node, nametype ("branches" or "bookmarks"), remote, name)
756 nametype_idx = 1
757 nametype_idx = 1
757 remote_idx = 2
758 remote_idx = 2
758 name_idx = 3
759 name_idx = 3
759 remotenames = [remotename for remotename in \
760 remotenames = [remotename for remotename in \
760 remotenamesext.readremotenames(repo) \
761 remotenamesext.readremotenames(repo) \
761 if remotename[remote_idx] == path]
762 if remotename[remote_idx] == path]
762 remote_bm_names = [remotename[name_idx] for remotename in \
763 remote_bm_names = [remotename[name_idx] for remotename in \
763 remotenames if remotename[nametype_idx] == "bookmarks"]
764 remotenames if remotename[nametype_idx] == "bookmarks"]
764
765
765 for name in names:
766 for name in names:
766 if name not in remote_bm_names:
767 if name not in remote_bm_names:
767 raise error.Abort(_("infinitepush bookmark '{}' does not exist "
768 raise error.Abort(_("infinitepush bookmark '{}' does not exist "
768 "in path '{}'").format(name, path))
769 "in path '{}'").format(name, path))
769
770
770 bookmarks = {}
771 bookmarks = {}
771 branches = collections.defaultdict(list)
772 branches = collections.defaultdict(list)
772 for node, nametype, remote, name in remotenames:
773 for node, nametype, remote, name in remotenames:
773 if nametype == "bookmarks" and name not in names:
774 if nametype == "bookmarks" and name not in names:
774 bookmarks[name] = node
775 bookmarks[name] = node
775 elif nametype == "branches":
776 elif nametype == "branches":
776 # saveremotenames wants binary nodes for branches
777 # saveremotenames wants binary nodes for branches
777 branches[name].append(bin(node))
778 branches[name].append(bin(node))
778
779
779 remotenamesext.saveremotenames(repo, path, branches, bookmarks)
780 remotenamesext.saveremotenames(repo, path, branches, bookmarks)
780
781
781 def _phasemove(orig, pushop, nodes, phase=phases.public):
782 def _phasemove(orig, pushop, nodes, phase=phases.public):
782 """prevent commits from being marked public
783 """prevent commits from being marked public
783
784
784 Since these are going to a scratch branch, they aren't really being
785 Since these are going to a scratch branch, they aren't really being
785 published."""
786 published."""
786
787
787 if phase != phases.public:
788 if phase != phases.public:
788 orig(pushop, nodes, phase)
789 orig(pushop, nodes, phase)
789
790
790 @exchange.b2partsgenerator(scratchbranchparttype)
791 @exchange.b2partsgenerator(scratchbranchparttype)
791 def partgen(pushop, bundler):
792 def partgen(pushop, bundler):
792 bookmark = pushop.ui.config(experimental, configbookmark)
793 bookmark = pushop.ui.config(experimental, configbookmark)
793 scratchpush = pushop.ui.configbool(experimental, configscratchpush)
794 scratchpush = pushop.ui.configbool(experimental, configscratchpush)
794 if 'changesets' in pushop.stepsdone or not scratchpush:
795 if 'changesets' in pushop.stepsdone or not scratchpush:
795 return
796 return
796
797
797 if scratchbranchparttype not in bundle2.bundle2caps(pushop.remote):
798 if scratchbranchparttype not in bundle2.bundle2caps(pushop.remote):
798 return
799 return
799
800
800 pushop.stepsdone.add('changesets')
801 pushop.stepsdone.add('changesets')
801 if not pushop.outgoing.missing:
802 if not pushop.outgoing.missing:
802 pushop.ui.status(_('no changes found\n'))
803 pushop.ui.status(_('no changes found\n'))
803 pushop.cgresult = 0
804 pushop.cgresult = 0
804 return
805 return
805
806
806 # This parameter tells the server that the following bundle is an
807 # This parameter tells the server that the following bundle is an
807 # infinitepush. This let's it switch the part processing to our infinitepush
808 # infinitepush. This let's it switch the part processing to our infinitepush
808 # code path.
809 # code path.
809 bundler.addparam("infinitepush", "True")
810 bundler.addparam("infinitepush", "True")
810
811
811 scratchparts = bundleparts.getscratchbranchparts(pushop.repo,
812 scratchparts = bundleparts.getscratchbranchparts(pushop.repo,
812 pushop.remote,
813 pushop.remote,
813 pushop.outgoing,
814 pushop.outgoing,
814 pushop.ui,
815 pushop.ui,
815 bookmark)
816 bookmark)
816
817
817 for scratchpart in scratchparts:
818 for scratchpart in scratchparts:
818 bundler.addpart(scratchpart)
819 bundler.addpart(scratchpart)
819
820
820 def handlereply(op):
821 def handlereply(op):
821 # server either succeeds or aborts; no code to read
822 # server either succeeds or aborts; no code to read
822 pushop.cgresult = 1
823 pushop.cgresult = 1
823
824
824 return handlereply
825 return handlereply
825
826
826 bundle2.capabilities[bundleparts.scratchbranchparttype] = ()
827 bundle2.capabilities[bundleparts.scratchbranchparttype] = ()
827
828
828 def _getrevs(bundle, oldnode, force, bookmark):
829 def _getrevs(bundle, oldnode, force, bookmark):
829 'extracts and validates the revs to be imported'
830 'extracts and validates the revs to be imported'
830 revs = [bundle[r] for r in bundle.revs('sort(bundle())')]
831 revs = [bundle[r] for r in bundle.revs('sort(bundle())')]
831
832
832 # new bookmark
833 # new bookmark
833 if oldnode is None:
834 if oldnode is None:
834 return revs
835 return revs
835
836
836 # Fast forward update
837 # Fast forward update
837 if oldnode in bundle and list(bundle.set('bundle() & %s::', oldnode)):
838 if oldnode in bundle and list(bundle.set('bundle() & %s::', oldnode)):
838 return revs
839 return revs
839
840
840 return revs
841 return revs
841
842
842 @contextlib.contextmanager
843 @contextlib.contextmanager
843 def logservicecall(logger, service, **kwargs):
844 def logservicecall(logger, service, **kwargs):
844 start = time.time()
845 start = time.time()
845 logger(service, eventtype='start', **kwargs)
846 logger(service, eventtype='start', **kwargs)
846 try:
847 try:
847 yield
848 yield
848 logger(service, eventtype='success',
849 logger(service, eventtype='success',
849 elapsedms=(time.time() - start) * 1000, **kwargs)
850 elapsedms=(time.time() - start) * 1000, **kwargs)
850 except Exception as e:
851 except Exception as e:
851 logger(service, eventtype='failure',
852 logger(service, eventtype='failure',
852 elapsedms=(time.time() - start) * 1000, errormsg=str(e),
853 elapsedms=(time.time() - start) * 1000, errormsg=str(e),
853 **kwargs)
854 **kwargs)
854 raise
855 raise
855
856
856 def _getorcreateinfinitepushlogger(op):
857 def _getorcreateinfinitepushlogger(op):
857 logger = op.records['infinitepushlogger']
858 logger = op.records['infinitepushlogger']
858 if not logger:
859 if not logger:
859 ui = op.repo.ui
860 ui = op.repo.ui
860 try:
861 try:
861 username = procutil.getuser()
862 username = procutil.getuser()
862 except Exception:
863 except Exception:
863 username = 'unknown'
864 username = 'unknown'
864 # Generate random request id to be able to find all logged entries
865 # Generate random request id to be able to find all logged entries
865 # for the same request. Since requestid is pseudo-generated it may
866 # for the same request. Since requestid is pseudo-generated it may
866 # not be unique, but we assume that (hostname, username, requestid)
867 # not be unique, but we assume that (hostname, username, requestid)
867 # is unique.
868 # is unique.
868 random.seed()
869 random.seed()
869 requestid = random.randint(0, 2000000000)
870 requestid = random.randint(0, 2000000000)
870 hostname = socket.gethostname()
871 hostname = socket.gethostname()
871 logger = functools.partial(ui.log, 'infinitepush', user=username,
872 logger = functools.partial(ui.log, 'infinitepush', user=username,
872 requestid=requestid, hostname=hostname,
873 requestid=requestid, hostname=hostname,
873 reponame=ui.config('infinitepush',
874 reponame=ui.config('infinitepush',
874 'reponame'))
875 'reponame'))
875 op.records.add('infinitepushlogger', logger)
876 op.records.add('infinitepushlogger', logger)
876 else:
877 else:
877 logger = logger[0]
878 logger = logger[0]
878 return logger
879 return logger
879
880
880 def storetobundlestore(orig, repo, op, unbundler):
881 def storetobundlestore(orig, repo, op, unbundler):
881 """stores the incoming bundle coming from push command to the bundlestore
882 """stores the incoming bundle coming from push command to the bundlestore
882 instead of applying on the revlogs"""
883 instead of applying on the revlogs"""
883
884
884 repo.ui.status(_("storing changesets on the bundlestore\n"))
885 repo.ui.status(_("storing changesets on the bundlestore\n"))
885 bundler = bundle2.bundle20(repo.ui)
886 bundler = bundle2.bundle20(repo.ui)
886
887
887 # processing each part and storing it in bundler
888 # processing each part and storing it in bundler
888 with bundle2.partiterator(repo, op, unbundler) as parts:
889 with bundle2.partiterator(repo, op, unbundler) as parts:
889 for part in parts:
890 for part in parts:
890 bundlepart = None
891 bundlepart = None
891 if part.type == 'replycaps':
892 if part.type == 'replycaps':
892 # This configures the current operation to allow reply parts.
893 # This configures the current operation to allow reply parts.
893 bundle2._processpart(op, part)
894 bundle2._processpart(op, part)
894 else:
895 else:
895 bundlepart = bundle2.bundlepart(part.type, data=part.read())
896 bundlepart = bundle2.bundlepart(part.type, data=part.read())
896 for key, value in part.params.iteritems():
897 for key, value in part.params.iteritems():
897 bundlepart.addparam(key, value)
898 bundlepart.addparam(key, value)
898
899
899 # Certain parts require a response
900 # Certain parts require a response
900 if part.type in ('pushkey', 'changegroup'):
901 if part.type in ('pushkey', 'changegroup'):
901 if op.reply is not None:
902 if op.reply is not None:
902 rpart = op.reply.newpart('reply:%s' % part.type)
903 rpart = op.reply.newpart('reply:%s' % part.type)
903 rpart.addparam('in-reply-to', str(part.id),
904 rpart.addparam('in-reply-to', str(part.id),
904 mandatory=False)
905 mandatory=False)
905 rpart.addparam('return', '1', mandatory=False)
906 rpart.addparam('return', '1', mandatory=False)
906
907
907 op.records.add(part.type, {
908 op.records.add(part.type, {
908 'return': 1,
909 'return': 1,
909 })
910 })
910 if bundlepart:
911 if bundlepart:
911 bundler.addpart(bundlepart)
912 bundler.addpart(bundlepart)
912
913
913 # storing the bundle in the bundlestore
914 # storing the bundle in the bundlestore
914 buf = util.chunkbuffer(bundler.getchunks())
915 buf = util.chunkbuffer(bundler.getchunks())
915 fd, bundlefile = tempfile.mkstemp()
916 fd, bundlefile = tempfile.mkstemp()
916 try:
917 try:
917 try:
918 try:
918 fp = os.fdopen(fd, r'wb')
919 fp = os.fdopen(fd, r'wb')
919 fp.write(buf.read())
920 fp.write(buf.read())
920 finally:
921 finally:
921 fp.close()
922 fp.close()
922 storebundle(op, {}, bundlefile)
923 storebundle(op, {}, bundlefile)
923 finally:
924 finally:
924 try:
925 try:
925 os.unlink(bundlefile)
926 os.unlink(bundlefile)
926 except Exception:
927 except Exception:
927 # we would rather see the original exception
928 # we would rather see the original exception
928 pass
929 pass
929
930
930 def processparts(orig, repo, op, unbundler):
931 def processparts(orig, repo, op, unbundler):
931
932
932 # make sure we don't wrap processparts in case of `hg unbundle`
933 # make sure we don't wrap processparts in case of `hg unbundle`
933 if op.source == 'unbundle':
934 if op.source == 'unbundle':
934 return orig(repo, op, unbundler)
935 return orig(repo, op, unbundler)
935
936
936 # this server routes each push to bundle store
937 # this server routes each push to bundle store
937 if repo.ui.configbool('infinitepush', 'pushtobundlestore'):
938 if repo.ui.configbool('infinitepush', 'pushtobundlestore'):
938 return storetobundlestore(orig, repo, op, unbundler)
939 return storetobundlestore(orig, repo, op, unbundler)
939
940
940 if unbundler.params.get('infinitepush') != 'True':
941 if unbundler.params.get('infinitepush') != 'True':
941 return orig(repo, op, unbundler)
942 return orig(repo, op, unbundler)
942
943
943 handleallparts = repo.ui.configbool('infinitepush', 'storeallparts')
944 handleallparts = repo.ui.configbool('infinitepush', 'storeallparts')
944
945
945 bundler = bundle2.bundle20(repo.ui)
946 bundler = bundle2.bundle20(repo.ui)
946 cgparams = None
947 cgparams = None
947 with bundle2.partiterator(repo, op, unbundler) as parts:
948 with bundle2.partiterator(repo, op, unbundler) as parts:
948 for part in parts:
949 for part in parts:
949 bundlepart = None
950 bundlepart = None
950 if part.type == 'replycaps':
951 if part.type == 'replycaps':
951 # This configures the current operation to allow reply parts.
952 # This configures the current operation to allow reply parts.
952 bundle2._processpart(op, part)
953 bundle2._processpart(op, part)
953 elif part.type == bundleparts.scratchbranchparttype:
954 elif part.type == bundleparts.scratchbranchparttype:
954 # Scratch branch parts need to be converted to normal
955 # Scratch branch parts need to be converted to normal
955 # changegroup parts, and the extra parameters stored for later
956 # changegroup parts, and the extra parameters stored for later
956 # when we upload to the store. Eventually those parameters will
957 # when we upload to the store. Eventually those parameters will
957 # be put on the actual bundle instead of this part, then we can
958 # be put on the actual bundle instead of this part, then we can
958 # send a vanilla changegroup instead of the scratchbranch part.
959 # send a vanilla changegroup instead of the scratchbranch part.
959 cgversion = part.params.get('cgversion', '01')
960 cgversion = part.params.get('cgversion', '01')
960 bundlepart = bundle2.bundlepart('changegroup', data=part.read())
961 bundlepart = bundle2.bundlepart('changegroup', data=part.read())
961 bundlepart.addparam('version', cgversion)
962 bundlepart.addparam('version', cgversion)
962 cgparams = part.params
963 cgparams = part.params
963
964
964 # If we're not dumping all parts into the new bundle, we need to
965 # If we're not dumping all parts into the new bundle, we need to
965 # alert the future pushkey and phase-heads handler to skip
966 # alert the future pushkey and phase-heads handler to skip
966 # the part.
967 # the part.
967 if not handleallparts:
968 if not handleallparts:
968 op.records.add(scratchbranchparttype + '_skippushkey', True)
969 op.records.add(scratchbranchparttype + '_skippushkey', True)
969 op.records.add(scratchbranchparttype + '_skipphaseheads',
970 op.records.add(scratchbranchparttype + '_skipphaseheads',
970 True)
971 True)
971 else:
972 else:
972 if handleallparts:
973 if handleallparts:
973 # Ideally we would not process any parts, and instead just
974 # Ideally we would not process any parts, and instead just
974 # forward them to the bundle for storage, but since this
975 # forward them to the bundle for storage, but since this
975 # differs from previous behavior, we need to put it behind a
976 # differs from previous behavior, we need to put it behind a
976 # config flag for incremental rollout.
977 # config flag for incremental rollout.
977 bundlepart = bundle2.bundlepart(part.type, data=part.read())
978 bundlepart = bundle2.bundlepart(part.type, data=part.read())
978 for key, value in part.params.iteritems():
979 for key, value in part.params.iteritems():
979 bundlepart.addparam(key, value)
980 bundlepart.addparam(key, value)
980
981
981 # Certain parts require a response
982 # Certain parts require a response
982 if part.type == 'pushkey':
983 if part.type == 'pushkey':
983 if op.reply is not None:
984 if op.reply is not None:
984 rpart = op.reply.newpart('reply:pushkey')
985 rpart = op.reply.newpart('reply:pushkey')
985 rpart.addparam('in-reply-to', str(part.id),
986 rpart.addparam('in-reply-to', str(part.id),
986 mandatory=False)
987 mandatory=False)
987 rpart.addparam('return', '1', mandatory=False)
988 rpart.addparam('return', '1', mandatory=False)
988 else:
989 else:
989 bundle2._processpart(op, part)
990 bundle2._processpart(op, part)
990
991
991 if handleallparts:
992 if handleallparts:
992 op.records.add(part.type, {
993 op.records.add(part.type, {
993 'return': 1,
994 'return': 1,
994 })
995 })
995 if bundlepart:
996 if bundlepart:
996 bundler.addpart(bundlepart)
997 bundler.addpart(bundlepart)
997
998
998 # If commits were sent, store them
999 # If commits were sent, store them
999 if cgparams:
1000 if cgparams:
1000 buf = util.chunkbuffer(bundler.getchunks())
1001 buf = util.chunkbuffer(bundler.getchunks())
1001 fd, bundlefile = tempfile.mkstemp()
1002 fd, bundlefile = tempfile.mkstemp()
1002 try:
1003 try:
1003 try:
1004 try:
1004 fp = os.fdopen(fd, r'wb')
1005 fp = os.fdopen(fd, r'wb')
1005 fp.write(buf.read())
1006 fp.write(buf.read())
1006 finally:
1007 finally:
1007 fp.close()
1008 fp.close()
1008 storebundle(op, cgparams, bundlefile)
1009 storebundle(op, cgparams, bundlefile)
1009 finally:
1010 finally:
1010 try:
1011 try:
1011 os.unlink(bundlefile)
1012 os.unlink(bundlefile)
1012 except Exception:
1013 except Exception:
1013 # we would rather see the original exception
1014 # we would rather see the original exception
1014 pass
1015 pass
1015
1016
1016 def storebundle(op, params, bundlefile):
1017 def storebundle(op, params, bundlefile):
1017 log = _getorcreateinfinitepushlogger(op)
1018 log = _getorcreateinfinitepushlogger(op)
1018 parthandlerstart = time.time()
1019 parthandlerstart = time.time()
1019 log(scratchbranchparttype, eventtype='start')
1020 log(scratchbranchparttype, eventtype='start')
1020 index = op.repo.bundlestore.index
1021 index = op.repo.bundlestore.index
1021 store = op.repo.bundlestore.store
1022 store = op.repo.bundlestore.store
1022 op.records.add(scratchbranchparttype + '_skippushkey', True)
1023 op.records.add(scratchbranchparttype + '_skippushkey', True)
1023
1024
1024 bundle = None
1025 bundle = None
1025 try: # guards bundle
1026 try: # guards bundle
1026 bundlepath = "bundle:%s+%s" % (op.repo.root, bundlefile)
1027 bundlepath = "bundle:%s+%s" % (op.repo.root, bundlefile)
1027 bundle = hg.repository(op.repo.ui, bundlepath)
1028 bundle = hg.repository(op.repo.ui, bundlepath)
1028
1029
1029 bookmark = params.get('bookmark')
1030 bookmark = params.get('bookmark')
1030 bookprevnode = params.get('bookprevnode', '')
1031 bookprevnode = params.get('bookprevnode', '')
1031 force = params.get('force')
1032 force = params.get('force')
1032
1033
1033 if bookmark:
1034 if bookmark:
1034 oldnode = index.getnode(bookmark)
1035 oldnode = index.getnode(bookmark)
1035 else:
1036 else:
1036 oldnode = None
1037 oldnode = None
1037 bundleheads = bundle.revs('heads(bundle())')
1038 bundleheads = bundle.revs('heads(bundle())')
1038 if bookmark and len(bundleheads) > 1:
1039 if bookmark and len(bundleheads) > 1:
1039 raise error.Abort(
1040 raise error.Abort(
1040 _('cannot push more than one head to a scratch branch'))
1041 _('cannot push more than one head to a scratch branch'))
1041
1042
1042 revs = _getrevs(bundle, oldnode, force, bookmark)
1043 revs = _getrevs(bundle, oldnode, force, bookmark)
1043
1044
1044 # Notify the user of what is being pushed
1045 # Notify the user of what is being pushed
1045 plural = 's' if len(revs) > 1 else ''
1046 plural = 's' if len(revs) > 1 else ''
1046 op.repo.ui.warn(_("pushing %d commit%s:\n") % (len(revs), plural))
1047 op.repo.ui.warn(_("pushing %d commit%s:\n") % (len(revs), plural))
1047 maxoutput = 10
1048 maxoutput = 10
1048 for i in range(0, min(len(revs), maxoutput)):
1049 for i in range(0, min(len(revs), maxoutput)):
1049 firstline = bundle[revs[i]].description().split('\n')[0][:50]
1050 firstline = bundle[revs[i]].description().split('\n')[0][:50]
1050 op.repo.ui.warn((" %s %s\n") % (revs[i], firstline))
1051 op.repo.ui.warn((" %s %s\n") % (revs[i], firstline))
1051
1052
1052 if len(revs) > maxoutput + 1:
1053 if len(revs) > maxoutput + 1:
1053 op.repo.ui.warn((" ...\n"))
1054 op.repo.ui.warn((" ...\n"))
1054 firstline = bundle[revs[-1]].description().split('\n')[0][:50]
1055 firstline = bundle[revs[-1]].description().split('\n')[0][:50]
1055 op.repo.ui.warn((" %s %s\n") % (revs[-1], firstline))
1056 op.repo.ui.warn((" %s %s\n") % (revs[-1], firstline))
1056
1057
1057 nodesctx = [bundle[rev] for rev in revs]
1058 nodesctx = [bundle[rev] for rev in revs]
1058 inindex = lambda rev: bool(index.getbundle(bundle[rev].hex()))
1059 inindex = lambda rev: bool(index.getbundle(bundle[rev].hex()))
1059 if bundleheads:
1060 if bundleheads:
1060 newheadscount = sum(not inindex(rev) for rev in bundleheads)
1061 newheadscount = sum(not inindex(rev) for rev in bundleheads)
1061 else:
1062 else:
1062 newheadscount = 0
1063 newheadscount = 0
1063 # If there's a bookmark specified, there should be only one head,
1064 # If there's a bookmark specified, there should be only one head,
1064 # so we choose the last node, which will be that head.
1065 # so we choose the last node, which will be that head.
1065 # If a bug or malicious client allows there to be a bookmark
1066 # If a bug or malicious client allows there to be a bookmark
1066 # with multiple heads, we will place the bookmark on the last head.
1067 # with multiple heads, we will place the bookmark on the last head.
1067 bookmarknode = nodesctx[-1].hex() if nodesctx else None
1068 bookmarknode = nodesctx[-1].hex() if nodesctx else None
1068 key = None
1069 key = None
1069 if newheadscount:
1070 if newheadscount:
1070 with open(bundlefile, 'r') as f:
1071 with open(bundlefile, 'r') as f:
1071 bundledata = f.read()
1072 bundledata = f.read()
1072 with logservicecall(log, 'bundlestore',
1073 with logservicecall(log, 'bundlestore',
1073 bundlesize=len(bundledata)):
1074 bundlesize=len(bundledata)):
1074 bundlesizelimit = 100 * 1024 * 1024 # 100 MB
1075 bundlesizelimit = 100 * 1024 * 1024 # 100 MB
1075 if len(bundledata) > bundlesizelimit:
1076 if len(bundledata) > bundlesizelimit:
1076 error_msg = ('bundle is too big: %d bytes. ' +
1077 error_msg = ('bundle is too big: %d bytes. ' +
1077 'max allowed size is 100 MB')
1078 'max allowed size is 100 MB')
1078 raise error.Abort(error_msg % (len(bundledata),))
1079 raise error.Abort(error_msg % (len(bundledata),))
1079 key = store.write(bundledata)
1080 key = store.write(bundledata)
1080
1081
1081 with logservicecall(log, 'index', newheadscount=newheadscount), index:
1082 with logservicecall(log, 'index', newheadscount=newheadscount), index:
1082 if key:
1083 if key:
1083 index.addbundle(key, nodesctx)
1084 index.addbundle(key, nodesctx)
1084 if bookmark:
1085 if bookmark:
1085 index.addbookmark(bookmark, bookmarknode)
1086 index.addbookmark(bookmark, bookmarknode)
1086 _maybeaddpushbackpart(op, bookmark, bookmarknode,
1087 _maybeaddpushbackpart(op, bookmark, bookmarknode,
1087 bookprevnode, params)
1088 bookprevnode, params)
1088 log(scratchbranchparttype, eventtype='success',
1089 log(scratchbranchparttype, eventtype='success',
1089 elapsedms=(time.time() - parthandlerstart) * 1000)
1090 elapsedms=(time.time() - parthandlerstart) * 1000)
1090
1091
1091 except Exception as e:
1092 except Exception as e:
1092 log(scratchbranchparttype, eventtype='failure',
1093 log(scratchbranchparttype, eventtype='failure',
1093 elapsedms=(time.time() - parthandlerstart) * 1000,
1094 elapsedms=(time.time() - parthandlerstart) * 1000,
1094 errormsg=str(e))
1095 errormsg=str(e))
1095 raise
1096 raise
1096 finally:
1097 finally:
1097 if bundle:
1098 if bundle:
1098 bundle.close()
1099 bundle.close()
1099
1100
1100 @bundle2.parthandler(scratchbranchparttype,
1101 @bundle2.parthandler(scratchbranchparttype,
1101 ('bookmark', 'bookprevnode', 'force',
1102 ('bookmark', 'bookprevnode', 'force',
1102 'pushbackbookmarks', 'cgversion'))
1103 'pushbackbookmarks', 'cgversion'))
1103 def bundle2scratchbranch(op, part):
1104 def bundle2scratchbranch(op, part):
1104 '''unbundle a bundle2 part containing a changegroup to store'''
1105 '''unbundle a bundle2 part containing a changegroup to store'''
1105
1106
1106 bundler = bundle2.bundle20(op.repo.ui)
1107 bundler = bundle2.bundle20(op.repo.ui)
1107 cgversion = part.params.get('cgversion', '01')
1108 cgversion = part.params.get('cgversion', '01')
1108 cgpart = bundle2.bundlepart('changegroup', data=part.read())
1109 cgpart = bundle2.bundlepart('changegroup', data=part.read())
1109 cgpart.addparam('version', cgversion)
1110 cgpart.addparam('version', cgversion)
1110 bundler.addpart(cgpart)
1111 bundler.addpart(cgpart)
1111 buf = util.chunkbuffer(bundler.getchunks())
1112 buf = util.chunkbuffer(bundler.getchunks())
1112
1113
1113 fd, bundlefile = tempfile.mkstemp()
1114 fd, bundlefile = tempfile.mkstemp()
1114 try:
1115 try:
1115 try:
1116 try:
1116 fp = os.fdopen(fd, r'wb')
1117 fp = os.fdopen(fd, r'wb')
1117 fp.write(buf.read())
1118 fp.write(buf.read())
1118 finally:
1119 finally:
1119 fp.close()
1120 fp.close()
1120 storebundle(op, part.params, bundlefile)
1121 storebundle(op, part.params, bundlefile)
1121 finally:
1122 finally:
1122 try:
1123 try:
1123 os.unlink(bundlefile)
1124 os.unlink(bundlefile)
1124 except OSError as e:
1125 except OSError as e:
1125 if e.errno != errno.ENOENT:
1126 if e.errno != errno.ENOENT:
1126 raise
1127 raise
1127
1128
1128 return 1
1129 return 1
1129
1130
1130 def _maybeaddpushbackpart(op, bookmark, newnode, oldnode, params):
1131 def _maybeaddpushbackpart(op, bookmark, newnode, oldnode, params):
1131 if params.get('pushbackbookmarks'):
1132 if params.get('pushbackbookmarks'):
1132 if op.reply and 'pushback' in op.reply.capabilities:
1133 if op.reply and 'pushback' in op.reply.capabilities:
1133 params = {
1134 params = {
1134 'namespace': 'bookmarks',
1135 'namespace': 'bookmarks',
1135 'key': bookmark,
1136 'key': bookmark,
1136 'new': newnode,
1137 'new': newnode,
1137 'old': oldnode,
1138 'old': oldnode,
1138 }
1139 }
1139 op.reply.newpart('pushkey', mandatoryparams=params.iteritems())
1140 op.reply.newpart('pushkey', mandatoryparams=params.iteritems())
1140
1141
1141 def bundle2pushkey(orig, op, part):
1142 def bundle2pushkey(orig, op, part):
1142 '''Wrapper of bundle2.handlepushkey()
1143 '''Wrapper of bundle2.handlepushkey()
1143
1144
1144 The only goal is to skip calling the original function if flag is set.
1145 The only goal is to skip calling the original function if flag is set.
1145 It's set if infinitepush push is happening.
1146 It's set if infinitepush push is happening.
1146 '''
1147 '''
1147 if op.records[scratchbranchparttype + '_skippushkey']:
1148 if op.records[scratchbranchparttype + '_skippushkey']:
1148 if op.reply is not None:
1149 if op.reply is not None:
1149 rpart = op.reply.newpart('reply:pushkey')
1150 rpart = op.reply.newpart('reply:pushkey')
1150 rpart.addparam('in-reply-to', str(part.id), mandatory=False)
1151 rpart.addparam('in-reply-to', str(part.id), mandatory=False)
1151 rpart.addparam('return', '1', mandatory=False)
1152 rpart.addparam('return', '1', mandatory=False)
1152 return 1
1153 return 1
1153
1154
1154 return orig(op, part)
1155 return orig(op, part)
1155
1156
1156 def bundle2handlephases(orig, op, part):
1157 def bundle2handlephases(orig, op, part):
1157 '''Wrapper of bundle2.handlephases()
1158 '''Wrapper of bundle2.handlephases()
1158
1159
1159 The only goal is to skip calling the original function if flag is set.
1160 The only goal is to skip calling the original function if flag is set.
1160 It's set if infinitepush push is happening.
1161 It's set if infinitepush push is happening.
1161 '''
1162 '''
1162
1163
1163 if op.records[scratchbranchparttype + '_skipphaseheads']:
1164 if op.records[scratchbranchparttype + '_skipphaseheads']:
1164 return
1165 return
1165
1166
1166 return orig(op, part)
1167 return orig(op, part)
1167
1168
1168 def _asyncsavemetadata(root, nodes):
1169 def _asyncsavemetadata(root, nodes):
1169 '''starts a separate process that fills metadata for the nodes
1170 '''starts a separate process that fills metadata for the nodes
1170
1171
1171 This function creates a separate process and doesn't wait for it's
1172 This function creates a separate process and doesn't wait for it's
1172 completion. This was done to avoid slowing down pushes
1173 completion. This was done to avoid slowing down pushes
1173 '''
1174 '''
1174
1175
1175 maxnodes = 50
1176 maxnodes = 50
1176 if len(nodes) > maxnodes:
1177 if len(nodes) > maxnodes:
1177 return
1178 return
1178 nodesargs = []
1179 nodesargs = []
1179 for node in nodes:
1180 for node in nodes:
1180 nodesargs.append('--node')
1181 nodesargs.append('--node')
1181 nodesargs.append(node)
1182 nodesargs.append(node)
1182 with open(os.devnull, 'w+b') as devnull:
1183 with open(os.devnull, 'w+b') as devnull:
1183 cmdline = [util.hgexecutable(), 'debugfillinfinitepushmetadata',
1184 cmdline = [util.hgexecutable(), 'debugfillinfinitepushmetadata',
1184 '-R', root] + nodesargs
1185 '-R', root] + nodesargs
1185 # Process will run in background. We don't care about the return code
1186 # Process will run in background. We don't care about the return code
1186 subprocess.Popen(cmdline, close_fds=True, shell=False,
1187 subprocess.Popen(cmdline, close_fds=True, shell=False,
1187 stdin=devnull, stdout=devnull, stderr=devnull)
1188 stdin=devnull, stdout=devnull, stderr=devnull)
@@ -1,193 +1,193 b''
1 # Copyright 2011 Fog Creek Software
1 # Copyright 2011 Fog Creek Software
2 #
2 #
3 # This software may be used and distributed according to the terms of the
3 # This software may be used and distributed according to the terms of the
4 # GNU General Public License version 2 or any later version.
4 # GNU General Public License version 2 or any later version.
5 from __future__ import absolute_import
5 from __future__ import absolute_import
6
6
7 import os
7 import os
8 import re
8 import re
9
9
10 from mercurial.i18n import _
10 from mercurial.i18n import _
11
11
12 from mercurial import (
12 from mercurial import (
13 error,
13 error,
14 httppeer,
14 httppeer,
15 util,
15 util,
16 wireproto,
17 wireprototypes,
16 wireprototypes,
17 wireprotov1peer,
18 )
18 )
19
19
20 from . import (
20 from . import (
21 lfutil,
21 lfutil,
22 )
22 )
23
23
24 urlerr = util.urlerr
24 urlerr = util.urlerr
25 urlreq = util.urlreq
25 urlreq = util.urlreq
26
26
27 LARGEFILES_REQUIRED_MSG = ('\nThis repository uses the largefiles extension.'
27 LARGEFILES_REQUIRED_MSG = ('\nThis repository uses the largefiles extension.'
28 '\n\nPlease enable it in your Mercurial config '
28 '\n\nPlease enable it in your Mercurial config '
29 'file.\n')
29 'file.\n')
30
30
31 # these will all be replaced by largefiles.uisetup
31 # these will all be replaced by largefiles.uisetup
32 ssholdcallstream = None
32 ssholdcallstream = None
33 httpoldcallstream = None
33 httpoldcallstream = None
34
34
35 def putlfile(repo, proto, sha):
35 def putlfile(repo, proto, sha):
36 '''Server command for putting a largefile into a repository's local store
36 '''Server command for putting a largefile into a repository's local store
37 and into the user cache.'''
37 and into the user cache.'''
38 with proto.mayberedirectstdio() as output:
38 with proto.mayberedirectstdio() as output:
39 path = lfutil.storepath(repo, sha)
39 path = lfutil.storepath(repo, sha)
40 util.makedirs(os.path.dirname(path))
40 util.makedirs(os.path.dirname(path))
41 tmpfp = util.atomictempfile(path, createmode=repo.store.createmode)
41 tmpfp = util.atomictempfile(path, createmode=repo.store.createmode)
42
42
43 try:
43 try:
44 for p in proto.getpayload():
44 for p in proto.getpayload():
45 tmpfp.write(p)
45 tmpfp.write(p)
46 tmpfp._fp.seek(0)
46 tmpfp._fp.seek(0)
47 if sha != lfutil.hexsha1(tmpfp._fp):
47 if sha != lfutil.hexsha1(tmpfp._fp):
48 raise IOError(0, _('largefile contents do not match hash'))
48 raise IOError(0, _('largefile contents do not match hash'))
49 tmpfp.close()
49 tmpfp.close()
50 lfutil.linktousercache(repo, sha)
50 lfutil.linktousercache(repo, sha)
51 except IOError as e:
51 except IOError as e:
52 repo.ui.warn(_('largefiles: failed to put %s into store: %s\n') %
52 repo.ui.warn(_('largefiles: failed to put %s into store: %s\n') %
53 (sha, e.strerror))
53 (sha, e.strerror))
54 return wireprototypes.pushres(
54 return wireprototypes.pushres(
55 1, output.getvalue() if output else '')
55 1, output.getvalue() if output else '')
56 finally:
56 finally:
57 tmpfp.discard()
57 tmpfp.discard()
58
58
59 return wireprototypes.pushres(0, output.getvalue() if output else '')
59 return wireprototypes.pushres(0, output.getvalue() if output else '')
60
60
61 def getlfile(repo, proto, sha):
61 def getlfile(repo, proto, sha):
62 '''Server command for retrieving a largefile from the repository-local
62 '''Server command for retrieving a largefile from the repository-local
63 cache or user cache.'''
63 cache or user cache.'''
64 filename = lfutil.findfile(repo, sha)
64 filename = lfutil.findfile(repo, sha)
65 if not filename:
65 if not filename:
66 raise error.Abort(_('requested largefile %s not present in cache')
66 raise error.Abort(_('requested largefile %s not present in cache')
67 % sha)
67 % sha)
68 f = open(filename, 'rb')
68 f = open(filename, 'rb')
69 length = os.fstat(f.fileno())[6]
69 length = os.fstat(f.fileno())[6]
70
70
71 # Since we can't set an HTTP content-length header here, and
71 # Since we can't set an HTTP content-length header here, and
72 # Mercurial core provides no way to give the length of a streamres
72 # Mercurial core provides no way to give the length of a streamres
73 # (and reading the entire file into RAM would be ill-advised), we
73 # (and reading the entire file into RAM would be ill-advised), we
74 # just send the length on the first line of the response, like the
74 # just send the length on the first line of the response, like the
75 # ssh proto does for string responses.
75 # ssh proto does for string responses.
76 def generator():
76 def generator():
77 yield '%d\n' % length
77 yield '%d\n' % length
78 for chunk in util.filechunkiter(f):
78 for chunk in util.filechunkiter(f):
79 yield chunk
79 yield chunk
80 return wireprototypes.streamreslegacy(gen=generator())
80 return wireprototypes.streamreslegacy(gen=generator())
81
81
82 def statlfile(repo, proto, sha):
82 def statlfile(repo, proto, sha):
83 '''Server command for checking if a largefile is present - returns '2\n' if
83 '''Server command for checking if a largefile is present - returns '2\n' if
84 the largefile is missing, '0\n' if it seems to be in good condition.
84 the largefile is missing, '0\n' if it seems to be in good condition.
85
85
86 The value 1 is reserved for mismatched checksum, but that is too expensive
86 The value 1 is reserved for mismatched checksum, but that is too expensive
87 to be verified on every stat and must be caught be running 'hg verify'
87 to be verified on every stat and must be caught be running 'hg verify'
88 server side.'''
88 server side.'''
89 filename = lfutil.findfile(repo, sha)
89 filename = lfutil.findfile(repo, sha)
90 if not filename:
90 if not filename:
91 return wireprototypes.bytesresponse('2\n')
91 return wireprototypes.bytesresponse('2\n')
92 return wireprototypes.bytesresponse('0\n')
92 return wireprototypes.bytesresponse('0\n')
93
93
94 def wirereposetup(ui, repo):
94 def wirereposetup(ui, repo):
95 class lfileswirerepository(repo.__class__):
95 class lfileswirerepository(repo.__class__):
96 def putlfile(self, sha, fd):
96 def putlfile(self, sha, fd):
97 # unfortunately, httprepository._callpush tries to convert its
97 # unfortunately, httprepository._callpush tries to convert its
98 # input file-like into a bundle before sending it, so we can't use
98 # input file-like into a bundle before sending it, so we can't use
99 # it ...
99 # it ...
100 if issubclass(self.__class__, httppeer.httppeer):
100 if issubclass(self.__class__, httppeer.httppeer):
101 res = self._call('putlfile', data=fd, sha=sha,
101 res = self._call('putlfile', data=fd, sha=sha,
102 headers={r'content-type': r'application/mercurial-0.1'})
102 headers={r'content-type': r'application/mercurial-0.1'})
103 try:
103 try:
104 d, output = res.split('\n', 1)
104 d, output = res.split('\n', 1)
105 for l in output.splitlines(True):
105 for l in output.splitlines(True):
106 self.ui.warn(_('remote: '), l) # assume l ends with \n
106 self.ui.warn(_('remote: '), l) # assume l ends with \n
107 return int(d)
107 return int(d)
108 except ValueError:
108 except ValueError:
109 self.ui.warn(_('unexpected putlfile response: %r\n') % res)
109 self.ui.warn(_('unexpected putlfile response: %r\n') % res)
110 return 1
110 return 1
111 # ... but we can't use sshrepository._call because the data=
111 # ... but we can't use sshrepository._call because the data=
112 # argument won't get sent, and _callpush does exactly what we want
112 # argument won't get sent, and _callpush does exactly what we want
113 # in this case: send the data straight through
113 # in this case: send the data straight through
114 else:
114 else:
115 try:
115 try:
116 ret, output = self._callpush("putlfile", fd, sha=sha)
116 ret, output = self._callpush("putlfile", fd, sha=sha)
117 if ret == "":
117 if ret == "":
118 raise error.ResponseError(_('putlfile failed:'),
118 raise error.ResponseError(_('putlfile failed:'),
119 output)
119 output)
120 return int(ret)
120 return int(ret)
121 except IOError:
121 except IOError:
122 return 1
122 return 1
123 except ValueError:
123 except ValueError:
124 raise error.ResponseError(
124 raise error.ResponseError(
125 _('putlfile failed (unexpected response):'), ret)
125 _('putlfile failed (unexpected response):'), ret)
126
126
127 def getlfile(self, sha):
127 def getlfile(self, sha):
128 """returns an iterable with the chunks of the file with sha sha"""
128 """returns an iterable with the chunks of the file with sha sha"""
129 stream = self._callstream("getlfile", sha=sha)
129 stream = self._callstream("getlfile", sha=sha)
130 length = stream.readline()
130 length = stream.readline()
131 try:
131 try:
132 length = int(length)
132 length = int(length)
133 except ValueError:
133 except ValueError:
134 self._abort(error.ResponseError(_("unexpected response:"),
134 self._abort(error.ResponseError(_("unexpected response:"),
135 length))
135 length))
136
136
137 # SSH streams will block if reading more than length
137 # SSH streams will block if reading more than length
138 for chunk in util.filechunkiter(stream, limit=length):
138 for chunk in util.filechunkiter(stream, limit=length):
139 yield chunk
139 yield chunk
140 # HTTP streams must hit the end to process the last empty
140 # HTTP streams must hit the end to process the last empty
141 # chunk of Chunked-Encoding so the connection can be reused.
141 # chunk of Chunked-Encoding so the connection can be reused.
142 if issubclass(self.__class__, httppeer.httppeer):
142 if issubclass(self.__class__, httppeer.httppeer):
143 chunk = stream.read(1)
143 chunk = stream.read(1)
144 if chunk:
144 if chunk:
145 self._abort(error.ResponseError(_("unexpected response:"),
145 self._abort(error.ResponseError(_("unexpected response:"),
146 chunk))
146 chunk))
147
147
148 @wireproto.batchable
148 @wireprotov1peer.batchable
149 def statlfile(self, sha):
149 def statlfile(self, sha):
150 f = wireproto.future()
150 f = wireprotov1peer.future()
151 result = {'sha': sha}
151 result = {'sha': sha}
152 yield result, f
152 yield result, f
153 try:
153 try:
154 yield int(f.value)
154 yield int(f.value)
155 except (ValueError, urlerr.httperror):
155 except (ValueError, urlerr.httperror):
156 # If the server returns anything but an integer followed by a
156 # If the server returns anything but an integer followed by a
157 # newline, newline, it's not speaking our language; if we get
157 # newline, newline, it's not speaking our language; if we get
158 # an HTTP error, we can't be sure the largefile is present;
158 # an HTTP error, we can't be sure the largefile is present;
159 # either way, consider it missing.
159 # either way, consider it missing.
160 yield 2
160 yield 2
161
161
162 repo.__class__ = lfileswirerepository
162 repo.__class__ = lfileswirerepository
163
163
164 # advertise the largefiles=serve capability
164 # advertise the largefiles=serve capability
165 def _capabilities(orig, repo, proto):
165 def _capabilities(orig, repo, proto):
166 '''announce largefile server capability'''
166 '''announce largefile server capability'''
167 caps = orig(repo, proto)
167 caps = orig(repo, proto)
168 caps.append('largefiles=serve')
168 caps.append('largefiles=serve')
169 return caps
169 return caps
170
170
171 def heads(orig, repo, proto):
171 def heads(orig, repo, proto):
172 '''Wrap server command - largefile capable clients will know to call
172 '''Wrap server command - largefile capable clients will know to call
173 lheads instead'''
173 lheads instead'''
174 if lfutil.islfilesrepo(repo):
174 if lfutil.islfilesrepo(repo):
175 return wireprototypes.ooberror(LARGEFILES_REQUIRED_MSG)
175 return wireprototypes.ooberror(LARGEFILES_REQUIRED_MSG)
176
176
177 return orig(repo, proto)
177 return orig(repo, proto)
178
178
179 def sshrepocallstream(self, cmd, **args):
179 def sshrepocallstream(self, cmd, **args):
180 if cmd == 'heads' and self.capable('largefiles'):
180 if cmd == 'heads' and self.capable('largefiles'):
181 cmd = 'lheads'
181 cmd = 'lheads'
182 if cmd == 'batch' and self.capable('largefiles'):
182 if cmd == 'batch' and self.capable('largefiles'):
183 args[r'cmds'] = args[r'cmds'].replace('heads ', 'lheads ')
183 args[r'cmds'] = args[r'cmds'].replace('heads ', 'lheads ')
184 return ssholdcallstream(self, cmd, **args)
184 return ssholdcallstream(self, cmd, **args)
185
185
186 headsre = re.compile(br'(^|;)heads\b')
186 headsre = re.compile(br'(^|;)heads\b')
187
187
188 def httprepocallstream(self, cmd, **args):
188 def httprepocallstream(self, cmd, **args):
189 if cmd == 'heads' and self.capable('largefiles'):
189 if cmd == 'heads' and self.capable('largefiles'):
190 cmd = 'lheads'
190 cmd = 'lheads'
191 if cmd == 'batch' and self.capable('largefiles'):
191 if cmd == 'batch' and self.capable('largefiles'):
192 args[r'cmds'] = headsre.sub('lheads', args[r'cmds'])
192 args[r'cmds'] = headsre.sub('lheads', args[r'cmds'])
193 return httpoldcallstream(self, cmd, **args)
193 return httpoldcallstream(self, cmd, **args)
@@ -1,816 +1,817 b''
1 # httppeer.py - HTTP repository proxy classes for mercurial
1 # httppeer.py - HTTP repository proxy classes for mercurial
2 #
2 #
3 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 #
5 #
6 # This software may be used and distributed according to the terms of the
6 # This software may be used and distributed according to the terms of the
7 # GNU General Public License version 2 or any later version.
7 # GNU General Public License version 2 or any later version.
8
8
9 from __future__ import absolute_import
9 from __future__ import absolute_import
10
10
11 import errno
11 import errno
12 import io
12 import io
13 import os
13 import os
14 import socket
14 import socket
15 import struct
15 import struct
16 import tempfile
16 import tempfile
17
17
18 from .i18n import _
18 from .i18n import _
19 from .thirdparty import (
19 from .thirdparty import (
20 cbor,
20 cbor,
21 )
21 )
22 from .thirdparty.zope import (
22 from .thirdparty.zope import (
23 interface as zi,
23 interface as zi,
24 )
24 )
25 from . import (
25 from . import (
26 bundle2,
26 bundle2,
27 error,
27 error,
28 httpconnection,
28 httpconnection,
29 pycompat,
29 pycompat,
30 repository,
30 repository,
31 statichttprepo,
31 statichttprepo,
32 url as urlmod,
32 url as urlmod,
33 util,
33 util,
34 wireproto,
34 wireproto,
35 wireprotoframing,
35 wireprotoframing,
36 wireprototypes,
36 wireprototypes,
37 wireprotov1peer,
37 wireprotov2server,
38 wireprotov2server,
38 )
39 )
39
40
40 httplib = util.httplib
41 httplib = util.httplib
41 urlerr = util.urlerr
42 urlerr = util.urlerr
42 urlreq = util.urlreq
43 urlreq = util.urlreq
43
44
44 def encodevalueinheaders(value, header, limit):
45 def encodevalueinheaders(value, header, limit):
45 """Encode a string value into multiple HTTP headers.
46 """Encode a string value into multiple HTTP headers.
46
47
47 ``value`` will be encoded into 1 or more HTTP headers with the names
48 ``value`` will be encoded into 1 or more HTTP headers with the names
48 ``header-<N>`` where ``<N>`` is an integer starting at 1. Each header
49 ``header-<N>`` where ``<N>`` is an integer starting at 1. Each header
49 name + value will be at most ``limit`` bytes long.
50 name + value will be at most ``limit`` bytes long.
50
51
51 Returns an iterable of 2-tuples consisting of header names and
52 Returns an iterable of 2-tuples consisting of header names and
52 values as native strings.
53 values as native strings.
53 """
54 """
54 # HTTP Headers are ASCII. Python 3 requires them to be unicodes,
55 # HTTP Headers are ASCII. Python 3 requires them to be unicodes,
55 # not bytes. This function always takes bytes in as arguments.
56 # not bytes. This function always takes bytes in as arguments.
56 fmt = pycompat.strurl(header) + r'-%s'
57 fmt = pycompat.strurl(header) + r'-%s'
57 # Note: it is *NOT* a bug that the last bit here is a bytestring
58 # Note: it is *NOT* a bug that the last bit here is a bytestring
58 # and not a unicode: we're just getting the encoded length anyway,
59 # and not a unicode: we're just getting the encoded length anyway,
59 # and using an r-string to make it portable between Python 2 and 3
60 # and using an r-string to make it portable between Python 2 and 3
60 # doesn't work because then the \r is a literal backslash-r
61 # doesn't work because then the \r is a literal backslash-r
61 # instead of a carriage return.
62 # instead of a carriage return.
62 valuelen = limit - len(fmt % r'000') - len(': \r\n')
63 valuelen = limit - len(fmt % r'000') - len(': \r\n')
63 result = []
64 result = []
64
65
65 n = 0
66 n = 0
66 for i in xrange(0, len(value), valuelen):
67 for i in xrange(0, len(value), valuelen):
67 n += 1
68 n += 1
68 result.append((fmt % str(n), pycompat.strurl(value[i:i + valuelen])))
69 result.append((fmt % str(n), pycompat.strurl(value[i:i + valuelen])))
69
70
70 return result
71 return result
71
72
72 def _wraphttpresponse(resp):
73 def _wraphttpresponse(resp):
73 """Wrap an HTTPResponse with common error handlers.
74 """Wrap an HTTPResponse with common error handlers.
74
75
75 This ensures that any I/O from any consumer raises the appropriate
76 This ensures that any I/O from any consumer raises the appropriate
76 error and messaging.
77 error and messaging.
77 """
78 """
78 origread = resp.read
79 origread = resp.read
79
80
80 class readerproxy(resp.__class__):
81 class readerproxy(resp.__class__):
81 def read(self, size=None):
82 def read(self, size=None):
82 try:
83 try:
83 return origread(size)
84 return origread(size)
84 except httplib.IncompleteRead as e:
85 except httplib.IncompleteRead as e:
85 # e.expected is an integer if length known or None otherwise.
86 # e.expected is an integer if length known or None otherwise.
86 if e.expected:
87 if e.expected:
87 msg = _('HTTP request error (incomplete response; '
88 msg = _('HTTP request error (incomplete response; '
88 'expected %d bytes got %d)') % (e.expected,
89 'expected %d bytes got %d)') % (e.expected,
89 len(e.partial))
90 len(e.partial))
90 else:
91 else:
91 msg = _('HTTP request error (incomplete response)')
92 msg = _('HTTP request error (incomplete response)')
92
93
93 raise error.PeerTransportError(
94 raise error.PeerTransportError(
94 msg,
95 msg,
95 hint=_('this may be an intermittent network failure; '
96 hint=_('this may be an intermittent network failure; '
96 'if the error persists, consider contacting the '
97 'if the error persists, consider contacting the '
97 'network or server operator'))
98 'network or server operator'))
98 except httplib.HTTPException as e:
99 except httplib.HTTPException as e:
99 raise error.PeerTransportError(
100 raise error.PeerTransportError(
100 _('HTTP request error (%s)') % e,
101 _('HTTP request error (%s)') % e,
101 hint=_('this may be an intermittent network failure; '
102 hint=_('this may be an intermittent network failure; '
102 'if the error persists, consider contacting the '
103 'if the error persists, consider contacting the '
103 'network or server operator'))
104 'network or server operator'))
104
105
105 resp.__class__ = readerproxy
106 resp.__class__ = readerproxy
106
107
107 class _multifile(object):
108 class _multifile(object):
108 def __init__(self, *fileobjs):
109 def __init__(self, *fileobjs):
109 for f in fileobjs:
110 for f in fileobjs:
110 if not util.safehasattr(f, 'length'):
111 if not util.safehasattr(f, 'length'):
111 raise ValueError(
112 raise ValueError(
112 '_multifile only supports file objects that '
113 '_multifile only supports file objects that '
113 'have a length but this one does not:', type(f), f)
114 'have a length but this one does not:', type(f), f)
114 self._fileobjs = fileobjs
115 self._fileobjs = fileobjs
115 self._index = 0
116 self._index = 0
116
117
117 @property
118 @property
118 def length(self):
119 def length(self):
119 return sum(f.length for f in self._fileobjs)
120 return sum(f.length for f in self._fileobjs)
120
121
121 def read(self, amt=None):
122 def read(self, amt=None):
122 if amt <= 0:
123 if amt <= 0:
123 return ''.join(f.read() for f in self._fileobjs)
124 return ''.join(f.read() for f in self._fileobjs)
124 parts = []
125 parts = []
125 while amt and self._index < len(self._fileobjs):
126 while amt and self._index < len(self._fileobjs):
126 parts.append(self._fileobjs[self._index].read(amt))
127 parts.append(self._fileobjs[self._index].read(amt))
127 got = len(parts[-1])
128 got = len(parts[-1])
128 if got < amt:
129 if got < amt:
129 self._index += 1
130 self._index += 1
130 amt -= got
131 amt -= got
131 return ''.join(parts)
132 return ''.join(parts)
132
133
133 def seek(self, offset, whence=os.SEEK_SET):
134 def seek(self, offset, whence=os.SEEK_SET):
134 if whence != os.SEEK_SET:
135 if whence != os.SEEK_SET:
135 raise NotImplementedError(
136 raise NotImplementedError(
136 '_multifile does not support anything other'
137 '_multifile does not support anything other'
137 ' than os.SEEK_SET for whence on seek()')
138 ' than os.SEEK_SET for whence on seek()')
138 if offset != 0:
139 if offset != 0:
139 raise NotImplementedError(
140 raise NotImplementedError(
140 '_multifile only supports seeking to start, but that '
141 '_multifile only supports seeking to start, but that '
141 'could be fixed if you need it')
142 'could be fixed if you need it')
142 for f in self._fileobjs:
143 for f in self._fileobjs:
143 f.seek(0)
144 f.seek(0)
144 self._index = 0
145 self._index = 0
145
146
146 def makev1commandrequest(ui, requestbuilder, caps, capablefn,
147 def makev1commandrequest(ui, requestbuilder, caps, capablefn,
147 repobaseurl, cmd, args):
148 repobaseurl, cmd, args):
148 """Make an HTTP request to run a command for a version 1 client.
149 """Make an HTTP request to run a command for a version 1 client.
149
150
150 ``caps`` is a set of known server capabilities. The value may be
151 ``caps`` is a set of known server capabilities. The value may be
151 None if capabilities are not yet known.
152 None if capabilities are not yet known.
152
153
153 ``capablefn`` is a function to evaluate a capability.
154 ``capablefn`` is a function to evaluate a capability.
154
155
155 ``cmd``, ``args``, and ``data`` define the command, its arguments, and
156 ``cmd``, ``args``, and ``data`` define the command, its arguments, and
156 raw data to pass to it.
157 raw data to pass to it.
157 """
158 """
158 if cmd == 'pushkey':
159 if cmd == 'pushkey':
159 args['data'] = ''
160 args['data'] = ''
160 data = args.pop('data', None)
161 data = args.pop('data', None)
161 headers = args.pop('headers', {})
162 headers = args.pop('headers', {})
162
163
163 ui.debug("sending %s command\n" % cmd)
164 ui.debug("sending %s command\n" % cmd)
164 q = [('cmd', cmd)]
165 q = [('cmd', cmd)]
165 headersize = 0
166 headersize = 0
166 # Important: don't use self.capable() here or else you end up
167 # Important: don't use self.capable() here or else you end up
167 # with infinite recursion when trying to look up capabilities
168 # with infinite recursion when trying to look up capabilities
168 # for the first time.
169 # for the first time.
169 postargsok = caps is not None and 'httppostargs' in caps
170 postargsok = caps is not None and 'httppostargs' in caps
170
171
171 # Send arguments via POST.
172 # Send arguments via POST.
172 if postargsok and args:
173 if postargsok and args:
173 strargs = urlreq.urlencode(sorted(args.items()))
174 strargs = urlreq.urlencode(sorted(args.items()))
174 if not data:
175 if not data:
175 data = strargs
176 data = strargs
176 else:
177 else:
177 if isinstance(data, bytes):
178 if isinstance(data, bytes):
178 i = io.BytesIO(data)
179 i = io.BytesIO(data)
179 i.length = len(data)
180 i.length = len(data)
180 data = i
181 data = i
181 argsio = io.BytesIO(strargs)
182 argsio = io.BytesIO(strargs)
182 argsio.length = len(strargs)
183 argsio.length = len(strargs)
183 data = _multifile(argsio, data)
184 data = _multifile(argsio, data)
184 headers[r'X-HgArgs-Post'] = len(strargs)
185 headers[r'X-HgArgs-Post'] = len(strargs)
185 elif args:
186 elif args:
186 # Calling self.capable() can infinite loop if we are calling
187 # Calling self.capable() can infinite loop if we are calling
187 # "capabilities". But that command should never accept wire
188 # "capabilities". But that command should never accept wire
188 # protocol arguments. So this should never happen.
189 # protocol arguments. So this should never happen.
189 assert cmd != 'capabilities'
190 assert cmd != 'capabilities'
190 httpheader = capablefn('httpheader')
191 httpheader = capablefn('httpheader')
191 if httpheader:
192 if httpheader:
192 headersize = int(httpheader.split(',', 1)[0])
193 headersize = int(httpheader.split(',', 1)[0])
193
194
194 # Send arguments via HTTP headers.
195 # Send arguments via HTTP headers.
195 if headersize > 0:
196 if headersize > 0:
196 # The headers can typically carry more data than the URL.
197 # The headers can typically carry more data than the URL.
197 encargs = urlreq.urlencode(sorted(args.items()))
198 encargs = urlreq.urlencode(sorted(args.items()))
198 for header, value in encodevalueinheaders(encargs, 'X-HgArg',
199 for header, value in encodevalueinheaders(encargs, 'X-HgArg',
199 headersize):
200 headersize):
200 headers[header] = value
201 headers[header] = value
201 # Send arguments via query string (Mercurial <1.9).
202 # Send arguments via query string (Mercurial <1.9).
202 else:
203 else:
203 q += sorted(args.items())
204 q += sorted(args.items())
204
205
205 qs = '?%s' % urlreq.urlencode(q)
206 qs = '?%s' % urlreq.urlencode(q)
206 cu = "%s%s" % (repobaseurl, qs)
207 cu = "%s%s" % (repobaseurl, qs)
207 size = 0
208 size = 0
208 if util.safehasattr(data, 'length'):
209 if util.safehasattr(data, 'length'):
209 size = data.length
210 size = data.length
210 elif data is not None:
211 elif data is not None:
211 size = len(data)
212 size = len(data)
212 if data is not None and r'Content-Type' not in headers:
213 if data is not None and r'Content-Type' not in headers:
213 headers[r'Content-Type'] = r'application/mercurial-0.1'
214 headers[r'Content-Type'] = r'application/mercurial-0.1'
214
215
215 # Tell the server we accept application/mercurial-0.2 and multiple
216 # Tell the server we accept application/mercurial-0.2 and multiple
216 # compression formats if the server is capable of emitting those
217 # compression formats if the server is capable of emitting those
217 # payloads.
218 # payloads.
218 # Note: Keep this set empty by default, as client advertisement of
219 # Note: Keep this set empty by default, as client advertisement of
219 # protocol parameters should only occur after the handshake.
220 # protocol parameters should only occur after the handshake.
220 protoparams = set()
221 protoparams = set()
221
222
222 mediatypes = set()
223 mediatypes = set()
223 if caps is not None:
224 if caps is not None:
224 mt = capablefn('httpmediatype')
225 mt = capablefn('httpmediatype')
225 if mt:
226 if mt:
226 protoparams.add('0.1')
227 protoparams.add('0.1')
227 mediatypes = set(mt.split(','))
228 mediatypes = set(mt.split(','))
228
229
229 protoparams.add('partial-pull')
230 protoparams.add('partial-pull')
230
231
231 if '0.2tx' in mediatypes:
232 if '0.2tx' in mediatypes:
232 protoparams.add('0.2')
233 protoparams.add('0.2')
233
234
234 if '0.2tx' in mediatypes and capablefn('compression'):
235 if '0.2tx' in mediatypes and capablefn('compression'):
235 # We /could/ compare supported compression formats and prune
236 # We /could/ compare supported compression formats and prune
236 # non-mutually supported or error if nothing is mutually supported.
237 # non-mutually supported or error if nothing is mutually supported.
237 # For now, send the full list to the server and have it error.
238 # For now, send the full list to the server and have it error.
238 comps = [e.wireprotosupport().name for e in
239 comps = [e.wireprotosupport().name for e in
239 util.compengines.supportedwireengines(util.CLIENTROLE)]
240 util.compengines.supportedwireengines(util.CLIENTROLE)]
240 protoparams.add('comp=%s' % ','.join(comps))
241 protoparams.add('comp=%s' % ','.join(comps))
241
242
242 if protoparams:
243 if protoparams:
243 protoheaders = encodevalueinheaders(' '.join(sorted(protoparams)),
244 protoheaders = encodevalueinheaders(' '.join(sorted(protoparams)),
244 'X-HgProto',
245 'X-HgProto',
245 headersize or 1024)
246 headersize or 1024)
246 for header, value in protoheaders:
247 for header, value in protoheaders:
247 headers[header] = value
248 headers[header] = value
248
249
249 varyheaders = []
250 varyheaders = []
250 for header in headers:
251 for header in headers:
251 if header.lower().startswith(r'x-hg'):
252 if header.lower().startswith(r'x-hg'):
252 varyheaders.append(header)
253 varyheaders.append(header)
253
254
254 if varyheaders:
255 if varyheaders:
255 headers[r'Vary'] = r','.join(sorted(varyheaders))
256 headers[r'Vary'] = r','.join(sorted(varyheaders))
256
257
257 req = requestbuilder(pycompat.strurl(cu), data, headers)
258 req = requestbuilder(pycompat.strurl(cu), data, headers)
258
259
259 if data is not None:
260 if data is not None:
260 ui.debug("sending %d bytes\n" % size)
261 ui.debug("sending %d bytes\n" % size)
261 req.add_unredirected_header(r'Content-Length', r'%d' % size)
262 req.add_unredirected_header(r'Content-Length', r'%d' % size)
262
263
263 return req, cu, qs
264 return req, cu, qs
264
265
265 def sendrequest(ui, opener, req):
266 def sendrequest(ui, opener, req):
266 """Send a prepared HTTP request.
267 """Send a prepared HTTP request.
267
268
268 Returns the response object.
269 Returns the response object.
269 """
270 """
270 if (ui.debugflag
271 if (ui.debugflag
271 and ui.configbool('devel', 'debug.peer-request')):
272 and ui.configbool('devel', 'debug.peer-request')):
272 dbg = ui.debug
273 dbg = ui.debug
273 line = 'devel-peer-request: %s\n'
274 line = 'devel-peer-request: %s\n'
274 dbg(line % '%s %s' % (req.get_method(), req.get_full_url()))
275 dbg(line % '%s %s' % (req.get_method(), req.get_full_url()))
275 hgargssize = None
276 hgargssize = None
276
277
277 for header, value in sorted(req.header_items()):
278 for header, value in sorted(req.header_items()):
278 if header.startswith('X-hgarg-'):
279 if header.startswith('X-hgarg-'):
279 if hgargssize is None:
280 if hgargssize is None:
280 hgargssize = 0
281 hgargssize = 0
281 hgargssize += len(value)
282 hgargssize += len(value)
282 else:
283 else:
283 dbg(line % ' %s %s' % (header, value))
284 dbg(line % ' %s %s' % (header, value))
284
285
285 if hgargssize is not None:
286 if hgargssize is not None:
286 dbg(line % ' %d bytes of commands arguments in headers'
287 dbg(line % ' %d bytes of commands arguments in headers'
287 % hgargssize)
288 % hgargssize)
288
289
289 if req.has_data():
290 if req.has_data():
290 data = req.get_data()
291 data = req.get_data()
291 length = getattr(data, 'length', None)
292 length = getattr(data, 'length', None)
292 if length is None:
293 if length is None:
293 length = len(data)
294 length = len(data)
294 dbg(line % ' %d bytes of data' % length)
295 dbg(line % ' %d bytes of data' % length)
295
296
296 start = util.timer()
297 start = util.timer()
297
298
298 try:
299 try:
299 res = opener.open(req)
300 res = opener.open(req)
300 except urlerr.httperror as inst:
301 except urlerr.httperror as inst:
301 if inst.code == 401:
302 if inst.code == 401:
302 raise error.Abort(_('authorization failed'))
303 raise error.Abort(_('authorization failed'))
303 raise
304 raise
304 except httplib.HTTPException as inst:
305 except httplib.HTTPException as inst:
305 ui.debug('http error requesting %s\n' %
306 ui.debug('http error requesting %s\n' %
306 util.hidepassword(req.get_full_url()))
307 util.hidepassword(req.get_full_url()))
307 ui.traceback()
308 ui.traceback()
308 raise IOError(None, inst)
309 raise IOError(None, inst)
309 finally:
310 finally:
310 if ui.configbool('devel', 'debug.peer-request'):
311 if ui.configbool('devel', 'debug.peer-request'):
311 dbg(line % ' finished in %.4f seconds (%s)'
312 dbg(line % ' finished in %.4f seconds (%s)'
312 % (util.timer() - start, res.code))
313 % (util.timer() - start, res.code))
313
314
314 # Insert error handlers for common I/O failures.
315 # Insert error handlers for common I/O failures.
315 _wraphttpresponse(res)
316 _wraphttpresponse(res)
316
317
317 return res
318 return res
318
319
319 def parsev1commandresponse(ui, baseurl, requrl, qs, resp, compressible,
320 def parsev1commandresponse(ui, baseurl, requrl, qs, resp, compressible,
320 allowcbor=False):
321 allowcbor=False):
321 # record the url we got redirected to
322 # record the url we got redirected to
322 respurl = pycompat.bytesurl(resp.geturl())
323 respurl = pycompat.bytesurl(resp.geturl())
323 if respurl.endswith(qs):
324 if respurl.endswith(qs):
324 respurl = respurl[:-len(qs)]
325 respurl = respurl[:-len(qs)]
325 if baseurl.rstrip('/') != respurl.rstrip('/'):
326 if baseurl.rstrip('/') != respurl.rstrip('/'):
326 if not ui.quiet:
327 if not ui.quiet:
327 ui.warn(_('real URL is %s\n') % respurl)
328 ui.warn(_('real URL is %s\n') % respurl)
328
329
329 try:
330 try:
330 proto = pycompat.bytesurl(resp.getheader(r'content-type', r''))
331 proto = pycompat.bytesurl(resp.getheader(r'content-type', r''))
331 except AttributeError:
332 except AttributeError:
332 proto = pycompat.bytesurl(resp.headers.get(r'content-type', r''))
333 proto = pycompat.bytesurl(resp.headers.get(r'content-type', r''))
333
334
334 safeurl = util.hidepassword(baseurl)
335 safeurl = util.hidepassword(baseurl)
335 if proto.startswith('application/hg-error'):
336 if proto.startswith('application/hg-error'):
336 raise error.OutOfBandError(resp.read())
337 raise error.OutOfBandError(resp.read())
337
338
338 # Pre 1.0 versions of Mercurial used text/plain and
339 # Pre 1.0 versions of Mercurial used text/plain and
339 # application/hg-changegroup. We don't support such old servers.
340 # application/hg-changegroup. We don't support such old servers.
340 if not proto.startswith('application/mercurial-'):
341 if not proto.startswith('application/mercurial-'):
341 ui.debug("requested URL: '%s'\n" % util.hidepassword(requrl))
342 ui.debug("requested URL: '%s'\n" % util.hidepassword(requrl))
342 raise error.RepoError(
343 raise error.RepoError(
343 _("'%s' does not appear to be an hg repository:\n"
344 _("'%s' does not appear to be an hg repository:\n"
344 "---%%<--- (%s)\n%s\n---%%<---\n")
345 "---%%<--- (%s)\n%s\n---%%<---\n")
345 % (safeurl, proto or 'no content-type', resp.read(1024)))
346 % (safeurl, proto or 'no content-type', resp.read(1024)))
346
347
347 try:
348 try:
348 subtype = proto.split('-', 1)[1]
349 subtype = proto.split('-', 1)[1]
349
350
350 # Unless we end up supporting CBOR in the legacy wire protocol,
351 # Unless we end up supporting CBOR in the legacy wire protocol,
351 # this should ONLY be encountered for the initial capabilities
352 # this should ONLY be encountered for the initial capabilities
352 # request during handshake.
353 # request during handshake.
353 if subtype == 'cbor':
354 if subtype == 'cbor':
354 if allowcbor:
355 if allowcbor:
355 return respurl, proto, resp
356 return respurl, proto, resp
356 else:
357 else:
357 raise error.RepoError(_('unexpected CBOR response from '
358 raise error.RepoError(_('unexpected CBOR response from '
358 'server'))
359 'server'))
359
360
360 version_info = tuple([int(n) for n in subtype.split('.')])
361 version_info = tuple([int(n) for n in subtype.split('.')])
361 except ValueError:
362 except ValueError:
362 raise error.RepoError(_("'%s' sent a broken Content-Type "
363 raise error.RepoError(_("'%s' sent a broken Content-Type "
363 "header (%s)") % (safeurl, proto))
364 "header (%s)") % (safeurl, proto))
364
365
365 # TODO consider switching to a decompression reader that uses
366 # TODO consider switching to a decompression reader that uses
366 # generators.
367 # generators.
367 if version_info == (0, 1):
368 if version_info == (0, 1):
368 if compressible:
369 if compressible:
369 resp = util.compengines['zlib'].decompressorreader(resp)
370 resp = util.compengines['zlib'].decompressorreader(resp)
370
371
371 elif version_info == (0, 2):
372 elif version_info == (0, 2):
372 # application/mercurial-0.2 always identifies the compression
373 # application/mercurial-0.2 always identifies the compression
373 # engine in the payload header.
374 # engine in the payload header.
374 elen = struct.unpack('B', resp.read(1))[0]
375 elen = struct.unpack('B', resp.read(1))[0]
375 ename = resp.read(elen)
376 ename = resp.read(elen)
376 engine = util.compengines.forwiretype(ename)
377 engine = util.compengines.forwiretype(ename)
377
378
378 resp = engine.decompressorreader(resp)
379 resp = engine.decompressorreader(resp)
379 else:
380 else:
380 raise error.RepoError(_("'%s' uses newer protocol %s") %
381 raise error.RepoError(_("'%s' uses newer protocol %s") %
381 (safeurl, subtype))
382 (safeurl, subtype))
382
383
383 return respurl, proto, resp
384 return respurl, proto, resp
384
385
385 class httppeer(wireproto.wirepeer):
386 class httppeer(wireprotov1peer.wirepeer):
386 def __init__(self, ui, path, url, opener, requestbuilder, caps):
387 def __init__(self, ui, path, url, opener, requestbuilder, caps):
387 self.ui = ui
388 self.ui = ui
388 self._path = path
389 self._path = path
389 self._url = url
390 self._url = url
390 self._caps = caps
391 self._caps = caps
391 self._urlopener = opener
392 self._urlopener = opener
392 self._requestbuilder = requestbuilder
393 self._requestbuilder = requestbuilder
393
394
394 def __del__(self):
395 def __del__(self):
395 for h in self._urlopener.handlers:
396 for h in self._urlopener.handlers:
396 h.close()
397 h.close()
397 getattr(h, "close_all", lambda: None)()
398 getattr(h, "close_all", lambda: None)()
398
399
399 # Begin of ipeerconnection interface.
400 # Begin of ipeerconnection interface.
400
401
401 def url(self):
402 def url(self):
402 return self._path
403 return self._path
403
404
404 def local(self):
405 def local(self):
405 return None
406 return None
406
407
407 def peer(self):
408 def peer(self):
408 return self
409 return self
409
410
410 def canpush(self):
411 def canpush(self):
411 return True
412 return True
412
413
413 def close(self):
414 def close(self):
414 pass
415 pass
415
416
416 # End of ipeerconnection interface.
417 # End of ipeerconnection interface.
417
418
418 # Begin of ipeercommands interface.
419 # Begin of ipeercommands interface.
419
420
420 def capabilities(self):
421 def capabilities(self):
421 return self._caps
422 return self._caps
422
423
423 # End of ipeercommands interface.
424 # End of ipeercommands interface.
424
425
425 # look up capabilities only when needed
426 # look up capabilities only when needed
426
427
427 def _callstream(self, cmd, _compressible=False, **args):
428 def _callstream(self, cmd, _compressible=False, **args):
428 args = pycompat.byteskwargs(args)
429 args = pycompat.byteskwargs(args)
429
430
430 req, cu, qs = makev1commandrequest(self.ui, self._requestbuilder,
431 req, cu, qs = makev1commandrequest(self.ui, self._requestbuilder,
431 self._caps, self.capable,
432 self._caps, self.capable,
432 self._url, cmd, args)
433 self._url, cmd, args)
433
434
434 resp = sendrequest(self.ui, self._urlopener, req)
435 resp = sendrequest(self.ui, self._urlopener, req)
435
436
436 self._url, ct, resp = parsev1commandresponse(self.ui, self._url, cu, qs,
437 self._url, ct, resp = parsev1commandresponse(self.ui, self._url, cu, qs,
437 resp, _compressible)
438 resp, _compressible)
438
439
439 return resp
440 return resp
440
441
441 def _call(self, cmd, **args):
442 def _call(self, cmd, **args):
442 fp = self._callstream(cmd, **args)
443 fp = self._callstream(cmd, **args)
443 try:
444 try:
444 return fp.read()
445 return fp.read()
445 finally:
446 finally:
446 # if using keepalive, allow connection to be reused
447 # if using keepalive, allow connection to be reused
447 fp.close()
448 fp.close()
448
449
449 def _callpush(self, cmd, cg, **args):
450 def _callpush(self, cmd, cg, **args):
450 # have to stream bundle to a temp file because we do not have
451 # have to stream bundle to a temp file because we do not have
451 # http 1.1 chunked transfer.
452 # http 1.1 chunked transfer.
452
453
453 types = self.capable('unbundle')
454 types = self.capable('unbundle')
454 try:
455 try:
455 types = types.split(',')
456 types = types.split(',')
456 except AttributeError:
457 except AttributeError:
457 # servers older than d1b16a746db6 will send 'unbundle' as a
458 # servers older than d1b16a746db6 will send 'unbundle' as a
458 # boolean capability. They only support headerless/uncompressed
459 # boolean capability. They only support headerless/uncompressed
459 # bundles.
460 # bundles.
460 types = [""]
461 types = [""]
461 for x in types:
462 for x in types:
462 if x in bundle2.bundletypes:
463 if x in bundle2.bundletypes:
463 type = x
464 type = x
464 break
465 break
465
466
466 tempname = bundle2.writebundle(self.ui, cg, None, type)
467 tempname = bundle2.writebundle(self.ui, cg, None, type)
467 fp = httpconnection.httpsendfile(self.ui, tempname, "rb")
468 fp = httpconnection.httpsendfile(self.ui, tempname, "rb")
468 headers = {r'Content-Type': r'application/mercurial-0.1'}
469 headers = {r'Content-Type': r'application/mercurial-0.1'}
469
470
470 try:
471 try:
471 r = self._call(cmd, data=fp, headers=headers, **args)
472 r = self._call(cmd, data=fp, headers=headers, **args)
472 vals = r.split('\n', 1)
473 vals = r.split('\n', 1)
473 if len(vals) < 2:
474 if len(vals) < 2:
474 raise error.ResponseError(_("unexpected response:"), r)
475 raise error.ResponseError(_("unexpected response:"), r)
475 return vals
476 return vals
476 except urlerr.httperror:
477 except urlerr.httperror:
477 # Catch and re-raise these so we don't try and treat them
478 # Catch and re-raise these so we don't try and treat them
478 # like generic socket errors. They lack any values in
479 # like generic socket errors. They lack any values in
479 # .args on Python 3 which breaks our socket.error block.
480 # .args on Python 3 which breaks our socket.error block.
480 raise
481 raise
481 except socket.error as err:
482 except socket.error as err:
482 if err.args[0] in (errno.ECONNRESET, errno.EPIPE):
483 if err.args[0] in (errno.ECONNRESET, errno.EPIPE):
483 raise error.Abort(_('push failed: %s') % err.args[1])
484 raise error.Abort(_('push failed: %s') % err.args[1])
484 raise error.Abort(err.args[1])
485 raise error.Abort(err.args[1])
485 finally:
486 finally:
486 fp.close()
487 fp.close()
487 os.unlink(tempname)
488 os.unlink(tempname)
488
489
489 def _calltwowaystream(self, cmd, fp, **args):
490 def _calltwowaystream(self, cmd, fp, **args):
490 fh = None
491 fh = None
491 fp_ = None
492 fp_ = None
492 filename = None
493 filename = None
493 try:
494 try:
494 # dump bundle to disk
495 # dump bundle to disk
495 fd, filename = tempfile.mkstemp(prefix="hg-bundle-", suffix=".hg")
496 fd, filename = tempfile.mkstemp(prefix="hg-bundle-", suffix=".hg")
496 fh = os.fdopen(fd, r"wb")
497 fh = os.fdopen(fd, r"wb")
497 d = fp.read(4096)
498 d = fp.read(4096)
498 while d:
499 while d:
499 fh.write(d)
500 fh.write(d)
500 d = fp.read(4096)
501 d = fp.read(4096)
501 fh.close()
502 fh.close()
502 # start http push
503 # start http push
503 fp_ = httpconnection.httpsendfile(self.ui, filename, "rb")
504 fp_ = httpconnection.httpsendfile(self.ui, filename, "rb")
504 headers = {r'Content-Type': r'application/mercurial-0.1'}
505 headers = {r'Content-Type': r'application/mercurial-0.1'}
505 return self._callstream(cmd, data=fp_, headers=headers, **args)
506 return self._callstream(cmd, data=fp_, headers=headers, **args)
506 finally:
507 finally:
507 if fp_ is not None:
508 if fp_ is not None:
508 fp_.close()
509 fp_.close()
509 if fh is not None:
510 if fh is not None:
510 fh.close()
511 fh.close()
511 os.unlink(filename)
512 os.unlink(filename)
512
513
513 def _callcompressable(self, cmd, **args):
514 def _callcompressable(self, cmd, **args):
514 return self._callstream(cmd, _compressible=True, **args)
515 return self._callstream(cmd, _compressible=True, **args)
515
516
516 def _abort(self, exception):
517 def _abort(self, exception):
517 raise exception
518 raise exception
518
519
519 # TODO implement interface for version 2 peers
520 # TODO implement interface for version 2 peers
520 @zi.implementer(repository.ipeerconnection, repository.ipeercapabilities)
521 @zi.implementer(repository.ipeerconnection, repository.ipeercapabilities)
521 class httpv2peer(object):
522 class httpv2peer(object):
522 def __init__(self, ui, repourl, apipath, opener, requestbuilder,
523 def __init__(self, ui, repourl, apipath, opener, requestbuilder,
523 apidescriptor):
524 apidescriptor):
524 self.ui = ui
525 self.ui = ui
525
526
526 if repourl.endswith('/'):
527 if repourl.endswith('/'):
527 repourl = repourl[:-1]
528 repourl = repourl[:-1]
528
529
529 self._url = repourl
530 self._url = repourl
530 self._apipath = apipath
531 self._apipath = apipath
531 self._opener = opener
532 self._opener = opener
532 self._requestbuilder = requestbuilder
533 self._requestbuilder = requestbuilder
533 self._descriptor = apidescriptor
534 self._descriptor = apidescriptor
534
535
535 # Start of ipeerconnection.
536 # Start of ipeerconnection.
536
537
537 def url(self):
538 def url(self):
538 return self._url
539 return self._url
539
540
540 def local(self):
541 def local(self):
541 return None
542 return None
542
543
543 def peer(self):
544 def peer(self):
544 return self
545 return self
545
546
546 def canpush(self):
547 def canpush(self):
547 # TODO change once implemented.
548 # TODO change once implemented.
548 return False
549 return False
549
550
550 def close(self):
551 def close(self):
551 pass
552 pass
552
553
553 # End of ipeerconnection.
554 # End of ipeerconnection.
554
555
555 # Start of ipeercapabilities.
556 # Start of ipeercapabilities.
556
557
557 def capable(self, name):
558 def capable(self, name):
558 # The capabilities used internally historically map to capabilities
559 # The capabilities used internally historically map to capabilities
559 # advertised from the "capabilities" wire protocol command. However,
560 # advertised from the "capabilities" wire protocol command. However,
560 # version 2 of that command works differently.
561 # version 2 of that command works differently.
561
562
562 # Maps to commands that are available.
563 # Maps to commands that are available.
563 if name in ('branchmap', 'getbundle', 'known', 'lookup', 'pushkey'):
564 if name in ('branchmap', 'getbundle', 'known', 'lookup', 'pushkey'):
564 return True
565 return True
565
566
566 # Other concepts.
567 # Other concepts.
567 if name in ('bundle2',):
568 if name in ('bundle2',):
568 return True
569 return True
569
570
570 return False
571 return False
571
572
572 def requirecap(self, name, purpose):
573 def requirecap(self, name, purpose):
573 if self.capable(name):
574 if self.capable(name):
574 return
575 return
575
576
576 raise error.CapabilityError(
577 raise error.CapabilityError(
577 _('cannot %s; client or remote repository does not support the %r '
578 _('cannot %s; client or remote repository does not support the %r '
578 'capability') % (purpose, name))
579 'capability') % (purpose, name))
579
580
580 # End of ipeercapabilities.
581 # End of ipeercapabilities.
581
582
582 # TODO require to be part of a batched primitive, use futures.
583 # TODO require to be part of a batched primitive, use futures.
583 def _call(self, name, **args):
584 def _call(self, name, **args):
584 """Call a wire protocol command with arguments."""
585 """Call a wire protocol command with arguments."""
585
586
586 # Having this early has a side-effect of importing wireprotov2server,
587 # Having this early has a side-effect of importing wireprotov2server,
587 # which has the side-effect of ensuring commands are registered.
588 # which has the side-effect of ensuring commands are registered.
588
589
589 # TODO modify user-agent to reflect v2.
590 # TODO modify user-agent to reflect v2.
590 headers = {
591 headers = {
591 r'Accept': wireprotov2server.FRAMINGTYPE,
592 r'Accept': wireprotov2server.FRAMINGTYPE,
592 r'Content-Type': wireprotov2server.FRAMINGTYPE,
593 r'Content-Type': wireprotov2server.FRAMINGTYPE,
593 }
594 }
594
595
595 # TODO permissions should come from capabilities results.
596 # TODO permissions should come from capabilities results.
596 permission = wireproto.commandsv2[name].permission
597 permission = wireproto.commandsv2[name].permission
597 if permission not in ('push', 'pull'):
598 if permission not in ('push', 'pull'):
598 raise error.ProgrammingError('unknown permission type: %s' %
599 raise error.ProgrammingError('unknown permission type: %s' %
599 permission)
600 permission)
600
601
601 permission = {
602 permission = {
602 'push': 'rw',
603 'push': 'rw',
603 'pull': 'ro',
604 'pull': 'ro',
604 }[permission]
605 }[permission]
605
606
606 url = '%s/%s/%s/%s' % (self._url, self._apipath, permission, name)
607 url = '%s/%s/%s/%s' % (self._url, self._apipath, permission, name)
607
608
608 # TODO this should be part of a generic peer for the frame-based
609 # TODO this should be part of a generic peer for the frame-based
609 # protocol.
610 # protocol.
610 reactor = wireprotoframing.clientreactor(hasmultiplesend=False,
611 reactor = wireprotoframing.clientreactor(hasmultiplesend=False,
611 buffersends=True)
612 buffersends=True)
612
613
613 request, action, meta = reactor.callcommand(name, args)
614 request, action, meta = reactor.callcommand(name, args)
614 assert action == 'noop'
615 assert action == 'noop'
615
616
616 action, meta = reactor.flushcommands()
617 action, meta = reactor.flushcommands()
617 assert action == 'sendframes'
618 assert action == 'sendframes'
618
619
619 body = b''.join(map(bytes, meta['framegen']))
620 body = b''.join(map(bytes, meta['framegen']))
620 req = self._requestbuilder(pycompat.strurl(url), body, headers)
621 req = self._requestbuilder(pycompat.strurl(url), body, headers)
621 req.add_unredirected_header(r'Content-Length', r'%d' % len(body))
622 req.add_unredirected_header(r'Content-Length', r'%d' % len(body))
622
623
623 # TODO unify this code with httppeer.
624 # TODO unify this code with httppeer.
624 try:
625 try:
625 res = self._opener.open(req)
626 res = self._opener.open(req)
626 except urlerr.httperror as e:
627 except urlerr.httperror as e:
627 if e.code == 401:
628 if e.code == 401:
628 raise error.Abort(_('authorization failed'))
629 raise error.Abort(_('authorization failed'))
629
630
630 raise
631 raise
631 except httplib.HTTPException as e:
632 except httplib.HTTPException as e:
632 self.ui.traceback()
633 self.ui.traceback()
633 raise IOError(None, e)
634 raise IOError(None, e)
634
635
635 # TODO validate response type, wrap response to handle I/O errors.
636 # TODO validate response type, wrap response to handle I/O errors.
636 # TODO more robust frame receiver.
637 # TODO more robust frame receiver.
637 results = []
638 results = []
638
639
639 while True:
640 while True:
640 frame = wireprotoframing.readframe(res)
641 frame = wireprotoframing.readframe(res)
641 if frame is None:
642 if frame is None:
642 break
643 break
643
644
644 self.ui.note(_('received %r\n') % frame)
645 self.ui.note(_('received %r\n') % frame)
645
646
646 action, meta = reactor.onframerecv(frame)
647 action, meta = reactor.onframerecv(frame)
647
648
648 if action == 'responsedata':
649 if action == 'responsedata':
649 if meta['cbor']:
650 if meta['cbor']:
650 payload = util.bytesio(meta['data'])
651 payload = util.bytesio(meta['data'])
651
652
652 decoder = cbor.CBORDecoder(payload)
653 decoder = cbor.CBORDecoder(payload)
653 while payload.tell() + 1 < len(meta['data']):
654 while payload.tell() + 1 < len(meta['data']):
654 results.append(decoder.decode())
655 results.append(decoder.decode())
655 else:
656 else:
656 results.append(meta['data'])
657 results.append(meta['data'])
657 else:
658 else:
658 error.ProgrammingError('unhandled action: %s' % action)
659 error.ProgrammingError('unhandled action: %s' % action)
659
660
660 return results
661 return results
661
662
662 # Registry of API service names to metadata about peers that handle it.
663 # Registry of API service names to metadata about peers that handle it.
663 #
664 #
664 # The following keys are meaningful:
665 # The following keys are meaningful:
665 #
666 #
666 # init
667 # init
667 # Callable receiving (ui, repourl, servicepath, opener, requestbuilder,
668 # Callable receiving (ui, repourl, servicepath, opener, requestbuilder,
668 # apidescriptor) to create a peer.
669 # apidescriptor) to create a peer.
669 #
670 #
670 # priority
671 # priority
671 # Integer priority for the service. If we could choose from multiple
672 # Integer priority for the service. If we could choose from multiple
672 # services, we choose the one with the highest priority.
673 # services, we choose the one with the highest priority.
673 API_PEERS = {
674 API_PEERS = {
674 wireprototypes.HTTPV2: {
675 wireprototypes.HTTPV2: {
675 'init': httpv2peer,
676 'init': httpv2peer,
676 'priority': 50,
677 'priority': 50,
677 },
678 },
678 }
679 }
679
680
680 def performhandshake(ui, url, opener, requestbuilder):
681 def performhandshake(ui, url, opener, requestbuilder):
681 # The handshake is a request to the capabilities command.
682 # The handshake is a request to the capabilities command.
682
683
683 caps = None
684 caps = None
684 def capable(x):
685 def capable(x):
685 raise error.ProgrammingError('should not be called')
686 raise error.ProgrammingError('should not be called')
686
687
687 args = {}
688 args = {}
688
689
689 # The client advertises support for newer protocols by adding an
690 # The client advertises support for newer protocols by adding an
690 # X-HgUpgrade-* header with a list of supported APIs and an
691 # X-HgUpgrade-* header with a list of supported APIs and an
691 # X-HgProto-* header advertising which serializing formats it supports.
692 # X-HgProto-* header advertising which serializing formats it supports.
692 # We only support the HTTP version 2 transport and CBOR responses for
693 # We only support the HTTP version 2 transport and CBOR responses for
693 # now.
694 # now.
694 advertisev2 = ui.configbool('experimental', 'httppeer.advertise-v2')
695 advertisev2 = ui.configbool('experimental', 'httppeer.advertise-v2')
695
696
696 if advertisev2:
697 if advertisev2:
697 args['headers'] = {
698 args['headers'] = {
698 r'X-HgProto-1': r'cbor',
699 r'X-HgProto-1': r'cbor',
699 }
700 }
700
701
701 args['headers'].update(
702 args['headers'].update(
702 encodevalueinheaders(' '.join(sorted(API_PEERS)),
703 encodevalueinheaders(' '.join(sorted(API_PEERS)),
703 'X-HgUpgrade',
704 'X-HgUpgrade',
704 # We don't know the header limit this early.
705 # We don't know the header limit this early.
705 # So make it small.
706 # So make it small.
706 1024))
707 1024))
707
708
708 req, requrl, qs = makev1commandrequest(ui, requestbuilder, caps,
709 req, requrl, qs = makev1commandrequest(ui, requestbuilder, caps,
709 capable, url, 'capabilities',
710 capable, url, 'capabilities',
710 args)
711 args)
711
712
712 resp = sendrequest(ui, opener, req)
713 resp = sendrequest(ui, opener, req)
713
714
714 respurl, ct, resp = parsev1commandresponse(ui, url, requrl, qs, resp,
715 respurl, ct, resp = parsev1commandresponse(ui, url, requrl, qs, resp,
715 compressible=False,
716 compressible=False,
716 allowcbor=advertisev2)
717 allowcbor=advertisev2)
717
718
718 try:
719 try:
719 rawdata = resp.read()
720 rawdata = resp.read()
720 finally:
721 finally:
721 resp.close()
722 resp.close()
722
723
723 if not ct.startswith('application/mercurial-'):
724 if not ct.startswith('application/mercurial-'):
724 raise error.ProgrammingError('unexpected content-type: %s' % ct)
725 raise error.ProgrammingError('unexpected content-type: %s' % ct)
725
726
726 if advertisev2:
727 if advertisev2:
727 if ct == 'application/mercurial-cbor':
728 if ct == 'application/mercurial-cbor':
728 try:
729 try:
729 info = cbor.loads(rawdata)
730 info = cbor.loads(rawdata)
730 except cbor.CBORDecodeError:
731 except cbor.CBORDecodeError:
731 raise error.Abort(_('error decoding CBOR from remote server'),
732 raise error.Abort(_('error decoding CBOR from remote server'),
732 hint=_('try again and consider contacting '
733 hint=_('try again and consider contacting '
733 'the server operator'))
734 'the server operator'))
734
735
735 # We got a legacy response. That's fine.
736 # We got a legacy response. That's fine.
736 elif ct in ('application/mercurial-0.1', 'application/mercurial-0.2'):
737 elif ct in ('application/mercurial-0.1', 'application/mercurial-0.2'):
737 info = {
738 info = {
738 'v1capabilities': set(rawdata.split())
739 'v1capabilities': set(rawdata.split())
739 }
740 }
740
741
741 else:
742 else:
742 raise error.RepoError(
743 raise error.RepoError(
743 _('unexpected response type from server: %s') % ct)
744 _('unexpected response type from server: %s') % ct)
744 else:
745 else:
745 info = {
746 info = {
746 'v1capabilities': set(rawdata.split())
747 'v1capabilities': set(rawdata.split())
747 }
748 }
748
749
749 return respurl, info
750 return respurl, info
750
751
751 def makepeer(ui, path, opener=None, requestbuilder=urlreq.request):
752 def makepeer(ui, path, opener=None, requestbuilder=urlreq.request):
752 """Construct an appropriate HTTP peer instance.
753 """Construct an appropriate HTTP peer instance.
753
754
754 ``opener`` is an ``url.opener`` that should be used to establish
755 ``opener`` is an ``url.opener`` that should be used to establish
755 connections, perform HTTP requests.
756 connections, perform HTTP requests.
756
757
757 ``requestbuilder`` is the type used for constructing HTTP requests.
758 ``requestbuilder`` is the type used for constructing HTTP requests.
758 It exists as an argument so extensions can override the default.
759 It exists as an argument so extensions can override the default.
759 """
760 """
760 u = util.url(path)
761 u = util.url(path)
761 if u.query or u.fragment:
762 if u.query or u.fragment:
762 raise error.Abort(_('unsupported URL component: "%s"') %
763 raise error.Abort(_('unsupported URL component: "%s"') %
763 (u.query or u.fragment))
764 (u.query or u.fragment))
764
765
765 # urllib cannot handle URLs with embedded user or passwd.
766 # urllib cannot handle URLs with embedded user or passwd.
766 url, authinfo = u.authinfo()
767 url, authinfo = u.authinfo()
767 ui.debug('using %s\n' % url)
768 ui.debug('using %s\n' % url)
768
769
769 opener = opener or urlmod.opener(ui, authinfo)
770 opener = opener or urlmod.opener(ui, authinfo)
770
771
771 respurl, info = performhandshake(ui, url, opener, requestbuilder)
772 respurl, info = performhandshake(ui, url, opener, requestbuilder)
772
773
773 # Given the intersection of APIs that both we and the server support,
774 # Given the intersection of APIs that both we and the server support,
774 # sort by their advertised priority and pick the first one.
775 # sort by their advertised priority and pick the first one.
775 #
776 #
776 # TODO consider making this request-based and interface driven. For
777 # TODO consider making this request-based and interface driven. For
777 # example, the caller could say "I want a peer that does X." It's quite
778 # example, the caller could say "I want a peer that does X." It's quite
778 # possible that not all peers would do that. Since we know the service
779 # possible that not all peers would do that. Since we know the service
779 # capabilities, we could filter out services not meeting the
780 # capabilities, we could filter out services not meeting the
780 # requirements. Possibly by consulting the interfaces defined by the
781 # requirements. Possibly by consulting the interfaces defined by the
781 # peer type.
782 # peer type.
782 apipeerchoices = set(info.get('apis', {}).keys()) & set(API_PEERS.keys())
783 apipeerchoices = set(info.get('apis', {}).keys()) & set(API_PEERS.keys())
783
784
784 preferredchoices = sorted(apipeerchoices,
785 preferredchoices = sorted(apipeerchoices,
785 key=lambda x: API_PEERS[x]['priority'],
786 key=lambda x: API_PEERS[x]['priority'],
786 reverse=True)
787 reverse=True)
787
788
788 for service in preferredchoices:
789 for service in preferredchoices:
789 apipath = '%s/%s' % (info['apibase'].rstrip('/'), service)
790 apipath = '%s/%s' % (info['apibase'].rstrip('/'), service)
790
791
791 return API_PEERS[service]['init'](ui, respurl, apipath, opener,
792 return API_PEERS[service]['init'](ui, respurl, apipath, opener,
792 requestbuilder,
793 requestbuilder,
793 info['apis'][service])
794 info['apis'][service])
794
795
795 # Failed to construct an API peer. Fall back to legacy.
796 # Failed to construct an API peer. Fall back to legacy.
796 return httppeer(ui, path, respurl, opener, requestbuilder,
797 return httppeer(ui, path, respurl, opener, requestbuilder,
797 info['v1capabilities'])
798 info['v1capabilities'])
798
799
799 def instance(ui, path, create):
800 def instance(ui, path, create):
800 if create:
801 if create:
801 raise error.Abort(_('cannot create new http repository'))
802 raise error.Abort(_('cannot create new http repository'))
802 try:
803 try:
803 if path.startswith('https:') and not urlmod.has_https:
804 if path.startswith('https:') and not urlmod.has_https:
804 raise error.Abort(_('Python support for SSL and HTTPS '
805 raise error.Abort(_('Python support for SSL and HTTPS '
805 'is not installed'))
806 'is not installed'))
806
807
807 inst = makepeer(ui, path)
808 inst = makepeer(ui, path)
808
809
809 return inst
810 return inst
810 except error.RepoError as httpexception:
811 except error.RepoError as httpexception:
811 try:
812 try:
812 r = statichttprepo.instance(ui, "static-" + path, create)
813 r = statichttprepo.instance(ui, "static-" + path, create)
813 ui.note(_('(falling back to static-http)\n'))
814 ui.note(_('(falling back to static-http)\n'))
814 return r
815 return r
815 except error.RepoError:
816 except error.RepoError:
816 raise httpexception # use the original http RepoError instead
817 raise httpexception # use the original http RepoError instead
@@ -1,635 +1,636 b''
1 # sshpeer.py - ssh repository proxy class for mercurial
1 # sshpeer.py - ssh repository proxy class for mercurial
2 #
2 #
3 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import re
10 import re
11 import uuid
11 import uuid
12
12
13 from .i18n import _
13 from .i18n import _
14 from . import (
14 from . import (
15 error,
15 error,
16 pycompat,
16 pycompat,
17 util,
17 util,
18 wireproto,
18 wireproto,
19 wireprotoserver,
19 wireprotoserver,
20 wireprototypes,
20 wireprototypes,
21 wireprotov1peer,
21 )
22 )
22 from .utils import (
23 from .utils import (
23 procutil,
24 procutil,
24 )
25 )
25
26
26 def _serverquote(s):
27 def _serverquote(s):
27 """quote a string for the remote shell ... which we assume is sh"""
28 """quote a string for the remote shell ... which we assume is sh"""
28 if not s:
29 if not s:
29 return s
30 return s
30 if re.match('[a-zA-Z0-9@%_+=:,./-]*$', s):
31 if re.match('[a-zA-Z0-9@%_+=:,./-]*$', s):
31 return s
32 return s
32 return "'%s'" % s.replace("'", "'\\''")
33 return "'%s'" % s.replace("'", "'\\''")
33
34
34 def _forwardoutput(ui, pipe):
35 def _forwardoutput(ui, pipe):
35 """display all data currently available on pipe as remote output.
36 """display all data currently available on pipe as remote output.
36
37
37 This is non blocking."""
38 This is non blocking."""
38 if pipe:
39 if pipe:
39 s = procutil.readpipe(pipe)
40 s = procutil.readpipe(pipe)
40 if s:
41 if s:
41 for l in s.splitlines():
42 for l in s.splitlines():
42 ui.status(_("remote: "), l, '\n')
43 ui.status(_("remote: "), l, '\n')
43
44
44 class doublepipe(object):
45 class doublepipe(object):
45 """Operate a side-channel pipe in addition of a main one
46 """Operate a side-channel pipe in addition of a main one
46
47
47 The side-channel pipe contains server output to be forwarded to the user
48 The side-channel pipe contains server output to be forwarded to the user
48 input. The double pipe will behave as the "main" pipe, but will ensure the
49 input. The double pipe will behave as the "main" pipe, but will ensure the
49 content of the "side" pipe is properly processed while we wait for blocking
50 content of the "side" pipe is properly processed while we wait for blocking
50 call on the "main" pipe.
51 call on the "main" pipe.
51
52
52 If large amounts of data are read from "main", the forward will cease after
53 If large amounts of data are read from "main", the forward will cease after
53 the first bytes start to appear. This simplifies the implementation
54 the first bytes start to appear. This simplifies the implementation
54 without affecting actual output of sshpeer too much as we rarely issue
55 without affecting actual output of sshpeer too much as we rarely issue
55 large read for data not yet emitted by the server.
56 large read for data not yet emitted by the server.
56
57
57 The main pipe is expected to be a 'bufferedinputpipe' from the util module
58 The main pipe is expected to be a 'bufferedinputpipe' from the util module
58 that handle all the os specific bits. This class lives in this module
59 that handle all the os specific bits. This class lives in this module
59 because it focus on behavior specific to the ssh protocol."""
60 because it focus on behavior specific to the ssh protocol."""
60
61
61 def __init__(self, ui, main, side):
62 def __init__(self, ui, main, side):
62 self._ui = ui
63 self._ui = ui
63 self._main = main
64 self._main = main
64 self._side = side
65 self._side = side
65
66
66 def _wait(self):
67 def _wait(self):
67 """wait until some data are available on main or side
68 """wait until some data are available on main or side
68
69
69 return a pair of boolean (ismainready, issideready)
70 return a pair of boolean (ismainready, issideready)
70
71
71 (This will only wait for data if the setup is supported by `util.poll`)
72 (This will only wait for data if the setup is supported by `util.poll`)
72 """
73 """
73 if (isinstance(self._main, util.bufferedinputpipe) and
74 if (isinstance(self._main, util.bufferedinputpipe) and
74 self._main.hasbuffer):
75 self._main.hasbuffer):
75 # Main has data. Assume side is worth poking at.
76 # Main has data. Assume side is worth poking at.
76 return True, True
77 return True, True
77
78
78 fds = [self._main.fileno(), self._side.fileno()]
79 fds = [self._main.fileno(), self._side.fileno()]
79 try:
80 try:
80 act = util.poll(fds)
81 act = util.poll(fds)
81 except NotImplementedError:
82 except NotImplementedError:
82 # non supported yet case, assume all have data.
83 # non supported yet case, assume all have data.
83 act = fds
84 act = fds
84 return (self._main.fileno() in act, self._side.fileno() in act)
85 return (self._main.fileno() in act, self._side.fileno() in act)
85
86
86 def write(self, data):
87 def write(self, data):
87 return self._call('write', data)
88 return self._call('write', data)
88
89
89 def read(self, size):
90 def read(self, size):
90 r = self._call('read', size)
91 r = self._call('read', size)
91 if size != 0 and not r:
92 if size != 0 and not r:
92 # We've observed a condition that indicates the
93 # We've observed a condition that indicates the
93 # stdout closed unexpectedly. Check stderr one
94 # stdout closed unexpectedly. Check stderr one
94 # more time and snag anything that's there before
95 # more time and snag anything that's there before
95 # letting anyone know the main part of the pipe
96 # letting anyone know the main part of the pipe
96 # closed prematurely.
97 # closed prematurely.
97 _forwardoutput(self._ui, self._side)
98 _forwardoutput(self._ui, self._side)
98 return r
99 return r
99
100
100 def readline(self):
101 def readline(self):
101 return self._call('readline')
102 return self._call('readline')
102
103
103 def _call(self, methname, data=None):
104 def _call(self, methname, data=None):
104 """call <methname> on "main", forward output of "side" while blocking
105 """call <methname> on "main", forward output of "side" while blocking
105 """
106 """
106 # data can be '' or 0
107 # data can be '' or 0
107 if (data is not None and not data) or self._main.closed:
108 if (data is not None and not data) or self._main.closed:
108 _forwardoutput(self._ui, self._side)
109 _forwardoutput(self._ui, self._side)
109 return ''
110 return ''
110 while True:
111 while True:
111 mainready, sideready = self._wait()
112 mainready, sideready = self._wait()
112 if sideready:
113 if sideready:
113 _forwardoutput(self._ui, self._side)
114 _forwardoutput(self._ui, self._side)
114 if mainready:
115 if mainready:
115 meth = getattr(self._main, methname)
116 meth = getattr(self._main, methname)
116 if data is None:
117 if data is None:
117 return meth()
118 return meth()
118 else:
119 else:
119 return meth(data)
120 return meth(data)
120
121
121 def close(self):
122 def close(self):
122 return self._main.close()
123 return self._main.close()
123
124
124 def flush(self):
125 def flush(self):
125 return self._main.flush()
126 return self._main.flush()
126
127
127 def _cleanuppipes(ui, pipei, pipeo, pipee):
128 def _cleanuppipes(ui, pipei, pipeo, pipee):
128 """Clean up pipes used by an SSH connection."""
129 """Clean up pipes used by an SSH connection."""
129 if pipeo:
130 if pipeo:
130 pipeo.close()
131 pipeo.close()
131 if pipei:
132 if pipei:
132 pipei.close()
133 pipei.close()
133
134
134 if pipee:
135 if pipee:
135 # Try to read from the err descriptor until EOF.
136 # Try to read from the err descriptor until EOF.
136 try:
137 try:
137 for l in pipee:
138 for l in pipee:
138 ui.status(_('remote: '), l)
139 ui.status(_('remote: '), l)
139 except (IOError, ValueError):
140 except (IOError, ValueError):
140 pass
141 pass
141
142
142 pipee.close()
143 pipee.close()
143
144
144 def _makeconnection(ui, sshcmd, args, remotecmd, path, sshenv=None):
145 def _makeconnection(ui, sshcmd, args, remotecmd, path, sshenv=None):
145 """Create an SSH connection to a server.
146 """Create an SSH connection to a server.
146
147
147 Returns a tuple of (process, stdin, stdout, stderr) for the
148 Returns a tuple of (process, stdin, stdout, stderr) for the
148 spawned process.
149 spawned process.
149 """
150 """
150 cmd = '%s %s %s' % (
151 cmd = '%s %s %s' % (
151 sshcmd,
152 sshcmd,
152 args,
153 args,
153 procutil.shellquote('%s -R %s serve --stdio' % (
154 procutil.shellquote('%s -R %s serve --stdio' % (
154 _serverquote(remotecmd), _serverquote(path))))
155 _serverquote(remotecmd), _serverquote(path))))
155
156
156 ui.debug('running %s\n' % cmd)
157 ui.debug('running %s\n' % cmd)
157 cmd = procutil.quotecommand(cmd)
158 cmd = procutil.quotecommand(cmd)
158
159
159 # no buffer allow the use of 'select'
160 # no buffer allow the use of 'select'
160 # feel free to remove buffering and select usage when we ultimately
161 # feel free to remove buffering and select usage when we ultimately
161 # move to threading.
162 # move to threading.
162 stdin, stdout, stderr, proc = procutil.popen4(cmd, bufsize=0, env=sshenv)
163 stdin, stdout, stderr, proc = procutil.popen4(cmd, bufsize=0, env=sshenv)
163
164
164 return proc, stdin, stdout, stderr
165 return proc, stdin, stdout, stderr
165
166
166 def _clientcapabilities():
167 def _clientcapabilities():
167 """Return list of capabilities of this client.
168 """Return list of capabilities of this client.
168
169
169 Returns a list of capabilities that are supported by this client.
170 Returns a list of capabilities that are supported by this client.
170 """
171 """
171 protoparams = {'partial-pull'}
172 protoparams = {'partial-pull'}
172 comps = [e.wireprotosupport().name for e in
173 comps = [e.wireprotosupport().name for e in
173 util.compengines.supportedwireengines(util.CLIENTROLE)]
174 util.compengines.supportedwireengines(util.CLIENTROLE)]
174 protoparams.add('comp=%s' % ','.join(comps))
175 protoparams.add('comp=%s' % ','.join(comps))
175 return protoparams
176 return protoparams
176
177
177 def _performhandshake(ui, stdin, stdout, stderr):
178 def _performhandshake(ui, stdin, stdout, stderr):
178 def badresponse():
179 def badresponse():
179 # Flush any output on stderr.
180 # Flush any output on stderr.
180 _forwardoutput(ui, stderr)
181 _forwardoutput(ui, stderr)
181
182
182 msg = _('no suitable response from remote hg')
183 msg = _('no suitable response from remote hg')
183 hint = ui.config('ui', 'ssherrorhint')
184 hint = ui.config('ui', 'ssherrorhint')
184 raise error.RepoError(msg, hint=hint)
185 raise error.RepoError(msg, hint=hint)
185
186
186 # The handshake consists of sending wire protocol commands in reverse
187 # The handshake consists of sending wire protocol commands in reverse
187 # order of protocol implementation and then sniffing for a response
188 # order of protocol implementation and then sniffing for a response
188 # to one of them.
189 # to one of them.
189 #
190 #
190 # Those commands (from oldest to newest) are:
191 # Those commands (from oldest to newest) are:
191 #
192 #
192 # ``between``
193 # ``between``
193 # Asks for the set of revisions between a pair of revisions. Command
194 # Asks for the set of revisions between a pair of revisions. Command
194 # present in all Mercurial server implementations.
195 # present in all Mercurial server implementations.
195 #
196 #
196 # ``hello``
197 # ``hello``
197 # Instructs the server to advertise its capabilities. Introduced in
198 # Instructs the server to advertise its capabilities. Introduced in
198 # Mercurial 0.9.1.
199 # Mercurial 0.9.1.
199 #
200 #
200 # ``upgrade``
201 # ``upgrade``
201 # Requests upgrade from default transport protocol version 1 to
202 # Requests upgrade from default transport protocol version 1 to
202 # a newer version. Introduced in Mercurial 4.6 as an experimental
203 # a newer version. Introduced in Mercurial 4.6 as an experimental
203 # feature.
204 # feature.
204 #
205 #
205 # The ``between`` command is issued with a request for the null
206 # The ``between`` command is issued with a request for the null
206 # range. If the remote is a Mercurial server, this request will
207 # range. If the remote is a Mercurial server, this request will
207 # generate a specific response: ``1\n\n``. This represents the
208 # generate a specific response: ``1\n\n``. This represents the
208 # wire protocol encoded value for ``\n``. We look for ``1\n\n``
209 # wire protocol encoded value for ``\n``. We look for ``1\n\n``
209 # in the output stream and know this is the response to ``between``
210 # in the output stream and know this is the response to ``between``
210 # and we're at the end of our handshake reply.
211 # and we're at the end of our handshake reply.
211 #
212 #
212 # The response to the ``hello`` command will be a line with the
213 # The response to the ``hello`` command will be a line with the
213 # length of the value returned by that command followed by that
214 # length of the value returned by that command followed by that
214 # value. If the server doesn't support ``hello`` (which should be
215 # value. If the server doesn't support ``hello`` (which should be
215 # rare), that line will be ``0\n``. Otherwise, the value will contain
216 # rare), that line will be ``0\n``. Otherwise, the value will contain
216 # RFC 822 like lines. Of these, the ``capabilities:`` line contains
217 # RFC 822 like lines. Of these, the ``capabilities:`` line contains
217 # the capabilities of the server.
218 # the capabilities of the server.
218 #
219 #
219 # The ``upgrade`` command isn't really a command in the traditional
220 # The ``upgrade`` command isn't really a command in the traditional
220 # sense of version 1 of the transport because it isn't using the
221 # sense of version 1 of the transport because it isn't using the
221 # proper mechanism for formatting insteads: instead, it just encodes
222 # proper mechanism for formatting insteads: instead, it just encodes
222 # arguments on the line, delimited by spaces.
223 # arguments on the line, delimited by spaces.
223 #
224 #
224 # The ``upgrade`` line looks like ``upgrade <token> <capabilities>``.
225 # The ``upgrade`` line looks like ``upgrade <token> <capabilities>``.
225 # If the server doesn't support protocol upgrades, it will reply to
226 # If the server doesn't support protocol upgrades, it will reply to
226 # this line with ``0\n``. Otherwise, it emits an
227 # this line with ``0\n``. Otherwise, it emits an
227 # ``upgraded <token> <protocol>`` line to both stdout and stderr.
228 # ``upgraded <token> <protocol>`` line to both stdout and stderr.
228 # Content immediately following this line describes additional
229 # Content immediately following this line describes additional
229 # protocol and server state.
230 # protocol and server state.
230 #
231 #
231 # In addition to the responses to our command requests, the server
232 # In addition to the responses to our command requests, the server
232 # may emit "banner" output on stdout. SSH servers are allowed to
233 # may emit "banner" output on stdout. SSH servers are allowed to
233 # print messages to stdout on login. Issuing commands on connection
234 # print messages to stdout on login. Issuing commands on connection
234 # allows us to flush this banner output from the server by scanning
235 # allows us to flush this banner output from the server by scanning
235 # for output to our well-known ``between`` command. Of course, if
236 # for output to our well-known ``between`` command. Of course, if
236 # the banner contains ``1\n\n``, this will throw off our detection.
237 # the banner contains ``1\n\n``, this will throw off our detection.
237
238
238 requestlog = ui.configbool('devel', 'debug.peer-request')
239 requestlog = ui.configbool('devel', 'debug.peer-request')
239
240
240 # Generate a random token to help identify responses to version 2
241 # Generate a random token to help identify responses to version 2
241 # upgrade request.
242 # upgrade request.
242 token = pycompat.sysbytes(str(uuid.uuid4()))
243 token = pycompat.sysbytes(str(uuid.uuid4()))
243 upgradecaps = [
244 upgradecaps = [
244 ('proto', wireprotoserver.SSHV2),
245 ('proto', wireprotoserver.SSHV2),
245 ]
246 ]
246 upgradecaps = util.urlreq.urlencode(upgradecaps)
247 upgradecaps = util.urlreq.urlencode(upgradecaps)
247
248
248 try:
249 try:
249 pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
250 pairsarg = '%s-%s' % ('0' * 40, '0' * 40)
250 handshake = [
251 handshake = [
251 'hello\n',
252 'hello\n',
252 'between\n',
253 'between\n',
253 'pairs %d\n' % len(pairsarg),
254 'pairs %d\n' % len(pairsarg),
254 pairsarg,
255 pairsarg,
255 ]
256 ]
256
257
257 # Request upgrade to version 2 if configured.
258 # Request upgrade to version 2 if configured.
258 if ui.configbool('experimental', 'sshpeer.advertise-v2'):
259 if ui.configbool('experimental', 'sshpeer.advertise-v2'):
259 ui.debug('sending upgrade request: %s %s\n' % (token, upgradecaps))
260 ui.debug('sending upgrade request: %s %s\n' % (token, upgradecaps))
260 handshake.insert(0, 'upgrade %s %s\n' % (token, upgradecaps))
261 handshake.insert(0, 'upgrade %s %s\n' % (token, upgradecaps))
261
262
262 if requestlog:
263 if requestlog:
263 ui.debug('devel-peer-request: hello\n')
264 ui.debug('devel-peer-request: hello\n')
264 ui.debug('sending hello command\n')
265 ui.debug('sending hello command\n')
265 if requestlog:
266 if requestlog:
266 ui.debug('devel-peer-request: between\n')
267 ui.debug('devel-peer-request: between\n')
267 ui.debug('devel-peer-request: pairs: %d bytes\n' % len(pairsarg))
268 ui.debug('devel-peer-request: pairs: %d bytes\n' % len(pairsarg))
268 ui.debug('sending between command\n')
269 ui.debug('sending between command\n')
269
270
270 stdin.write(''.join(handshake))
271 stdin.write(''.join(handshake))
271 stdin.flush()
272 stdin.flush()
272 except IOError:
273 except IOError:
273 badresponse()
274 badresponse()
274
275
275 # Assume version 1 of wire protocol by default.
276 # Assume version 1 of wire protocol by default.
276 protoname = wireprototypes.SSHV1
277 protoname = wireprototypes.SSHV1
277 reupgraded = re.compile(b'^upgraded %s (.*)$' % re.escape(token))
278 reupgraded = re.compile(b'^upgraded %s (.*)$' % re.escape(token))
278
279
279 lines = ['', 'dummy']
280 lines = ['', 'dummy']
280 max_noise = 500
281 max_noise = 500
281 while lines[-1] and max_noise:
282 while lines[-1] and max_noise:
282 try:
283 try:
283 l = stdout.readline()
284 l = stdout.readline()
284 _forwardoutput(ui, stderr)
285 _forwardoutput(ui, stderr)
285
286
286 # Look for reply to protocol upgrade request. It has a token
287 # Look for reply to protocol upgrade request. It has a token
287 # in it, so there should be no false positives.
288 # in it, so there should be no false positives.
288 m = reupgraded.match(l)
289 m = reupgraded.match(l)
289 if m:
290 if m:
290 protoname = m.group(1)
291 protoname = m.group(1)
291 ui.debug('protocol upgraded to %s\n' % protoname)
292 ui.debug('protocol upgraded to %s\n' % protoname)
292 # If an upgrade was handled, the ``hello`` and ``between``
293 # If an upgrade was handled, the ``hello`` and ``between``
293 # requests are ignored. The next output belongs to the
294 # requests are ignored. The next output belongs to the
294 # protocol, so stop scanning lines.
295 # protocol, so stop scanning lines.
295 break
296 break
296
297
297 # Otherwise it could be a banner, ``0\n`` response if server
298 # Otherwise it could be a banner, ``0\n`` response if server
298 # doesn't support upgrade.
299 # doesn't support upgrade.
299
300
300 if lines[-1] == '1\n' and l == '\n':
301 if lines[-1] == '1\n' and l == '\n':
301 break
302 break
302 if l:
303 if l:
303 ui.debug('remote: ', l)
304 ui.debug('remote: ', l)
304 lines.append(l)
305 lines.append(l)
305 max_noise -= 1
306 max_noise -= 1
306 except IOError:
307 except IOError:
307 badresponse()
308 badresponse()
308 else:
309 else:
309 badresponse()
310 badresponse()
310
311
311 caps = set()
312 caps = set()
312
313
313 # For version 1, we should see a ``capabilities`` line in response to the
314 # For version 1, we should see a ``capabilities`` line in response to the
314 # ``hello`` command.
315 # ``hello`` command.
315 if protoname == wireprototypes.SSHV1:
316 if protoname == wireprototypes.SSHV1:
316 for l in reversed(lines):
317 for l in reversed(lines):
317 # Look for response to ``hello`` command. Scan from the back so
318 # Look for response to ``hello`` command. Scan from the back so
318 # we don't misinterpret banner output as the command reply.
319 # we don't misinterpret banner output as the command reply.
319 if l.startswith('capabilities:'):
320 if l.startswith('capabilities:'):
320 caps.update(l[:-1].split(':')[1].split())
321 caps.update(l[:-1].split(':')[1].split())
321 break
322 break
322 elif protoname == wireprotoserver.SSHV2:
323 elif protoname == wireprotoserver.SSHV2:
323 # We see a line with number of bytes to follow and then a value
324 # We see a line with number of bytes to follow and then a value
324 # looking like ``capabilities: *``.
325 # looking like ``capabilities: *``.
325 line = stdout.readline()
326 line = stdout.readline()
326 try:
327 try:
327 valuelen = int(line)
328 valuelen = int(line)
328 except ValueError:
329 except ValueError:
329 badresponse()
330 badresponse()
330
331
331 capsline = stdout.read(valuelen)
332 capsline = stdout.read(valuelen)
332 if not capsline.startswith('capabilities: '):
333 if not capsline.startswith('capabilities: '):
333 badresponse()
334 badresponse()
334
335
335 ui.debug('remote: %s\n' % capsline)
336 ui.debug('remote: %s\n' % capsline)
336
337
337 caps.update(capsline.split(':')[1].split())
338 caps.update(capsline.split(':')[1].split())
338 # Trailing newline.
339 # Trailing newline.
339 stdout.read(1)
340 stdout.read(1)
340
341
341 # Error if we couldn't find capabilities, this means:
342 # Error if we couldn't find capabilities, this means:
342 #
343 #
343 # 1. Remote isn't a Mercurial server
344 # 1. Remote isn't a Mercurial server
344 # 2. Remote is a <0.9.1 Mercurial server
345 # 2. Remote is a <0.9.1 Mercurial server
345 # 3. Remote is a future Mercurial server that dropped ``hello``
346 # 3. Remote is a future Mercurial server that dropped ``hello``
346 # and other attempted handshake mechanisms.
347 # and other attempted handshake mechanisms.
347 if not caps:
348 if not caps:
348 badresponse()
349 badresponse()
349
350
350 # Flush any output on stderr before proceeding.
351 # Flush any output on stderr before proceeding.
351 _forwardoutput(ui, stderr)
352 _forwardoutput(ui, stderr)
352
353
353 return protoname, caps
354 return protoname, caps
354
355
355 class sshv1peer(wireproto.wirepeer):
356 class sshv1peer(wireprotov1peer.wirepeer):
356 def __init__(self, ui, url, proc, stdin, stdout, stderr, caps,
357 def __init__(self, ui, url, proc, stdin, stdout, stderr, caps,
357 autoreadstderr=True):
358 autoreadstderr=True):
358 """Create a peer from an existing SSH connection.
359 """Create a peer from an existing SSH connection.
359
360
360 ``proc`` is a handle on the underlying SSH process.
361 ``proc`` is a handle on the underlying SSH process.
361 ``stdin``, ``stdout``, and ``stderr`` are handles on the stdio
362 ``stdin``, ``stdout``, and ``stderr`` are handles on the stdio
362 pipes for that process.
363 pipes for that process.
363 ``caps`` is a set of capabilities supported by the remote.
364 ``caps`` is a set of capabilities supported by the remote.
364 ``autoreadstderr`` denotes whether to automatically read from
365 ``autoreadstderr`` denotes whether to automatically read from
365 stderr and to forward its output.
366 stderr and to forward its output.
366 """
367 """
367 self._url = url
368 self._url = url
368 self.ui = ui
369 self.ui = ui
369 # self._subprocess is unused. Keeping a handle on the process
370 # self._subprocess is unused. Keeping a handle on the process
370 # holds a reference and prevents it from being garbage collected.
371 # holds a reference and prevents it from being garbage collected.
371 self._subprocess = proc
372 self._subprocess = proc
372
373
373 # And we hook up our "doublepipe" wrapper to allow querying
374 # And we hook up our "doublepipe" wrapper to allow querying
374 # stderr any time we perform I/O.
375 # stderr any time we perform I/O.
375 if autoreadstderr:
376 if autoreadstderr:
376 stdout = doublepipe(ui, util.bufferedinputpipe(stdout), stderr)
377 stdout = doublepipe(ui, util.bufferedinputpipe(stdout), stderr)
377 stdin = doublepipe(ui, stdin, stderr)
378 stdin = doublepipe(ui, stdin, stderr)
378
379
379 self._pipeo = stdin
380 self._pipeo = stdin
380 self._pipei = stdout
381 self._pipei = stdout
381 self._pipee = stderr
382 self._pipee = stderr
382 self._caps = caps
383 self._caps = caps
383 self._autoreadstderr = autoreadstderr
384 self._autoreadstderr = autoreadstderr
384
385
385 # Commands that have a "framed" response where the first line of the
386 # Commands that have a "framed" response where the first line of the
386 # response contains the length of that response.
387 # response contains the length of that response.
387 _FRAMED_COMMANDS = {
388 _FRAMED_COMMANDS = {
388 'batch',
389 'batch',
389 }
390 }
390
391
391 # Begin of ipeerconnection interface.
392 # Begin of ipeerconnection interface.
392
393
393 def url(self):
394 def url(self):
394 return self._url
395 return self._url
395
396
396 def local(self):
397 def local(self):
397 return None
398 return None
398
399
399 def peer(self):
400 def peer(self):
400 return self
401 return self
401
402
402 def canpush(self):
403 def canpush(self):
403 return True
404 return True
404
405
405 def close(self):
406 def close(self):
406 pass
407 pass
407
408
408 # End of ipeerconnection interface.
409 # End of ipeerconnection interface.
409
410
410 # Begin of ipeercommands interface.
411 # Begin of ipeercommands interface.
411
412
412 def capabilities(self):
413 def capabilities(self):
413 return self._caps
414 return self._caps
414
415
415 # End of ipeercommands interface.
416 # End of ipeercommands interface.
416
417
417 def _readerr(self):
418 def _readerr(self):
418 _forwardoutput(self.ui, self._pipee)
419 _forwardoutput(self.ui, self._pipee)
419
420
420 def _abort(self, exception):
421 def _abort(self, exception):
421 self._cleanup()
422 self._cleanup()
422 raise exception
423 raise exception
423
424
424 def _cleanup(self):
425 def _cleanup(self):
425 _cleanuppipes(self.ui, self._pipei, self._pipeo, self._pipee)
426 _cleanuppipes(self.ui, self._pipei, self._pipeo, self._pipee)
426
427
427 __del__ = _cleanup
428 __del__ = _cleanup
428
429
429 def _sendrequest(self, cmd, args, framed=False):
430 def _sendrequest(self, cmd, args, framed=False):
430 if (self.ui.debugflag
431 if (self.ui.debugflag
431 and self.ui.configbool('devel', 'debug.peer-request')):
432 and self.ui.configbool('devel', 'debug.peer-request')):
432 dbg = self.ui.debug
433 dbg = self.ui.debug
433 line = 'devel-peer-request: %s\n'
434 line = 'devel-peer-request: %s\n'
434 dbg(line % cmd)
435 dbg(line % cmd)
435 for key, value in sorted(args.items()):
436 for key, value in sorted(args.items()):
436 if not isinstance(value, dict):
437 if not isinstance(value, dict):
437 dbg(line % ' %s: %d bytes' % (key, len(value)))
438 dbg(line % ' %s: %d bytes' % (key, len(value)))
438 else:
439 else:
439 for dk, dv in sorted(value.items()):
440 for dk, dv in sorted(value.items()):
440 dbg(line % ' %s-%s: %d' % (key, dk, len(dv)))
441 dbg(line % ' %s-%s: %d' % (key, dk, len(dv)))
441 self.ui.debug("sending %s command\n" % cmd)
442 self.ui.debug("sending %s command\n" % cmd)
442 self._pipeo.write("%s\n" % cmd)
443 self._pipeo.write("%s\n" % cmd)
443 _func, names = wireproto.commands[cmd]
444 _func, names = wireproto.commands[cmd]
444 keys = names.split()
445 keys = names.split()
445 wireargs = {}
446 wireargs = {}
446 for k in keys:
447 for k in keys:
447 if k == '*':
448 if k == '*':
448 wireargs['*'] = args
449 wireargs['*'] = args
449 break
450 break
450 else:
451 else:
451 wireargs[k] = args[k]
452 wireargs[k] = args[k]
452 del args[k]
453 del args[k]
453 for k, v in sorted(wireargs.iteritems()):
454 for k, v in sorted(wireargs.iteritems()):
454 self._pipeo.write("%s %d\n" % (k, len(v)))
455 self._pipeo.write("%s %d\n" % (k, len(v)))
455 if isinstance(v, dict):
456 if isinstance(v, dict):
456 for dk, dv in v.iteritems():
457 for dk, dv in v.iteritems():
457 self._pipeo.write("%s %d\n" % (dk, len(dv)))
458 self._pipeo.write("%s %d\n" % (dk, len(dv)))
458 self._pipeo.write(dv)
459 self._pipeo.write(dv)
459 else:
460 else:
460 self._pipeo.write(v)
461 self._pipeo.write(v)
461 self._pipeo.flush()
462 self._pipeo.flush()
462
463
463 # We know exactly how many bytes are in the response. So return a proxy
464 # We know exactly how many bytes are in the response. So return a proxy
464 # around the raw output stream that allows reading exactly this many
465 # around the raw output stream that allows reading exactly this many
465 # bytes. Callers then can read() without fear of overrunning the
466 # bytes. Callers then can read() without fear of overrunning the
466 # response.
467 # response.
467 if framed:
468 if framed:
468 amount = self._getamount()
469 amount = self._getamount()
469 return util.cappedreader(self._pipei, amount)
470 return util.cappedreader(self._pipei, amount)
470
471
471 return self._pipei
472 return self._pipei
472
473
473 def _callstream(self, cmd, **args):
474 def _callstream(self, cmd, **args):
474 args = pycompat.byteskwargs(args)
475 args = pycompat.byteskwargs(args)
475 return self._sendrequest(cmd, args, framed=cmd in self._FRAMED_COMMANDS)
476 return self._sendrequest(cmd, args, framed=cmd in self._FRAMED_COMMANDS)
476
477
477 def _callcompressable(self, cmd, **args):
478 def _callcompressable(self, cmd, **args):
478 args = pycompat.byteskwargs(args)
479 args = pycompat.byteskwargs(args)
479 return self._sendrequest(cmd, args, framed=cmd in self._FRAMED_COMMANDS)
480 return self._sendrequest(cmd, args, framed=cmd in self._FRAMED_COMMANDS)
480
481
481 def _call(self, cmd, **args):
482 def _call(self, cmd, **args):
482 args = pycompat.byteskwargs(args)
483 args = pycompat.byteskwargs(args)
483 return self._sendrequest(cmd, args, framed=True).read()
484 return self._sendrequest(cmd, args, framed=True).read()
484
485
485 def _callpush(self, cmd, fp, **args):
486 def _callpush(self, cmd, fp, **args):
486 # The server responds with an empty frame if the client should
487 # The server responds with an empty frame if the client should
487 # continue submitting the payload.
488 # continue submitting the payload.
488 r = self._call(cmd, **args)
489 r = self._call(cmd, **args)
489 if r:
490 if r:
490 return '', r
491 return '', r
491
492
492 # The payload consists of frames with content followed by an empty
493 # The payload consists of frames with content followed by an empty
493 # frame.
494 # frame.
494 for d in iter(lambda: fp.read(4096), ''):
495 for d in iter(lambda: fp.read(4096), ''):
495 self._writeframed(d)
496 self._writeframed(d)
496 self._writeframed("", flush=True)
497 self._writeframed("", flush=True)
497
498
498 # In case of success, there is an empty frame and a frame containing
499 # In case of success, there is an empty frame and a frame containing
499 # the integer result (as a string).
500 # the integer result (as a string).
500 # In case of error, there is a non-empty frame containing the error.
501 # In case of error, there is a non-empty frame containing the error.
501 r = self._readframed()
502 r = self._readframed()
502 if r:
503 if r:
503 return '', r
504 return '', r
504 return self._readframed(), ''
505 return self._readframed(), ''
505
506
506 def _calltwowaystream(self, cmd, fp, **args):
507 def _calltwowaystream(self, cmd, fp, **args):
507 # The server responds with an empty frame if the client should
508 # The server responds with an empty frame if the client should
508 # continue submitting the payload.
509 # continue submitting the payload.
509 r = self._call(cmd, **args)
510 r = self._call(cmd, **args)
510 if r:
511 if r:
511 # XXX needs to be made better
512 # XXX needs to be made better
512 raise error.Abort(_('unexpected remote reply: %s') % r)
513 raise error.Abort(_('unexpected remote reply: %s') % r)
513
514
514 # The payload consists of frames with content followed by an empty
515 # The payload consists of frames with content followed by an empty
515 # frame.
516 # frame.
516 for d in iter(lambda: fp.read(4096), ''):
517 for d in iter(lambda: fp.read(4096), ''):
517 self._writeframed(d)
518 self._writeframed(d)
518 self._writeframed("", flush=True)
519 self._writeframed("", flush=True)
519
520
520 return self._pipei
521 return self._pipei
521
522
522 def _getamount(self):
523 def _getamount(self):
523 l = self._pipei.readline()
524 l = self._pipei.readline()
524 if l == '\n':
525 if l == '\n':
525 if self._autoreadstderr:
526 if self._autoreadstderr:
526 self._readerr()
527 self._readerr()
527 msg = _('check previous remote output')
528 msg = _('check previous remote output')
528 self._abort(error.OutOfBandError(hint=msg))
529 self._abort(error.OutOfBandError(hint=msg))
529 if self._autoreadstderr:
530 if self._autoreadstderr:
530 self._readerr()
531 self._readerr()
531 try:
532 try:
532 return int(l)
533 return int(l)
533 except ValueError:
534 except ValueError:
534 self._abort(error.ResponseError(_("unexpected response:"), l))
535 self._abort(error.ResponseError(_("unexpected response:"), l))
535
536
536 def _readframed(self):
537 def _readframed(self):
537 size = self._getamount()
538 size = self._getamount()
538 if not size:
539 if not size:
539 return b''
540 return b''
540
541
541 return self._pipei.read(size)
542 return self._pipei.read(size)
542
543
543 def _writeframed(self, data, flush=False):
544 def _writeframed(self, data, flush=False):
544 self._pipeo.write("%d\n" % len(data))
545 self._pipeo.write("%d\n" % len(data))
545 if data:
546 if data:
546 self._pipeo.write(data)
547 self._pipeo.write(data)
547 if flush:
548 if flush:
548 self._pipeo.flush()
549 self._pipeo.flush()
549 if self._autoreadstderr:
550 if self._autoreadstderr:
550 self._readerr()
551 self._readerr()
551
552
552 class sshv2peer(sshv1peer):
553 class sshv2peer(sshv1peer):
553 """A peer that speakers version 2 of the transport protocol."""
554 """A peer that speakers version 2 of the transport protocol."""
554 # Currently version 2 is identical to version 1 post handshake.
555 # Currently version 2 is identical to version 1 post handshake.
555 # And handshake is performed before the peer is instantiated. So
556 # And handshake is performed before the peer is instantiated. So
556 # we need no custom code.
557 # we need no custom code.
557
558
558 def makepeer(ui, path, proc, stdin, stdout, stderr, autoreadstderr=True):
559 def makepeer(ui, path, proc, stdin, stdout, stderr, autoreadstderr=True):
559 """Make a peer instance from existing pipes.
560 """Make a peer instance from existing pipes.
560
561
561 ``path`` and ``proc`` are stored on the eventual peer instance and may
562 ``path`` and ``proc`` are stored on the eventual peer instance and may
562 not be used for anything meaningful.
563 not be used for anything meaningful.
563
564
564 ``stdin``, ``stdout``, and ``stderr`` are the pipes connected to the
565 ``stdin``, ``stdout``, and ``stderr`` are the pipes connected to the
565 SSH server's stdio handles.
566 SSH server's stdio handles.
566
567
567 This function is factored out to allow creating peers that don't
568 This function is factored out to allow creating peers that don't
568 actually spawn a new process. It is useful for starting SSH protocol
569 actually spawn a new process. It is useful for starting SSH protocol
569 servers and clients via non-standard means, which can be useful for
570 servers and clients via non-standard means, which can be useful for
570 testing.
571 testing.
571 """
572 """
572 try:
573 try:
573 protoname, caps = _performhandshake(ui, stdin, stdout, stderr)
574 protoname, caps = _performhandshake(ui, stdin, stdout, stderr)
574 except Exception:
575 except Exception:
575 _cleanuppipes(ui, stdout, stdin, stderr)
576 _cleanuppipes(ui, stdout, stdin, stderr)
576 raise
577 raise
577
578
578 if protoname == wireprototypes.SSHV1:
579 if protoname == wireprototypes.SSHV1:
579 return sshv1peer(ui, path, proc, stdin, stdout, stderr, caps,
580 return sshv1peer(ui, path, proc, stdin, stdout, stderr, caps,
580 autoreadstderr=autoreadstderr)
581 autoreadstderr=autoreadstderr)
581 elif protoname == wireprototypes.SSHV2:
582 elif protoname == wireprototypes.SSHV2:
582 return sshv2peer(ui, path, proc, stdin, stdout, stderr, caps,
583 return sshv2peer(ui, path, proc, stdin, stdout, stderr, caps,
583 autoreadstderr=autoreadstderr)
584 autoreadstderr=autoreadstderr)
584 else:
585 else:
585 _cleanuppipes(ui, stdout, stdin, stderr)
586 _cleanuppipes(ui, stdout, stdin, stderr)
586 raise error.RepoError(_('unknown version of SSH protocol: %s') %
587 raise error.RepoError(_('unknown version of SSH protocol: %s') %
587 protoname)
588 protoname)
588
589
589 def instance(ui, path, create):
590 def instance(ui, path, create):
590 """Create an SSH peer.
591 """Create an SSH peer.
591
592
592 The returned object conforms to the ``wireproto.wirepeer`` interface.
593 The returned object conforms to the ``wireprotov1peer.wirepeer`` interface.
593 """
594 """
594 u = util.url(path, parsequery=False, parsefragment=False)
595 u = util.url(path, parsequery=False, parsefragment=False)
595 if u.scheme != 'ssh' or not u.host or u.path is None:
596 if u.scheme != 'ssh' or not u.host or u.path is None:
596 raise error.RepoError(_("couldn't parse location %s") % path)
597 raise error.RepoError(_("couldn't parse location %s") % path)
597
598
598 util.checksafessh(path)
599 util.checksafessh(path)
599
600
600 if u.passwd is not None:
601 if u.passwd is not None:
601 raise error.RepoError(_('password in URL not supported'))
602 raise error.RepoError(_('password in URL not supported'))
602
603
603 sshcmd = ui.config('ui', 'ssh')
604 sshcmd = ui.config('ui', 'ssh')
604 remotecmd = ui.config('ui', 'remotecmd')
605 remotecmd = ui.config('ui', 'remotecmd')
605 sshaddenv = dict(ui.configitems('sshenv'))
606 sshaddenv = dict(ui.configitems('sshenv'))
606 sshenv = procutil.shellenviron(sshaddenv)
607 sshenv = procutil.shellenviron(sshaddenv)
607 remotepath = u.path or '.'
608 remotepath = u.path or '.'
608
609
609 args = procutil.sshargs(sshcmd, u.host, u.user, u.port)
610 args = procutil.sshargs(sshcmd, u.host, u.user, u.port)
610
611
611 if create:
612 if create:
612 cmd = '%s %s %s' % (sshcmd, args,
613 cmd = '%s %s %s' % (sshcmd, args,
613 procutil.shellquote('%s init %s' %
614 procutil.shellquote('%s init %s' %
614 (_serverquote(remotecmd), _serverquote(remotepath))))
615 (_serverquote(remotecmd), _serverquote(remotepath))))
615 ui.debug('running %s\n' % cmd)
616 ui.debug('running %s\n' % cmd)
616 res = ui.system(cmd, blockedtag='sshpeer', environ=sshenv)
617 res = ui.system(cmd, blockedtag='sshpeer', environ=sshenv)
617 if res != 0:
618 if res != 0:
618 raise error.RepoError(_('could not create remote repo'))
619 raise error.RepoError(_('could not create remote repo'))
619
620
620 proc, stdin, stdout, stderr = _makeconnection(ui, sshcmd, args, remotecmd,
621 proc, stdin, stdout, stderr = _makeconnection(ui, sshcmd, args, remotecmd,
621 remotepath, sshenv)
622 remotepath, sshenv)
622
623
623 peer = makepeer(ui, path, proc, stdin, stdout, stderr)
624 peer = makepeer(ui, path, proc, stdin, stdout, stderr)
624
625
625 # Finally, if supported by the server, notify it about our own
626 # Finally, if supported by the server, notify it about our own
626 # capabilities.
627 # capabilities.
627 if 'protocaps' in peer.capabilities():
628 if 'protocaps' in peer.capabilities():
628 try:
629 try:
629 peer._call("protocaps",
630 peer._call("protocaps",
630 caps=' '.join(sorted(_clientcapabilities())))
631 caps=' '.join(sorted(_clientcapabilities())))
631 except IOError:
632 except IOError:
632 peer._cleanup()
633 peer._cleanup()
633 raise error.RepoError(_('capability exchange failed'))
634 raise error.RepoError(_('capability exchange failed'))
634
635
635 return peer
636 return peer
@@ -1,1265 +1,866 b''
1 # wireproto.py - generic wire protocol support functions
1 # wireproto.py - generic wire protocol support functions
2 #
2 #
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import hashlib
11 import os
10 import os
12 import tempfile
11 import tempfile
13
12
14 from .i18n import _
13 from .i18n import _
15 from .node import (
14 from .node import (
16 bin,
17 hex,
15 hex,
18 nullid,
16 nullid,
19 )
17 )
20
18
21 from . import (
19 from . import (
22 bundle2,
20 bundle2,
23 changegroup as changegroupmod,
21 changegroup as changegroupmod,
24 discovery,
22 discovery,
25 encoding,
23 encoding,
26 error,
24 error,
27 exchange,
25 exchange,
28 peer,
29 pushkey as pushkeymod,
26 pushkey as pushkeymod,
30 pycompat,
27 pycompat,
31 repository,
32 streamclone,
28 streamclone,
33 util,
29 util,
34 wireprototypes,
30 wireprototypes,
35 )
31 )
36
32
37 from .utils import (
33 from .utils import (
38 procutil,
34 procutil,
39 stringutil,
35 stringutil,
40 )
36 )
41
37
42 urlerr = util.urlerr
38 urlerr = util.urlerr
43 urlreq = util.urlreq
39 urlreq = util.urlreq
44
40
45 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
41 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
46 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
42 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
47 'IncompatibleClient')
43 'IncompatibleClient')
48 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
44 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
49
45
50 class remoteiterbatcher(peer.iterbatcher):
51 def __init__(self, remote):
52 super(remoteiterbatcher, self).__init__()
53 self._remote = remote
54
55 def __getattr__(self, name):
56 # Validate this method is batchable, since submit() only supports
57 # batchable methods.
58 fn = getattr(self._remote, name)
59 if not getattr(fn, 'batchable', None):
60 raise error.ProgrammingError('Attempted to batch a non-batchable '
61 'call to %r' % name)
62
63 return super(remoteiterbatcher, self).__getattr__(name)
64
65 def submit(self):
66 """Break the batch request into many patch calls and pipeline them.
67
68 This is mostly valuable over http where request sizes can be
69 limited, but can be used in other places as well.
70 """
71 # 2-tuple of (command, arguments) that represents what will be
72 # sent over the wire.
73 requests = []
74
75 # 4-tuple of (command, final future, @batchable generator, remote
76 # future).
77 results = []
78
79 for command, args, opts, finalfuture in self.calls:
80 mtd = getattr(self._remote, command)
81 batchable = mtd.batchable(mtd.__self__, *args, **opts)
82
83 commandargs, fremote = next(batchable)
84 assert fremote
85 requests.append((command, commandargs))
86 results.append((command, finalfuture, batchable, fremote))
87
88 if requests:
89 self._resultiter = self._remote._submitbatch(requests)
90
91 self._results = results
92
93 def results(self):
94 for command, finalfuture, batchable, remotefuture in self._results:
95 # Get the raw result, set it in the remote future, feed it
96 # back into the @batchable generator so it can be decoded, and
97 # set the result on the final future to this value.
98 remoteresult = next(self._resultiter)
99 remotefuture.set(remoteresult)
100 finalfuture.set(next(batchable))
101
102 # Verify our @batchable generators only emit 2 values.
103 try:
104 next(batchable)
105 except StopIteration:
106 pass
107 else:
108 raise error.ProgrammingError('%s @batchable generator emitted '
109 'unexpected value count' % command)
110
111 yield finalfuture.value
112
113 # Forward a couple of names from peer to make wireproto interactions
114 # slightly more sensible.
115 batchable = peer.batchable
116 future = peer.future
117
118
119 def encodebatchcmds(req):
120 """Return a ``cmds`` argument value for the ``batch`` command."""
121 escapearg = wireprototypes.escapebatcharg
122
123 cmds = []
124 for op, argsdict in req:
125 # Old servers didn't properly unescape argument names. So prevent
126 # the sending of argument names that may not be decoded properly by
127 # servers.
128 assert all(escapearg(k) == k for k in argsdict)
129
130 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
131 for k, v in argsdict.iteritems())
132 cmds.append('%s %s' % (op, args))
133
134 return ';'.join(cmds)
135
136 def clientcompressionsupport(proto):
46 def clientcompressionsupport(proto):
137 """Returns a list of compression methods supported by the client.
47 """Returns a list of compression methods supported by the client.
138
48
139 Returns a list of the compression methods supported by the client
49 Returns a list of the compression methods supported by the client
140 according to the protocol capabilities. If no such capability has
50 according to the protocol capabilities. If no such capability has
141 been announced, fallback to the default of zlib and uncompressed.
51 been announced, fallback to the default of zlib and uncompressed.
142 """
52 """
143 for cap in proto.getprotocaps():
53 for cap in proto.getprotocaps():
144 if cap.startswith('comp='):
54 if cap.startswith('comp='):
145 return cap[5:].split(',')
55 return cap[5:].split(',')
146 return ['zlib', 'none']
56 return ['zlib', 'none']
147
57
148 # client side
149
150 class wirepeer(repository.legacypeer):
151 """Client-side interface for communicating with a peer repository.
152
153 Methods commonly call wire protocol commands of the same name.
154
155 See also httppeer.py and sshpeer.py for protocol-specific
156 implementations of this interface.
157 """
158 # Begin of ipeercommands interface.
159
160 def iterbatch(self):
161 return remoteiterbatcher(self)
162
163 @batchable
164 def lookup(self, key):
165 self.requirecap('lookup', _('look up remote revision'))
166 f = future()
167 yield {'key': encoding.fromlocal(key)}, f
168 d = f.value
169 success, data = d[:-1].split(" ", 1)
170 if int(success):
171 yield bin(data)
172 else:
173 self._abort(error.RepoError(data))
174
175 @batchable
176 def heads(self):
177 f = future()
178 yield {}, f
179 d = f.value
180 try:
181 yield wireprototypes.decodelist(d[:-1])
182 except ValueError:
183 self._abort(error.ResponseError(_("unexpected response:"), d))
184
185 @batchable
186 def known(self, nodes):
187 f = future()
188 yield {'nodes': wireprototypes.encodelist(nodes)}, f
189 d = f.value
190 try:
191 yield [bool(int(b)) for b in d]
192 except ValueError:
193 self._abort(error.ResponseError(_("unexpected response:"), d))
194
195 @batchable
196 def branchmap(self):
197 f = future()
198 yield {}, f
199 d = f.value
200 try:
201 branchmap = {}
202 for branchpart in d.splitlines():
203 branchname, branchheads = branchpart.split(' ', 1)
204 branchname = encoding.tolocal(urlreq.unquote(branchname))
205 branchheads = wireprototypes.decodelist(branchheads)
206 branchmap[branchname] = branchheads
207 yield branchmap
208 except TypeError:
209 self._abort(error.ResponseError(_("unexpected response:"), d))
210
211 @batchable
212 def listkeys(self, namespace):
213 if not self.capable('pushkey'):
214 yield {}, None
215 f = future()
216 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
217 yield {'namespace': encoding.fromlocal(namespace)}, f
218 d = f.value
219 self.ui.debug('received listkey for "%s": %i bytes\n'
220 % (namespace, len(d)))
221 yield pushkeymod.decodekeys(d)
222
223 @batchable
224 def pushkey(self, namespace, key, old, new):
225 if not self.capable('pushkey'):
226 yield False, None
227 f = future()
228 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
229 yield {'namespace': encoding.fromlocal(namespace),
230 'key': encoding.fromlocal(key),
231 'old': encoding.fromlocal(old),
232 'new': encoding.fromlocal(new)}, f
233 d = f.value
234 d, output = d.split('\n', 1)
235 try:
236 d = bool(int(d))
237 except ValueError:
238 raise error.ResponseError(
239 _('push failed (unexpected response):'), d)
240 for l in output.splitlines(True):
241 self.ui.status(_('remote: '), l)
242 yield d
243
244 def stream_out(self):
245 return self._callstream('stream_out')
246
247 def getbundle(self, source, **kwargs):
248 kwargs = pycompat.byteskwargs(kwargs)
249 self.requirecap('getbundle', _('look up remote changes'))
250 opts = {}
251 bundlecaps = kwargs.get('bundlecaps') or set()
252 for key, value in kwargs.iteritems():
253 if value is None:
254 continue
255 keytype = wireprototypes.GETBUNDLE_ARGUMENTS.get(key)
256 if keytype is None:
257 raise error.ProgrammingError(
258 'Unexpectedly None keytype for key %s' % key)
259 elif keytype == 'nodes':
260 value = wireprototypes.encodelist(value)
261 elif keytype == 'csv':
262 value = ','.join(value)
263 elif keytype == 'scsv':
264 value = ','.join(sorted(value))
265 elif keytype == 'boolean':
266 value = '%i' % bool(value)
267 elif keytype != 'plain':
268 raise KeyError('unknown getbundle option type %s'
269 % keytype)
270 opts[key] = value
271 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
272 if any((cap.startswith('HG2') for cap in bundlecaps)):
273 return bundle2.getunbundler(self.ui, f)
274 else:
275 return changegroupmod.cg1unpacker(f, 'UN')
276
277 def unbundle(self, cg, heads, url):
278 '''Send cg (a readable file-like object representing the
279 changegroup to push, typically a chunkbuffer object) to the
280 remote server as a bundle.
281
282 When pushing a bundle10 stream, return an integer indicating the
283 result of the push (see changegroup.apply()).
284
285 When pushing a bundle20 stream, return a bundle20 stream.
286
287 `url` is the url the client thinks it's pushing to, which is
288 visible to hooks.
289 '''
290
291 if heads != ['force'] and self.capable('unbundlehash'):
292 heads = wireprototypes.encodelist(
293 ['hashed', hashlib.sha1(''.join(sorted(heads))).digest()])
294 else:
295 heads = wireprototypes.encodelist(heads)
296
297 if util.safehasattr(cg, 'deltaheader'):
298 # this a bundle10, do the old style call sequence
299 ret, output = self._callpush("unbundle", cg, heads=heads)
300 if ret == "":
301 raise error.ResponseError(
302 _('push failed:'), output)
303 try:
304 ret = int(ret)
305 except ValueError:
306 raise error.ResponseError(
307 _('push failed (unexpected response):'), ret)
308
309 for l in output.splitlines(True):
310 self.ui.status(_('remote: '), l)
311 else:
312 # bundle2 push. Send a stream, fetch a stream.
313 stream = self._calltwowaystream('unbundle', cg, heads=heads)
314 ret = bundle2.getunbundler(self.ui, stream)
315 return ret
316
317 # End of ipeercommands interface.
318
319 # Begin of ipeerlegacycommands interface.
320
321 def branches(self, nodes):
322 n = wireprototypes.encodelist(nodes)
323 d = self._call("branches", nodes=n)
324 try:
325 br = [tuple(wireprototypes.decodelist(b)) for b in d.splitlines()]
326 return br
327 except ValueError:
328 self._abort(error.ResponseError(_("unexpected response:"), d))
329
330 def between(self, pairs):
331 batch = 8 # avoid giant requests
332 r = []
333 for i in xrange(0, len(pairs), batch):
334 n = " ".join([wireprototypes.encodelist(p, '-')
335 for p in pairs[i:i + batch]])
336 d = self._call("between", pairs=n)
337 try:
338 r.extend(l and wireprototypes.decodelist(l) or []
339 for l in d.splitlines())
340 except ValueError:
341 self._abort(error.ResponseError(_("unexpected response:"), d))
342 return r
343
344 def changegroup(self, nodes, kind):
345 n = wireprototypes.encodelist(nodes)
346 f = self._callcompressable("changegroup", roots=n)
347 return changegroupmod.cg1unpacker(f, 'UN')
348
349 def changegroupsubset(self, bases, heads, kind):
350 self.requirecap('changegroupsubset', _('look up remote changes'))
351 bases = wireprototypes.encodelist(bases)
352 heads = wireprototypes.encodelist(heads)
353 f = self._callcompressable("changegroupsubset",
354 bases=bases, heads=heads)
355 return changegroupmod.cg1unpacker(f, 'UN')
356
357 # End of ipeerlegacycommands interface.
358
359 def _submitbatch(self, req):
360 """run batch request <req> on the server
361
362 Returns an iterator of the raw responses from the server.
363 """
364 ui = self.ui
365 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
366 ui.debug('devel-peer-request: batched-content\n')
367 for op, args in req:
368 msg = 'devel-peer-request: - %s (%d arguments)\n'
369 ui.debug(msg % (op, len(args)))
370
371 unescapearg = wireprototypes.unescapebatcharg
372
373 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
374 chunk = rsp.read(1024)
375 work = [chunk]
376 while chunk:
377 while ';' not in chunk and chunk:
378 chunk = rsp.read(1024)
379 work.append(chunk)
380 merged = ''.join(work)
381 while ';' in merged:
382 one, merged = merged.split(';', 1)
383 yield unescapearg(one)
384 chunk = rsp.read(1024)
385 work = [merged, chunk]
386 yield unescapearg(''.join(work))
387
388 def _submitone(self, op, args):
389 return self._call(op, **pycompat.strkwargs(args))
390
391 def debugwireargs(self, one, two, three=None, four=None, five=None):
392 # don't pass optional arguments left at their default value
393 opts = {}
394 if three is not None:
395 opts[r'three'] = three
396 if four is not None:
397 opts[r'four'] = four
398 return self._call('debugwireargs', one=one, two=two, **opts)
399
400 def _call(self, cmd, **args):
401 """execute <cmd> on the server
402
403 The command is expected to return a simple string.
404
405 returns the server reply as a string."""
406 raise NotImplementedError()
407
408 def _callstream(self, cmd, **args):
409 """execute <cmd> on the server
410
411 The command is expected to return a stream. Note that if the
412 command doesn't return a stream, _callstream behaves
413 differently for ssh and http peers.
414
415 returns the server reply as a file like object.
416 """
417 raise NotImplementedError()
418
419 def _callcompressable(self, cmd, **args):
420 """execute <cmd> on the server
421
422 The command is expected to return a stream.
423
424 The stream may have been compressed in some implementations. This
425 function takes care of the decompression. This is the only difference
426 with _callstream.
427
428 returns the server reply as a file like object.
429 """
430 raise NotImplementedError()
431
432 def _callpush(self, cmd, fp, **args):
433 """execute a <cmd> on server
434
435 The command is expected to be related to a push. Push has a special
436 return method.
437
438 returns the server reply as a (ret, output) tuple. ret is either
439 empty (error) or a stringified int.
440 """
441 raise NotImplementedError()
442
443 def _calltwowaystream(self, cmd, fp, **args):
444 """execute <cmd> on server
445
446 The command will send a stream to the server and get a stream in reply.
447 """
448 raise NotImplementedError()
449
450 def _abort(self, exception):
451 """clearly abort the wire protocol connection and raise the exception
452 """
453 raise NotImplementedError()
454
455 # server side
456
457 # wire protocol command can either return a string or one of these classes.
58 # wire protocol command can either return a string or one of these classes.
458
59
459 def getdispatchrepo(repo, proto, command):
60 def getdispatchrepo(repo, proto, command):
460 """Obtain the repo used for processing wire protocol commands.
61 """Obtain the repo used for processing wire protocol commands.
461
62
462 The intent of this function is to serve as a monkeypatch point for
63 The intent of this function is to serve as a monkeypatch point for
463 extensions that need commands to operate on different repo views under
64 extensions that need commands to operate on different repo views under
464 specialized circumstances.
65 specialized circumstances.
465 """
66 """
466 return repo.filtered('served')
67 return repo.filtered('served')
467
68
468 def dispatch(repo, proto, command):
69 def dispatch(repo, proto, command):
469 repo = getdispatchrepo(repo, proto, command)
70 repo = getdispatchrepo(repo, proto, command)
470
71
471 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
72 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
472 commandtable = commandsv2 if transportversion == 2 else commands
73 commandtable = commandsv2 if transportversion == 2 else commands
473 func, spec = commandtable[command]
74 func, spec = commandtable[command]
474
75
475 args = proto.getargs(spec)
76 args = proto.getargs(spec)
476
77
477 # Version 1 protocols define arguments as a list. Version 2 uses a dict.
78 # Version 1 protocols define arguments as a list. Version 2 uses a dict.
478 if isinstance(args, list):
79 if isinstance(args, list):
479 return func(repo, proto, *args)
80 return func(repo, proto, *args)
480 elif isinstance(args, dict):
81 elif isinstance(args, dict):
481 return func(repo, proto, **args)
82 return func(repo, proto, **args)
482 else:
83 else:
483 raise error.ProgrammingError('unexpected type returned from '
84 raise error.ProgrammingError('unexpected type returned from '
484 'proto.getargs(): %s' % type(args))
85 'proto.getargs(): %s' % type(args))
485
86
486 def options(cmd, keys, others):
87 def options(cmd, keys, others):
487 opts = {}
88 opts = {}
488 for k in keys:
89 for k in keys:
489 if k in others:
90 if k in others:
490 opts[k] = others[k]
91 opts[k] = others[k]
491 del others[k]
92 del others[k]
492 if others:
93 if others:
493 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
94 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
494 % (cmd, ",".join(others)))
95 % (cmd, ",".join(others)))
495 return opts
96 return opts
496
97
497 def bundle1allowed(repo, action):
98 def bundle1allowed(repo, action):
498 """Whether a bundle1 operation is allowed from the server.
99 """Whether a bundle1 operation is allowed from the server.
499
100
500 Priority is:
101 Priority is:
501
102
502 1. server.bundle1gd.<action> (if generaldelta active)
103 1. server.bundle1gd.<action> (if generaldelta active)
503 2. server.bundle1.<action>
104 2. server.bundle1.<action>
504 3. server.bundle1gd (if generaldelta active)
105 3. server.bundle1gd (if generaldelta active)
505 4. server.bundle1
106 4. server.bundle1
506 """
107 """
507 ui = repo.ui
108 ui = repo.ui
508 gd = 'generaldelta' in repo.requirements
109 gd = 'generaldelta' in repo.requirements
509
110
510 if gd:
111 if gd:
511 v = ui.configbool('server', 'bundle1gd.%s' % action)
112 v = ui.configbool('server', 'bundle1gd.%s' % action)
512 if v is not None:
113 if v is not None:
513 return v
114 return v
514
115
515 v = ui.configbool('server', 'bundle1.%s' % action)
116 v = ui.configbool('server', 'bundle1.%s' % action)
516 if v is not None:
117 if v is not None:
517 return v
118 return v
518
119
519 if gd:
120 if gd:
520 v = ui.configbool('server', 'bundle1gd')
121 v = ui.configbool('server', 'bundle1gd')
521 if v is not None:
122 if v is not None:
522 return v
123 return v
523
124
524 return ui.configbool('server', 'bundle1')
125 return ui.configbool('server', 'bundle1')
525
126
526 def supportedcompengines(ui, role):
127 def supportedcompengines(ui, role):
527 """Obtain the list of supported compression engines for a request."""
128 """Obtain the list of supported compression engines for a request."""
528 assert role in (util.CLIENTROLE, util.SERVERROLE)
129 assert role in (util.CLIENTROLE, util.SERVERROLE)
529
130
530 compengines = util.compengines.supportedwireengines(role)
131 compengines = util.compengines.supportedwireengines(role)
531
132
532 # Allow config to override default list and ordering.
133 # Allow config to override default list and ordering.
533 if role == util.SERVERROLE:
134 if role == util.SERVERROLE:
534 configengines = ui.configlist('server', 'compressionengines')
135 configengines = ui.configlist('server', 'compressionengines')
535 config = 'server.compressionengines'
136 config = 'server.compressionengines'
536 else:
137 else:
537 # This is currently implemented mainly to facilitate testing. In most
138 # This is currently implemented mainly to facilitate testing. In most
538 # cases, the server should be in charge of choosing a compression engine
139 # cases, the server should be in charge of choosing a compression engine
539 # because a server has the most to lose from a sub-optimal choice. (e.g.
140 # because a server has the most to lose from a sub-optimal choice. (e.g.
540 # CPU DoS due to an expensive engine or a network DoS due to poor
141 # CPU DoS due to an expensive engine or a network DoS due to poor
541 # compression ratio).
142 # compression ratio).
542 configengines = ui.configlist('experimental',
143 configengines = ui.configlist('experimental',
543 'clientcompressionengines')
144 'clientcompressionengines')
544 config = 'experimental.clientcompressionengines'
145 config = 'experimental.clientcompressionengines'
545
146
546 # No explicit config. Filter out the ones that aren't supposed to be
147 # No explicit config. Filter out the ones that aren't supposed to be
547 # advertised and return default ordering.
148 # advertised and return default ordering.
548 if not configengines:
149 if not configengines:
549 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
150 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
550 return [e for e in compengines
151 return [e for e in compengines
551 if getattr(e.wireprotosupport(), attr) > 0]
152 if getattr(e.wireprotosupport(), attr) > 0]
552
153
553 # If compression engines are listed in the config, assume there is a good
154 # If compression engines are listed in the config, assume there is a good
554 # reason for it (like server operators wanting to achieve specific
155 # reason for it (like server operators wanting to achieve specific
555 # performance characteristics). So fail fast if the config references
156 # performance characteristics). So fail fast if the config references
556 # unusable compression engines.
157 # unusable compression engines.
557 validnames = set(e.name() for e in compengines)
158 validnames = set(e.name() for e in compengines)
558 invalidnames = set(e for e in configengines if e not in validnames)
159 invalidnames = set(e for e in configengines if e not in validnames)
559 if invalidnames:
160 if invalidnames:
560 raise error.Abort(_('invalid compression engine defined in %s: %s') %
161 raise error.Abort(_('invalid compression engine defined in %s: %s') %
561 (config, ', '.join(sorted(invalidnames))))
162 (config, ', '.join(sorted(invalidnames))))
562
163
563 compengines = [e for e in compengines if e.name() in configengines]
164 compengines = [e for e in compengines if e.name() in configengines]
564 compengines = sorted(compengines,
165 compengines = sorted(compengines,
565 key=lambda e: configengines.index(e.name()))
166 key=lambda e: configengines.index(e.name()))
566
167
567 if not compengines:
168 if not compengines:
568 raise error.Abort(_('%s config option does not specify any known '
169 raise error.Abort(_('%s config option does not specify any known '
569 'compression engines') % config,
170 'compression engines') % config,
570 hint=_('usable compression engines: %s') %
171 hint=_('usable compression engines: %s') %
571 ', '.sorted(validnames))
172 ', '.sorted(validnames))
572
173
573 return compengines
174 return compengines
574
175
575 class commandentry(object):
176 class commandentry(object):
576 """Represents a declared wire protocol command."""
177 """Represents a declared wire protocol command."""
577 def __init__(self, func, args='', transports=None,
178 def __init__(self, func, args='', transports=None,
578 permission='push'):
179 permission='push'):
579 self.func = func
180 self.func = func
580 self.args = args
181 self.args = args
581 self.transports = transports or set()
182 self.transports = transports or set()
582 self.permission = permission
183 self.permission = permission
583
184
584 def _merge(self, func, args):
185 def _merge(self, func, args):
585 """Merge this instance with an incoming 2-tuple.
186 """Merge this instance with an incoming 2-tuple.
586
187
587 This is called when a caller using the old 2-tuple API attempts
188 This is called when a caller using the old 2-tuple API attempts
588 to replace an instance. The incoming values are merged with
189 to replace an instance. The incoming values are merged with
589 data not captured by the 2-tuple and a new instance containing
190 data not captured by the 2-tuple and a new instance containing
590 the union of the two objects is returned.
191 the union of the two objects is returned.
591 """
192 """
592 return commandentry(func, args=args, transports=set(self.transports),
193 return commandentry(func, args=args, transports=set(self.transports),
593 permission=self.permission)
194 permission=self.permission)
594
195
595 # Old code treats instances as 2-tuples. So expose that interface.
196 # Old code treats instances as 2-tuples. So expose that interface.
596 def __iter__(self):
197 def __iter__(self):
597 yield self.func
198 yield self.func
598 yield self.args
199 yield self.args
599
200
600 def __getitem__(self, i):
201 def __getitem__(self, i):
601 if i == 0:
202 if i == 0:
602 return self.func
203 return self.func
603 elif i == 1:
204 elif i == 1:
604 return self.args
205 return self.args
605 else:
206 else:
606 raise IndexError('can only access elements 0 and 1')
207 raise IndexError('can only access elements 0 and 1')
607
208
608 class commanddict(dict):
209 class commanddict(dict):
609 """Container for registered wire protocol commands.
210 """Container for registered wire protocol commands.
610
211
611 It behaves like a dict. But __setitem__ is overwritten to allow silent
212 It behaves like a dict. But __setitem__ is overwritten to allow silent
612 coercion of values from 2-tuples for API compatibility.
213 coercion of values from 2-tuples for API compatibility.
613 """
214 """
614 def __setitem__(self, k, v):
215 def __setitem__(self, k, v):
615 if isinstance(v, commandentry):
216 if isinstance(v, commandentry):
616 pass
217 pass
617 # Cast 2-tuples to commandentry instances.
218 # Cast 2-tuples to commandentry instances.
618 elif isinstance(v, tuple):
219 elif isinstance(v, tuple):
619 if len(v) != 2:
220 if len(v) != 2:
620 raise ValueError('command tuples must have exactly 2 elements')
221 raise ValueError('command tuples must have exactly 2 elements')
621
222
622 # It is common for extensions to wrap wire protocol commands via
223 # It is common for extensions to wrap wire protocol commands via
623 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
224 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
624 # doing this aren't aware of the new API that uses objects to store
225 # doing this aren't aware of the new API that uses objects to store
625 # command entries, we automatically merge old state with new.
226 # command entries, we automatically merge old state with new.
626 if k in self:
227 if k in self:
627 v = self[k]._merge(v[0], v[1])
228 v = self[k]._merge(v[0], v[1])
628 else:
229 else:
629 # Use default values from @wireprotocommand.
230 # Use default values from @wireprotocommand.
630 v = commandentry(v[0], args=v[1],
231 v = commandentry(v[0], args=v[1],
631 transports=set(wireprototypes.TRANSPORTS),
232 transports=set(wireprototypes.TRANSPORTS),
632 permission='push')
233 permission='push')
633 else:
234 else:
634 raise ValueError('command entries must be commandentry instances '
235 raise ValueError('command entries must be commandentry instances '
635 'or 2-tuples')
236 'or 2-tuples')
636
237
637 return super(commanddict, self).__setitem__(k, v)
238 return super(commanddict, self).__setitem__(k, v)
638
239
639 def commandavailable(self, command, proto):
240 def commandavailable(self, command, proto):
640 """Determine if a command is available for the requested protocol."""
241 """Determine if a command is available for the requested protocol."""
641 assert proto.name in wireprototypes.TRANSPORTS
242 assert proto.name in wireprototypes.TRANSPORTS
642
243
643 entry = self.get(command)
244 entry = self.get(command)
644
245
645 if not entry:
246 if not entry:
646 return False
247 return False
647
248
648 if proto.name not in entry.transports:
249 if proto.name not in entry.transports:
649 return False
250 return False
650
251
651 return True
252 return True
652
253
653 # Constants specifying which transports a wire protocol command should be
254 # Constants specifying which transports a wire protocol command should be
654 # available on. For use with @wireprotocommand.
255 # available on. For use with @wireprotocommand.
655 POLICY_V1_ONLY = 'v1-only'
256 POLICY_V1_ONLY = 'v1-only'
656 POLICY_V2_ONLY = 'v2-only'
257 POLICY_V2_ONLY = 'v2-only'
657
258
658 # For version 1 transports.
259 # For version 1 transports.
659 commands = commanddict()
260 commands = commanddict()
660
261
661 # For version 2 transports.
262 # For version 2 transports.
662 commandsv2 = commanddict()
263 commandsv2 = commanddict()
663
264
664 def wireprotocommand(name, args=None, transportpolicy=POLICY_V1_ONLY,
265 def wireprotocommand(name, args=None, transportpolicy=POLICY_V1_ONLY,
665 permission='push'):
266 permission='push'):
666 """Decorator to declare a wire protocol command.
267 """Decorator to declare a wire protocol command.
667
268
668 ``name`` is the name of the wire protocol command being provided.
269 ``name`` is the name of the wire protocol command being provided.
669
270
670 ``args`` defines the named arguments accepted by the command. It is
271 ``args`` defines the named arguments accepted by the command. It is
671 ideally a dict mapping argument names to their types. For backwards
272 ideally a dict mapping argument names to their types. For backwards
672 compatibility, it can be a space-delimited list of argument names. For
273 compatibility, it can be a space-delimited list of argument names. For
673 version 1 transports, ``*`` denotes a special value that says to accept
274 version 1 transports, ``*`` denotes a special value that says to accept
674 all named arguments.
275 all named arguments.
675
276
676 ``transportpolicy`` is a POLICY_* constant denoting which transports
277 ``transportpolicy`` is a POLICY_* constant denoting which transports
677 this wire protocol command should be exposed to. By default, commands
278 this wire protocol command should be exposed to. By default, commands
678 are exposed to all wire protocol transports.
279 are exposed to all wire protocol transports.
679
280
680 ``permission`` defines the permission type needed to run this command.
281 ``permission`` defines the permission type needed to run this command.
681 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
282 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
682 respectively. Default is to assume command requires ``push`` permissions
283 respectively. Default is to assume command requires ``push`` permissions
683 because otherwise commands not declaring their permissions could modify
284 because otherwise commands not declaring their permissions could modify
684 a repository that is supposed to be read-only.
285 a repository that is supposed to be read-only.
685 """
286 """
686 if transportpolicy == POLICY_V1_ONLY:
287 if transportpolicy == POLICY_V1_ONLY:
687 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
288 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
688 if v['version'] == 1}
289 if v['version'] == 1}
689 transportversion = 1
290 transportversion = 1
690 elif transportpolicy == POLICY_V2_ONLY:
291 elif transportpolicy == POLICY_V2_ONLY:
691 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
292 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
692 if v['version'] == 2}
293 if v['version'] == 2}
693 transportversion = 2
294 transportversion = 2
694 else:
295 else:
695 raise error.ProgrammingError('invalid transport policy value: %s' %
296 raise error.ProgrammingError('invalid transport policy value: %s' %
696 transportpolicy)
297 transportpolicy)
697
298
698 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
299 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
699 # SSHv2.
300 # SSHv2.
700 # TODO undo this hack when SSH is using the unified frame protocol.
301 # TODO undo this hack when SSH is using the unified frame protocol.
701 if name == b'batch':
302 if name == b'batch':
702 transports.add(wireprototypes.SSHV2)
303 transports.add(wireprototypes.SSHV2)
703
304
704 if permission not in ('push', 'pull'):
305 if permission not in ('push', 'pull'):
705 raise error.ProgrammingError('invalid wire protocol permission; '
306 raise error.ProgrammingError('invalid wire protocol permission; '
706 'got %s; expected "push" or "pull"' %
307 'got %s; expected "push" or "pull"' %
707 permission)
308 permission)
708
309
709 if transportversion == 1:
310 if transportversion == 1:
710 if args is None:
311 if args is None:
711 args = ''
312 args = ''
712
313
713 if not isinstance(args, bytes):
314 if not isinstance(args, bytes):
714 raise error.ProgrammingError('arguments for version 1 commands '
315 raise error.ProgrammingError('arguments for version 1 commands '
715 'must be declared as bytes')
316 'must be declared as bytes')
716 elif transportversion == 2:
317 elif transportversion == 2:
717 if args is None:
318 if args is None:
718 args = {}
319 args = {}
719
320
720 if not isinstance(args, dict):
321 if not isinstance(args, dict):
721 raise error.ProgrammingError('arguments for version 2 commands '
322 raise error.ProgrammingError('arguments for version 2 commands '
722 'must be declared as dicts')
323 'must be declared as dicts')
723
324
724 def register(func):
325 def register(func):
725 if transportversion == 1:
326 if transportversion == 1:
726 if name in commands:
327 if name in commands:
727 raise error.ProgrammingError('%s command already registered '
328 raise error.ProgrammingError('%s command already registered '
728 'for version 1' % name)
329 'for version 1' % name)
729 commands[name] = commandentry(func, args=args,
330 commands[name] = commandentry(func, args=args,
730 transports=transports,
331 transports=transports,
731 permission=permission)
332 permission=permission)
732 elif transportversion == 2:
333 elif transportversion == 2:
733 if name in commandsv2:
334 if name in commandsv2:
734 raise error.ProgrammingError('%s command already registered '
335 raise error.ProgrammingError('%s command already registered '
735 'for version 2' % name)
336 'for version 2' % name)
736
337
737 commandsv2[name] = commandentry(func, args=args,
338 commandsv2[name] = commandentry(func, args=args,
738 transports=transports,
339 transports=transports,
739 permission=permission)
340 permission=permission)
740 else:
341 else:
741 raise error.ProgrammingError('unhandled transport version: %d' %
342 raise error.ProgrammingError('unhandled transport version: %d' %
742 transportversion)
343 transportversion)
743
344
744 return func
345 return func
745 return register
346 return register
746
347
747 # TODO define a more appropriate permissions type to use for this.
348 # TODO define a more appropriate permissions type to use for this.
748 @wireprotocommand('batch', 'cmds *', permission='pull',
349 @wireprotocommand('batch', 'cmds *', permission='pull',
749 transportpolicy=POLICY_V1_ONLY)
350 transportpolicy=POLICY_V1_ONLY)
750 def batch(repo, proto, cmds, others):
351 def batch(repo, proto, cmds, others):
751 unescapearg = wireprototypes.unescapebatcharg
352 unescapearg = wireprototypes.unescapebatcharg
752 repo = repo.filtered("served")
353 repo = repo.filtered("served")
753 res = []
354 res = []
754 for pair in cmds.split(';'):
355 for pair in cmds.split(';'):
755 op, args = pair.split(' ', 1)
356 op, args = pair.split(' ', 1)
756 vals = {}
357 vals = {}
757 for a in args.split(','):
358 for a in args.split(','):
758 if a:
359 if a:
759 n, v = a.split('=')
360 n, v = a.split('=')
760 vals[unescapearg(n)] = unescapearg(v)
361 vals[unescapearg(n)] = unescapearg(v)
761 func, spec = commands[op]
362 func, spec = commands[op]
762
363
763 # Validate that client has permissions to perform this command.
364 # Validate that client has permissions to perform this command.
764 perm = commands[op].permission
365 perm = commands[op].permission
765 assert perm in ('push', 'pull')
366 assert perm in ('push', 'pull')
766 proto.checkperm(perm)
367 proto.checkperm(perm)
767
368
768 if spec:
369 if spec:
769 keys = spec.split()
370 keys = spec.split()
770 data = {}
371 data = {}
771 for k in keys:
372 for k in keys:
772 if k == '*':
373 if k == '*':
773 star = {}
374 star = {}
774 for key in vals.keys():
375 for key in vals.keys():
775 if key not in keys:
376 if key not in keys:
776 star[key] = vals[key]
377 star[key] = vals[key]
777 data['*'] = star
378 data['*'] = star
778 else:
379 else:
779 data[k] = vals[k]
380 data[k] = vals[k]
780 result = func(repo, proto, *[data[k] for k in keys])
381 result = func(repo, proto, *[data[k] for k in keys])
781 else:
382 else:
782 result = func(repo, proto)
383 result = func(repo, proto)
783 if isinstance(result, wireprototypes.ooberror):
384 if isinstance(result, wireprototypes.ooberror):
784 return result
385 return result
785
386
786 # For now, all batchable commands must return bytesresponse or
387 # For now, all batchable commands must return bytesresponse or
787 # raw bytes (for backwards compatibility).
388 # raw bytes (for backwards compatibility).
788 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
389 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
789 if isinstance(result, wireprototypes.bytesresponse):
390 if isinstance(result, wireprototypes.bytesresponse):
790 result = result.data
391 result = result.data
791 res.append(wireprototypes.escapebatcharg(result))
392 res.append(wireprototypes.escapebatcharg(result))
792
393
793 return wireprototypes.bytesresponse(';'.join(res))
394 return wireprototypes.bytesresponse(';'.join(res))
794
395
795 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
396 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
796 permission='pull')
397 permission='pull')
797 def between(repo, proto, pairs):
398 def between(repo, proto, pairs):
798 pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
399 pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
799 r = []
400 r = []
800 for b in repo.between(pairs):
401 for b in repo.between(pairs):
801 r.append(wireprototypes.encodelist(b) + "\n")
402 r.append(wireprototypes.encodelist(b) + "\n")
802
403
803 return wireprototypes.bytesresponse(''.join(r))
404 return wireprototypes.bytesresponse(''.join(r))
804
405
805 @wireprotocommand('branchmap', permission='pull',
406 @wireprotocommand('branchmap', permission='pull',
806 transportpolicy=POLICY_V1_ONLY)
407 transportpolicy=POLICY_V1_ONLY)
807 def branchmap(repo, proto):
408 def branchmap(repo, proto):
808 branchmap = repo.branchmap()
409 branchmap = repo.branchmap()
809 heads = []
410 heads = []
810 for branch, nodes in branchmap.iteritems():
411 for branch, nodes in branchmap.iteritems():
811 branchname = urlreq.quote(encoding.fromlocal(branch))
412 branchname = urlreq.quote(encoding.fromlocal(branch))
812 branchnodes = wireprototypes.encodelist(nodes)
413 branchnodes = wireprototypes.encodelist(nodes)
813 heads.append('%s %s' % (branchname, branchnodes))
414 heads.append('%s %s' % (branchname, branchnodes))
814
415
815 return wireprototypes.bytesresponse('\n'.join(heads))
416 return wireprototypes.bytesresponse('\n'.join(heads))
816
417
817 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
418 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
818 permission='pull')
419 permission='pull')
819 def branches(repo, proto, nodes):
420 def branches(repo, proto, nodes):
820 nodes = wireprototypes.decodelist(nodes)
421 nodes = wireprototypes.decodelist(nodes)
821 r = []
422 r = []
822 for b in repo.branches(nodes):
423 for b in repo.branches(nodes):
823 r.append(wireprototypes.encodelist(b) + "\n")
424 r.append(wireprototypes.encodelist(b) + "\n")
824
425
825 return wireprototypes.bytesresponse(''.join(r))
426 return wireprototypes.bytesresponse(''.join(r))
826
427
827 @wireprotocommand('clonebundles', '', permission='pull',
428 @wireprotocommand('clonebundles', '', permission='pull',
828 transportpolicy=POLICY_V1_ONLY)
429 transportpolicy=POLICY_V1_ONLY)
829 def clonebundles(repo, proto):
430 def clonebundles(repo, proto):
830 """Server command for returning info for available bundles to seed clones.
431 """Server command for returning info for available bundles to seed clones.
831
432
832 Clients will parse this response and determine what bundle to fetch.
433 Clients will parse this response and determine what bundle to fetch.
833
434
834 Extensions may wrap this command to filter or dynamically emit data
435 Extensions may wrap this command to filter or dynamically emit data
835 depending on the request. e.g. you could advertise URLs for the closest
436 depending on the request. e.g. you could advertise URLs for the closest
836 data center given the client's IP address.
437 data center given the client's IP address.
837 """
438 """
838 return wireprototypes.bytesresponse(
439 return wireprototypes.bytesresponse(
839 repo.vfs.tryread('clonebundles.manifest'))
440 repo.vfs.tryread('clonebundles.manifest'))
840
441
841 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
442 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
842 'known', 'getbundle', 'unbundlehash']
443 'known', 'getbundle', 'unbundlehash']
843
444
844 def _capabilities(repo, proto):
445 def _capabilities(repo, proto):
845 """return a list of capabilities for a repo
446 """return a list of capabilities for a repo
846
447
847 This function exists to allow extensions to easily wrap capabilities
448 This function exists to allow extensions to easily wrap capabilities
848 computation
449 computation
849
450
850 - returns a lists: easy to alter
451 - returns a lists: easy to alter
851 - change done here will be propagated to both `capabilities` and `hello`
452 - change done here will be propagated to both `capabilities` and `hello`
852 command without any other action needed.
453 command without any other action needed.
853 """
454 """
854 # copy to prevent modification of the global list
455 # copy to prevent modification of the global list
855 caps = list(wireprotocaps)
456 caps = list(wireprotocaps)
856
457
857 # Command of same name as capability isn't exposed to version 1 of
458 # Command of same name as capability isn't exposed to version 1 of
858 # transports. So conditionally add it.
459 # transports. So conditionally add it.
859 if commands.commandavailable('changegroupsubset', proto):
460 if commands.commandavailable('changegroupsubset', proto):
860 caps.append('changegroupsubset')
461 caps.append('changegroupsubset')
861
462
862 if streamclone.allowservergeneration(repo):
463 if streamclone.allowservergeneration(repo):
863 if repo.ui.configbool('server', 'preferuncompressed'):
464 if repo.ui.configbool('server', 'preferuncompressed'):
864 caps.append('stream-preferred')
465 caps.append('stream-preferred')
865 requiredformats = repo.requirements & repo.supportedformats
466 requiredformats = repo.requirements & repo.supportedformats
866 # if our local revlogs are just revlogv1, add 'stream' cap
467 # if our local revlogs are just revlogv1, add 'stream' cap
867 if not requiredformats - {'revlogv1'}:
468 if not requiredformats - {'revlogv1'}:
868 caps.append('stream')
469 caps.append('stream')
869 # otherwise, add 'streamreqs' detailing our local revlog format
470 # otherwise, add 'streamreqs' detailing our local revlog format
870 else:
471 else:
871 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
472 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
872 if repo.ui.configbool('experimental', 'bundle2-advertise'):
473 if repo.ui.configbool('experimental', 'bundle2-advertise'):
873 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
474 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
874 caps.append('bundle2=' + urlreq.quote(capsblob))
475 caps.append('bundle2=' + urlreq.quote(capsblob))
875 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
476 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
876
477
877 return proto.addcapabilities(repo, caps)
478 return proto.addcapabilities(repo, caps)
878
479
879 # If you are writing an extension and consider wrapping this function. Wrap
480 # If you are writing an extension and consider wrapping this function. Wrap
880 # `_capabilities` instead.
481 # `_capabilities` instead.
881 @wireprotocommand('capabilities', permission='pull',
482 @wireprotocommand('capabilities', permission='pull',
882 transportpolicy=POLICY_V1_ONLY)
483 transportpolicy=POLICY_V1_ONLY)
883 def capabilities(repo, proto):
484 def capabilities(repo, proto):
884 caps = _capabilities(repo, proto)
485 caps = _capabilities(repo, proto)
885 return wireprototypes.bytesresponse(' '.join(sorted(caps)))
486 return wireprototypes.bytesresponse(' '.join(sorted(caps)))
886
487
887 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
488 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
888 permission='pull')
489 permission='pull')
889 def changegroup(repo, proto, roots):
490 def changegroup(repo, proto, roots):
890 nodes = wireprototypes.decodelist(roots)
491 nodes = wireprototypes.decodelist(roots)
891 outgoing = discovery.outgoing(repo, missingroots=nodes,
492 outgoing = discovery.outgoing(repo, missingroots=nodes,
892 missingheads=repo.heads())
493 missingheads=repo.heads())
893 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
494 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
894 gen = iter(lambda: cg.read(32768), '')
495 gen = iter(lambda: cg.read(32768), '')
895 return wireprototypes.streamres(gen=gen)
496 return wireprototypes.streamres(gen=gen)
896
497
897 @wireprotocommand('changegroupsubset', 'bases heads',
498 @wireprotocommand('changegroupsubset', 'bases heads',
898 transportpolicy=POLICY_V1_ONLY,
499 transportpolicy=POLICY_V1_ONLY,
899 permission='pull')
500 permission='pull')
900 def changegroupsubset(repo, proto, bases, heads):
501 def changegroupsubset(repo, proto, bases, heads):
901 bases = wireprototypes.decodelist(bases)
502 bases = wireprototypes.decodelist(bases)
902 heads = wireprototypes.decodelist(heads)
503 heads = wireprototypes.decodelist(heads)
903 outgoing = discovery.outgoing(repo, missingroots=bases,
504 outgoing = discovery.outgoing(repo, missingroots=bases,
904 missingheads=heads)
505 missingheads=heads)
905 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
506 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
906 gen = iter(lambda: cg.read(32768), '')
507 gen = iter(lambda: cg.read(32768), '')
907 return wireprototypes.streamres(gen=gen)
508 return wireprototypes.streamres(gen=gen)
908
509
909 @wireprotocommand('debugwireargs', 'one two *',
510 @wireprotocommand('debugwireargs', 'one two *',
910 permission='pull', transportpolicy=POLICY_V1_ONLY)
511 permission='pull', transportpolicy=POLICY_V1_ONLY)
911 def debugwireargs(repo, proto, one, two, others):
512 def debugwireargs(repo, proto, one, two, others):
912 # only accept optional args from the known set
513 # only accept optional args from the known set
913 opts = options('debugwireargs', ['three', 'four'], others)
514 opts = options('debugwireargs', ['three', 'four'], others)
914 return wireprototypes.bytesresponse(repo.debugwireargs(
515 return wireprototypes.bytesresponse(repo.debugwireargs(
915 one, two, **pycompat.strkwargs(opts)))
516 one, two, **pycompat.strkwargs(opts)))
916
517
917 def find_pullbundle(repo, proto, opts, clheads, heads, common):
518 def find_pullbundle(repo, proto, opts, clheads, heads, common):
918 """Return a file object for the first matching pullbundle.
519 """Return a file object for the first matching pullbundle.
919
520
920 Pullbundles are specified in .hg/pullbundles.manifest similar to
521 Pullbundles are specified in .hg/pullbundles.manifest similar to
921 clonebundles.
522 clonebundles.
922 For each entry, the bundle specification is checked for compatibility:
523 For each entry, the bundle specification is checked for compatibility:
923 - Client features vs the BUNDLESPEC.
524 - Client features vs the BUNDLESPEC.
924 - Revisions shared with the clients vs base revisions of the bundle.
525 - Revisions shared with the clients vs base revisions of the bundle.
925 A bundle can be applied only if all its base revisions are known by
526 A bundle can be applied only if all its base revisions are known by
926 the client.
527 the client.
927 - At least one leaf of the bundle's DAG is missing on the client.
528 - At least one leaf of the bundle's DAG is missing on the client.
928 - Every leaf of the bundle's DAG is part of node set the client wants.
529 - Every leaf of the bundle's DAG is part of node set the client wants.
929 E.g. do not send a bundle of all changes if the client wants only
530 E.g. do not send a bundle of all changes if the client wants only
930 one specific branch of many.
531 one specific branch of many.
931 """
532 """
932 def decodehexstring(s):
533 def decodehexstring(s):
933 return set([h.decode('hex') for h in s.split(';')])
534 return set([h.decode('hex') for h in s.split(';')])
934
535
935 manifest = repo.vfs.tryread('pullbundles.manifest')
536 manifest = repo.vfs.tryread('pullbundles.manifest')
936 if not manifest:
537 if not manifest:
937 return None
538 return None
938 res = exchange.parseclonebundlesmanifest(repo, manifest)
539 res = exchange.parseclonebundlesmanifest(repo, manifest)
939 res = exchange.filterclonebundleentries(repo, res)
540 res = exchange.filterclonebundleentries(repo, res)
940 if not res:
541 if not res:
941 return None
542 return None
942 cl = repo.changelog
543 cl = repo.changelog
943 heads_anc = cl.ancestors([cl.rev(rev) for rev in heads], inclusive=True)
544 heads_anc = cl.ancestors([cl.rev(rev) for rev in heads], inclusive=True)
944 common_anc = cl.ancestors([cl.rev(rev) for rev in common], inclusive=True)
545 common_anc = cl.ancestors([cl.rev(rev) for rev in common], inclusive=True)
945 compformats = clientcompressionsupport(proto)
546 compformats = clientcompressionsupport(proto)
946 for entry in res:
547 for entry in res:
947 if 'COMPRESSION' in entry and entry['COMPRESSION'] not in compformats:
548 if 'COMPRESSION' in entry and entry['COMPRESSION'] not in compformats:
948 continue
549 continue
949 # No test yet for VERSION, since V2 is supported by any client
550 # No test yet for VERSION, since V2 is supported by any client
950 # that advertises partial pulls
551 # that advertises partial pulls
951 if 'heads' in entry:
552 if 'heads' in entry:
952 try:
553 try:
953 bundle_heads = decodehexstring(entry['heads'])
554 bundle_heads = decodehexstring(entry['heads'])
954 except TypeError:
555 except TypeError:
955 # Bad heads entry
556 # Bad heads entry
956 continue
557 continue
957 if bundle_heads.issubset(common):
558 if bundle_heads.issubset(common):
958 continue # Nothing new
559 continue # Nothing new
959 if all(cl.rev(rev) in common_anc for rev in bundle_heads):
560 if all(cl.rev(rev) in common_anc for rev in bundle_heads):
960 continue # Still nothing new
561 continue # Still nothing new
961 if any(cl.rev(rev) not in heads_anc and
562 if any(cl.rev(rev) not in heads_anc and
962 cl.rev(rev) not in common_anc for rev in bundle_heads):
563 cl.rev(rev) not in common_anc for rev in bundle_heads):
963 continue
564 continue
964 if 'bases' in entry:
565 if 'bases' in entry:
965 try:
566 try:
966 bundle_bases = decodehexstring(entry['bases'])
567 bundle_bases = decodehexstring(entry['bases'])
967 except TypeError:
568 except TypeError:
968 # Bad bases entry
569 # Bad bases entry
969 continue
570 continue
970 if not all(cl.rev(rev) in common_anc for rev in bundle_bases):
571 if not all(cl.rev(rev) in common_anc for rev in bundle_bases):
971 continue
572 continue
972 path = entry['URL']
573 path = entry['URL']
973 repo.ui.debug('sending pullbundle "%s"\n' % path)
574 repo.ui.debug('sending pullbundle "%s"\n' % path)
974 try:
575 try:
975 return repo.vfs.open(path)
576 return repo.vfs.open(path)
976 except IOError:
577 except IOError:
977 repo.ui.debug('pullbundle "%s" not accessible\n' % path)
578 repo.ui.debug('pullbundle "%s" not accessible\n' % path)
978 continue
579 continue
979 return None
580 return None
980
581
981 @wireprotocommand('getbundle', '*', permission='pull',
582 @wireprotocommand('getbundle', '*', permission='pull',
982 transportpolicy=POLICY_V1_ONLY)
583 transportpolicy=POLICY_V1_ONLY)
983 def getbundle(repo, proto, others):
584 def getbundle(repo, proto, others):
984 opts = options('getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(),
585 opts = options('getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(),
985 others)
586 others)
986 for k, v in opts.iteritems():
587 for k, v in opts.iteritems():
987 keytype = wireprototypes.GETBUNDLE_ARGUMENTS[k]
588 keytype = wireprototypes.GETBUNDLE_ARGUMENTS[k]
988 if keytype == 'nodes':
589 if keytype == 'nodes':
989 opts[k] = wireprototypes.decodelist(v)
590 opts[k] = wireprototypes.decodelist(v)
990 elif keytype == 'csv':
591 elif keytype == 'csv':
991 opts[k] = list(v.split(','))
592 opts[k] = list(v.split(','))
992 elif keytype == 'scsv':
593 elif keytype == 'scsv':
993 opts[k] = set(v.split(','))
594 opts[k] = set(v.split(','))
994 elif keytype == 'boolean':
595 elif keytype == 'boolean':
995 # Client should serialize False as '0', which is a non-empty string
596 # Client should serialize False as '0', which is a non-empty string
996 # so it evaluates as a True bool.
597 # so it evaluates as a True bool.
997 if v == '0':
598 if v == '0':
998 opts[k] = False
599 opts[k] = False
999 else:
600 else:
1000 opts[k] = bool(v)
601 opts[k] = bool(v)
1001 elif keytype != 'plain':
602 elif keytype != 'plain':
1002 raise KeyError('unknown getbundle option type %s'
603 raise KeyError('unknown getbundle option type %s'
1003 % keytype)
604 % keytype)
1004
605
1005 if not bundle1allowed(repo, 'pull'):
606 if not bundle1allowed(repo, 'pull'):
1006 if not exchange.bundle2requested(opts.get('bundlecaps')):
607 if not exchange.bundle2requested(opts.get('bundlecaps')):
1007 if proto.name == 'http-v1':
608 if proto.name == 'http-v1':
1008 return wireprototypes.ooberror(bundle2required)
609 return wireprototypes.ooberror(bundle2required)
1009 raise error.Abort(bundle2requiredmain,
610 raise error.Abort(bundle2requiredmain,
1010 hint=bundle2requiredhint)
611 hint=bundle2requiredhint)
1011
612
1012 prefercompressed = True
613 prefercompressed = True
1013
614
1014 try:
615 try:
1015 clheads = set(repo.changelog.heads())
616 clheads = set(repo.changelog.heads())
1016 heads = set(opts.get('heads', set()))
617 heads = set(opts.get('heads', set()))
1017 common = set(opts.get('common', set()))
618 common = set(opts.get('common', set()))
1018 common.discard(nullid)
619 common.discard(nullid)
1019 if (repo.ui.configbool('server', 'pullbundle') and
620 if (repo.ui.configbool('server', 'pullbundle') and
1020 'partial-pull' in proto.getprotocaps()):
621 'partial-pull' in proto.getprotocaps()):
1021 # Check if a pre-built bundle covers this request.
622 # Check if a pre-built bundle covers this request.
1022 bundle = find_pullbundle(repo, proto, opts, clheads, heads, common)
623 bundle = find_pullbundle(repo, proto, opts, clheads, heads, common)
1023 if bundle:
624 if bundle:
1024 return wireprototypes.streamres(gen=util.filechunkiter(bundle),
625 return wireprototypes.streamres(gen=util.filechunkiter(bundle),
1025 prefer_uncompressed=True)
626 prefer_uncompressed=True)
1026
627
1027 if repo.ui.configbool('server', 'disablefullbundle'):
628 if repo.ui.configbool('server', 'disablefullbundle'):
1028 # Check to see if this is a full clone.
629 # Check to see if this is a full clone.
1029 changegroup = opts.get('cg', True)
630 changegroup = opts.get('cg', True)
1030 if changegroup and not common and clheads == heads:
631 if changegroup and not common and clheads == heads:
1031 raise error.Abort(
632 raise error.Abort(
1032 _('server has pull-based clones disabled'),
633 _('server has pull-based clones disabled'),
1033 hint=_('remove --pull if specified or upgrade Mercurial'))
634 hint=_('remove --pull if specified or upgrade Mercurial'))
1034
635
1035 info, chunks = exchange.getbundlechunks(repo, 'serve',
636 info, chunks = exchange.getbundlechunks(repo, 'serve',
1036 **pycompat.strkwargs(opts))
637 **pycompat.strkwargs(opts))
1037 prefercompressed = info.get('prefercompressed', True)
638 prefercompressed = info.get('prefercompressed', True)
1038 except error.Abort as exc:
639 except error.Abort as exc:
1039 # cleanly forward Abort error to the client
640 # cleanly forward Abort error to the client
1040 if not exchange.bundle2requested(opts.get('bundlecaps')):
641 if not exchange.bundle2requested(opts.get('bundlecaps')):
1041 if proto.name == 'http-v1':
642 if proto.name == 'http-v1':
1042 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
643 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
1043 raise # cannot do better for bundle1 + ssh
644 raise # cannot do better for bundle1 + ssh
1044 # bundle2 request expect a bundle2 reply
645 # bundle2 request expect a bundle2 reply
1045 bundler = bundle2.bundle20(repo.ui)
646 bundler = bundle2.bundle20(repo.ui)
1046 manargs = [('message', pycompat.bytestr(exc))]
647 manargs = [('message', pycompat.bytestr(exc))]
1047 advargs = []
648 advargs = []
1048 if exc.hint is not None:
649 if exc.hint is not None:
1049 advargs.append(('hint', exc.hint))
650 advargs.append(('hint', exc.hint))
1050 bundler.addpart(bundle2.bundlepart('error:abort',
651 bundler.addpart(bundle2.bundlepart('error:abort',
1051 manargs, advargs))
652 manargs, advargs))
1052 chunks = bundler.getchunks()
653 chunks = bundler.getchunks()
1053 prefercompressed = False
654 prefercompressed = False
1054
655
1055 return wireprototypes.streamres(
656 return wireprototypes.streamres(
1056 gen=chunks, prefer_uncompressed=not prefercompressed)
657 gen=chunks, prefer_uncompressed=not prefercompressed)
1057
658
1058 @wireprotocommand('heads', permission='pull', transportpolicy=POLICY_V1_ONLY)
659 @wireprotocommand('heads', permission='pull', transportpolicy=POLICY_V1_ONLY)
1059 def heads(repo, proto):
660 def heads(repo, proto):
1060 h = repo.heads()
661 h = repo.heads()
1061 return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
662 return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
1062
663
1063 @wireprotocommand('hello', permission='pull', transportpolicy=POLICY_V1_ONLY)
664 @wireprotocommand('hello', permission='pull', transportpolicy=POLICY_V1_ONLY)
1064 def hello(repo, proto):
665 def hello(repo, proto):
1065 """Called as part of SSH handshake to obtain server info.
666 """Called as part of SSH handshake to obtain server info.
1066
667
1067 Returns a list of lines describing interesting things about the
668 Returns a list of lines describing interesting things about the
1068 server, in an RFC822-like format.
669 server, in an RFC822-like format.
1069
670
1070 Currently, the only one defined is ``capabilities``, which consists of a
671 Currently, the only one defined is ``capabilities``, which consists of a
1071 line of space separated tokens describing server abilities:
672 line of space separated tokens describing server abilities:
1072
673
1073 capabilities: <token0> <token1> <token2>
674 capabilities: <token0> <token1> <token2>
1074 """
675 """
1075 caps = capabilities(repo, proto).data
676 caps = capabilities(repo, proto).data
1076 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
677 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
1077
678
1078 @wireprotocommand('listkeys', 'namespace', permission='pull',
679 @wireprotocommand('listkeys', 'namespace', permission='pull',
1079 transportpolicy=POLICY_V1_ONLY)
680 transportpolicy=POLICY_V1_ONLY)
1080 def listkeys(repo, proto, namespace):
681 def listkeys(repo, proto, namespace):
1081 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
682 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
1082 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
683 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
1083
684
1084 @wireprotocommand('lookup', 'key', permission='pull',
685 @wireprotocommand('lookup', 'key', permission='pull',
1085 transportpolicy=POLICY_V1_ONLY)
686 transportpolicy=POLICY_V1_ONLY)
1086 def lookup(repo, proto, key):
687 def lookup(repo, proto, key):
1087 try:
688 try:
1088 k = encoding.tolocal(key)
689 k = encoding.tolocal(key)
1089 n = repo.lookup(k)
690 n = repo.lookup(k)
1090 r = hex(n)
691 r = hex(n)
1091 success = 1
692 success = 1
1092 except Exception as inst:
693 except Exception as inst:
1093 r = stringutil.forcebytestr(inst)
694 r = stringutil.forcebytestr(inst)
1094 success = 0
695 success = 0
1095 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
696 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
1096
697
1097 @wireprotocommand('known', 'nodes *', permission='pull',
698 @wireprotocommand('known', 'nodes *', permission='pull',
1098 transportpolicy=POLICY_V1_ONLY)
699 transportpolicy=POLICY_V1_ONLY)
1099 def known(repo, proto, nodes, others):
700 def known(repo, proto, nodes, others):
1100 v = ''.join(b and '1' or '0'
701 v = ''.join(b and '1' or '0'
1101 for b in repo.known(wireprototypes.decodelist(nodes)))
702 for b in repo.known(wireprototypes.decodelist(nodes)))
1102 return wireprototypes.bytesresponse(v)
703 return wireprototypes.bytesresponse(v)
1103
704
1104 @wireprotocommand('protocaps', 'caps', permission='pull',
705 @wireprotocommand('protocaps', 'caps', permission='pull',
1105 transportpolicy=POLICY_V1_ONLY)
706 transportpolicy=POLICY_V1_ONLY)
1106 def protocaps(repo, proto, caps):
707 def protocaps(repo, proto, caps):
1107 if proto.name == wireprototypes.SSHV1:
708 if proto.name == wireprototypes.SSHV1:
1108 proto._protocaps = set(caps.split(' '))
709 proto._protocaps = set(caps.split(' '))
1109 return wireprototypes.bytesresponse('OK')
710 return wireprototypes.bytesresponse('OK')
1110
711
1111 @wireprotocommand('pushkey', 'namespace key old new', permission='push',
712 @wireprotocommand('pushkey', 'namespace key old new', permission='push',
1112 transportpolicy=POLICY_V1_ONLY)
713 transportpolicy=POLICY_V1_ONLY)
1113 def pushkey(repo, proto, namespace, key, old, new):
714 def pushkey(repo, proto, namespace, key, old, new):
1114 # compatibility with pre-1.8 clients which were accidentally
715 # compatibility with pre-1.8 clients which were accidentally
1115 # sending raw binary nodes rather than utf-8-encoded hex
716 # sending raw binary nodes rather than utf-8-encoded hex
1116 if len(new) == 20 and stringutil.escapestr(new) != new:
717 if len(new) == 20 and stringutil.escapestr(new) != new:
1117 # looks like it could be a binary node
718 # looks like it could be a binary node
1118 try:
719 try:
1119 new.decode('utf-8')
720 new.decode('utf-8')
1120 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
721 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
1121 except UnicodeDecodeError:
722 except UnicodeDecodeError:
1122 pass # binary, leave unmodified
723 pass # binary, leave unmodified
1123 else:
724 else:
1124 new = encoding.tolocal(new) # normal path
725 new = encoding.tolocal(new) # normal path
1125
726
1126 with proto.mayberedirectstdio() as output:
727 with proto.mayberedirectstdio() as output:
1127 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
728 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
1128 encoding.tolocal(old), new) or False
729 encoding.tolocal(old), new) or False
1129
730
1130 output = output.getvalue() if output else ''
731 output = output.getvalue() if output else ''
1131 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
732 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
1132
733
1133 @wireprotocommand('stream_out', permission='pull',
734 @wireprotocommand('stream_out', permission='pull',
1134 transportpolicy=POLICY_V1_ONLY)
735 transportpolicy=POLICY_V1_ONLY)
1135 def stream(repo, proto):
736 def stream(repo, proto):
1136 '''If the server supports streaming clone, it advertises the "stream"
737 '''If the server supports streaming clone, it advertises the "stream"
1137 capability with a value representing the version and flags of the repo
738 capability with a value representing the version and flags of the repo
1138 it is serving. Client checks to see if it understands the format.
739 it is serving. Client checks to see if it understands the format.
1139 '''
740 '''
1140 return wireprototypes.streamreslegacy(
741 return wireprototypes.streamreslegacy(
1141 streamclone.generatev1wireproto(repo))
742 streamclone.generatev1wireproto(repo))
1142
743
1143 @wireprotocommand('unbundle', 'heads', permission='push',
744 @wireprotocommand('unbundle', 'heads', permission='push',
1144 transportpolicy=POLICY_V1_ONLY)
745 transportpolicy=POLICY_V1_ONLY)
1145 def unbundle(repo, proto, heads):
746 def unbundle(repo, proto, heads):
1146 their_heads = wireprototypes.decodelist(heads)
747 their_heads = wireprototypes.decodelist(heads)
1147
748
1148 with proto.mayberedirectstdio() as output:
749 with proto.mayberedirectstdio() as output:
1149 try:
750 try:
1150 exchange.check_heads(repo, their_heads, 'preparing changes')
751 exchange.check_heads(repo, their_heads, 'preparing changes')
1151 cleanup = lambda: None
752 cleanup = lambda: None
1152 try:
753 try:
1153 payload = proto.getpayload()
754 payload = proto.getpayload()
1154 if repo.ui.configbool('server', 'streamunbundle'):
755 if repo.ui.configbool('server', 'streamunbundle'):
1155 def cleanup():
756 def cleanup():
1156 # Ensure that the full payload is consumed, so
757 # Ensure that the full payload is consumed, so
1157 # that the connection doesn't contain trailing garbage.
758 # that the connection doesn't contain trailing garbage.
1158 for p in payload:
759 for p in payload:
1159 pass
760 pass
1160 fp = util.chunkbuffer(payload)
761 fp = util.chunkbuffer(payload)
1161 else:
762 else:
1162 # write bundle data to temporary file as it can be big
763 # write bundle data to temporary file as it can be big
1163 fp, tempname = None, None
764 fp, tempname = None, None
1164 def cleanup():
765 def cleanup():
1165 if fp:
766 if fp:
1166 fp.close()
767 fp.close()
1167 if tempname:
768 if tempname:
1168 os.unlink(tempname)
769 os.unlink(tempname)
1169 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
770 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1170 repo.ui.debug('redirecting incoming bundle to %s\n' %
771 repo.ui.debug('redirecting incoming bundle to %s\n' %
1171 tempname)
772 tempname)
1172 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
773 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1173 r = 0
774 r = 0
1174 for p in payload:
775 for p in payload:
1175 fp.write(p)
776 fp.write(p)
1176 fp.seek(0)
777 fp.seek(0)
1177
778
1178 gen = exchange.readbundle(repo.ui, fp, None)
779 gen = exchange.readbundle(repo.ui, fp, None)
1179 if (isinstance(gen, changegroupmod.cg1unpacker)
780 if (isinstance(gen, changegroupmod.cg1unpacker)
1180 and not bundle1allowed(repo, 'push')):
781 and not bundle1allowed(repo, 'push')):
1181 if proto.name == 'http-v1':
782 if proto.name == 'http-v1':
1182 # need to special case http because stderr do not get to
783 # need to special case http because stderr do not get to
1183 # the http client on failed push so we need to abuse
784 # the http client on failed push so we need to abuse
1184 # some other error type to make sure the message get to
785 # some other error type to make sure the message get to
1185 # the user.
786 # the user.
1186 return wireprototypes.ooberror(bundle2required)
787 return wireprototypes.ooberror(bundle2required)
1187 raise error.Abort(bundle2requiredmain,
788 raise error.Abort(bundle2requiredmain,
1188 hint=bundle2requiredhint)
789 hint=bundle2requiredhint)
1189
790
1190 r = exchange.unbundle(repo, gen, their_heads, 'serve',
791 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1191 proto.client())
792 proto.client())
1192 if util.safehasattr(r, 'addpart'):
793 if util.safehasattr(r, 'addpart'):
1193 # The return looks streamable, we are in the bundle2 case
794 # The return looks streamable, we are in the bundle2 case
1194 # and should return a stream.
795 # and should return a stream.
1195 return wireprototypes.streamreslegacy(gen=r.getchunks())
796 return wireprototypes.streamreslegacy(gen=r.getchunks())
1196 return wireprototypes.pushres(
797 return wireprototypes.pushres(
1197 r, output.getvalue() if output else '')
798 r, output.getvalue() if output else '')
1198
799
1199 finally:
800 finally:
1200 cleanup()
801 cleanup()
1201
802
1202 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
803 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1203 # handle non-bundle2 case first
804 # handle non-bundle2 case first
1204 if not getattr(exc, 'duringunbundle2', False):
805 if not getattr(exc, 'duringunbundle2', False):
1205 try:
806 try:
1206 raise
807 raise
1207 except error.Abort:
808 except error.Abort:
1208 # The old code we moved used procutil.stderr directly.
809 # The old code we moved used procutil.stderr directly.
1209 # We did not change it to minimise code change.
810 # We did not change it to minimise code change.
1210 # This need to be moved to something proper.
811 # This need to be moved to something proper.
1211 # Feel free to do it.
812 # Feel free to do it.
1212 procutil.stderr.write("abort: %s\n" % exc)
813 procutil.stderr.write("abort: %s\n" % exc)
1213 if exc.hint is not None:
814 if exc.hint is not None:
1214 procutil.stderr.write("(%s)\n" % exc.hint)
815 procutil.stderr.write("(%s)\n" % exc.hint)
1215 procutil.stderr.flush()
816 procutil.stderr.flush()
1216 return wireprototypes.pushres(
817 return wireprototypes.pushres(
1217 0, output.getvalue() if output else '')
818 0, output.getvalue() if output else '')
1218 except error.PushRaced:
819 except error.PushRaced:
1219 return wireprototypes.pusherr(
820 return wireprototypes.pusherr(
1220 pycompat.bytestr(exc),
821 pycompat.bytestr(exc),
1221 output.getvalue() if output else '')
822 output.getvalue() if output else '')
1222
823
1223 bundler = bundle2.bundle20(repo.ui)
824 bundler = bundle2.bundle20(repo.ui)
1224 for out in getattr(exc, '_bundle2salvagedoutput', ()):
825 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1225 bundler.addpart(out)
826 bundler.addpart(out)
1226 try:
827 try:
1227 try:
828 try:
1228 raise
829 raise
1229 except error.PushkeyFailed as exc:
830 except error.PushkeyFailed as exc:
1230 # check client caps
831 # check client caps
1231 remotecaps = getattr(exc, '_replycaps', None)
832 remotecaps = getattr(exc, '_replycaps', None)
1232 if (remotecaps is not None
833 if (remotecaps is not None
1233 and 'pushkey' not in remotecaps.get('error', ())):
834 and 'pushkey' not in remotecaps.get('error', ())):
1234 # no support remote side, fallback to Abort handler.
835 # no support remote side, fallback to Abort handler.
1235 raise
836 raise
1236 part = bundler.newpart('error:pushkey')
837 part = bundler.newpart('error:pushkey')
1237 part.addparam('in-reply-to', exc.partid)
838 part.addparam('in-reply-to', exc.partid)
1238 if exc.namespace is not None:
839 if exc.namespace is not None:
1239 part.addparam('namespace', exc.namespace,
840 part.addparam('namespace', exc.namespace,
1240 mandatory=False)
841 mandatory=False)
1241 if exc.key is not None:
842 if exc.key is not None:
1242 part.addparam('key', exc.key, mandatory=False)
843 part.addparam('key', exc.key, mandatory=False)
1243 if exc.new is not None:
844 if exc.new is not None:
1244 part.addparam('new', exc.new, mandatory=False)
845 part.addparam('new', exc.new, mandatory=False)
1245 if exc.old is not None:
846 if exc.old is not None:
1246 part.addparam('old', exc.old, mandatory=False)
847 part.addparam('old', exc.old, mandatory=False)
1247 if exc.ret is not None:
848 if exc.ret is not None:
1248 part.addparam('ret', exc.ret, mandatory=False)
849 part.addparam('ret', exc.ret, mandatory=False)
1249 except error.BundleValueError as exc:
850 except error.BundleValueError as exc:
1250 errpart = bundler.newpart('error:unsupportedcontent')
851 errpart = bundler.newpart('error:unsupportedcontent')
1251 if exc.parttype is not None:
852 if exc.parttype is not None:
1252 errpart.addparam('parttype', exc.parttype)
853 errpart.addparam('parttype', exc.parttype)
1253 if exc.params:
854 if exc.params:
1254 errpart.addparam('params', '\0'.join(exc.params))
855 errpart.addparam('params', '\0'.join(exc.params))
1255 except error.Abort as exc:
856 except error.Abort as exc:
1256 manargs = [('message', stringutil.forcebytestr(exc))]
857 manargs = [('message', stringutil.forcebytestr(exc))]
1257 advargs = []
858 advargs = []
1258 if exc.hint is not None:
859 if exc.hint is not None:
1259 advargs.append(('hint', exc.hint))
860 advargs.append(('hint', exc.hint))
1260 bundler.addpart(bundle2.bundlepart('error:abort',
861 bundler.addpart(bundle2.bundlepart('error:abort',
1261 manargs, advargs))
862 manargs, advargs))
1262 except error.PushRaced as exc:
863 except error.PushRaced as exc:
1263 bundler.newpart('error:pushraced',
864 bundler.newpart('error:pushraced',
1264 [('message', stringutil.forcebytestr(exc))])
865 [('message', stringutil.forcebytestr(exc))])
1265 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
866 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
This diff has been collapsed as it changes many lines, (847 lines changed) Show them Hide them
@@ -1,1265 +1,420 b''
1 # wireproto.py - generic wire protocol support functions
1 # wireprotov1peer.py - Client-side functionality for wire protocol version 1.
2 #
2 #
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import hashlib
10 import hashlib
11 import os
12 import tempfile
13
11
14 from .i18n import _
12 from .i18n import _
15 from .node import (
13 from .node import (
16 bin,
14 bin,
17 hex,
18 nullid,
19 )
15 )
20
16
21 from . import (
17 from . import (
22 bundle2,
18 bundle2,
23 changegroup as changegroupmod,
19 changegroup as changegroupmod,
24 discovery,
25 encoding,
20 encoding,
26 error,
21 error,
27 exchange,
28 peer,
22 peer,
29 pushkey as pushkeymod,
23 pushkey as pushkeymod,
30 pycompat,
24 pycompat,
31 repository,
25 repository,
32 streamclone,
33 util,
26 util,
34 wireprototypes,
27 wireprototypes,
35 )
28 )
36
29
37 from .utils import (
38 procutil,
39 stringutil,
40 )
41
42 urlerr = util.urlerr
43 urlreq = util.urlreq
30 urlreq = util.urlreq
44
31
45 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
46 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
47 'IncompatibleClient')
48 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
49
50 class remoteiterbatcher(peer.iterbatcher):
32 class remoteiterbatcher(peer.iterbatcher):
51 def __init__(self, remote):
33 def __init__(self, remote):
52 super(remoteiterbatcher, self).__init__()
34 super(remoteiterbatcher, self).__init__()
53 self._remote = remote
35 self._remote = remote
54
36
55 def __getattr__(self, name):
37 def __getattr__(self, name):
56 # Validate this method is batchable, since submit() only supports
38 # Validate this method is batchable, since submit() only supports
57 # batchable methods.
39 # batchable methods.
58 fn = getattr(self._remote, name)
40 fn = getattr(self._remote, name)
59 if not getattr(fn, 'batchable', None):
41 if not getattr(fn, 'batchable', None):
60 raise error.ProgrammingError('Attempted to batch a non-batchable '
42 raise error.ProgrammingError('Attempted to batch a non-batchable '
61 'call to %r' % name)
43 'call to %r' % name)
62
44
63 return super(remoteiterbatcher, self).__getattr__(name)
45 return super(remoteiterbatcher, self).__getattr__(name)
64
46
65 def submit(self):
47 def submit(self):
66 """Break the batch request into many patch calls and pipeline them.
48 """Break the batch request into many patch calls and pipeline them.
67
49
68 This is mostly valuable over http where request sizes can be
50 This is mostly valuable over http where request sizes can be
69 limited, but can be used in other places as well.
51 limited, but can be used in other places as well.
70 """
52 """
71 # 2-tuple of (command, arguments) that represents what will be
53 # 2-tuple of (command, arguments) that represents what will be
72 # sent over the wire.
54 # sent over the wire.
73 requests = []
55 requests = []
74
56
75 # 4-tuple of (command, final future, @batchable generator, remote
57 # 4-tuple of (command, final future, @batchable generator, remote
76 # future).
58 # future).
77 results = []
59 results = []
78
60
79 for command, args, opts, finalfuture in self.calls:
61 for command, args, opts, finalfuture in self.calls:
80 mtd = getattr(self._remote, command)
62 mtd = getattr(self._remote, command)
81 batchable = mtd.batchable(mtd.__self__, *args, **opts)
63 batchable = mtd.batchable(mtd.__self__, *args, **opts)
82
64
83 commandargs, fremote = next(batchable)
65 commandargs, fremote = next(batchable)
84 assert fremote
66 assert fremote
85 requests.append((command, commandargs))
67 requests.append((command, commandargs))
86 results.append((command, finalfuture, batchable, fremote))
68 results.append((command, finalfuture, batchable, fremote))
87
69
88 if requests:
70 if requests:
89 self._resultiter = self._remote._submitbatch(requests)
71 self._resultiter = self._remote._submitbatch(requests)
90
72
91 self._results = results
73 self._results = results
92
74
93 def results(self):
75 def results(self):
94 for command, finalfuture, batchable, remotefuture in self._results:
76 for command, finalfuture, batchable, remotefuture in self._results:
95 # Get the raw result, set it in the remote future, feed it
77 # Get the raw result, set it in the remote future, feed it
96 # back into the @batchable generator so it can be decoded, and
78 # back into the @batchable generator so it can be decoded, and
97 # set the result on the final future to this value.
79 # set the result on the final future to this value.
98 remoteresult = next(self._resultiter)
80 remoteresult = next(self._resultiter)
99 remotefuture.set(remoteresult)
81 remotefuture.set(remoteresult)
100 finalfuture.set(next(batchable))
82 finalfuture.set(next(batchable))
101
83
102 # Verify our @batchable generators only emit 2 values.
84 # Verify our @batchable generators only emit 2 values.
103 try:
85 try:
104 next(batchable)
86 next(batchable)
105 except StopIteration:
87 except StopIteration:
106 pass
88 pass
107 else:
89 else:
108 raise error.ProgrammingError('%s @batchable generator emitted '
90 raise error.ProgrammingError('%s @batchable generator emitted '
109 'unexpected value count' % command)
91 'unexpected value count' % command)
110
92
111 yield finalfuture.value
93 yield finalfuture.value
112
94
113 # Forward a couple of names from peer to make wireproto interactions
95 # Forward a couple of names from peer to make wireproto interactions
114 # slightly more sensible.
96 # slightly more sensible.
115 batchable = peer.batchable
97 batchable = peer.batchable
116 future = peer.future
98 future = peer.future
117
99
118
119 def encodebatchcmds(req):
100 def encodebatchcmds(req):
120 """Return a ``cmds`` argument value for the ``batch`` command."""
101 """Return a ``cmds`` argument value for the ``batch`` command."""
121 escapearg = wireprototypes.escapebatcharg
102 escapearg = wireprototypes.escapebatcharg
122
103
123 cmds = []
104 cmds = []
124 for op, argsdict in req:
105 for op, argsdict in req:
125 # Old servers didn't properly unescape argument names. So prevent
106 # Old servers didn't properly unescape argument names. So prevent
126 # the sending of argument names that may not be decoded properly by
107 # the sending of argument names that may not be decoded properly by
127 # servers.
108 # servers.
128 assert all(escapearg(k) == k for k in argsdict)
109 assert all(escapearg(k) == k for k in argsdict)
129
110
130 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
111 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
131 for k, v in argsdict.iteritems())
112 for k, v in argsdict.iteritems())
132 cmds.append('%s %s' % (op, args))
113 cmds.append('%s %s' % (op, args))
133
114
134 return ';'.join(cmds)
115 return ';'.join(cmds)
135
116
136 def clientcompressionsupport(proto):
137 """Returns a list of compression methods supported by the client.
138
139 Returns a list of the compression methods supported by the client
140 according to the protocol capabilities. If no such capability has
141 been announced, fallback to the default of zlib and uncompressed.
142 """
143 for cap in proto.getprotocaps():
144 if cap.startswith('comp='):
145 return cap[5:].split(',')
146 return ['zlib', 'none']
147
148 # client side
149
150 class wirepeer(repository.legacypeer):
117 class wirepeer(repository.legacypeer):
151 """Client-side interface for communicating with a peer repository.
118 """Client-side interface for communicating with a peer repository.
152
119
153 Methods commonly call wire protocol commands of the same name.
120 Methods commonly call wire protocol commands of the same name.
154
121
155 See also httppeer.py and sshpeer.py for protocol-specific
122 See also httppeer.py and sshpeer.py for protocol-specific
156 implementations of this interface.
123 implementations of this interface.
157 """
124 """
158 # Begin of ipeercommands interface.
125 # Begin of ipeercommands interface.
159
126
160 def iterbatch(self):
127 def iterbatch(self):
161 return remoteiterbatcher(self)
128 return remoteiterbatcher(self)
162
129
163 @batchable
130 @batchable
164 def lookup(self, key):
131 def lookup(self, key):
165 self.requirecap('lookup', _('look up remote revision'))
132 self.requirecap('lookup', _('look up remote revision'))
166 f = future()
133 f = future()
167 yield {'key': encoding.fromlocal(key)}, f
134 yield {'key': encoding.fromlocal(key)}, f
168 d = f.value
135 d = f.value
169 success, data = d[:-1].split(" ", 1)
136 success, data = d[:-1].split(" ", 1)
170 if int(success):
137 if int(success):
171 yield bin(data)
138 yield bin(data)
172 else:
139 else:
173 self._abort(error.RepoError(data))
140 self._abort(error.RepoError(data))
174
141
175 @batchable
142 @batchable
176 def heads(self):
143 def heads(self):
177 f = future()
144 f = future()
178 yield {}, f
145 yield {}, f
179 d = f.value
146 d = f.value
180 try:
147 try:
181 yield wireprototypes.decodelist(d[:-1])
148 yield wireprototypes.decodelist(d[:-1])
182 except ValueError:
149 except ValueError:
183 self._abort(error.ResponseError(_("unexpected response:"), d))
150 self._abort(error.ResponseError(_("unexpected response:"), d))
184
151
185 @batchable
152 @batchable
186 def known(self, nodes):
153 def known(self, nodes):
187 f = future()
154 f = future()
188 yield {'nodes': wireprototypes.encodelist(nodes)}, f
155 yield {'nodes': wireprototypes.encodelist(nodes)}, f
189 d = f.value
156 d = f.value
190 try:
157 try:
191 yield [bool(int(b)) for b in d]
158 yield [bool(int(b)) for b in d]
192 except ValueError:
159 except ValueError:
193 self._abort(error.ResponseError(_("unexpected response:"), d))
160 self._abort(error.ResponseError(_("unexpected response:"), d))
194
161
195 @batchable
162 @batchable
196 def branchmap(self):
163 def branchmap(self):
197 f = future()
164 f = future()
198 yield {}, f
165 yield {}, f
199 d = f.value
166 d = f.value
200 try:
167 try:
201 branchmap = {}
168 branchmap = {}
202 for branchpart in d.splitlines():
169 for branchpart in d.splitlines():
203 branchname, branchheads = branchpart.split(' ', 1)
170 branchname, branchheads = branchpart.split(' ', 1)
204 branchname = encoding.tolocal(urlreq.unquote(branchname))
171 branchname = encoding.tolocal(urlreq.unquote(branchname))
205 branchheads = wireprototypes.decodelist(branchheads)
172 branchheads = wireprototypes.decodelist(branchheads)
206 branchmap[branchname] = branchheads
173 branchmap[branchname] = branchheads
207 yield branchmap
174 yield branchmap
208 except TypeError:
175 except TypeError:
209 self._abort(error.ResponseError(_("unexpected response:"), d))
176 self._abort(error.ResponseError(_("unexpected response:"), d))
210
177
211 @batchable
178 @batchable
212 def listkeys(self, namespace):
179 def listkeys(self, namespace):
213 if not self.capable('pushkey'):
180 if not self.capable('pushkey'):
214 yield {}, None
181 yield {}, None
215 f = future()
182 f = future()
216 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
183 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
217 yield {'namespace': encoding.fromlocal(namespace)}, f
184 yield {'namespace': encoding.fromlocal(namespace)}, f
218 d = f.value
185 d = f.value
219 self.ui.debug('received listkey for "%s": %i bytes\n'
186 self.ui.debug('received listkey for "%s": %i bytes\n'
220 % (namespace, len(d)))
187 % (namespace, len(d)))
221 yield pushkeymod.decodekeys(d)
188 yield pushkeymod.decodekeys(d)
222
189
223 @batchable
190 @batchable
224 def pushkey(self, namespace, key, old, new):
191 def pushkey(self, namespace, key, old, new):
225 if not self.capable('pushkey'):
192 if not self.capable('pushkey'):
226 yield False, None
193 yield False, None
227 f = future()
194 f = future()
228 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
195 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
229 yield {'namespace': encoding.fromlocal(namespace),
196 yield {'namespace': encoding.fromlocal(namespace),
230 'key': encoding.fromlocal(key),
197 'key': encoding.fromlocal(key),
231 'old': encoding.fromlocal(old),
198 'old': encoding.fromlocal(old),
232 'new': encoding.fromlocal(new)}, f
199 'new': encoding.fromlocal(new)}, f
233 d = f.value
200 d = f.value
234 d, output = d.split('\n', 1)
201 d, output = d.split('\n', 1)
235 try:
202 try:
236 d = bool(int(d))
203 d = bool(int(d))
237 except ValueError:
204 except ValueError:
238 raise error.ResponseError(
205 raise error.ResponseError(
239 _('push failed (unexpected response):'), d)
206 _('push failed (unexpected response):'), d)
240 for l in output.splitlines(True):
207 for l in output.splitlines(True):
241 self.ui.status(_('remote: '), l)
208 self.ui.status(_('remote: '), l)
242 yield d
209 yield d
243
210
244 def stream_out(self):
211 def stream_out(self):
245 return self._callstream('stream_out')
212 return self._callstream('stream_out')
246
213
247 def getbundle(self, source, **kwargs):
214 def getbundle(self, source, **kwargs):
248 kwargs = pycompat.byteskwargs(kwargs)
215 kwargs = pycompat.byteskwargs(kwargs)
249 self.requirecap('getbundle', _('look up remote changes'))
216 self.requirecap('getbundle', _('look up remote changes'))
250 opts = {}
217 opts = {}
251 bundlecaps = kwargs.get('bundlecaps') or set()
218 bundlecaps = kwargs.get('bundlecaps') or set()
252 for key, value in kwargs.iteritems():
219 for key, value in kwargs.iteritems():
253 if value is None:
220 if value is None:
254 continue
221 continue
255 keytype = wireprototypes.GETBUNDLE_ARGUMENTS.get(key)
222 keytype = wireprototypes.GETBUNDLE_ARGUMENTS.get(key)
256 if keytype is None:
223 if keytype is None:
257 raise error.ProgrammingError(
224 raise error.ProgrammingError(
258 'Unexpectedly None keytype for key %s' % key)
225 'Unexpectedly None keytype for key %s' % key)
259 elif keytype == 'nodes':
226 elif keytype == 'nodes':
260 value = wireprototypes.encodelist(value)
227 value = wireprototypes.encodelist(value)
261 elif keytype == 'csv':
228 elif keytype == 'csv':
262 value = ','.join(value)
229 value = ','.join(value)
263 elif keytype == 'scsv':
230 elif keytype == 'scsv':
264 value = ','.join(sorted(value))
231 value = ','.join(sorted(value))
265 elif keytype == 'boolean':
232 elif keytype == 'boolean':
266 value = '%i' % bool(value)
233 value = '%i' % bool(value)
267 elif keytype != 'plain':
234 elif keytype != 'plain':
268 raise KeyError('unknown getbundle option type %s'
235 raise KeyError('unknown getbundle option type %s'
269 % keytype)
236 % keytype)
270 opts[key] = value
237 opts[key] = value
271 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
238 f = self._callcompressable("getbundle", **pycompat.strkwargs(opts))
272 if any((cap.startswith('HG2') for cap in bundlecaps)):
239 if any((cap.startswith('HG2') for cap in bundlecaps)):
273 return bundle2.getunbundler(self.ui, f)
240 return bundle2.getunbundler(self.ui, f)
274 else:
241 else:
275 return changegroupmod.cg1unpacker(f, 'UN')
242 return changegroupmod.cg1unpacker(f, 'UN')
276
243
277 def unbundle(self, cg, heads, url):
244 def unbundle(self, cg, heads, url):
278 '''Send cg (a readable file-like object representing the
245 '''Send cg (a readable file-like object representing the
279 changegroup to push, typically a chunkbuffer object) to the
246 changegroup to push, typically a chunkbuffer object) to the
280 remote server as a bundle.
247 remote server as a bundle.
281
248
282 When pushing a bundle10 stream, return an integer indicating the
249 When pushing a bundle10 stream, return an integer indicating the
283 result of the push (see changegroup.apply()).
250 result of the push (see changegroup.apply()).
284
251
285 When pushing a bundle20 stream, return a bundle20 stream.
252 When pushing a bundle20 stream, return a bundle20 stream.
286
253
287 `url` is the url the client thinks it's pushing to, which is
254 `url` is the url the client thinks it's pushing to, which is
288 visible to hooks.
255 visible to hooks.
289 '''
256 '''
290
257
291 if heads != ['force'] and self.capable('unbundlehash'):
258 if heads != ['force'] and self.capable('unbundlehash'):
292 heads = wireprototypes.encodelist(
259 heads = wireprototypes.encodelist(
293 ['hashed', hashlib.sha1(''.join(sorted(heads))).digest()])
260 ['hashed', hashlib.sha1(''.join(sorted(heads))).digest()])
294 else:
261 else:
295 heads = wireprototypes.encodelist(heads)
262 heads = wireprototypes.encodelist(heads)
296
263
297 if util.safehasattr(cg, 'deltaheader'):
264 if util.safehasattr(cg, 'deltaheader'):
298 # this a bundle10, do the old style call sequence
265 # this a bundle10, do the old style call sequence
299 ret, output = self._callpush("unbundle", cg, heads=heads)
266 ret, output = self._callpush("unbundle", cg, heads=heads)
300 if ret == "":
267 if ret == "":
301 raise error.ResponseError(
268 raise error.ResponseError(
302 _('push failed:'), output)
269 _('push failed:'), output)
303 try:
270 try:
304 ret = int(ret)
271 ret = int(ret)
305 except ValueError:
272 except ValueError:
306 raise error.ResponseError(
273 raise error.ResponseError(
307 _('push failed (unexpected response):'), ret)
274 _('push failed (unexpected response):'), ret)
308
275
309 for l in output.splitlines(True):
276 for l in output.splitlines(True):
310 self.ui.status(_('remote: '), l)
277 self.ui.status(_('remote: '), l)
311 else:
278 else:
312 # bundle2 push. Send a stream, fetch a stream.
279 # bundle2 push. Send a stream, fetch a stream.
313 stream = self._calltwowaystream('unbundle', cg, heads=heads)
280 stream = self._calltwowaystream('unbundle', cg, heads=heads)
314 ret = bundle2.getunbundler(self.ui, stream)
281 ret = bundle2.getunbundler(self.ui, stream)
315 return ret
282 return ret
316
283
317 # End of ipeercommands interface.
284 # End of ipeercommands interface.
318
285
319 # Begin of ipeerlegacycommands interface.
286 # Begin of ipeerlegacycommands interface.
320
287
321 def branches(self, nodes):
288 def branches(self, nodes):
322 n = wireprototypes.encodelist(nodes)
289 n = wireprototypes.encodelist(nodes)
323 d = self._call("branches", nodes=n)
290 d = self._call("branches", nodes=n)
324 try:
291 try:
325 br = [tuple(wireprototypes.decodelist(b)) for b in d.splitlines()]
292 br = [tuple(wireprototypes.decodelist(b)) for b in d.splitlines()]
326 return br
293 return br
327 except ValueError:
294 except ValueError:
328 self._abort(error.ResponseError(_("unexpected response:"), d))
295 self._abort(error.ResponseError(_("unexpected response:"), d))
329
296
330 def between(self, pairs):
297 def between(self, pairs):
331 batch = 8 # avoid giant requests
298 batch = 8 # avoid giant requests
332 r = []
299 r = []
333 for i in xrange(0, len(pairs), batch):
300 for i in xrange(0, len(pairs), batch):
334 n = " ".join([wireprototypes.encodelist(p, '-')
301 n = " ".join([wireprototypes.encodelist(p, '-')
335 for p in pairs[i:i + batch]])
302 for p in pairs[i:i + batch]])
336 d = self._call("between", pairs=n)
303 d = self._call("between", pairs=n)
337 try:
304 try:
338 r.extend(l and wireprototypes.decodelist(l) or []
305 r.extend(l and wireprototypes.decodelist(l) or []
339 for l in d.splitlines())
306 for l in d.splitlines())
340 except ValueError:
307 except ValueError:
341 self._abort(error.ResponseError(_("unexpected response:"), d))
308 self._abort(error.ResponseError(_("unexpected response:"), d))
342 return r
309 return r
343
310
344 def changegroup(self, nodes, kind):
311 def changegroup(self, nodes, kind):
345 n = wireprototypes.encodelist(nodes)
312 n = wireprototypes.encodelist(nodes)
346 f = self._callcompressable("changegroup", roots=n)
313 f = self._callcompressable("changegroup", roots=n)
347 return changegroupmod.cg1unpacker(f, 'UN')
314 return changegroupmod.cg1unpacker(f, 'UN')
348
315
349 def changegroupsubset(self, bases, heads, kind):
316 def changegroupsubset(self, bases, heads, kind):
350 self.requirecap('changegroupsubset', _('look up remote changes'))
317 self.requirecap('changegroupsubset', _('look up remote changes'))
351 bases = wireprototypes.encodelist(bases)
318 bases = wireprototypes.encodelist(bases)
352 heads = wireprototypes.encodelist(heads)
319 heads = wireprototypes.encodelist(heads)
353 f = self._callcompressable("changegroupsubset",
320 f = self._callcompressable("changegroupsubset",
354 bases=bases, heads=heads)
321 bases=bases, heads=heads)
355 return changegroupmod.cg1unpacker(f, 'UN')
322 return changegroupmod.cg1unpacker(f, 'UN')
356
323
357 # End of ipeerlegacycommands interface.
324 # End of ipeerlegacycommands interface.
358
325
359 def _submitbatch(self, req):
326 def _submitbatch(self, req):
360 """run batch request <req> on the server
327 """run batch request <req> on the server
361
328
362 Returns an iterator of the raw responses from the server.
329 Returns an iterator of the raw responses from the server.
363 """
330 """
364 ui = self.ui
331 ui = self.ui
365 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
332 if ui.debugflag and ui.configbool('devel', 'debug.peer-request'):
366 ui.debug('devel-peer-request: batched-content\n')
333 ui.debug('devel-peer-request: batched-content\n')
367 for op, args in req:
334 for op, args in req:
368 msg = 'devel-peer-request: - %s (%d arguments)\n'
335 msg = 'devel-peer-request: - %s (%d arguments)\n'
369 ui.debug(msg % (op, len(args)))
336 ui.debug(msg % (op, len(args)))
370
337
371 unescapearg = wireprototypes.unescapebatcharg
338 unescapearg = wireprototypes.unescapebatcharg
372
339
373 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
340 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
374 chunk = rsp.read(1024)
341 chunk = rsp.read(1024)
375 work = [chunk]
342 work = [chunk]
376 while chunk:
343 while chunk:
377 while ';' not in chunk and chunk:
344 while ';' not in chunk and chunk:
378 chunk = rsp.read(1024)
345 chunk = rsp.read(1024)
379 work.append(chunk)
346 work.append(chunk)
380 merged = ''.join(work)
347 merged = ''.join(work)
381 while ';' in merged:
348 while ';' in merged:
382 one, merged = merged.split(';', 1)
349 one, merged = merged.split(';', 1)
383 yield unescapearg(one)
350 yield unescapearg(one)
384 chunk = rsp.read(1024)
351 chunk = rsp.read(1024)
385 work = [merged, chunk]
352 work = [merged, chunk]
386 yield unescapearg(''.join(work))
353 yield unescapearg(''.join(work))
387
354
388 def _submitone(self, op, args):
355 def _submitone(self, op, args):
389 return self._call(op, **pycompat.strkwargs(args))
356 return self._call(op, **pycompat.strkwargs(args))
390
357
391 def debugwireargs(self, one, two, three=None, four=None, five=None):
358 def debugwireargs(self, one, two, three=None, four=None, five=None):
392 # don't pass optional arguments left at their default value
359 # don't pass optional arguments left at their default value
393 opts = {}
360 opts = {}
394 if three is not None:
361 if three is not None:
395 opts[r'three'] = three
362 opts[r'three'] = three
396 if four is not None:
363 if four is not None:
397 opts[r'four'] = four
364 opts[r'four'] = four
398 return self._call('debugwireargs', one=one, two=two, **opts)
365 return self._call('debugwireargs', one=one, two=two, **opts)
399
366
400 def _call(self, cmd, **args):
367 def _call(self, cmd, **args):
401 """execute <cmd> on the server
368 """execute <cmd> on the server
402
369
403 The command is expected to return a simple string.
370 The command is expected to return a simple string.
404
371
405 returns the server reply as a string."""
372 returns the server reply as a string."""
406 raise NotImplementedError()
373 raise NotImplementedError()
407
374
408 def _callstream(self, cmd, **args):
375 def _callstream(self, cmd, **args):
409 """execute <cmd> on the server
376 """execute <cmd> on the server
410
377
411 The command is expected to return a stream. Note that if the
378 The command is expected to return a stream. Note that if the
412 command doesn't return a stream, _callstream behaves
379 command doesn't return a stream, _callstream behaves
413 differently for ssh and http peers.
380 differently for ssh and http peers.
414
381
415 returns the server reply as a file like object.
382 returns the server reply as a file like object.
416 """
383 """
417 raise NotImplementedError()
384 raise NotImplementedError()
418
385
419 def _callcompressable(self, cmd, **args):
386 def _callcompressable(self, cmd, **args):
420 """execute <cmd> on the server
387 """execute <cmd> on the server
421
388
422 The command is expected to return a stream.
389 The command is expected to return a stream.
423
390
424 The stream may have been compressed in some implementations. This
391 The stream may have been compressed in some implementations. This
425 function takes care of the decompression. This is the only difference
392 function takes care of the decompression. This is the only difference
426 with _callstream.
393 with _callstream.
427
394
428 returns the server reply as a file like object.
395 returns the server reply as a file like object.
429 """
396 """
430 raise NotImplementedError()
397 raise NotImplementedError()
431
398
432 def _callpush(self, cmd, fp, **args):
399 def _callpush(self, cmd, fp, **args):
433 """execute a <cmd> on server
400 """execute a <cmd> on server
434
401
435 The command is expected to be related to a push. Push has a special
402 The command is expected to be related to a push. Push has a special
436 return method.
403 return method.
437
404
438 returns the server reply as a (ret, output) tuple. ret is either
405 returns the server reply as a (ret, output) tuple. ret is either
439 empty (error) or a stringified int.
406 empty (error) or a stringified int.
440 """
407 """
441 raise NotImplementedError()
408 raise NotImplementedError()
442
409
443 def _calltwowaystream(self, cmd, fp, **args):
410 def _calltwowaystream(self, cmd, fp, **args):
444 """execute <cmd> on server
411 """execute <cmd> on server
445
412
446 The command will send a stream to the server and get a stream in reply.
413 The command will send a stream to the server and get a stream in reply.
447 """
414 """
448 raise NotImplementedError()
415 raise NotImplementedError()
449
416
450 def _abort(self, exception):
417 def _abort(self, exception):
451 """clearly abort the wire protocol connection and raise the exception
418 """clearly abort the wire protocol connection and raise the exception
452 """
419 """
453 raise NotImplementedError()
420 raise NotImplementedError()
454
455 # server side
456
457 # wire protocol command can either return a string or one of these classes.
458
459 def getdispatchrepo(repo, proto, command):
460 """Obtain the repo used for processing wire protocol commands.
461
462 The intent of this function is to serve as a monkeypatch point for
463 extensions that need commands to operate on different repo views under
464 specialized circumstances.
465 """
466 return repo.filtered('served')
467
468 def dispatch(repo, proto, command):
469 repo = getdispatchrepo(repo, proto, command)
470
471 transportversion = wireprototypes.TRANSPORTS[proto.name]['version']
472 commandtable = commandsv2 if transportversion == 2 else commands
473 func, spec = commandtable[command]
474
475 args = proto.getargs(spec)
476
477 # Version 1 protocols define arguments as a list. Version 2 uses a dict.
478 if isinstance(args, list):
479 return func(repo, proto, *args)
480 elif isinstance(args, dict):
481 return func(repo, proto, **args)
482 else:
483 raise error.ProgrammingError('unexpected type returned from '
484 'proto.getargs(): %s' % type(args))
485
486 def options(cmd, keys, others):
487 opts = {}
488 for k in keys:
489 if k in others:
490 opts[k] = others[k]
491 del others[k]
492 if others:
493 procutil.stderr.write("warning: %s ignored unexpected arguments %s\n"
494 % (cmd, ",".join(others)))
495 return opts
496
497 def bundle1allowed(repo, action):
498 """Whether a bundle1 operation is allowed from the server.
499
500 Priority is:
501
502 1. server.bundle1gd.<action> (if generaldelta active)
503 2. server.bundle1.<action>
504 3. server.bundle1gd (if generaldelta active)
505 4. server.bundle1
506 """
507 ui = repo.ui
508 gd = 'generaldelta' in repo.requirements
509
510 if gd:
511 v = ui.configbool('server', 'bundle1gd.%s' % action)
512 if v is not None:
513 return v
514
515 v = ui.configbool('server', 'bundle1.%s' % action)
516 if v is not None:
517 return v
518
519 if gd:
520 v = ui.configbool('server', 'bundle1gd')
521 if v is not None:
522 return v
523
524 return ui.configbool('server', 'bundle1')
525
526 def supportedcompengines(ui, role):
527 """Obtain the list of supported compression engines for a request."""
528 assert role in (util.CLIENTROLE, util.SERVERROLE)
529
530 compengines = util.compengines.supportedwireengines(role)
531
532 # Allow config to override default list and ordering.
533 if role == util.SERVERROLE:
534 configengines = ui.configlist('server', 'compressionengines')
535 config = 'server.compressionengines'
536 else:
537 # This is currently implemented mainly to facilitate testing. In most
538 # cases, the server should be in charge of choosing a compression engine
539 # because a server has the most to lose from a sub-optimal choice. (e.g.
540 # CPU DoS due to an expensive engine or a network DoS due to poor
541 # compression ratio).
542 configengines = ui.configlist('experimental',
543 'clientcompressionengines')
544 config = 'experimental.clientcompressionengines'
545
546 # No explicit config. Filter out the ones that aren't supposed to be
547 # advertised and return default ordering.
548 if not configengines:
549 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
550 return [e for e in compengines
551 if getattr(e.wireprotosupport(), attr) > 0]
552
553 # If compression engines are listed in the config, assume there is a good
554 # reason for it (like server operators wanting to achieve specific
555 # performance characteristics). So fail fast if the config references
556 # unusable compression engines.
557 validnames = set(e.name() for e in compengines)
558 invalidnames = set(e for e in configengines if e not in validnames)
559 if invalidnames:
560 raise error.Abort(_('invalid compression engine defined in %s: %s') %
561 (config, ', '.join(sorted(invalidnames))))
562
563 compengines = [e for e in compengines if e.name() in configengines]
564 compengines = sorted(compengines,
565 key=lambda e: configengines.index(e.name()))
566
567 if not compengines:
568 raise error.Abort(_('%s config option does not specify any known '
569 'compression engines') % config,
570 hint=_('usable compression engines: %s') %
571 ', '.sorted(validnames))
572
573 return compengines
574
575 class commandentry(object):
576 """Represents a declared wire protocol command."""
577 def __init__(self, func, args='', transports=None,
578 permission='push'):
579 self.func = func
580 self.args = args
581 self.transports = transports or set()
582 self.permission = permission
583
584 def _merge(self, func, args):
585 """Merge this instance with an incoming 2-tuple.
586
587 This is called when a caller using the old 2-tuple API attempts
588 to replace an instance. The incoming values are merged with
589 data not captured by the 2-tuple and a new instance containing
590 the union of the two objects is returned.
591 """
592 return commandentry(func, args=args, transports=set(self.transports),
593 permission=self.permission)
594
595 # Old code treats instances as 2-tuples. So expose that interface.
596 def __iter__(self):
597 yield self.func
598 yield self.args
599
600 def __getitem__(self, i):
601 if i == 0:
602 return self.func
603 elif i == 1:
604 return self.args
605 else:
606 raise IndexError('can only access elements 0 and 1')
607
608 class commanddict(dict):
609 """Container for registered wire protocol commands.
610
611 It behaves like a dict. But __setitem__ is overwritten to allow silent
612 coercion of values from 2-tuples for API compatibility.
613 """
614 def __setitem__(self, k, v):
615 if isinstance(v, commandentry):
616 pass
617 # Cast 2-tuples to commandentry instances.
618 elif isinstance(v, tuple):
619 if len(v) != 2:
620 raise ValueError('command tuples must have exactly 2 elements')
621
622 # It is common for extensions to wrap wire protocol commands via
623 # e.g. ``wireproto.commands[x] = (newfn, args)``. Because callers
624 # doing this aren't aware of the new API that uses objects to store
625 # command entries, we automatically merge old state with new.
626 if k in self:
627 v = self[k]._merge(v[0], v[1])
628 else:
629 # Use default values from @wireprotocommand.
630 v = commandentry(v[0], args=v[1],
631 transports=set(wireprototypes.TRANSPORTS),
632 permission='push')
633 else:
634 raise ValueError('command entries must be commandentry instances '
635 'or 2-tuples')
636
637 return super(commanddict, self).__setitem__(k, v)
638
639 def commandavailable(self, command, proto):
640 """Determine if a command is available for the requested protocol."""
641 assert proto.name in wireprototypes.TRANSPORTS
642
643 entry = self.get(command)
644
645 if not entry:
646 return False
647
648 if proto.name not in entry.transports:
649 return False
650
651 return True
652
653 # Constants specifying which transports a wire protocol command should be
654 # available on. For use with @wireprotocommand.
655 POLICY_V1_ONLY = 'v1-only'
656 POLICY_V2_ONLY = 'v2-only'
657
658 # For version 1 transports.
659 commands = commanddict()
660
661 # For version 2 transports.
662 commandsv2 = commanddict()
663
664 def wireprotocommand(name, args=None, transportpolicy=POLICY_V1_ONLY,
665 permission='push'):
666 """Decorator to declare a wire protocol command.
667
668 ``name`` is the name of the wire protocol command being provided.
669
670 ``args`` defines the named arguments accepted by the command. It is
671 ideally a dict mapping argument names to their types. For backwards
672 compatibility, it can be a space-delimited list of argument names. For
673 version 1 transports, ``*`` denotes a special value that says to accept
674 all named arguments.
675
676 ``transportpolicy`` is a POLICY_* constant denoting which transports
677 this wire protocol command should be exposed to. By default, commands
678 are exposed to all wire protocol transports.
679
680 ``permission`` defines the permission type needed to run this command.
681 Can be ``push`` or ``pull``. These roughly map to read-write and read-only,
682 respectively. Default is to assume command requires ``push`` permissions
683 because otherwise commands not declaring their permissions could modify
684 a repository that is supposed to be read-only.
685 """
686 if transportpolicy == POLICY_V1_ONLY:
687 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
688 if v['version'] == 1}
689 transportversion = 1
690 elif transportpolicy == POLICY_V2_ONLY:
691 transports = {k for k, v in wireprototypes.TRANSPORTS.items()
692 if v['version'] == 2}
693 transportversion = 2
694 else:
695 raise error.ProgrammingError('invalid transport policy value: %s' %
696 transportpolicy)
697
698 # Because SSHv2 is a mirror of SSHv1, we allow "batch" commands through to
699 # SSHv2.
700 # TODO undo this hack when SSH is using the unified frame protocol.
701 if name == b'batch':
702 transports.add(wireprototypes.SSHV2)
703
704 if permission not in ('push', 'pull'):
705 raise error.ProgrammingError('invalid wire protocol permission; '
706 'got %s; expected "push" or "pull"' %
707 permission)
708
709 if transportversion == 1:
710 if args is None:
711 args = ''
712
713 if not isinstance(args, bytes):
714 raise error.ProgrammingError('arguments for version 1 commands '
715 'must be declared as bytes')
716 elif transportversion == 2:
717 if args is None:
718 args = {}
719
720 if not isinstance(args, dict):
721 raise error.ProgrammingError('arguments for version 2 commands '
722 'must be declared as dicts')
723
724 def register(func):
725 if transportversion == 1:
726 if name in commands:
727 raise error.ProgrammingError('%s command already registered '
728 'for version 1' % name)
729 commands[name] = commandentry(func, args=args,
730 transports=transports,
731 permission=permission)
732 elif transportversion == 2:
733 if name in commandsv2:
734 raise error.ProgrammingError('%s command already registered '
735 'for version 2' % name)
736
737 commandsv2[name] = commandentry(func, args=args,
738 transports=transports,
739 permission=permission)
740 else:
741 raise error.ProgrammingError('unhandled transport version: %d' %
742 transportversion)
743
744 return func
745 return register
746
747 # TODO define a more appropriate permissions type to use for this.
748 @wireprotocommand('batch', 'cmds *', permission='pull',
749 transportpolicy=POLICY_V1_ONLY)
750 def batch(repo, proto, cmds, others):
751 unescapearg = wireprototypes.unescapebatcharg
752 repo = repo.filtered("served")
753 res = []
754 for pair in cmds.split(';'):
755 op, args = pair.split(' ', 1)
756 vals = {}
757 for a in args.split(','):
758 if a:
759 n, v = a.split('=')
760 vals[unescapearg(n)] = unescapearg(v)
761 func, spec = commands[op]
762
763 # Validate that client has permissions to perform this command.
764 perm = commands[op].permission
765 assert perm in ('push', 'pull')
766 proto.checkperm(perm)
767
768 if spec:
769 keys = spec.split()
770 data = {}
771 for k in keys:
772 if k == '*':
773 star = {}
774 for key in vals.keys():
775 if key not in keys:
776 star[key] = vals[key]
777 data['*'] = star
778 else:
779 data[k] = vals[k]
780 result = func(repo, proto, *[data[k] for k in keys])
781 else:
782 result = func(repo, proto)
783 if isinstance(result, wireprototypes.ooberror):
784 return result
785
786 # For now, all batchable commands must return bytesresponse or
787 # raw bytes (for backwards compatibility).
788 assert isinstance(result, (wireprototypes.bytesresponse, bytes))
789 if isinstance(result, wireprototypes.bytesresponse):
790 result = result.data
791 res.append(wireprototypes.escapebatcharg(result))
792
793 return wireprototypes.bytesresponse(';'.join(res))
794
795 @wireprotocommand('between', 'pairs', transportpolicy=POLICY_V1_ONLY,
796 permission='pull')
797 def between(repo, proto, pairs):
798 pairs = [wireprototypes.decodelist(p, '-') for p in pairs.split(" ")]
799 r = []
800 for b in repo.between(pairs):
801 r.append(wireprototypes.encodelist(b) + "\n")
802
803 return wireprototypes.bytesresponse(''.join(r))
804
805 @wireprotocommand('branchmap', permission='pull',
806 transportpolicy=POLICY_V1_ONLY)
807 def branchmap(repo, proto):
808 branchmap = repo.branchmap()
809 heads = []
810 for branch, nodes in branchmap.iteritems():
811 branchname = urlreq.quote(encoding.fromlocal(branch))
812 branchnodes = wireprototypes.encodelist(nodes)
813 heads.append('%s %s' % (branchname, branchnodes))
814
815 return wireprototypes.bytesresponse('\n'.join(heads))
816
817 @wireprotocommand('branches', 'nodes', transportpolicy=POLICY_V1_ONLY,
818 permission='pull')
819 def branches(repo, proto, nodes):
820 nodes = wireprototypes.decodelist(nodes)
821 r = []
822 for b in repo.branches(nodes):
823 r.append(wireprototypes.encodelist(b) + "\n")
824
825 return wireprototypes.bytesresponse(''.join(r))
826
827 @wireprotocommand('clonebundles', '', permission='pull',
828 transportpolicy=POLICY_V1_ONLY)
829 def clonebundles(repo, proto):
830 """Server command for returning info for available bundles to seed clones.
831
832 Clients will parse this response and determine what bundle to fetch.
833
834 Extensions may wrap this command to filter or dynamically emit data
835 depending on the request. e.g. you could advertise URLs for the closest
836 data center given the client's IP address.
837 """
838 return wireprototypes.bytesresponse(
839 repo.vfs.tryread('clonebundles.manifest'))
840
841 wireprotocaps = ['lookup', 'branchmap', 'pushkey',
842 'known', 'getbundle', 'unbundlehash']
843
844 def _capabilities(repo, proto):
845 """return a list of capabilities for a repo
846
847 This function exists to allow extensions to easily wrap capabilities
848 computation
849
850 - returns a lists: easy to alter
851 - change done here will be propagated to both `capabilities` and `hello`
852 command without any other action needed.
853 """
854 # copy to prevent modification of the global list
855 caps = list(wireprotocaps)
856
857 # Command of same name as capability isn't exposed to version 1 of
858 # transports. So conditionally add it.
859 if commands.commandavailable('changegroupsubset', proto):
860 caps.append('changegroupsubset')
861
862 if streamclone.allowservergeneration(repo):
863 if repo.ui.configbool('server', 'preferuncompressed'):
864 caps.append('stream-preferred')
865 requiredformats = repo.requirements & repo.supportedformats
866 # if our local revlogs are just revlogv1, add 'stream' cap
867 if not requiredformats - {'revlogv1'}:
868 caps.append('stream')
869 # otherwise, add 'streamreqs' detailing our local revlog format
870 else:
871 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
872 if repo.ui.configbool('experimental', 'bundle2-advertise'):
873 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo, role='server'))
874 caps.append('bundle2=' + urlreq.quote(capsblob))
875 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
876
877 return proto.addcapabilities(repo, caps)
878
879 # If you are writing an extension and consider wrapping this function. Wrap
880 # `_capabilities` instead.
881 @wireprotocommand('capabilities', permission='pull',
882 transportpolicy=POLICY_V1_ONLY)
883 def capabilities(repo, proto):
884 caps = _capabilities(repo, proto)
885 return wireprototypes.bytesresponse(' '.join(sorted(caps)))
886
887 @wireprotocommand('changegroup', 'roots', transportpolicy=POLICY_V1_ONLY,
888 permission='pull')
889 def changegroup(repo, proto, roots):
890 nodes = wireprototypes.decodelist(roots)
891 outgoing = discovery.outgoing(repo, missingroots=nodes,
892 missingheads=repo.heads())
893 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
894 gen = iter(lambda: cg.read(32768), '')
895 return wireprototypes.streamres(gen=gen)
896
897 @wireprotocommand('changegroupsubset', 'bases heads',
898 transportpolicy=POLICY_V1_ONLY,
899 permission='pull')
900 def changegroupsubset(repo, proto, bases, heads):
901 bases = wireprototypes.decodelist(bases)
902 heads = wireprototypes.decodelist(heads)
903 outgoing = discovery.outgoing(repo, missingroots=bases,
904 missingheads=heads)
905 cg = changegroupmod.makechangegroup(repo, outgoing, '01', 'serve')
906 gen = iter(lambda: cg.read(32768), '')
907 return wireprototypes.streamres(gen=gen)
908
909 @wireprotocommand('debugwireargs', 'one two *',
910 permission='pull', transportpolicy=POLICY_V1_ONLY)
911 def debugwireargs(repo, proto, one, two, others):
912 # only accept optional args from the known set
913 opts = options('debugwireargs', ['three', 'four'], others)
914 return wireprototypes.bytesresponse(repo.debugwireargs(
915 one, two, **pycompat.strkwargs(opts)))
916
917 def find_pullbundle(repo, proto, opts, clheads, heads, common):
918 """Return a file object for the first matching pullbundle.
919
920 Pullbundles are specified in .hg/pullbundles.manifest similar to
921 clonebundles.
922 For each entry, the bundle specification is checked for compatibility:
923 - Client features vs the BUNDLESPEC.
924 - Revisions shared with the clients vs base revisions of the bundle.
925 A bundle can be applied only if all its base revisions are known by
926 the client.
927 - At least one leaf of the bundle's DAG is missing on the client.
928 - Every leaf of the bundle's DAG is part of node set the client wants.
929 E.g. do not send a bundle of all changes if the client wants only
930 one specific branch of many.
931 """
932 def decodehexstring(s):
933 return set([h.decode('hex') for h in s.split(';')])
934
935 manifest = repo.vfs.tryread('pullbundles.manifest')
936 if not manifest:
937 return None
938 res = exchange.parseclonebundlesmanifest(repo, manifest)
939 res = exchange.filterclonebundleentries(repo, res)
940 if not res:
941 return None
942 cl = repo.changelog
943 heads_anc = cl.ancestors([cl.rev(rev) for rev in heads], inclusive=True)
944 common_anc = cl.ancestors([cl.rev(rev) for rev in common], inclusive=True)
945 compformats = clientcompressionsupport(proto)
946 for entry in res:
947 if 'COMPRESSION' in entry and entry['COMPRESSION'] not in compformats:
948 continue
949 # No test yet for VERSION, since V2 is supported by any client
950 # that advertises partial pulls
951 if 'heads' in entry:
952 try:
953 bundle_heads = decodehexstring(entry['heads'])
954 except TypeError:
955 # Bad heads entry
956 continue
957 if bundle_heads.issubset(common):
958 continue # Nothing new
959 if all(cl.rev(rev) in common_anc for rev in bundle_heads):
960 continue # Still nothing new
961 if any(cl.rev(rev) not in heads_anc and
962 cl.rev(rev) not in common_anc for rev in bundle_heads):
963 continue
964 if 'bases' in entry:
965 try:
966 bundle_bases = decodehexstring(entry['bases'])
967 except TypeError:
968 # Bad bases entry
969 continue
970 if not all(cl.rev(rev) in common_anc for rev in bundle_bases):
971 continue
972 path = entry['URL']
973 repo.ui.debug('sending pullbundle "%s"\n' % path)
974 try:
975 return repo.vfs.open(path)
976 except IOError:
977 repo.ui.debug('pullbundle "%s" not accessible\n' % path)
978 continue
979 return None
980
981 @wireprotocommand('getbundle', '*', permission='pull',
982 transportpolicy=POLICY_V1_ONLY)
983 def getbundle(repo, proto, others):
984 opts = options('getbundle', wireprototypes.GETBUNDLE_ARGUMENTS.keys(),
985 others)
986 for k, v in opts.iteritems():
987 keytype = wireprototypes.GETBUNDLE_ARGUMENTS[k]
988 if keytype == 'nodes':
989 opts[k] = wireprototypes.decodelist(v)
990 elif keytype == 'csv':
991 opts[k] = list(v.split(','))
992 elif keytype == 'scsv':
993 opts[k] = set(v.split(','))
994 elif keytype == 'boolean':
995 # Client should serialize False as '0', which is a non-empty string
996 # so it evaluates as a True bool.
997 if v == '0':
998 opts[k] = False
999 else:
1000 opts[k] = bool(v)
1001 elif keytype != 'plain':
1002 raise KeyError('unknown getbundle option type %s'
1003 % keytype)
1004
1005 if not bundle1allowed(repo, 'pull'):
1006 if not exchange.bundle2requested(opts.get('bundlecaps')):
1007 if proto.name == 'http-v1':
1008 return wireprototypes.ooberror(bundle2required)
1009 raise error.Abort(bundle2requiredmain,
1010 hint=bundle2requiredhint)
1011
1012 prefercompressed = True
1013
1014 try:
1015 clheads = set(repo.changelog.heads())
1016 heads = set(opts.get('heads', set()))
1017 common = set(opts.get('common', set()))
1018 common.discard(nullid)
1019 if (repo.ui.configbool('server', 'pullbundle') and
1020 'partial-pull' in proto.getprotocaps()):
1021 # Check if a pre-built bundle covers this request.
1022 bundle = find_pullbundle(repo, proto, opts, clheads, heads, common)
1023 if bundle:
1024 return wireprototypes.streamres(gen=util.filechunkiter(bundle),
1025 prefer_uncompressed=True)
1026
1027 if repo.ui.configbool('server', 'disablefullbundle'):
1028 # Check to see if this is a full clone.
1029 changegroup = opts.get('cg', True)
1030 if changegroup and not common and clheads == heads:
1031 raise error.Abort(
1032 _('server has pull-based clones disabled'),
1033 hint=_('remove --pull if specified or upgrade Mercurial'))
1034
1035 info, chunks = exchange.getbundlechunks(repo, 'serve',
1036 **pycompat.strkwargs(opts))
1037 prefercompressed = info.get('prefercompressed', True)
1038 except error.Abort as exc:
1039 # cleanly forward Abort error to the client
1040 if not exchange.bundle2requested(opts.get('bundlecaps')):
1041 if proto.name == 'http-v1':
1042 return wireprototypes.ooberror(pycompat.bytestr(exc) + '\n')
1043 raise # cannot do better for bundle1 + ssh
1044 # bundle2 request expect a bundle2 reply
1045 bundler = bundle2.bundle20(repo.ui)
1046 manargs = [('message', pycompat.bytestr(exc))]
1047 advargs = []
1048 if exc.hint is not None:
1049 advargs.append(('hint', exc.hint))
1050 bundler.addpart(bundle2.bundlepart('error:abort',
1051 manargs, advargs))
1052 chunks = bundler.getchunks()
1053 prefercompressed = False
1054
1055 return wireprototypes.streamres(
1056 gen=chunks, prefer_uncompressed=not prefercompressed)
1057
1058 @wireprotocommand('heads', permission='pull', transportpolicy=POLICY_V1_ONLY)
1059 def heads(repo, proto):
1060 h = repo.heads()
1061 return wireprototypes.bytesresponse(wireprototypes.encodelist(h) + '\n')
1062
1063 @wireprotocommand('hello', permission='pull', transportpolicy=POLICY_V1_ONLY)
1064 def hello(repo, proto):
1065 """Called as part of SSH handshake to obtain server info.
1066
1067 Returns a list of lines describing interesting things about the
1068 server, in an RFC822-like format.
1069
1070 Currently, the only one defined is ``capabilities``, which consists of a
1071 line of space separated tokens describing server abilities:
1072
1073 capabilities: <token0> <token1> <token2>
1074 """
1075 caps = capabilities(repo, proto).data
1076 return wireprototypes.bytesresponse('capabilities: %s\n' % caps)
1077
1078 @wireprotocommand('listkeys', 'namespace', permission='pull',
1079 transportpolicy=POLICY_V1_ONLY)
1080 def listkeys(repo, proto, namespace):
1081 d = sorted(repo.listkeys(encoding.tolocal(namespace)).items())
1082 return wireprototypes.bytesresponse(pushkeymod.encodekeys(d))
1083
1084 @wireprotocommand('lookup', 'key', permission='pull',
1085 transportpolicy=POLICY_V1_ONLY)
1086 def lookup(repo, proto, key):
1087 try:
1088 k = encoding.tolocal(key)
1089 n = repo.lookup(k)
1090 r = hex(n)
1091 success = 1
1092 except Exception as inst:
1093 r = stringutil.forcebytestr(inst)
1094 success = 0
1095 return wireprototypes.bytesresponse('%d %s\n' % (success, r))
1096
1097 @wireprotocommand('known', 'nodes *', permission='pull',
1098 transportpolicy=POLICY_V1_ONLY)
1099 def known(repo, proto, nodes, others):
1100 v = ''.join(b and '1' or '0'
1101 for b in repo.known(wireprototypes.decodelist(nodes)))
1102 return wireprototypes.bytesresponse(v)
1103
1104 @wireprotocommand('protocaps', 'caps', permission='pull',
1105 transportpolicy=POLICY_V1_ONLY)
1106 def protocaps(repo, proto, caps):
1107 if proto.name == wireprototypes.SSHV1:
1108 proto._protocaps = set(caps.split(' '))
1109 return wireprototypes.bytesresponse('OK')
1110
1111 @wireprotocommand('pushkey', 'namespace key old new', permission='push',
1112 transportpolicy=POLICY_V1_ONLY)
1113 def pushkey(repo, proto, namespace, key, old, new):
1114 # compatibility with pre-1.8 clients which were accidentally
1115 # sending raw binary nodes rather than utf-8-encoded hex
1116 if len(new) == 20 and stringutil.escapestr(new) != new:
1117 # looks like it could be a binary node
1118 try:
1119 new.decode('utf-8')
1120 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
1121 except UnicodeDecodeError:
1122 pass # binary, leave unmodified
1123 else:
1124 new = encoding.tolocal(new) # normal path
1125
1126 with proto.mayberedirectstdio() as output:
1127 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
1128 encoding.tolocal(old), new) or False
1129
1130 output = output.getvalue() if output else ''
1131 return wireprototypes.bytesresponse('%d\n%s' % (int(r), output))
1132
1133 @wireprotocommand('stream_out', permission='pull',
1134 transportpolicy=POLICY_V1_ONLY)
1135 def stream(repo, proto):
1136 '''If the server supports streaming clone, it advertises the "stream"
1137 capability with a value representing the version and flags of the repo
1138 it is serving. Client checks to see if it understands the format.
1139 '''
1140 return wireprototypes.streamreslegacy(
1141 streamclone.generatev1wireproto(repo))
1142
1143 @wireprotocommand('unbundle', 'heads', permission='push',
1144 transportpolicy=POLICY_V1_ONLY)
1145 def unbundle(repo, proto, heads):
1146 their_heads = wireprototypes.decodelist(heads)
1147
1148 with proto.mayberedirectstdio() as output:
1149 try:
1150 exchange.check_heads(repo, their_heads, 'preparing changes')
1151 cleanup = lambda: None
1152 try:
1153 payload = proto.getpayload()
1154 if repo.ui.configbool('server', 'streamunbundle'):
1155 def cleanup():
1156 # Ensure that the full payload is consumed, so
1157 # that the connection doesn't contain trailing garbage.
1158 for p in payload:
1159 pass
1160 fp = util.chunkbuffer(payload)
1161 else:
1162 # write bundle data to temporary file as it can be big
1163 fp, tempname = None, None
1164 def cleanup():
1165 if fp:
1166 fp.close()
1167 if tempname:
1168 os.unlink(tempname)
1169 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
1170 repo.ui.debug('redirecting incoming bundle to %s\n' %
1171 tempname)
1172 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
1173 r = 0
1174 for p in payload:
1175 fp.write(p)
1176 fp.seek(0)
1177
1178 gen = exchange.readbundle(repo.ui, fp, None)
1179 if (isinstance(gen, changegroupmod.cg1unpacker)
1180 and not bundle1allowed(repo, 'push')):
1181 if proto.name == 'http-v1':
1182 # need to special case http because stderr do not get to
1183 # the http client on failed push so we need to abuse
1184 # some other error type to make sure the message get to
1185 # the user.
1186 return wireprototypes.ooberror(bundle2required)
1187 raise error.Abort(bundle2requiredmain,
1188 hint=bundle2requiredhint)
1189
1190 r = exchange.unbundle(repo, gen, their_heads, 'serve',
1191 proto.client())
1192 if util.safehasattr(r, 'addpart'):
1193 # The return looks streamable, we are in the bundle2 case
1194 # and should return a stream.
1195 return wireprototypes.streamreslegacy(gen=r.getchunks())
1196 return wireprototypes.pushres(
1197 r, output.getvalue() if output else '')
1198
1199 finally:
1200 cleanup()
1201
1202 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1203 # handle non-bundle2 case first
1204 if not getattr(exc, 'duringunbundle2', False):
1205 try:
1206 raise
1207 except error.Abort:
1208 # The old code we moved used procutil.stderr directly.
1209 # We did not change it to minimise code change.
1210 # This need to be moved to something proper.
1211 # Feel free to do it.
1212 procutil.stderr.write("abort: %s\n" % exc)
1213 if exc.hint is not None:
1214 procutil.stderr.write("(%s)\n" % exc.hint)
1215 procutil.stderr.flush()
1216 return wireprototypes.pushres(
1217 0, output.getvalue() if output else '')
1218 except error.PushRaced:
1219 return wireprototypes.pusherr(
1220 pycompat.bytestr(exc),
1221 output.getvalue() if output else '')
1222
1223 bundler = bundle2.bundle20(repo.ui)
1224 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1225 bundler.addpart(out)
1226 try:
1227 try:
1228 raise
1229 except error.PushkeyFailed as exc:
1230 # check client caps
1231 remotecaps = getattr(exc, '_replycaps', None)
1232 if (remotecaps is not None
1233 and 'pushkey' not in remotecaps.get('error', ())):
1234 # no support remote side, fallback to Abort handler.
1235 raise
1236 part = bundler.newpart('error:pushkey')
1237 part.addparam('in-reply-to', exc.partid)
1238 if exc.namespace is not None:
1239 part.addparam('namespace', exc.namespace,
1240 mandatory=False)
1241 if exc.key is not None:
1242 part.addparam('key', exc.key, mandatory=False)
1243 if exc.new is not None:
1244 part.addparam('new', exc.new, mandatory=False)
1245 if exc.old is not None:
1246 part.addparam('old', exc.old, mandatory=False)
1247 if exc.ret is not None:
1248 part.addparam('ret', exc.ret, mandatory=False)
1249 except error.BundleValueError as exc:
1250 errpart = bundler.newpart('error:unsupportedcontent')
1251 if exc.parttype is not None:
1252 errpart.addparam('parttype', exc.parttype)
1253 if exc.params:
1254 errpart.addparam('params', '\0'.join(exc.params))
1255 except error.Abort as exc:
1256 manargs = [('message', stringutil.forcebytestr(exc))]
1257 advargs = []
1258 if exc.hint is not None:
1259 advargs.append(('hint', exc.hint))
1260 bundler.addpart(bundle2.bundlepart('error:abort',
1261 manargs, advargs))
1262 except error.PushRaced as exc:
1263 bundler.newpart('error:pushraced',
1264 [('message', stringutil.forcebytestr(exc))])
1265 return wireprototypes.streamreslegacy(gen=bundler.getchunks())
@@ -1,206 +1,206 b''
1 # test-batching.py - tests for transparent command batching
1 # test-batching.py - tests for transparent command batching
2 #
2 #
3 # Copyright 2011 Peter Arrenbrecht <peter@arrenbrecht.ch>
3 # Copyright 2011 Peter Arrenbrecht <peter@arrenbrecht.ch>
4 #
4 #
5 # This software may be used and distributed according to the terms of the
5 # This software may be used and distributed according to the terms of the
6 # GNU General Public License version 2 or any later version.
6 # GNU General Public License version 2 or any later version.
7
7
8 from __future__ import absolute_import, print_function
8 from __future__ import absolute_import, print_function
9
9
10 from mercurial import (
10 from mercurial import (
11 error,
11 error,
12 peer,
12 peer,
13 util,
13 util,
14 wireproto,
14 wireprotov1peer,
15 )
15 )
16
16
17 # equivalent of repo.repository
17 # equivalent of repo.repository
18 class thing(object):
18 class thing(object):
19 def hello(self):
19 def hello(self):
20 return "Ready."
20 return "Ready."
21
21
22 # equivalent of localrepo.localrepository
22 # equivalent of localrepo.localrepository
23 class localthing(thing):
23 class localthing(thing):
24 def foo(self, one, two=None):
24 def foo(self, one, two=None):
25 if one:
25 if one:
26 return "%s and %s" % (one, two,)
26 return "%s and %s" % (one, two,)
27 return "Nope"
27 return "Nope"
28 def bar(self, b, a):
28 def bar(self, b, a):
29 return "%s und %s" % (b, a,)
29 return "%s und %s" % (b, a,)
30 def greet(self, name=None):
30 def greet(self, name=None):
31 return "Hello, %s" % name
31 return "Hello, %s" % name
32 def batchiter(self):
32 def batchiter(self):
33 '''Support for local batching.'''
33 '''Support for local batching.'''
34 return peer.localiterbatcher(self)
34 return peer.localiterbatcher(self)
35
35
36 # usage of "thing" interface
36 # usage of "thing" interface
37 def use(it):
37 def use(it):
38
38
39 # Direct call to base method shared between client and server.
39 # Direct call to base method shared between client and server.
40 print(it.hello())
40 print(it.hello())
41
41
42 # Direct calls to proxied methods. They cause individual roundtrips.
42 # Direct calls to proxied methods. They cause individual roundtrips.
43 print(it.foo("Un", two="Deux"))
43 print(it.foo("Un", two="Deux"))
44 print(it.bar("Eins", "Zwei"))
44 print(it.bar("Eins", "Zwei"))
45
45
46 # Batched call to a couple of proxied methods.
46 # Batched call to a couple of proxied methods.
47 batch = it.batchiter()
47 batch = it.batchiter()
48 # The calls return futures to eventually hold results.
48 # The calls return futures to eventually hold results.
49 foo = batch.foo(one="One", two="Two")
49 foo = batch.foo(one="One", two="Two")
50 bar = batch.bar("Eins", "Zwei")
50 bar = batch.bar("Eins", "Zwei")
51 bar2 = batch.bar(b="Uno", a="Due")
51 bar2 = batch.bar(b="Uno", a="Due")
52
52
53 # Future shouldn't be set until we submit().
53 # Future shouldn't be set until we submit().
54 assert isinstance(foo, peer.future)
54 assert isinstance(foo, peer.future)
55 assert not util.safehasattr(foo, 'value')
55 assert not util.safehasattr(foo, 'value')
56 assert not util.safehasattr(bar, 'value')
56 assert not util.safehasattr(bar, 'value')
57 batch.submit()
57 batch.submit()
58 # Call results() to obtain results as a generator.
58 # Call results() to obtain results as a generator.
59 results = batch.results()
59 results = batch.results()
60
60
61 # Future results shouldn't be set until we consume a value.
61 # Future results shouldn't be set until we consume a value.
62 assert not util.safehasattr(foo, 'value')
62 assert not util.safehasattr(foo, 'value')
63 foovalue = next(results)
63 foovalue = next(results)
64 assert util.safehasattr(foo, 'value')
64 assert util.safehasattr(foo, 'value')
65 assert foovalue == foo.value
65 assert foovalue == foo.value
66 print(foo.value)
66 print(foo.value)
67 next(results)
67 next(results)
68 print(bar.value)
68 print(bar.value)
69 next(results)
69 next(results)
70 print(bar2.value)
70 print(bar2.value)
71
71
72 # We should be at the end of the results generator.
72 # We should be at the end of the results generator.
73 try:
73 try:
74 next(results)
74 next(results)
75 except StopIteration:
75 except StopIteration:
76 print('proper end of results generator')
76 print('proper end of results generator')
77 else:
77 else:
78 print('extra emitted element!')
78 print('extra emitted element!')
79
79
80 # Attempting to call a non-batchable method inside a batch fails.
80 # Attempting to call a non-batchable method inside a batch fails.
81 batch = it.batchiter()
81 batch = it.batchiter()
82 try:
82 try:
83 batch.greet(name='John Smith')
83 batch.greet(name='John Smith')
84 except error.ProgrammingError as e:
84 except error.ProgrammingError as e:
85 print(e)
85 print(e)
86
86
87 # Attempting to call a local method inside a batch fails.
87 # Attempting to call a local method inside a batch fails.
88 batch = it.batchiter()
88 batch = it.batchiter()
89 try:
89 try:
90 batch.hello()
90 batch.hello()
91 except error.ProgrammingError as e:
91 except error.ProgrammingError as e:
92 print(e)
92 print(e)
93
93
94 # local usage
94 # local usage
95 mylocal = localthing()
95 mylocal = localthing()
96 print()
96 print()
97 print("== Local")
97 print("== Local")
98 use(mylocal)
98 use(mylocal)
99
99
100 # demo remoting; mimicks what wireproto and HTTP/SSH do
100 # demo remoting; mimicks what wireproto and HTTP/SSH do
101
101
102 # shared
102 # shared
103
103
104 def escapearg(plain):
104 def escapearg(plain):
105 return (plain
105 return (plain
106 .replace(':', '::')
106 .replace(':', '::')
107 .replace(',', ':,')
107 .replace(',', ':,')
108 .replace(';', ':;')
108 .replace(';', ':;')
109 .replace('=', ':='))
109 .replace('=', ':='))
110 def unescapearg(escaped):
110 def unescapearg(escaped):
111 return (escaped
111 return (escaped
112 .replace(':=', '=')
112 .replace(':=', '=')
113 .replace(':;', ';')
113 .replace(':;', ';')
114 .replace(':,', ',')
114 .replace(':,', ',')
115 .replace('::', ':'))
115 .replace('::', ':'))
116
116
117 # server side
117 # server side
118
118
119 # equivalent of wireproto's global functions
119 # equivalent of wireproto's global functions
120 class server(object):
120 class server(object):
121 def __init__(self, local):
121 def __init__(self, local):
122 self.local = local
122 self.local = local
123 def _call(self, name, args):
123 def _call(self, name, args):
124 args = dict(arg.split('=', 1) for arg in args)
124 args = dict(arg.split('=', 1) for arg in args)
125 return getattr(self, name)(**args)
125 return getattr(self, name)(**args)
126 def perform(self, req):
126 def perform(self, req):
127 print("REQ:", req)
127 print("REQ:", req)
128 name, args = req.split('?', 1)
128 name, args = req.split('?', 1)
129 args = args.split('&')
129 args = args.split('&')
130 vals = dict(arg.split('=', 1) for arg in args)
130 vals = dict(arg.split('=', 1) for arg in args)
131 res = getattr(self, name)(**vals)
131 res = getattr(self, name)(**vals)
132 print(" ->", res)
132 print(" ->", res)
133 return res
133 return res
134 def batch(self, cmds):
134 def batch(self, cmds):
135 res = []
135 res = []
136 for pair in cmds.split(';'):
136 for pair in cmds.split(';'):
137 name, args = pair.split(':', 1)
137 name, args = pair.split(':', 1)
138 vals = {}
138 vals = {}
139 for a in args.split(','):
139 for a in args.split(','):
140 if a:
140 if a:
141 n, v = a.split('=')
141 n, v = a.split('=')
142 vals[n] = unescapearg(v)
142 vals[n] = unescapearg(v)
143 res.append(escapearg(getattr(self, name)(**vals)))
143 res.append(escapearg(getattr(self, name)(**vals)))
144 return ';'.join(res)
144 return ';'.join(res)
145 def foo(self, one, two):
145 def foo(self, one, two):
146 return mangle(self.local.foo(unmangle(one), unmangle(two)))
146 return mangle(self.local.foo(unmangle(one), unmangle(two)))
147 def bar(self, b, a):
147 def bar(self, b, a):
148 return mangle(self.local.bar(unmangle(b), unmangle(a)))
148 return mangle(self.local.bar(unmangle(b), unmangle(a)))
149 def greet(self, name):
149 def greet(self, name):
150 return mangle(self.local.greet(unmangle(name)))
150 return mangle(self.local.greet(unmangle(name)))
151 myserver = server(mylocal)
151 myserver = server(mylocal)
152
152
153 # local side
153 # local side
154
154
155 # equivalent of wireproto.encode/decodelist, that is, type-specific marshalling
155 # equivalent of wireproto.encode/decodelist, that is, type-specific marshalling
156 # here we just transform the strings a bit to check we're properly en-/decoding
156 # here we just transform the strings a bit to check we're properly en-/decoding
157 def mangle(s):
157 def mangle(s):
158 return ''.join(chr(ord(c) + 1) for c in s)
158 return ''.join(chr(ord(c) + 1) for c in s)
159 def unmangle(s):
159 def unmangle(s):
160 return ''.join(chr(ord(c) - 1) for c in s)
160 return ''.join(chr(ord(c) - 1) for c in s)
161
161
162 # equivalent of wireproto.wirerepository and something like http's wire format
162 # equivalent of wireproto.wirerepository and something like http's wire format
163 class remotething(thing):
163 class remotething(thing):
164 def __init__(self, server):
164 def __init__(self, server):
165 self.server = server
165 self.server = server
166 def _submitone(self, name, args):
166 def _submitone(self, name, args):
167 req = name + '?' + '&'.join(['%s=%s' % (n, v) for n, v in args])
167 req = name + '?' + '&'.join(['%s=%s' % (n, v) for n, v in args])
168 return self.server.perform(req)
168 return self.server.perform(req)
169 def _submitbatch(self, cmds):
169 def _submitbatch(self, cmds):
170 req = []
170 req = []
171 for name, args in cmds:
171 for name, args in cmds:
172 args = ','.join(n + '=' + escapearg(v) for n, v in args)
172 args = ','.join(n + '=' + escapearg(v) for n, v in args)
173 req.append(name + ':' + args)
173 req.append(name + ':' + args)
174 req = ';'.join(req)
174 req = ';'.join(req)
175 res = self._submitone('batch', [('cmds', req,)])
175 res = self._submitone('batch', [('cmds', req,)])
176 for r in res.split(';'):
176 for r in res.split(';'):
177 yield r
177 yield r
178
178
179 def batchiter(self):
179 def batchiter(self):
180 return wireproto.remoteiterbatcher(self)
180 return wireprotov1peer.remoteiterbatcher(self)
181
181
182 @peer.batchable
182 @peer.batchable
183 def foo(self, one, two=None):
183 def foo(self, one, two=None):
184 encargs = [('one', mangle(one),), ('two', mangle(two),)]
184 encargs = [('one', mangle(one),), ('two', mangle(two),)]
185 encresref = peer.future()
185 encresref = peer.future()
186 yield encargs, encresref
186 yield encargs, encresref
187 yield unmangle(encresref.value)
187 yield unmangle(encresref.value)
188
188
189 @peer.batchable
189 @peer.batchable
190 def bar(self, b, a):
190 def bar(self, b, a):
191 encresref = peer.future()
191 encresref = peer.future()
192 yield [('b', mangle(b),), ('a', mangle(a),)], encresref
192 yield [('b', mangle(b),), ('a', mangle(a),)], encresref
193 yield unmangle(encresref.value)
193 yield unmangle(encresref.value)
194
194
195 # greet is coded directly. It therefore does not support batching. If it
195 # greet is coded directly. It therefore does not support batching. If it
196 # does appear in a batch, the batch is split around greet, and the call to
196 # does appear in a batch, the batch is split around greet, and the call to
197 # greet is done in its own roundtrip.
197 # greet is done in its own roundtrip.
198 def greet(self, name=None):
198 def greet(self, name=None):
199 return unmangle(self._submitone('greet', [('name', mangle(name),)]))
199 return unmangle(self._submitone('greet', [('name', mangle(name),)]))
200
200
201 # demo remote usage
201 # demo remote usage
202
202
203 myproxy = remotething(myserver)
203 myproxy = remotething(myserver)
204 print()
204 print()
205 print("== Remote")
205 print("== Remote")
206 use(myproxy)
206 use(myproxy)
@@ -1,98 +1,99 b''
1 from __future__ import absolute_import, print_function
1 from __future__ import absolute_import, print_function
2
2
3 from mercurial import (
3 from mercurial import (
4 error,
4 error,
5 pycompat,
5 pycompat,
6 ui as uimod,
6 ui as uimod,
7 util,
7 util,
8 wireproto,
8 wireproto,
9 wireprototypes,
9 wireprototypes,
10 wireprotov1peer,
10 )
11 )
11 stringio = util.stringio
12 stringio = util.stringio
12
13
13 class proto(object):
14 class proto(object):
14 def __init__(self, args):
15 def __init__(self, args):
15 self.args = args
16 self.args = args
16 self.name = 'dummyproto'
17 self.name = 'dummyproto'
17
18
18 def getargs(self, spec):
19 def getargs(self, spec):
19 args = self.args
20 args = self.args
20 args.setdefault(b'*', {})
21 args.setdefault(b'*', {})
21 names = spec.split()
22 names = spec.split()
22 return [args[n] for n in names]
23 return [args[n] for n in names]
23
24
24 def checkperm(self, perm):
25 def checkperm(self, perm):
25 pass
26 pass
26
27
27 wireprototypes.TRANSPORTS['dummyproto'] = {
28 wireprototypes.TRANSPORTS['dummyproto'] = {
28 'transport': 'dummy',
29 'transport': 'dummy',
29 'version': 1,
30 'version': 1,
30 }
31 }
31
32
32 class clientpeer(wireproto.wirepeer):
33 class clientpeer(wireprotov1peer.wirepeer):
33 def __init__(self, serverrepo, ui):
34 def __init__(self, serverrepo, ui):
34 self.serverrepo = serverrepo
35 self.serverrepo = serverrepo
35 self.ui = ui
36 self.ui = ui
36
37
37 def url(self):
38 def url(self):
38 return b'test'
39 return b'test'
39
40
40 def local(self):
41 def local(self):
41 return None
42 return None
42
43
43 def peer(self):
44 def peer(self):
44 return self
45 return self
45
46
46 def canpush(self):
47 def canpush(self):
47 return True
48 return True
48
49
49 def close(self):
50 def close(self):
50 pass
51 pass
51
52
52 def capabilities(self):
53 def capabilities(self):
53 return [b'batch']
54 return [b'batch']
54
55
55 def _call(self, cmd, **args):
56 def _call(self, cmd, **args):
56 args = pycompat.byteskwargs(args)
57 args = pycompat.byteskwargs(args)
57 res = wireproto.dispatch(self.serverrepo, proto(args), cmd)
58 res = wireproto.dispatch(self.serverrepo, proto(args), cmd)
58 if isinstance(res, wireprototypes.bytesresponse):
59 if isinstance(res, wireprototypes.bytesresponse):
59 return res.data
60 return res.data
60 elif isinstance(res, bytes):
61 elif isinstance(res, bytes):
61 return res
62 return res
62 else:
63 else:
63 raise error.Abort('dummy client does not support response type')
64 raise error.Abort('dummy client does not support response type')
64
65
65 def _callstream(self, cmd, **args):
66 def _callstream(self, cmd, **args):
66 return stringio(self._call(cmd, **args))
67 return stringio(self._call(cmd, **args))
67
68
68 @wireproto.batchable
69 @wireprotov1peer.batchable
69 def greet(self, name):
70 def greet(self, name):
70 f = wireproto.future()
71 f = wireprotov1peer.future()
71 yield {b'name': mangle(name)}, f
72 yield {b'name': mangle(name)}, f
72 yield unmangle(f.value)
73 yield unmangle(f.value)
73
74
74 class serverrepo(object):
75 class serverrepo(object):
75 def greet(self, name):
76 def greet(self, name):
76 return b"Hello, " + name
77 return b"Hello, " + name
77
78
78 def filtered(self, name):
79 def filtered(self, name):
79 return self
80 return self
80
81
81 def mangle(s):
82 def mangle(s):
82 return b''.join(pycompat.bytechr(ord(c) + 1) for c in pycompat.bytestr(s))
83 return b''.join(pycompat.bytechr(ord(c) + 1) for c in pycompat.bytestr(s))
83 def unmangle(s):
84 def unmangle(s):
84 return b''.join(pycompat.bytechr(ord(c) - 1) for c in pycompat.bytestr(s))
85 return b''.join(pycompat.bytechr(ord(c) - 1) for c in pycompat.bytestr(s))
85
86
86 def greet(repo, proto, name):
87 def greet(repo, proto, name):
87 return mangle(repo.greet(unmangle(name)))
88 return mangle(repo.greet(unmangle(name)))
88
89
89 wireproto.commands[b'greet'] = (greet, b'name',)
90 wireproto.commands[b'greet'] = (greet, b'name',)
90
91
91 srv = serverrepo()
92 srv = serverrepo()
92 clt = clientpeer(srv, uimod.ui())
93 clt = clientpeer(srv, uimod.ui())
93
94
94 print(clt.greet(b"Foobar"))
95 print(clt.greet(b"Foobar"))
95 b = clt.iterbatch()
96 b = clt.iterbatch()
96 list(map(b.greet, (b'Fo, =;:<o', b'Bar')))
97 list(map(b.greet, (b'Fo, =;:<o', b'Bar')))
97 b.submit()
98 b.submit()
98 print([r for r in b.results()])
99 print([r for r in b.results()])
General Comments 0
You need to be logged in to leave comments. Login now