##// END OF EJS Templates
py3: fix a bytes vs str issue in remotefilelog extension...
Kyle Lippincott -
r44287:94670e12 default
parent child Browse files
Show More
@@ -1,667 +1,667 b''
1 1 # fileserverclient.py - client for communicating with the cache process
2 2 #
3 3 # Copyright 2013 Facebook, Inc.
4 4 #
5 5 # This software may be used and distributed according to the terms of the
6 6 # GNU General Public License version 2 or any later version.
7 7
8 8 from __future__ import absolute_import
9 9
10 10 import hashlib
11 11 import io
12 12 import os
13 13 import threading
14 14 import time
15 15 import zlib
16 16
17 17 from mercurial.i18n import _
18 18 from mercurial.node import bin, hex, nullid
19 19 from mercurial import (
20 20 error,
21 21 node,
22 22 pycompat,
23 23 revlog,
24 24 sshpeer,
25 25 util,
26 26 wireprotov1peer,
27 27 )
28 28 from mercurial.utils import procutil
29 29
30 30 from . import (
31 31 constants,
32 32 contentstore,
33 33 metadatastore,
34 34 )
35 35
36 36 _sshv1peer = sshpeer.sshv1peer
37 37
38 38 # Statistics for debugging
39 39 fetchcost = 0
40 40 fetches = 0
41 41 fetched = 0
42 42 fetchmisses = 0
43 43
44 44 _lfsmod = None
45 45
46 46
47 47 def getcachekey(reponame, file, id):
48 48 pathhash = node.hex(hashlib.sha1(file).digest())
49 49 return os.path.join(reponame, pathhash[:2], pathhash[2:], id)
50 50
51 51
52 52 def getlocalkey(file, id):
53 53 pathhash = node.hex(hashlib.sha1(file).digest())
54 54 return os.path.join(pathhash, id)
55 55
56 56
57 57 def peersetup(ui, peer):
58 58 class remotefilepeer(peer.__class__):
59 59 @wireprotov1peer.batchable
60 60 def x_rfl_getfile(self, file, node):
61 61 if not self.capable(b'x_rfl_getfile'):
62 62 raise error.Abort(
63 63 b'configured remotefile server does not support getfile'
64 64 )
65 65 f = wireprotov1peer.future()
66 66 yield {b'file': file, b'node': node}, f
67 67 code, data = f.value.split(b'\0', 1)
68 68 if int(code):
69 69 raise error.LookupError(file, node, data)
70 70 yield data
71 71
72 72 @wireprotov1peer.batchable
73 73 def x_rfl_getflogheads(self, path):
74 74 if not self.capable(b'x_rfl_getflogheads'):
75 75 raise error.Abort(
76 76 b'configured remotefile server does not '
77 77 b'support getflogheads'
78 78 )
79 79 f = wireprotov1peer.future()
80 80 yield {b'path': path}, f
81 81 heads = f.value.split(b'\n') if f.value else []
82 82 yield heads
83 83
84 84 def _updatecallstreamopts(self, command, opts):
85 85 if command != b'getbundle':
86 86 return
87 87 if (
88 88 constants.NETWORK_CAP_LEGACY_SSH_GETFILES
89 89 not in self.capabilities()
90 90 ):
91 91 return
92 92 if not util.safehasattr(self, '_localrepo'):
93 93 return
94 94 if (
95 95 constants.SHALLOWREPO_REQUIREMENT
96 96 not in self._localrepo.requirements
97 97 ):
98 98 return
99 99
100 100 bundlecaps = opts.get(b'bundlecaps')
101 101 if bundlecaps:
102 102 bundlecaps = [bundlecaps]
103 103 else:
104 104 bundlecaps = []
105 105
106 106 # shallow, includepattern, and excludepattern are a hacky way of
107 107 # carrying over data from the local repo to this getbundle
108 108 # command. We need to do it this way because bundle1 getbundle
109 109 # doesn't provide any other place we can hook in to manipulate
110 110 # getbundle args before it goes across the wire. Once we get rid
111 111 # of bundle1, we can use bundle2's _pullbundle2extraprepare to
112 112 # do this more cleanly.
113 113 bundlecaps.append(constants.BUNDLE2_CAPABLITY)
114 114 if self._localrepo.includepattern:
115 115 patterns = b'\0'.join(self._localrepo.includepattern)
116 116 includecap = b"includepattern=" + patterns
117 117 bundlecaps.append(includecap)
118 118 if self._localrepo.excludepattern:
119 119 patterns = b'\0'.join(self._localrepo.excludepattern)
120 120 excludecap = b"excludepattern=" + patterns
121 121 bundlecaps.append(excludecap)
122 122 opts[b'bundlecaps'] = b','.join(bundlecaps)
123 123
124 124 def _sendrequest(self, command, args, **opts):
125 125 self._updatecallstreamopts(command, args)
126 126 return super(remotefilepeer, self)._sendrequest(
127 127 command, args, **opts
128 128 )
129 129
130 130 def _callstream(self, command, **opts):
131 131 supertype = super(remotefilepeer, self)
132 132 if not util.safehasattr(supertype, '_sendrequest'):
133 133 self._updatecallstreamopts(command, pycompat.byteskwargs(opts))
134 134 return super(remotefilepeer, self)._callstream(command, **opts)
135 135
136 136 peer.__class__ = remotefilepeer
137 137
138 138
139 139 class cacheconnection(object):
140 140 """The connection for communicating with the remote cache. Performs
141 141 gets and sets by communicating with an external process that has the
142 142 cache-specific implementation.
143 143 """
144 144
145 145 def __init__(self):
146 146 self.pipeo = self.pipei = self.pipee = None
147 147 self.subprocess = None
148 148 self.connected = False
149 149
150 150 def connect(self, cachecommand):
151 151 if self.pipeo:
152 152 raise error.Abort(_(b"cache connection already open"))
153 153 self.pipei, self.pipeo, self.pipee, self.subprocess = procutil.popen4(
154 154 cachecommand
155 155 )
156 156 self.connected = True
157 157
158 158 def close(self):
159 159 def tryclose(pipe):
160 160 try:
161 161 pipe.close()
162 162 except Exception:
163 163 pass
164 164
165 165 if self.connected:
166 166 try:
167 167 self.pipei.write(b"exit\n")
168 168 except Exception:
169 169 pass
170 170 tryclose(self.pipei)
171 171 self.pipei = None
172 172 tryclose(self.pipeo)
173 173 self.pipeo = None
174 174 tryclose(self.pipee)
175 175 self.pipee = None
176 176 try:
177 177 # Wait for process to terminate, making sure to avoid deadlock.
178 178 # See https://docs.python.org/2/library/subprocess.html for
179 179 # warnings about wait() and deadlocking.
180 180 self.subprocess.communicate()
181 181 except Exception:
182 182 pass
183 183 self.subprocess = None
184 184 self.connected = False
185 185
186 186 def request(self, request, flush=True):
187 187 if self.connected:
188 188 try:
189 189 self.pipei.write(request)
190 190 if flush:
191 191 self.pipei.flush()
192 192 except IOError:
193 193 self.close()
194 194
195 195 def receiveline(self):
196 196 if not self.connected:
197 197 return None
198 198 try:
199 199 result = self.pipeo.readline()[:-1]
200 200 if not result:
201 201 self.close()
202 202 except IOError:
203 203 self.close()
204 204
205 205 return result
206 206
207 207
208 208 def _getfilesbatch(
209 209 remote, receivemissing, progresstick, missed, idmap, batchsize
210 210 ):
211 211 # Over http(s), iterbatch is a streamy method and we can start
212 212 # looking at results early. This means we send one (potentially
213 213 # large) request, but then we show nice progress as we process
214 214 # file results, rather than showing chunks of $batchsize in
215 215 # progress.
216 216 #
217 217 # Over ssh, iterbatch isn't streamy because batch() wasn't
218 218 # explicitly designed as a streaming method. In the future we
219 219 # should probably introduce a streambatch() method upstream and
220 220 # use that for this.
221 221 with remote.commandexecutor() as e:
222 222 futures = []
223 223 for m in missed:
224 224 futures.append(
225 225 e.callcommand(
226 226 b'x_rfl_getfile', {b'file': idmap[m], b'node': m[-40:]}
227 227 )
228 228 )
229 229
230 230 for i, m in enumerate(missed):
231 231 r = futures[i].result()
232 232 futures[i] = None # release memory
233 233 file_ = idmap[m]
234 234 node = m[-40:]
235 235 receivemissing(io.BytesIO(b'%d\n%s' % (len(r), r)), file_, node)
236 236 progresstick()
237 237
238 238
239 239 def _getfiles_optimistic(
240 240 remote, receivemissing, progresstick, missed, idmap, step
241 241 ):
242 242 remote._callstream(b"x_rfl_getfiles")
243 243 i = 0
244 244 pipeo = remote._pipeo
245 245 pipei = remote._pipei
246 246 while i < len(missed):
247 247 # issue a batch of requests
248 248 start = i
249 249 end = min(len(missed), start + step)
250 250 i = end
251 251 for missingid in missed[start:end]:
252 252 # issue new request
253 253 versionid = missingid[-40:]
254 254 file = idmap[missingid]
255 255 sshrequest = b"%s%s\n" % (versionid, file)
256 256 pipeo.write(sshrequest)
257 257 pipeo.flush()
258 258
259 259 # receive batch results
260 260 for missingid in missed[start:end]:
261 261 versionid = missingid[-40:]
262 262 file = idmap[missingid]
263 263 receivemissing(pipei, file, versionid)
264 264 progresstick()
265 265
266 266 # End the command
267 267 pipeo.write(b'\n')
268 268 pipeo.flush()
269 269
270 270
271 271 def _getfiles_threaded(
272 272 remote, receivemissing, progresstick, missed, idmap, step
273 273 ):
274 274 remote._callstream(b"getfiles")
275 275 pipeo = remote._pipeo
276 276 pipei = remote._pipei
277 277
278 278 def writer():
279 279 for missingid in missed:
280 280 versionid = missingid[-40:]
281 281 file = idmap[missingid]
282 282 sshrequest = b"%s%s\n" % (versionid, file)
283 283 pipeo.write(sshrequest)
284 284 pipeo.flush()
285 285
286 286 writerthread = threading.Thread(target=writer)
287 287 writerthread.daemon = True
288 288 writerthread.start()
289 289
290 290 for missingid in missed:
291 291 versionid = missingid[-40:]
292 292 file = idmap[missingid]
293 293 receivemissing(pipei, file, versionid)
294 294 progresstick()
295 295
296 296 writerthread.join()
297 297 # End the command
298 298 pipeo.write(b'\n')
299 299 pipeo.flush()
300 300
301 301
302 302 class fileserverclient(object):
303 303 """A client for requesting files from the remote file server.
304 304 """
305 305
306 306 def __init__(self, repo):
307 307 ui = repo.ui
308 308 self.repo = repo
309 309 self.ui = ui
310 310 self.cacheprocess = ui.config(b"remotefilelog", b"cacheprocess")
311 311 if self.cacheprocess:
312 312 self.cacheprocess = util.expandpath(self.cacheprocess)
313 313
314 314 # This option causes remotefilelog to pass the full file path to the
315 315 # cacheprocess instead of a hashed key.
316 316 self.cacheprocesspasspath = ui.configbool(
317 317 b"remotefilelog", b"cacheprocess.includepath"
318 318 )
319 319
320 320 self.debugoutput = ui.configbool(b"remotefilelog", b"debug")
321 321
322 322 self.remotecache = cacheconnection()
323 323
324 324 def setstore(self, datastore, historystore, writedata, writehistory):
325 325 self.datastore = datastore
326 326 self.historystore = historystore
327 327 self.writedata = writedata
328 328 self.writehistory = writehistory
329 329
330 330 def _connect(self):
331 331 return self.repo.connectionpool.get(self.repo.fallbackpath)
332 332
333 333 def request(self, fileids):
334 334 """Takes a list of filename/node pairs and fetches them from the
335 335 server. Files are stored in the local cache.
336 336 A list of nodes that the server couldn't find is returned.
337 337 If the connection fails, an exception is raised.
338 338 """
339 339 if not self.remotecache.connected:
340 340 self.connect()
341 341 cache = self.remotecache
342 342 writedata = self.writedata
343 343
344 344 repo = self.repo
345 345 total = len(fileids)
346 346 request = b"get\n%d\n" % total
347 347 idmap = {}
348 348 reponame = repo.name
349 349 for file, id in fileids:
350 350 fullid = getcachekey(reponame, file, id)
351 351 if self.cacheprocesspasspath:
352 352 request += file + b'\0'
353 353 request += fullid + b"\n"
354 354 idmap[fullid] = file
355 355
356 356 cache.request(request)
357 357
358 358 progress = self.ui.makeprogress(_(b'downloading'), total=total)
359 359 progress.update(0)
360 360
361 361 missed = []
362 362 while True:
363 363 missingid = cache.receiveline()
364 364 if not missingid:
365 365 missedset = set(missed)
366 366 for missingid in idmap:
367 367 if not missingid in missedset:
368 368 missed.append(missingid)
369 369 self.ui.warn(
370 370 _(
371 371 b"warning: cache connection closed early - "
372 372 + b"falling back to server\n"
373 373 )
374 374 )
375 375 break
376 376 if missingid == b"0":
377 377 break
378 378 if missingid.startswith(b"_hits_"):
379 379 # receive progress reports
380 380 parts = missingid.split(b"_")
381 381 progress.increment(int(parts[2]))
382 382 continue
383 383
384 384 missed.append(missingid)
385 385
386 386 global fetchmisses
387 387 fetchmisses += len(missed)
388 388
389 389 fromcache = total - len(missed)
390 390 progress.update(fromcache, total=total)
391 391 self.ui.log(
392 392 b"remotefilelog",
393 393 b"remote cache hit rate is %r of %r\n",
394 394 fromcache,
395 395 total,
396 396 hit=fromcache,
397 397 total=total,
398 398 )
399 399
400 400 oldumask = os.umask(0o002)
401 401 try:
402 402 # receive cache misses from master
403 403 if missed:
404 404 # When verbose is true, sshpeer prints 'running ssh...'
405 405 # to stdout, which can interfere with some command
406 406 # outputs
407 407 verbose = self.ui.verbose
408 408 self.ui.verbose = False
409 409 try:
410 410 with self._connect() as conn:
411 411 remote = conn.peer
412 412 if remote.capable(
413 413 constants.NETWORK_CAP_LEGACY_SSH_GETFILES
414 414 ):
415 415 if not isinstance(remote, _sshv1peer):
416 416 raise error.Abort(
417 417 b'remotefilelog requires ssh servers'
418 418 )
419 419 step = self.ui.configint(
420 420 b'remotefilelog', b'getfilesstep'
421 421 )
422 422 getfilestype = self.ui.config(
423 423 b'remotefilelog', b'getfilestype'
424 424 )
425 425 if getfilestype == b'threaded':
426 426 _getfiles = _getfiles_threaded
427 427 else:
428 428 _getfiles = _getfiles_optimistic
429 429 _getfiles(
430 430 remote,
431 431 self.receivemissing,
432 432 progress.increment,
433 433 missed,
434 434 idmap,
435 435 step,
436 436 )
437 437 elif remote.capable(b"x_rfl_getfile"):
438 438 if remote.capable(b'batch'):
439 439 batchdefault = 100
440 440 else:
441 441 batchdefault = 10
442 442 batchsize = self.ui.configint(
443 443 b'remotefilelog', b'batchsize', batchdefault
444 444 )
445 445 self.ui.debug(
446 446 b'requesting %d files from '
447 447 b'remotefilelog server...\n' % len(missed)
448 448 )
449 449 _getfilesbatch(
450 450 remote,
451 451 self.receivemissing,
452 452 progress.increment,
453 453 missed,
454 454 idmap,
455 455 batchsize,
456 456 )
457 457 else:
458 458 raise error.Abort(
459 459 b"configured remotefilelog server"
460 460 b" does not support remotefilelog"
461 461 )
462 462
463 463 self.ui.log(
464 464 b"remotefilefetchlog",
465 465 b"Success\n",
466 466 fetched_files=progress.pos - fromcache,
467 467 total_to_fetch=total - fromcache,
468 468 )
469 469 except Exception:
470 470 self.ui.log(
471 471 b"remotefilefetchlog",
472 472 b"Fail\n",
473 473 fetched_files=progress.pos - fromcache,
474 474 total_to_fetch=total - fromcache,
475 475 )
476 476 raise
477 477 finally:
478 478 self.ui.verbose = verbose
479 479 # send to memcache
480 480 request = b"set\n%d\n%s\n" % (len(missed), b"\n".join(missed))
481 481 cache.request(request)
482 482
483 483 progress.complete()
484 484
485 485 # mark ourselves as a user of this cache
486 486 writedata.markrepo(self.repo.path)
487 487 finally:
488 488 os.umask(oldumask)
489 489
490 490 def receivemissing(self, pipe, filename, node):
491 491 line = pipe.readline()[:-1]
492 492 if not line:
493 493 raise error.ResponseError(
494 494 _(b"error downloading file contents:"),
495 495 _(b"connection closed early"),
496 496 )
497 497 size = int(line)
498 498 data = pipe.read(size)
499 499 if len(data) != size:
500 500 raise error.ResponseError(
501 501 _(b"error downloading file contents:"),
502 502 _(b"only received %s of %s bytes") % (len(data), size),
503 503 )
504 504
505 505 self.writedata.addremotefilelognode(
506 506 filename, bin(node), zlib.decompress(data)
507 507 )
508 508
509 509 def connect(self):
510 510 if self.cacheprocess:
511 511 cmd = b"%s %s" % (self.cacheprocess, self.writedata._path)
512 512 self.remotecache.connect(cmd)
513 513 else:
514 514 # If no cache process is specified, we fake one that always
515 515 # returns cache misses. This enables tests to run easily
516 516 # and may eventually allow us to be a drop in replacement
517 517 # for the largefiles extension.
518 518 class simplecache(object):
519 519 def __init__(self):
520 520 self.missingids = []
521 521 self.connected = True
522 522
523 523 def close(self):
524 524 pass
525 525
526 526 def request(self, value, flush=True):
527 527 lines = value.split(b"\n")
528 528 if lines[0] != b"get":
529 529 return
530 530 self.missingids = lines[2:-1]
531 531 self.missingids.append(b'0')
532 532
533 533 def receiveline(self):
534 534 if len(self.missingids) > 0:
535 535 return self.missingids.pop(0)
536 536 return None
537 537
538 538 self.remotecache = simplecache()
539 539
540 540 def close(self):
541 541 if fetches:
542 542 msg = (
543 543 b"%d files fetched over %d fetches - "
544 544 + b"(%d misses, %0.2f%% hit ratio) over %0.2fs\n"
545 545 ) % (
546 546 fetched,
547 547 fetches,
548 548 fetchmisses,
549 549 float(fetched - fetchmisses) / float(fetched) * 100.0,
550 550 fetchcost,
551 551 )
552 552 if self.debugoutput:
553 553 self.ui.warn(msg)
554 554 self.ui.log(
555 555 b"remotefilelog.prefetch",
556 556 msg.replace(b"%", b"%%"),
557 557 remotefilelogfetched=fetched,
558 558 remotefilelogfetches=fetches,
559 559 remotefilelogfetchmisses=fetchmisses,
560 560 remotefilelogfetchtime=fetchcost * 1000,
561 561 )
562 562
563 563 if self.remotecache.connected:
564 564 self.remotecache.close()
565 565
566 566 def prefetch(
567 567 self, fileids, force=False, fetchdata=True, fetchhistory=False
568 568 ):
569 569 """downloads the given file versions to the cache
570 570 """
571 571 repo = self.repo
572 572 idstocheck = []
573 573 for file, id in fileids:
574 574 # hack
575 575 # - we don't use .hgtags
576 576 # - workingctx produces ids with length 42,
577 577 # which we skip since they aren't in any cache
578 578 if (
579 579 file == b'.hgtags'
580 580 or len(id) == 42
581 581 or not repo.shallowmatch(file)
582 582 ):
583 583 continue
584 584
585 585 idstocheck.append((file, bin(id)))
586 586
587 587 datastore = self.datastore
588 588 historystore = self.historystore
589 589 if force:
590 590 datastore = contentstore.unioncontentstore(*repo.shareddatastores)
591 591 historystore = metadatastore.unionmetadatastore(
592 592 *repo.sharedhistorystores
593 593 )
594 594
595 595 missingids = set()
596 596 if fetchdata:
597 597 missingids.update(datastore.getmissing(idstocheck))
598 598 if fetchhistory:
599 599 missingids.update(historystore.getmissing(idstocheck))
600 600
601 601 # partition missing nodes into nullid and not-nullid so we can
602 602 # warn about this filtering potentially shadowing bugs.
603 603 nullids = len([None for unused, id in missingids if id == nullid])
604 604 if nullids:
605 605 missingids = [(f, id) for f, id in missingids if id != nullid]
606 606 repo.ui.develwarn(
607 607 (
608 608 b'remotefilelog not fetching %d null revs'
609 609 b' - this is likely hiding bugs' % nullids
610 610 ),
611 611 config=b'remotefilelog-ext',
612 612 )
613 613 if missingids:
614 614 global fetches, fetched, fetchcost
615 615 fetches += 1
616 616
617 617 # We want to be able to detect excess individual file downloads, so
618 618 # let's log that information for debugging.
619 619 if fetches >= 15 and fetches < 18:
620 620 if fetches == 15:
621 621 fetchwarning = self.ui.config(
622 622 b'remotefilelog', b'fetchwarning'
623 623 )
624 624 if fetchwarning:
625 625 self.ui.warn(fetchwarning + b'\n')
626 626 self.logstacktrace()
627 627 missingids = [(file, hex(id)) for file, id in sorted(missingids)]
628 628 fetched += len(missingids)
629 629 start = time.time()
630 630 missingids = self.request(missingids)
631 631 if missingids:
632 632 raise error.Abort(
633 633 _(b"unable to download %d files") % len(missingids)
634 634 )
635 635 fetchcost += time.time() - start
636 636 self._lfsprefetch(fileids)
637 637
638 638 def _lfsprefetch(self, fileids):
639 639 if not _lfsmod or not util.safehasattr(
640 640 self.repo.svfs, b'lfslocalblobstore'
641 641 ):
642 642 return
643 643 if not _lfsmod.wrapper.candownload(self.repo):
644 644 return
645 645 pointers = []
646 646 store = self.repo.svfs.lfslocalblobstore
647 647 for file, id in fileids:
648 648 node = bin(id)
649 649 rlog = self.repo.file(file)
650 650 if rlog.flags(node) & revlog.REVIDX_EXTSTORED:
651 651 text = rlog.rawdata(node)
652 652 p = _lfsmod.pointer.deserialize(text)
653 653 oid = p.oid()
654 654 if not store.has(oid):
655 655 pointers.append(p)
656 656 if len(pointers) > 0:
657 657 self.repo.svfs.lfsremoteblobstore.readbatch(pointers, store)
658 658 assert all(store.has(p.oid()) for p in pointers)
659 659
660 660 def logstacktrace(self):
661 661 import traceback
662 662
663 663 self.ui.log(
664 664 b'remotefilelog',
665 665 b'excess remotefilelog fetching:\n%s\n',
666 b''.join(traceback.format_stack()),
666 b''.join(pycompat.sysbytes(traceback.format_stack())),
667 667 )
General Comments 0
You need to be logged in to leave comments. Login now