##// END OF EJS Templates
wireproto: overhaul iterating batcher code (API)...
Gregory Szorc -
r33761:4c706037 default
parent child Browse files
Show More
@@ -1,152 +1,153 b''
1 1 # peer.py - repository base classes for mercurial
2 2 #
3 3 # Copyright 2005, 2006 Matt Mackall <mpm@selenic.com>
4 4 # Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
5 5 #
6 6 # This software may be used and distributed according to the terms of the
7 7 # GNU General Public License version 2 or any later version.
8 8
9 9 from __future__ import absolute_import
10 10
11 11 from .i18n import _
12 12 from . import (
13 13 error,
14 14 util,
15 15 )
16 16
17 17 # abstract batching support
18 18
19 19 class future(object):
20 20 '''placeholder for a value to be set later'''
21 21 def set(self, value):
22 22 if util.safehasattr(self, 'value'):
23 23 raise error.RepoError("future is already set")
24 24 self.value = value
25 25
26 26 class batcher(object):
27 27 '''base class for batches of commands submittable in a single request
28 28
29 29 All methods invoked on instances of this class are simply queued and
30 30 return a a future for the result. Once you call submit(), all the queued
31 31 calls are performed and the results set in their respective futures.
32 32 '''
33 33 def __init__(self):
34 34 self.calls = []
35 35 def __getattr__(self, name):
36 36 def call(*args, **opts):
37 37 resref = future()
38 38 self.calls.append((name, args, opts, resref,))
39 39 return resref
40 40 return call
41 41 def submit(self):
42 42 raise NotImplementedError()
43 43
44 44 class iterbatcher(batcher):
45 45
46 46 def submit(self):
47 47 raise NotImplementedError()
48 48
49 49 def results(self):
50 50 raise NotImplementedError()
51 51
52 52 class localbatch(batcher):
53 53 '''performs the queued calls directly'''
54 54 def __init__(self, local):
55 55 batcher.__init__(self)
56 56 self.local = local
57 57 def submit(self):
58 58 for name, args, opts, resref in self.calls:
59 59 resref.set(getattr(self.local, name)(*args, **opts))
60 60
61 61 class localiterbatcher(iterbatcher):
62 62 def __init__(self, local):
63 63 super(iterbatcher, self).__init__()
64 64 self.local = local
65 65
66 66 def submit(self):
67 67 # submit for a local iter batcher is a noop
68 68 pass
69 69
70 70 def results(self):
71 71 for name, args, opts, resref in self.calls:
72 yield getattr(self.local, name)(*args, **opts)
72 resref.set(getattr(self.local, name)(*args, **opts))
73 yield resref.value
73 74
74 75 def batchable(f):
75 76 '''annotation for batchable methods
76 77
77 78 Such methods must implement a coroutine as follows:
78 79
79 80 @batchable
80 81 def sample(self, one, two=None):
81 82 # Build list of encoded arguments suitable for your wire protocol:
82 83 encargs = [('one', encode(one),), ('two', encode(two),)]
83 84 # Create future for injection of encoded result:
84 85 encresref = future()
85 86 # Return encoded arguments and future:
86 87 yield encargs, encresref
87 88 # Assuming the future to be filled with the result from the batched
88 89 # request now. Decode it:
89 90 yield decode(encresref.value)
90 91
91 92 The decorator returns a function which wraps this coroutine as a plain
92 93 method, but adds the original method as an attribute called "batchable",
93 94 which is used by remotebatch to split the call into separate encoding and
94 95 decoding phases.
95 96 '''
96 97 def plain(*args, **opts):
97 98 batchable = f(*args, **opts)
98 99 encargsorres, encresref = next(batchable)
99 100 if not encresref:
100 101 return encargsorres # a local result in this case
101 102 self = args[0]
102 103 encresref.set(self._submitone(f.func_name, encargsorres))
103 104 return next(batchable)
104 105 setattr(plain, 'batchable', f)
105 106 return plain
106 107
107 108 class peerrepository(object):
108 109
109 110 def batch(self):
110 111 return localbatch(self)
111 112
112 113 def iterbatch(self):
113 114 """Batch requests but allow iterating over the results.
114 115
115 116 This is to allow interleaving responses with things like
116 117 progress updates for clients.
117 118 """
118 119 return localiterbatcher(self)
119 120
120 121 def capable(self, name):
121 122 '''tell whether repo supports named capability.
122 123 return False if not supported.
123 124 if boolean capability, return True.
124 125 if string capability, return string.'''
125 126 caps = self._capabilities()
126 127 if name in caps:
127 128 return True
128 129 name_eq = name + '='
129 130 for cap in caps:
130 131 if cap.startswith(name_eq):
131 132 return cap[len(name_eq):]
132 133 return False
133 134
134 135 def requirecap(self, name, purpose):
135 136 '''raise an exception if the given capability is not present'''
136 137 if not self.capable(name):
137 138 raise error.CapabilityError(
138 139 _('cannot %s; remote repository does not '
139 140 'support the %r capability') % (purpose, name))
140 141
141 142 def local(self):
142 143 '''return peer as a localrepo, or None'''
143 144 return None
144 145
145 146 def peer(self):
146 147 return self
147 148
148 149 def canpush(self):
149 150 return True
150 151
151 152 def close(self):
152 153 pass
@@ -1,1064 +1,1088 b''
1 1 # wireproto.py - generic wire protocol support functions
2 2 #
3 3 # Copyright 2005-2010 Matt Mackall <mpm@selenic.com>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 11 import itertools
12 12 import os
13 13 import tempfile
14 14
15 15 from .i18n import _
16 16 from .node import (
17 17 bin,
18 18 hex,
19 19 nullid,
20 20 )
21 21
22 22 from . import (
23 23 bundle2,
24 24 changegroup as changegroupmod,
25 25 encoding,
26 26 error,
27 27 exchange,
28 28 peer,
29 29 pushkey as pushkeymod,
30 30 pycompat,
31 31 streamclone,
32 32 util,
33 33 )
34 34
35 35 urlerr = util.urlerr
36 36 urlreq = util.urlreq
37 37
38 38 bundle2requiredmain = _('incompatible Mercurial client; bundle2 required')
39 39 bundle2requiredhint = _('see https://www.mercurial-scm.org/wiki/'
40 40 'IncompatibleClient')
41 41 bundle2required = '%s\n(%s)\n' % (bundle2requiredmain, bundle2requiredhint)
42 42
43 43 class abstractserverproto(object):
44 44 """abstract class that summarizes the protocol API
45 45
46 46 Used as reference and documentation.
47 47 """
48 48
49 49 def getargs(self, args):
50 50 """return the value for arguments in <args>
51 51
52 52 returns a list of values (same order as <args>)"""
53 53 raise NotImplementedError()
54 54
55 55 def getfile(self, fp):
56 56 """write the whole content of a file into a file like object
57 57
58 58 The file is in the form::
59 59
60 60 (<chunk-size>\n<chunk>)+0\n
61 61
62 62 chunk size is the ascii version of the int.
63 63 """
64 64 raise NotImplementedError()
65 65
66 66 def redirect(self):
67 67 """may setup interception for stdout and stderr
68 68
69 69 See also the `restore` method."""
70 70 raise NotImplementedError()
71 71
72 72 # If the `redirect` function does install interception, the `restore`
73 73 # function MUST be defined. If interception is not used, this function
74 74 # MUST NOT be defined.
75 75 #
76 76 # left commented here on purpose
77 77 #
78 78 #def restore(self):
79 79 # """reinstall previous stdout and stderr and return intercepted stdout
80 80 # """
81 81 # raise NotImplementedError()
82 82
83 83 class remotebatch(peer.batcher):
84 84 '''batches the queued calls; uses as few roundtrips as possible'''
85 85 def __init__(self, remote):
86 86 '''remote must support _submitbatch(encbatch) and
87 87 _submitone(op, encargs)'''
88 88 peer.batcher.__init__(self)
89 89 self.remote = remote
90 90 def submit(self):
91 91 req, rsp = [], []
92 92 for name, args, opts, resref in self.calls:
93 93 mtd = getattr(self.remote, name)
94 94 batchablefn = getattr(mtd, 'batchable', None)
95 95 if batchablefn is not None:
96 96 batchable = batchablefn(mtd.im_self, *args, **opts)
97 97 encargsorres, encresref = next(batchable)
98 98 assert encresref
99 99 req.append((name, encargsorres,))
100 100 rsp.append((batchable, encresref, resref,))
101 101 else:
102 102 if req:
103 103 self._submitreq(req, rsp)
104 104 req, rsp = [], []
105 105 resref.set(mtd(*args, **opts))
106 106 if req:
107 107 self._submitreq(req, rsp)
108 108 def _submitreq(self, req, rsp):
109 109 encresults = self.remote._submitbatch(req)
110 110 for encres, r in zip(encresults, rsp):
111 111 batchable, encresref, resref = r
112 112 encresref.set(encres)
113 113 resref.set(next(batchable))
114 114
115 115 class remoteiterbatcher(peer.iterbatcher):
116 116 def __init__(self, remote):
117 117 super(remoteiterbatcher, self).__init__()
118 118 self._remote = remote
119 119
120 120 def __getattr__(self, name):
121 121 # Validate this method is batchable, since submit() only supports
122 122 # batchable methods.
123 123 fn = getattr(self._remote, name)
124 124 if not getattr(fn, 'batchable', None):
125 125 raise error.ProgrammingError('Attempted to batch a non-batchable '
126 126 'call to %r' % name)
127 127
128 128 return super(remoteiterbatcher, self).__getattr__(name)
129 129
130 130 def submit(self):
131 131 """Break the batch request into many patch calls and pipeline them.
132 132
133 133 This is mostly valuable over http where request sizes can be
134 134 limited, but can be used in other places as well.
135 135 """
136 req, rsp = [], []
137 for name, args, opts, resref in self.calls:
138 mtd = getattr(self._remote, name)
136 # 2-tuple of (command, arguments) that represents what will be
137 # sent over the wire.
138 requests = []
139
140 # 4-tuple of (command, final future, @batchable generator, remote
141 # future).
142 results = []
143
144 for command, args, opts, finalfuture in self.calls:
145 mtd = getattr(self._remote, command)
139 146 batchable = mtd.batchable(mtd.im_self, *args, **opts)
140 encargsorres, encresref = next(batchable)
141 assert encresref
142 req.append((name, encargsorres))
143 rsp.append((batchable, encresref))
144 if req:
145 self._resultiter = self._remote._submitbatch(req)
146 self._rsp = rsp
147
148 commandargs, fremote = next(batchable)
149 assert fremote
150 requests.append((command, commandargs))
151 results.append((command, finalfuture, batchable, fremote))
152
153 if requests:
154 self._resultiter = self._remote._submitbatch(requests)
155
156 self._results = results
147 157
148 158 def results(self):
149 for (batchable, encresref), encres in itertools.izip(
150 self._rsp, self._resultiter):
151 encresref.set(encres)
152 yield next(batchable)
159 for command, finalfuture, batchable, remotefuture in self._results:
160 # Get the raw result, set it in the remote future, feed it
161 # back into the @batchable generator so it can be decoded, and
162 # set the result on the final future to this value.
163 remoteresult = next(self._resultiter)
164 remotefuture.set(remoteresult)
165 finalfuture.set(next(batchable))
166
167 # Verify our @batchable generators only emit 2 values.
168 try:
169 next(batchable)
170 except StopIteration:
171 pass
172 else:
173 raise error.ProgrammingError('%s @batchable generator emitted '
174 'unexpected value count' % command)
175
176 yield finalfuture.value
153 177
154 178 # Forward a couple of names from peer to make wireproto interactions
155 179 # slightly more sensible.
156 180 batchable = peer.batchable
157 181 future = peer.future
158 182
159 183 # list of nodes encoding / decoding
160 184
161 185 def decodelist(l, sep=' '):
162 186 if l:
163 187 return map(bin, l.split(sep))
164 188 return []
165 189
166 190 def encodelist(l, sep=' '):
167 191 try:
168 192 return sep.join(map(hex, l))
169 193 except TypeError:
170 194 raise
171 195
172 196 # batched call argument encoding
173 197
174 198 def escapearg(plain):
175 199 return (plain
176 200 .replace(':', ':c')
177 201 .replace(',', ':o')
178 202 .replace(';', ':s')
179 203 .replace('=', ':e'))
180 204
181 205 def unescapearg(escaped):
182 206 return (escaped
183 207 .replace(':e', '=')
184 208 .replace(':s', ';')
185 209 .replace(':o', ',')
186 210 .replace(':c', ':'))
187 211
188 212 def encodebatchcmds(req):
189 213 """Return a ``cmds`` argument value for the ``batch`` command."""
190 214 cmds = []
191 215 for op, argsdict in req:
192 216 # Old servers didn't properly unescape argument names. So prevent
193 217 # the sending of argument names that may not be decoded properly by
194 218 # servers.
195 219 assert all(escapearg(k) == k for k in argsdict)
196 220
197 221 args = ','.join('%s=%s' % (escapearg(k), escapearg(v))
198 222 for k, v in argsdict.iteritems())
199 223 cmds.append('%s %s' % (op, args))
200 224
201 225 return ';'.join(cmds)
202 226
203 227 # mapping of options accepted by getbundle and their types
204 228 #
205 229 # Meant to be extended by extensions. It is extensions responsibility to ensure
206 230 # such options are properly processed in exchange.getbundle.
207 231 #
208 232 # supported types are:
209 233 #
210 234 # :nodes: list of binary nodes
211 235 # :csv: list of comma-separated values
212 236 # :scsv: list of comma-separated values return as set
213 237 # :plain: string with no transformation needed.
214 238 gboptsmap = {'heads': 'nodes',
215 239 'common': 'nodes',
216 240 'obsmarkers': 'boolean',
217 241 'bundlecaps': 'scsv',
218 242 'listkeys': 'csv',
219 243 'cg': 'boolean',
220 244 'cbattempted': 'boolean'}
221 245
222 246 # client side
223 247
224 248 class wirepeer(peer.peerrepository):
225 249 """Client-side interface for communicating with a peer repository.
226 250
227 251 Methods commonly call wire protocol commands of the same name.
228 252
229 253 See also httppeer.py and sshpeer.py for protocol-specific
230 254 implementations of this interface.
231 255 """
232 256 def batch(self):
233 257 if self.capable('batch'):
234 258 return remotebatch(self)
235 259 else:
236 260 return peer.localbatch(self)
237 261 def _submitbatch(self, req):
238 262 """run batch request <req> on the server
239 263
240 264 Returns an iterator of the raw responses from the server.
241 265 """
242 266 rsp = self._callstream("batch", cmds=encodebatchcmds(req))
243 267 chunk = rsp.read(1024)
244 268 work = [chunk]
245 269 while chunk:
246 270 while ';' not in chunk and chunk:
247 271 chunk = rsp.read(1024)
248 272 work.append(chunk)
249 273 merged = ''.join(work)
250 274 while ';' in merged:
251 275 one, merged = merged.split(';', 1)
252 276 yield unescapearg(one)
253 277 chunk = rsp.read(1024)
254 278 work = [merged, chunk]
255 279 yield unescapearg(''.join(work))
256 280
257 281 def _submitone(self, op, args):
258 282 return self._call(op, **args)
259 283
260 284 def iterbatch(self):
261 285 return remoteiterbatcher(self)
262 286
263 287 @batchable
264 288 def lookup(self, key):
265 289 self.requirecap('lookup', _('look up remote revision'))
266 290 f = future()
267 291 yield {'key': encoding.fromlocal(key)}, f
268 292 d = f.value
269 293 success, data = d[:-1].split(" ", 1)
270 294 if int(success):
271 295 yield bin(data)
272 296 self._abort(error.RepoError(data))
273 297
274 298 @batchable
275 299 def heads(self):
276 300 f = future()
277 301 yield {}, f
278 302 d = f.value
279 303 try:
280 304 yield decodelist(d[:-1])
281 305 except ValueError:
282 306 self._abort(error.ResponseError(_("unexpected response:"), d))
283 307
284 308 @batchable
285 309 def known(self, nodes):
286 310 f = future()
287 311 yield {'nodes': encodelist(nodes)}, f
288 312 d = f.value
289 313 try:
290 314 yield [bool(int(b)) for b in d]
291 315 except ValueError:
292 316 self._abort(error.ResponseError(_("unexpected response:"), d))
293 317
294 318 @batchable
295 319 def branchmap(self):
296 320 f = future()
297 321 yield {}, f
298 322 d = f.value
299 323 try:
300 324 branchmap = {}
301 325 for branchpart in d.splitlines():
302 326 branchname, branchheads = branchpart.split(' ', 1)
303 327 branchname = encoding.tolocal(urlreq.unquote(branchname))
304 328 branchheads = decodelist(branchheads)
305 329 branchmap[branchname] = branchheads
306 330 yield branchmap
307 331 except TypeError:
308 332 self._abort(error.ResponseError(_("unexpected response:"), d))
309 333
310 334 def branches(self, nodes):
311 335 n = encodelist(nodes)
312 336 d = self._call("branches", nodes=n)
313 337 try:
314 338 br = [tuple(decodelist(b)) for b in d.splitlines()]
315 339 return br
316 340 except ValueError:
317 341 self._abort(error.ResponseError(_("unexpected response:"), d))
318 342
319 343 def between(self, pairs):
320 344 batch = 8 # avoid giant requests
321 345 r = []
322 346 for i in xrange(0, len(pairs), batch):
323 347 n = " ".join([encodelist(p, '-') for p in pairs[i:i + batch]])
324 348 d = self._call("between", pairs=n)
325 349 try:
326 350 r.extend(l and decodelist(l) or [] for l in d.splitlines())
327 351 except ValueError:
328 352 self._abort(error.ResponseError(_("unexpected response:"), d))
329 353 return r
330 354
331 355 @batchable
332 356 def pushkey(self, namespace, key, old, new):
333 357 if not self.capable('pushkey'):
334 358 yield False, None
335 359 f = future()
336 360 self.ui.debug('preparing pushkey for "%s:%s"\n' % (namespace, key))
337 361 yield {'namespace': encoding.fromlocal(namespace),
338 362 'key': encoding.fromlocal(key),
339 363 'old': encoding.fromlocal(old),
340 364 'new': encoding.fromlocal(new)}, f
341 365 d = f.value
342 366 d, output = d.split('\n', 1)
343 367 try:
344 368 d = bool(int(d))
345 369 except ValueError:
346 370 raise error.ResponseError(
347 371 _('push failed (unexpected response):'), d)
348 372 for l in output.splitlines(True):
349 373 self.ui.status(_('remote: '), l)
350 374 yield d
351 375
352 376 @batchable
353 377 def listkeys(self, namespace):
354 378 if not self.capable('pushkey'):
355 379 yield {}, None
356 380 f = future()
357 381 self.ui.debug('preparing listkeys for "%s"\n' % namespace)
358 382 yield {'namespace': encoding.fromlocal(namespace)}, f
359 383 d = f.value
360 384 self.ui.debug('received listkey for "%s": %i bytes\n'
361 385 % (namespace, len(d)))
362 386 yield pushkeymod.decodekeys(d)
363 387
364 388 def stream_out(self):
365 389 return self._callstream('stream_out')
366 390
367 391 def changegroup(self, nodes, kind):
368 392 n = encodelist(nodes)
369 393 f = self._callcompressable("changegroup", roots=n)
370 394 return changegroupmod.cg1unpacker(f, 'UN')
371 395
372 396 def changegroupsubset(self, bases, heads, kind):
373 397 self.requirecap('changegroupsubset', _('look up remote changes'))
374 398 bases = encodelist(bases)
375 399 heads = encodelist(heads)
376 400 f = self._callcompressable("changegroupsubset",
377 401 bases=bases, heads=heads)
378 402 return changegroupmod.cg1unpacker(f, 'UN')
379 403
380 404 def getbundle(self, source, **kwargs):
381 405 self.requirecap('getbundle', _('look up remote changes'))
382 406 opts = {}
383 407 bundlecaps = kwargs.get('bundlecaps')
384 408 if bundlecaps is not None:
385 409 kwargs['bundlecaps'] = sorted(bundlecaps)
386 410 else:
387 411 bundlecaps = () # kwargs could have it to None
388 412 for key, value in kwargs.iteritems():
389 413 if value is None:
390 414 continue
391 415 keytype = gboptsmap.get(key)
392 416 if keytype is None:
393 417 assert False, 'unexpected'
394 418 elif keytype == 'nodes':
395 419 value = encodelist(value)
396 420 elif keytype in ('csv', 'scsv'):
397 421 value = ','.join(value)
398 422 elif keytype == 'boolean':
399 423 value = '%i' % bool(value)
400 424 elif keytype != 'plain':
401 425 raise KeyError('unknown getbundle option type %s'
402 426 % keytype)
403 427 opts[key] = value
404 428 f = self._callcompressable("getbundle", **opts)
405 429 if any((cap.startswith('HG2') for cap in bundlecaps)):
406 430 return bundle2.getunbundler(self.ui, f)
407 431 else:
408 432 return changegroupmod.cg1unpacker(f, 'UN')
409 433
410 434 def unbundle(self, cg, heads, url):
411 435 '''Send cg (a readable file-like object representing the
412 436 changegroup to push, typically a chunkbuffer object) to the
413 437 remote server as a bundle.
414 438
415 439 When pushing a bundle10 stream, return an integer indicating the
416 440 result of the push (see changegroup.apply()).
417 441
418 442 When pushing a bundle20 stream, return a bundle20 stream.
419 443
420 444 `url` is the url the client thinks it's pushing to, which is
421 445 visible to hooks.
422 446 '''
423 447
424 448 if heads != ['force'] and self.capable('unbundlehash'):
425 449 heads = encodelist(['hashed',
426 450 hashlib.sha1(''.join(sorted(heads))).digest()])
427 451 else:
428 452 heads = encodelist(heads)
429 453
430 454 if util.safehasattr(cg, 'deltaheader'):
431 455 # this a bundle10, do the old style call sequence
432 456 ret, output = self._callpush("unbundle", cg, heads=heads)
433 457 if ret == "":
434 458 raise error.ResponseError(
435 459 _('push failed:'), output)
436 460 try:
437 461 ret = int(ret)
438 462 except ValueError:
439 463 raise error.ResponseError(
440 464 _('push failed (unexpected response):'), ret)
441 465
442 466 for l in output.splitlines(True):
443 467 self.ui.status(_('remote: '), l)
444 468 else:
445 469 # bundle2 push. Send a stream, fetch a stream.
446 470 stream = self._calltwowaystream('unbundle', cg, heads=heads)
447 471 ret = bundle2.getunbundler(self.ui, stream)
448 472 return ret
449 473
450 474 def debugwireargs(self, one, two, three=None, four=None, five=None):
451 475 # don't pass optional arguments left at their default value
452 476 opts = {}
453 477 if three is not None:
454 478 opts['three'] = three
455 479 if four is not None:
456 480 opts['four'] = four
457 481 return self._call('debugwireargs', one=one, two=two, **opts)
458 482
459 483 def _call(self, cmd, **args):
460 484 """execute <cmd> on the server
461 485
462 486 The command is expected to return a simple string.
463 487
464 488 returns the server reply as a string."""
465 489 raise NotImplementedError()
466 490
467 491 def _callstream(self, cmd, **args):
468 492 """execute <cmd> on the server
469 493
470 494 The command is expected to return a stream. Note that if the
471 495 command doesn't return a stream, _callstream behaves
472 496 differently for ssh and http peers.
473 497
474 498 returns the server reply as a file like object.
475 499 """
476 500 raise NotImplementedError()
477 501
478 502 def _callcompressable(self, cmd, **args):
479 503 """execute <cmd> on the server
480 504
481 505 The command is expected to return a stream.
482 506
483 507 The stream may have been compressed in some implementations. This
484 508 function takes care of the decompression. This is the only difference
485 509 with _callstream.
486 510
487 511 returns the server reply as a file like object.
488 512 """
489 513 raise NotImplementedError()
490 514
491 515 def _callpush(self, cmd, fp, **args):
492 516 """execute a <cmd> on server
493 517
494 518 The command is expected to be related to a push. Push has a special
495 519 return method.
496 520
497 521 returns the server reply as a (ret, output) tuple. ret is either
498 522 empty (error) or a stringified int.
499 523 """
500 524 raise NotImplementedError()
501 525
502 526 def _calltwowaystream(self, cmd, fp, **args):
503 527 """execute <cmd> on server
504 528
505 529 The command will send a stream to the server and get a stream in reply.
506 530 """
507 531 raise NotImplementedError()
508 532
509 533 def _abort(self, exception):
510 534 """clearly abort the wire protocol connection and raise the exception
511 535 """
512 536 raise NotImplementedError()
513 537
514 538 # server side
515 539
516 540 # wire protocol command can either return a string or one of these classes.
517 541 class streamres(object):
518 542 """wireproto reply: binary stream
519 543
520 544 The call was successful and the result is a stream.
521 545
522 546 Accepts either a generator or an object with a ``read(size)`` method.
523 547
524 548 ``v1compressible`` indicates whether this data can be compressed to
525 549 "version 1" clients (technically: HTTP peers using
526 550 application/mercurial-0.1 media type). This flag should NOT be used on
527 551 new commands because new clients should support a more modern compression
528 552 mechanism.
529 553 """
530 554 def __init__(self, gen=None, reader=None, v1compressible=False):
531 555 self.gen = gen
532 556 self.reader = reader
533 557 self.v1compressible = v1compressible
534 558
535 559 class pushres(object):
536 560 """wireproto reply: success with simple integer return
537 561
538 562 The call was successful and returned an integer contained in `self.res`.
539 563 """
540 564 def __init__(self, res):
541 565 self.res = res
542 566
543 567 class pusherr(object):
544 568 """wireproto reply: failure
545 569
546 570 The call failed. The `self.res` attribute contains the error message.
547 571 """
548 572 def __init__(self, res):
549 573 self.res = res
550 574
551 575 class ooberror(object):
552 576 """wireproto reply: failure of a batch of operation
553 577
554 578 Something failed during a batch call. The error message is stored in
555 579 `self.message`.
556 580 """
557 581 def __init__(self, message):
558 582 self.message = message
559 583
560 584 def getdispatchrepo(repo, proto, command):
561 585 """Obtain the repo used for processing wire protocol commands.
562 586
563 587 The intent of this function is to serve as a monkeypatch point for
564 588 extensions that need commands to operate on different repo views under
565 589 specialized circumstances.
566 590 """
567 591 return repo.filtered('served')
568 592
569 593 def dispatch(repo, proto, command):
570 594 repo = getdispatchrepo(repo, proto, command)
571 595 func, spec = commands[command]
572 596 args = proto.getargs(spec)
573 597 return func(repo, proto, *args)
574 598
575 599 def options(cmd, keys, others):
576 600 opts = {}
577 601 for k in keys:
578 602 if k in others:
579 603 opts[k] = others[k]
580 604 del others[k]
581 605 if others:
582 606 util.stderr.write("warning: %s ignored unexpected arguments %s\n"
583 607 % (cmd, ",".join(others)))
584 608 return opts
585 609
586 610 def bundle1allowed(repo, action):
587 611 """Whether a bundle1 operation is allowed from the server.
588 612
589 613 Priority is:
590 614
591 615 1. server.bundle1gd.<action> (if generaldelta active)
592 616 2. server.bundle1.<action>
593 617 3. server.bundle1gd (if generaldelta active)
594 618 4. server.bundle1
595 619 """
596 620 ui = repo.ui
597 621 gd = 'generaldelta' in repo.requirements
598 622
599 623 if gd:
600 624 v = ui.configbool('server', 'bundle1gd.%s' % action, None)
601 625 if v is not None:
602 626 return v
603 627
604 628 v = ui.configbool('server', 'bundle1.%s' % action, None)
605 629 if v is not None:
606 630 return v
607 631
608 632 if gd:
609 633 v = ui.configbool('server', 'bundle1gd')
610 634 if v is not None:
611 635 return v
612 636
613 637 return ui.configbool('server', 'bundle1')
614 638
615 639 def supportedcompengines(ui, proto, role):
616 640 """Obtain the list of supported compression engines for a request."""
617 641 assert role in (util.CLIENTROLE, util.SERVERROLE)
618 642
619 643 compengines = util.compengines.supportedwireengines(role)
620 644
621 645 # Allow config to override default list and ordering.
622 646 if role == util.SERVERROLE:
623 647 configengines = ui.configlist('server', 'compressionengines')
624 648 config = 'server.compressionengines'
625 649 else:
626 650 # This is currently implemented mainly to facilitate testing. In most
627 651 # cases, the server should be in charge of choosing a compression engine
628 652 # because a server has the most to lose from a sub-optimal choice. (e.g.
629 653 # CPU DoS due to an expensive engine or a network DoS due to poor
630 654 # compression ratio).
631 655 configengines = ui.configlist('experimental',
632 656 'clientcompressionengines')
633 657 config = 'experimental.clientcompressionengines'
634 658
635 659 # No explicit config. Filter out the ones that aren't supposed to be
636 660 # advertised and return default ordering.
637 661 if not configengines:
638 662 attr = 'serverpriority' if role == util.SERVERROLE else 'clientpriority'
639 663 return [e for e in compengines
640 664 if getattr(e.wireprotosupport(), attr) > 0]
641 665
642 666 # If compression engines are listed in the config, assume there is a good
643 667 # reason for it (like server operators wanting to achieve specific
644 668 # performance characteristics). So fail fast if the config references
645 669 # unusable compression engines.
646 670 validnames = set(e.name() for e in compengines)
647 671 invalidnames = set(e for e in configengines if e not in validnames)
648 672 if invalidnames:
649 673 raise error.Abort(_('invalid compression engine defined in %s: %s') %
650 674 (config, ', '.join(sorted(invalidnames))))
651 675
652 676 compengines = [e for e in compengines if e.name() in configengines]
653 677 compengines = sorted(compengines,
654 678 key=lambda e: configengines.index(e.name()))
655 679
656 680 if not compengines:
657 681 raise error.Abort(_('%s config option does not specify any known '
658 682 'compression engines') % config,
659 683 hint=_('usable compression engines: %s') %
660 684 ', '.sorted(validnames))
661 685
662 686 return compengines
663 687
664 688 # list of commands
665 689 commands = {}
666 690
667 691 def wireprotocommand(name, args=''):
668 692 """decorator for wire protocol command"""
669 693 def register(func):
670 694 commands[name] = (func, args)
671 695 return func
672 696 return register
673 697
674 698 @wireprotocommand('batch', 'cmds *')
675 699 def batch(repo, proto, cmds, others):
676 700 repo = repo.filtered("served")
677 701 res = []
678 702 for pair in cmds.split(';'):
679 703 op, args = pair.split(' ', 1)
680 704 vals = {}
681 705 for a in args.split(','):
682 706 if a:
683 707 n, v = a.split('=')
684 708 vals[unescapearg(n)] = unescapearg(v)
685 709 func, spec = commands[op]
686 710 if spec:
687 711 keys = spec.split()
688 712 data = {}
689 713 for k in keys:
690 714 if k == '*':
691 715 star = {}
692 716 for key in vals.keys():
693 717 if key not in keys:
694 718 star[key] = vals[key]
695 719 data['*'] = star
696 720 else:
697 721 data[k] = vals[k]
698 722 result = func(repo, proto, *[data[k] for k in keys])
699 723 else:
700 724 result = func(repo, proto)
701 725 if isinstance(result, ooberror):
702 726 return result
703 727 res.append(escapearg(result))
704 728 return ';'.join(res)
705 729
706 730 @wireprotocommand('between', 'pairs')
707 731 def between(repo, proto, pairs):
708 732 pairs = [decodelist(p, '-') for p in pairs.split(" ")]
709 733 r = []
710 734 for b in repo.between(pairs):
711 735 r.append(encodelist(b) + "\n")
712 736 return "".join(r)
713 737
714 738 @wireprotocommand('branchmap')
715 739 def branchmap(repo, proto):
716 740 branchmap = repo.branchmap()
717 741 heads = []
718 742 for branch, nodes in branchmap.iteritems():
719 743 branchname = urlreq.quote(encoding.fromlocal(branch))
720 744 branchnodes = encodelist(nodes)
721 745 heads.append('%s %s' % (branchname, branchnodes))
722 746 return '\n'.join(heads)
723 747
724 748 @wireprotocommand('branches', 'nodes')
725 749 def branches(repo, proto, nodes):
726 750 nodes = decodelist(nodes)
727 751 r = []
728 752 for b in repo.branches(nodes):
729 753 r.append(encodelist(b) + "\n")
730 754 return "".join(r)
731 755
732 756 @wireprotocommand('clonebundles', '')
733 757 def clonebundles(repo, proto):
734 758 """Server command for returning info for available bundles to seed clones.
735 759
736 760 Clients will parse this response and determine what bundle to fetch.
737 761
738 762 Extensions may wrap this command to filter or dynamically emit data
739 763 depending on the request. e.g. you could advertise URLs for the closest
740 764 data center given the client's IP address.
741 765 """
742 766 return repo.vfs.tryread('clonebundles.manifest')
743 767
744 768 wireprotocaps = ['lookup', 'changegroupsubset', 'branchmap', 'pushkey',
745 769 'known', 'getbundle', 'unbundlehash', 'batch']
746 770
747 771 def _capabilities(repo, proto):
748 772 """return a list of capabilities for a repo
749 773
750 774 This function exists to allow extensions to easily wrap capabilities
751 775 computation
752 776
753 777 - returns a lists: easy to alter
754 778 - change done here will be propagated to both `capabilities` and `hello`
755 779 command without any other action needed.
756 780 """
757 781 # copy to prevent modification of the global list
758 782 caps = list(wireprotocaps)
759 783 if streamclone.allowservergeneration(repo):
760 784 if repo.ui.configbool('server', 'preferuncompressed'):
761 785 caps.append('stream-preferred')
762 786 requiredformats = repo.requirements & repo.supportedformats
763 787 # if our local revlogs are just revlogv1, add 'stream' cap
764 788 if not requiredformats - {'revlogv1'}:
765 789 caps.append('stream')
766 790 # otherwise, add 'streamreqs' detailing our local revlog format
767 791 else:
768 792 caps.append('streamreqs=%s' % ','.join(sorted(requiredformats)))
769 793 if repo.ui.configbool('experimental', 'bundle2-advertise'):
770 794 capsblob = bundle2.encodecaps(bundle2.getrepocaps(repo))
771 795 caps.append('bundle2=' + urlreq.quote(capsblob))
772 796 caps.append('unbundle=%s' % ','.join(bundle2.bundlepriority))
773 797
774 798 if proto.name == 'http':
775 799 caps.append('httpheader=%d' %
776 800 repo.ui.configint('server', 'maxhttpheaderlen'))
777 801 if repo.ui.configbool('experimental', 'httppostargs'):
778 802 caps.append('httppostargs')
779 803
780 804 # FUTURE advertise 0.2rx once support is implemented
781 805 # FUTURE advertise minrx and mintx after consulting config option
782 806 caps.append('httpmediatype=0.1rx,0.1tx,0.2tx')
783 807
784 808 compengines = supportedcompengines(repo.ui, proto, util.SERVERROLE)
785 809 if compengines:
786 810 comptypes = ','.join(urlreq.quote(e.wireprotosupport().name)
787 811 for e in compengines)
788 812 caps.append('compression=%s' % comptypes)
789 813
790 814 return caps
791 815
792 816 # If you are writing an extension and consider wrapping this function. Wrap
793 817 # `_capabilities` instead.
794 818 @wireprotocommand('capabilities')
795 819 def capabilities(repo, proto):
796 820 return ' '.join(_capabilities(repo, proto))
797 821
798 822 @wireprotocommand('changegroup', 'roots')
799 823 def changegroup(repo, proto, roots):
800 824 nodes = decodelist(roots)
801 825 cg = changegroupmod.changegroup(repo, nodes, 'serve')
802 826 return streamres(reader=cg, v1compressible=True)
803 827
804 828 @wireprotocommand('changegroupsubset', 'bases heads')
805 829 def changegroupsubset(repo, proto, bases, heads):
806 830 bases = decodelist(bases)
807 831 heads = decodelist(heads)
808 832 cg = changegroupmod.changegroupsubset(repo, bases, heads, 'serve')
809 833 return streamres(reader=cg, v1compressible=True)
810 834
811 835 @wireprotocommand('debugwireargs', 'one two *')
812 836 def debugwireargs(repo, proto, one, two, others):
813 837 # only accept optional args from the known set
814 838 opts = options('debugwireargs', ['three', 'four'], others)
815 839 return repo.debugwireargs(one, two, **opts)
816 840
817 841 @wireprotocommand('getbundle', '*')
818 842 def getbundle(repo, proto, others):
819 843 opts = options('getbundle', gboptsmap.keys(), others)
820 844 for k, v in opts.iteritems():
821 845 keytype = gboptsmap[k]
822 846 if keytype == 'nodes':
823 847 opts[k] = decodelist(v)
824 848 elif keytype == 'csv':
825 849 opts[k] = list(v.split(','))
826 850 elif keytype == 'scsv':
827 851 opts[k] = set(v.split(','))
828 852 elif keytype == 'boolean':
829 853 # Client should serialize False as '0', which is a non-empty string
830 854 # so it evaluates as a True bool.
831 855 if v == '0':
832 856 opts[k] = False
833 857 else:
834 858 opts[k] = bool(v)
835 859 elif keytype != 'plain':
836 860 raise KeyError('unknown getbundle option type %s'
837 861 % keytype)
838 862
839 863 if not bundle1allowed(repo, 'pull'):
840 864 if not exchange.bundle2requested(opts.get('bundlecaps')):
841 865 if proto.name == 'http':
842 866 return ooberror(bundle2required)
843 867 raise error.Abort(bundle2requiredmain,
844 868 hint=bundle2requiredhint)
845 869
846 870 try:
847 871 if repo.ui.configbool('server', 'disablefullbundle'):
848 872 # Check to see if this is a full clone.
849 873 clheads = set(repo.changelog.heads())
850 874 heads = set(opts.get('heads', set()))
851 875 common = set(opts.get('common', set()))
852 876 common.discard(nullid)
853 877 if not common and clheads == heads:
854 878 raise error.Abort(
855 879 _('server has pull-based clones disabled'),
856 880 hint=_('remove --pull if specified or upgrade Mercurial'))
857 881
858 882 chunks = exchange.getbundlechunks(repo, 'serve', **opts)
859 883 except error.Abort as exc:
860 884 # cleanly forward Abort error to the client
861 885 if not exchange.bundle2requested(opts.get('bundlecaps')):
862 886 if proto.name == 'http':
863 887 return ooberror(str(exc) + '\n')
864 888 raise # cannot do better for bundle1 + ssh
865 889 # bundle2 request expect a bundle2 reply
866 890 bundler = bundle2.bundle20(repo.ui)
867 891 manargs = [('message', str(exc))]
868 892 advargs = []
869 893 if exc.hint is not None:
870 894 advargs.append(('hint', exc.hint))
871 895 bundler.addpart(bundle2.bundlepart('error:abort',
872 896 manargs, advargs))
873 897 return streamres(gen=bundler.getchunks(), v1compressible=True)
874 898 return streamres(gen=chunks, v1compressible=True)
875 899
876 900 @wireprotocommand('heads')
877 901 def heads(repo, proto):
878 902 h = repo.heads()
879 903 return encodelist(h) + "\n"
880 904
881 905 @wireprotocommand('hello')
882 906 def hello(repo, proto):
883 907 '''the hello command returns a set of lines describing various
884 908 interesting things about the server, in an RFC822-like format.
885 909 Currently the only one defined is "capabilities", which
886 910 consists of a line in the form:
887 911
888 912 capabilities: space separated list of tokens
889 913 '''
890 914 return "capabilities: %s\n" % (capabilities(repo, proto))
891 915
892 916 @wireprotocommand('listkeys', 'namespace')
893 917 def listkeys(repo, proto, namespace):
894 918 d = repo.listkeys(encoding.tolocal(namespace)).items()
895 919 return pushkeymod.encodekeys(d)
896 920
897 921 @wireprotocommand('lookup', 'key')
898 922 def lookup(repo, proto, key):
899 923 try:
900 924 k = encoding.tolocal(key)
901 925 c = repo[k]
902 926 r = c.hex()
903 927 success = 1
904 928 except Exception as inst:
905 929 r = str(inst)
906 930 success = 0
907 931 return "%s %s\n" % (success, r)
908 932
909 933 @wireprotocommand('known', 'nodes *')
910 934 def known(repo, proto, nodes, others):
911 935 return ''.join(b and "1" or "0" for b in repo.known(decodelist(nodes)))
912 936
913 937 @wireprotocommand('pushkey', 'namespace key old new')
914 938 def pushkey(repo, proto, namespace, key, old, new):
915 939 # compatibility with pre-1.8 clients which were accidentally
916 940 # sending raw binary nodes rather than utf-8-encoded hex
917 941 if len(new) == 20 and util.escapestr(new) != new:
918 942 # looks like it could be a binary node
919 943 try:
920 944 new.decode('utf-8')
921 945 new = encoding.tolocal(new) # but cleanly decodes as UTF-8
922 946 except UnicodeDecodeError:
923 947 pass # binary, leave unmodified
924 948 else:
925 949 new = encoding.tolocal(new) # normal path
926 950
927 951 if util.safehasattr(proto, 'restore'):
928 952
929 953 proto.redirect()
930 954
931 955 try:
932 956 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
933 957 encoding.tolocal(old), new) or False
934 958 except error.Abort:
935 959 r = False
936 960
937 961 output = proto.restore()
938 962
939 963 return '%s\n%s' % (int(r), output)
940 964
941 965 r = repo.pushkey(encoding.tolocal(namespace), encoding.tolocal(key),
942 966 encoding.tolocal(old), new)
943 967 return '%s\n' % int(r)
944 968
945 969 @wireprotocommand('stream_out')
946 970 def stream(repo, proto):
947 971 '''If the server supports streaming clone, it advertises the "stream"
948 972 capability with a value representing the version and flags of the repo
949 973 it is serving. Client checks to see if it understands the format.
950 974 '''
951 975 if not streamclone.allowservergeneration(repo):
952 976 return '1\n'
953 977
954 978 def getstream(it):
955 979 yield '0\n'
956 980 for chunk in it:
957 981 yield chunk
958 982
959 983 try:
960 984 # LockError may be raised before the first result is yielded. Don't
961 985 # emit output until we're sure we got the lock successfully.
962 986 it = streamclone.generatev1wireproto(repo)
963 987 return streamres(gen=getstream(it))
964 988 except error.LockError:
965 989 return '2\n'
966 990
967 991 @wireprotocommand('unbundle', 'heads')
968 992 def unbundle(repo, proto, heads):
969 993 their_heads = decodelist(heads)
970 994
971 995 try:
972 996 proto.redirect()
973 997
974 998 exchange.check_heads(repo, their_heads, 'preparing changes')
975 999
976 1000 # write bundle data to temporary file because it can be big
977 1001 fd, tempname = tempfile.mkstemp(prefix='hg-unbundle-')
978 1002 fp = os.fdopen(fd, pycompat.sysstr('wb+'))
979 1003 r = 0
980 1004 try:
981 1005 proto.getfile(fp)
982 1006 fp.seek(0)
983 1007 gen = exchange.readbundle(repo.ui, fp, None)
984 1008 if (isinstance(gen, changegroupmod.cg1unpacker)
985 1009 and not bundle1allowed(repo, 'push')):
986 1010 if proto.name == 'http':
987 1011 # need to special case http because stderr do not get to
988 1012 # the http client on failed push so we need to abuse some
989 1013 # other error type to make sure the message get to the
990 1014 # user.
991 1015 return ooberror(bundle2required)
992 1016 raise error.Abort(bundle2requiredmain,
993 1017 hint=bundle2requiredhint)
994 1018
995 1019 r = exchange.unbundle(repo, gen, their_heads, 'serve',
996 1020 proto._client())
997 1021 if util.safehasattr(r, 'addpart'):
998 1022 # The return looks streamable, we are in the bundle2 case and
999 1023 # should return a stream.
1000 1024 return streamres(gen=r.getchunks())
1001 1025 return pushres(r)
1002 1026
1003 1027 finally:
1004 1028 fp.close()
1005 1029 os.unlink(tempname)
1006 1030
1007 1031 except (error.BundleValueError, error.Abort, error.PushRaced) as exc:
1008 1032 # handle non-bundle2 case first
1009 1033 if not getattr(exc, 'duringunbundle2', False):
1010 1034 try:
1011 1035 raise
1012 1036 except error.Abort:
1013 1037 # The old code we moved used util.stderr directly.
1014 1038 # We did not change it to minimise code change.
1015 1039 # This need to be moved to something proper.
1016 1040 # Feel free to do it.
1017 1041 util.stderr.write("abort: %s\n" % exc)
1018 1042 if exc.hint is not None:
1019 1043 util.stderr.write("(%s)\n" % exc.hint)
1020 1044 return pushres(0)
1021 1045 except error.PushRaced:
1022 1046 return pusherr(str(exc))
1023 1047
1024 1048 bundler = bundle2.bundle20(repo.ui)
1025 1049 for out in getattr(exc, '_bundle2salvagedoutput', ()):
1026 1050 bundler.addpart(out)
1027 1051 try:
1028 1052 try:
1029 1053 raise
1030 1054 except error.PushkeyFailed as exc:
1031 1055 # check client caps
1032 1056 remotecaps = getattr(exc, '_replycaps', None)
1033 1057 if (remotecaps is not None
1034 1058 and 'pushkey' not in remotecaps.get('error', ())):
1035 1059 # no support remote side, fallback to Abort handler.
1036 1060 raise
1037 1061 part = bundler.newpart('error:pushkey')
1038 1062 part.addparam('in-reply-to', exc.partid)
1039 1063 if exc.namespace is not None:
1040 1064 part.addparam('namespace', exc.namespace, mandatory=False)
1041 1065 if exc.key is not None:
1042 1066 part.addparam('key', exc.key, mandatory=False)
1043 1067 if exc.new is not None:
1044 1068 part.addparam('new', exc.new, mandatory=False)
1045 1069 if exc.old is not None:
1046 1070 part.addparam('old', exc.old, mandatory=False)
1047 1071 if exc.ret is not None:
1048 1072 part.addparam('ret', exc.ret, mandatory=False)
1049 1073 except error.BundleValueError as exc:
1050 1074 errpart = bundler.newpart('error:unsupportedcontent')
1051 1075 if exc.parttype is not None:
1052 1076 errpart.addparam('parttype', exc.parttype)
1053 1077 if exc.params:
1054 1078 errpart.addparam('params', '\0'.join(exc.params))
1055 1079 except error.Abort as exc:
1056 1080 manargs = [('message', str(exc))]
1057 1081 advargs = []
1058 1082 if exc.hint is not None:
1059 1083 advargs.append(('hint', exc.hint))
1060 1084 bundler.addpart(bundle2.bundlepart('error:abort',
1061 1085 manargs, advargs))
1062 1086 except error.PushRaced as exc:
1063 1087 bundler.newpart('error:pushraced', [('message', str(exc))])
1064 1088 return streamres(gen=bundler.getchunks())
@@ -1,176 +1,206 b''
1 1 # test-batching.py - tests for transparent command batching
2 2 #
3 3 # Copyright 2011 Peter Arrenbrecht <peter@arrenbrecht.ch>
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import, print_function
9 9
10 10 from mercurial import (
11 error,
11 12 peer,
13 util,
12 14 wireproto,
13 15 )
14 16
15 17 # equivalent of repo.repository
16 18 class thing(object):
17 19 def hello(self):
18 20 return "Ready."
19 21
20 22 # equivalent of localrepo.localrepository
21 23 class localthing(thing):
22 24 def foo(self, one, two=None):
23 25 if one:
24 26 return "%s and %s" % (one, two,)
25 27 return "Nope"
26 28 def bar(self, b, a):
27 29 return "%s und %s" % (b, a,)
28 30 def greet(self, name=None):
29 31 return "Hello, %s" % name
30 def batch(self):
32 def batchiter(self):
31 33 '''Support for local batching.'''
32 return peer.localbatch(self)
34 return peer.localiterbatcher(self)
33 35
34 36 # usage of "thing" interface
35 37 def use(it):
36 38
37 39 # Direct call to base method shared between client and server.
38 40 print(it.hello())
39 41
40 42 # Direct calls to proxied methods. They cause individual roundtrips.
41 43 print(it.foo("Un", two="Deux"))
42 44 print(it.bar("Eins", "Zwei"))
43 45
44 # Batched call to a couple of (possibly proxied) methods.
45 batch = it.batch()
46 # Batched call to a couple of proxied methods.
47 batch = it.batchiter()
46 48 # The calls return futures to eventually hold results.
47 49 foo = batch.foo(one="One", two="Two")
48 50 bar = batch.bar("Eins", "Zwei")
49 # We can call non-batchable proxy methods, but the break the current batch
50 # request and cause additional roundtrips.
51 greet = batch.greet(name="John Smith")
52 # We can also add local methods into the mix, but they break the batch too.
53 hello = batch.hello()
54 51 bar2 = batch.bar(b="Uno", a="Due")
55 # Only now are all the calls executed in sequence, with as few roundtrips
56 # as possible.
52
53 # Future shouldn't be set until we submit().
54 assert isinstance(foo, peer.future)
55 assert not util.safehasattr(foo, 'value')
56 assert not util.safehasattr(bar, 'value')
57 57 batch.submit()
58 # After the call to submit, the futures actually contain values.
58 # Call results() to obtain results as a generator.
59 results = batch.results()
60
61 # Future results shouldn't be set until we consume a value.
62 assert not util.safehasattr(foo, 'value')
63 foovalue = next(results)
64 assert util.safehasattr(foo, 'value')
65 assert foovalue == foo.value
59 66 print(foo.value)
67 next(results)
60 68 print(bar.value)
61 print(greet.value)
62 print(hello.value)
69 next(results)
63 70 print(bar2.value)
64 71
72 # We should be at the end of the results generator.
73 try:
74 next(results)
75 except StopIteration:
76 print('proper end of results generator')
77 else:
78 print('extra emitted element!')
79
80 # Attempting to call a non-batchable method inside a batch fails.
81 batch = it.batchiter()
82 try:
83 batch.greet(name='John Smith')
84 except error.ProgrammingError as e:
85 print(e)
86
87 # Attempting to call a local method inside a batch fails.
88 batch = it.batchiter()
89 try:
90 batch.hello()
91 except error.ProgrammingError as e:
92 print(e)
93
65 94 # local usage
66 95 mylocal = localthing()
67 96 print()
68 97 print("== Local")
69 98 use(mylocal)
70 99
71 100 # demo remoting; mimicks what wireproto and HTTP/SSH do
72 101
73 102 # shared
74 103
75 104 def escapearg(plain):
76 105 return (plain
77 106 .replace(':', '::')
78 107 .replace(',', ':,')
79 108 .replace(';', ':;')
80 109 .replace('=', ':='))
81 110 def unescapearg(escaped):
82 111 return (escaped
83 112 .replace(':=', '=')
84 113 .replace(':;', ';')
85 114 .replace(':,', ',')
86 115 .replace('::', ':'))
87 116
88 117 # server side
89 118
90 119 # equivalent of wireproto's global functions
91 120 class server(object):
92 121 def __init__(self, local):
93 122 self.local = local
94 123 def _call(self, name, args):
95 124 args = dict(arg.split('=', 1) for arg in args)
96 125 return getattr(self, name)(**args)
97 126 def perform(self, req):
98 127 print("REQ:", req)
99 128 name, args = req.split('?', 1)
100 129 args = args.split('&')
101 130 vals = dict(arg.split('=', 1) for arg in args)
102 131 res = getattr(self, name)(**vals)
103 132 print(" ->", res)
104 133 return res
105 134 def batch(self, cmds):
106 135 res = []
107 136 for pair in cmds.split(';'):
108 137 name, args = pair.split(':', 1)
109 138 vals = {}
110 139 for a in args.split(','):
111 140 if a:
112 141 n, v = a.split('=')
113 142 vals[n] = unescapearg(v)
114 143 res.append(escapearg(getattr(self, name)(**vals)))
115 144 return ';'.join(res)
116 145 def foo(self, one, two):
117 146 return mangle(self.local.foo(unmangle(one), unmangle(two)))
118 147 def bar(self, b, a):
119 148 return mangle(self.local.bar(unmangle(b), unmangle(a)))
120 149 def greet(self, name):
121 150 return mangle(self.local.greet(unmangle(name)))
122 151 myserver = server(mylocal)
123 152
124 153 # local side
125 154
126 155 # equivalent of wireproto.encode/decodelist, that is, type-specific marshalling
127 156 # here we just transform the strings a bit to check we're properly en-/decoding
128 157 def mangle(s):
129 158 return ''.join(chr(ord(c) + 1) for c in s)
130 159 def unmangle(s):
131 160 return ''.join(chr(ord(c) - 1) for c in s)
132 161
133 162 # equivalent of wireproto.wirerepository and something like http's wire format
134 163 class remotething(thing):
135 164 def __init__(self, server):
136 165 self.server = server
137 166 def _submitone(self, name, args):
138 167 req = name + '?' + '&'.join(['%s=%s' % (n, v) for n, v in args])
139 168 return self.server.perform(req)
140 169 def _submitbatch(self, cmds):
141 170 req = []
142 171 for name, args in cmds:
143 172 args = ','.join(n + '=' + escapearg(v) for n, v in args)
144 173 req.append(name + ':' + args)
145 174 req = ';'.join(req)
146 175 res = self._submitone('batch', [('cmds', req,)])
147 return res.split(';')
176 for r in res.split(';'):
177 yield r
148 178
149 def batch(self):
150 return wireproto.remotebatch(self)
179 def batchiter(self):
180 return wireproto.remoteiterbatcher(self)
151 181
152 182 @peer.batchable
153 183 def foo(self, one, two=None):
154 184 encargs = [('one', mangle(one),), ('two', mangle(two),)]
155 185 encresref = peer.future()
156 186 yield encargs, encresref
157 187 yield unmangle(encresref.value)
158 188
159 189 @peer.batchable
160 190 def bar(self, b, a):
161 191 encresref = peer.future()
162 192 yield [('b', mangle(b),), ('a', mangle(a),)], encresref
163 193 yield unmangle(encresref.value)
164 194
165 195 # greet is coded directly. It therefore does not support batching. If it
166 196 # does appear in a batch, the batch is split around greet, and the call to
167 197 # greet is done in its own roundtrip.
168 198 def greet(self, name=None):
169 199 return unmangle(self._submitone('greet', [('name', mangle(name),)]))
170 200
171 201 # demo remote usage
172 202
173 203 myproxy = remotething(myserver)
174 204 print()
175 205 print("== Remote")
176 206 use(myproxy)
@@ -1,30 +1,26 b''
1 1
2 2 == Local
3 3 Ready.
4 4 Un and Deux
5 5 Eins und Zwei
6 6 One and Two
7 7 Eins und Zwei
8 Hello, John Smith
9 Ready.
10 8 Uno und Due
9 proper end of results generator
11 10
12 11 == Remote
13 12 Ready.
14 13 REQ: foo?one=Vo&two=Efvy
15 14 -> Vo!boe!Efvy
16 15 Un and Deux
17 16 REQ: bar?b=Fjot&a=[xfj
18 17 -> Fjot!voe![xfj
19 18 Eins und Zwei
20 REQ: batch?cmds=foo:one=Pof,two=Uxp;bar:b=Fjot,a=[xfj
21 -> Pof!boe!Uxp;Fjot!voe![xfj
22 REQ: greet?name=Kpio!Tnjui
23 -> Ifmmp-!Kpio!Tnjui
24 REQ: batch?cmds=bar:b=Vop,a=Evf
25 -> Vop!voe!Evf
19 REQ: batch?cmds=foo:one=Pof,two=Uxp;bar:b=Fjot,a=[xfj;bar:b=Vop,a=Evf
20 -> Pof!boe!Uxp;Fjot!voe![xfj;Vop!voe!Evf
26 21 One and Two
27 22 Eins und Zwei
28 Hello, John Smith
29 Ready.
30 23 Uno und Due
24 proper end of results generator
25 Attempted to batch a non-batchable call to 'greet'
26 Attempted to batch a non-batchable call to 'hello'
General Comments 0
You need to be logged in to leave comments. Login now