##// END OF EJS Templates
fix(auth-info): fixed auth credentials problem for HG Fixes RCCE-33...
super-admin -
r1196:b009ad36 default
parent child Browse files
Show More
@@ -1,1525 +1,1517 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import collections
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import traceback
24 24 import urllib.request
25 25 import urllib.parse
26 26 import urllib.error
27 27 from functools import wraps
28 28
29 29 import more_itertools
30 30 import pygit2
31 31 from pygit2 import Repository as LibGit2Repo
32 32 from pygit2 import index as LibGit2Index
33 33 from dulwich import index, objects
34 34 from dulwich.client import HttpGitClient, LocalGitClient, FetchPackResult
35 35 from dulwich.errors import (
36 36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 37 MissingCommitError, ObjectMissing, HangupException,
38 38 UnexpectedCommandError)
39 39 from dulwich.repo import Repo as DulwichRepo
40 40 from dulwich.server import update_server_info
41 41
42 42 import rhodecode
43 43 from vcsserver import exceptions, settings, subprocessio
44 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes
44 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes, convert_to_str
45 45 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, store_archive_in_cache, BytesEnvelope, BinaryEnvelope
46 46 from vcsserver.hgcompat import (
47 47 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
48 48 from vcsserver.git_lfs.lib import LFSOidStore
49 49 from vcsserver.vcs_base import RemoteBase
50 50
51 51 DIR_STAT = stat.S_IFDIR
52 52 FILE_MODE = stat.S_IFMT
53 53 GIT_LINK = objects.S_IFGITLINK
54 54 PEELED_REF_MARKER = b'^{}'
55 55 HEAD_MARKER = b'HEAD'
56 56
57 57 log = logging.getLogger(__name__)
58 58
59 59
60 60 def reraise_safe_exceptions(func):
61 61 """Converts Dulwich exceptions to something neutral."""
62 62
63 63 @wraps(func)
64 64 def wrapper(*args, **kwargs):
65 65 try:
66 66 return func(*args, **kwargs)
67 67 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
68 68 exc = exceptions.LookupException(org_exc=e)
69 69 raise exc(safe_str(e))
70 70 except (HangupException, UnexpectedCommandError) as e:
71 71 exc = exceptions.VcsException(org_exc=e)
72 72 raise exc(safe_str(e))
73 73 except Exception:
74 74 # NOTE(marcink): because of how dulwich handles some exceptions
75 75 # (KeyError on empty repos), we cannot track this and catch all
76 76 # exceptions, it's an exceptions from other handlers
77 77 #if not hasattr(e, '_vcs_kind'):
78 78 #log.exception("Unhandled exception in git remote call")
79 79 #raise_from_original(exceptions.UnhandledException)
80 80 raise
81 81 return wrapper
82 82
83 83
84 84 class Repo(DulwichRepo):
85 85 """
86 86 A wrapper for dulwich Repo class.
87 87
88 88 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
89 89 "Too many open files" error. We need to close all opened file descriptors
90 90 once the repo object is destroyed.
91 91 """
92 92 def __del__(self):
93 93 if hasattr(self, 'object_store'):
94 94 self.close()
95 95
96 96
97 97 class Repository(LibGit2Repo):
98 98
99 99 def __enter__(self):
100 100 return self
101 101
102 102 def __exit__(self, exc_type, exc_val, exc_tb):
103 103 self.free()
104 104
105 105
106 106 class GitFactory(RepoFactory):
107 107 repo_type = 'git'
108 108
109 109 def _create_repo(self, wire, create, use_libgit2=False):
110 110 if use_libgit2:
111 111 repo = Repository(safe_bytes(wire['path']))
112 112 else:
113 113 # dulwich mode
114 114 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
115 115 repo = Repo(repo_path)
116 116
117 117 log.debug('repository created: got GIT object: %s', repo)
118 118 return repo
119 119
120 120 def repo(self, wire, create=False, use_libgit2=False):
121 121 """
122 122 Get a repository instance for the given path.
123 123 """
124 124 return self._create_repo(wire, create, use_libgit2)
125 125
126 126 def repo_libgit2(self, wire):
127 127 return self.repo(wire, use_libgit2=True)
128 128
129 129
130 130 def create_signature_from_string(author_str, **kwargs):
131 131 """
132 132 Creates a pygit2.Signature object from a string of the format 'Name <email>'.
133 133
134 134 :param author_str: String of the format 'Name <email>'
135 135 :return: pygit2.Signature object
136 136 """
137 137 match = re.match(r'^(.+) <(.+)>$', author_str)
138 138 if match is None:
139 139 raise ValueError(f"Invalid format: {author_str}")
140 140
141 141 name, email = match.groups()
142 142 return pygit2.Signature(name, email, **kwargs)
143 143
144 144
145 145 def get_obfuscated_url(url_obj):
146 146 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
147 147 url_obj.query = obfuscate_qs(url_obj.query)
148 148 obfuscated_uri = str(url_obj)
149 149 return obfuscated_uri
150 150
151 151
152 152 class GitRemote(RemoteBase):
153 153
154 154 def __init__(self, factory):
155 155 self._factory = factory
156 156 self._bulk_methods = {
157 157 "date": self.date,
158 158 "author": self.author,
159 159 "branch": self.branch,
160 160 "message": self.message,
161 161 "parents": self.parents,
162 162 "_commit": self.revision,
163 163 }
164 164 self._bulk_file_methods = {
165 165 "size": self.get_node_size,
166 166 "data": self.get_node_data,
167 167 "flags": self.get_node_flags,
168 168 "is_binary": self.get_node_is_binary,
169 169 "md5": self.md5_hash
170 170 }
171 171
172 172 def _wire_to_config(self, wire):
173 173 if 'config' in wire:
174 174 return {x[0] + '_' + x[1]: x[2] for x in wire['config']}
175 175 return {}
176 176
177 177 def _remote_conf(self, config):
178 178 params = [
179 179 '-c', 'core.askpass=""',
180 180 ]
181 181 config_attrs = {
182 182 'vcs_ssl_dir': 'http.sslCAinfo={}',
183 183 'vcs_git_lfs_store_location': 'lfs.storage={}'
184 184 }
185 185 for key, param in config_attrs.items():
186 186 if value := config.get(key):
187 187 params.extend(['-c', param.format(value)])
188 188 return params
189 189
190 190 @reraise_safe_exceptions
191 191 def discover_git_version(self):
192 192 stdout, _ = self.run_git_command(
193 193 {}, ['--version'], _bare=True, _safe=True)
194 194 prefix = b'git version'
195 195 if stdout.startswith(prefix):
196 196 stdout = stdout[len(prefix):]
197 197 return safe_str(stdout.strip())
198 198
199 199 @reraise_safe_exceptions
200 200 def is_empty(self, wire):
201 201 repo_init = self._factory.repo_libgit2(wire)
202 202 with repo_init as repo:
203 203
204 204 try:
205 205 has_head = repo.head.name
206 206 if has_head:
207 207 return False
208 208
209 209 # NOTE(marcink): check again using more expensive method
210 210 return repo.is_empty
211 211 except Exception:
212 212 pass
213 213
214 214 return True
215 215
216 216 @reraise_safe_exceptions
217 217 def assert_correct_path(self, wire):
218 218 cache_on, context_uid, repo_id = self._cache_on(wire)
219 219 region = self._region(wire)
220 220
221 221 @region.conditional_cache_on_arguments(condition=cache_on)
222 222 def _assert_correct_path(_context_uid, _repo_id, fast_check):
223 223 if fast_check:
224 224 path = safe_str(wire['path'])
225 225 if pygit2.discover_repository(path):
226 226 return True
227 227 return False
228 228 else:
229 229 try:
230 230 repo_init = self._factory.repo_libgit2(wire)
231 231 with repo_init:
232 232 pass
233 233 except pygit2.GitError:
234 234 path = wire.get('path')
235 235 tb = traceback.format_exc()
236 236 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
237 237 return False
238 238 return True
239 239
240 240 return _assert_correct_path(context_uid, repo_id, True)
241 241
242 242 @reraise_safe_exceptions
243 243 def bare(self, wire):
244 244 repo_init = self._factory.repo_libgit2(wire)
245 245 with repo_init as repo:
246 246 return repo.is_bare
247 247
248 248 @reraise_safe_exceptions
249 249 def get_node_data(self, wire, commit_id, path):
250 250 repo_init = self._factory.repo_libgit2(wire)
251 251 with repo_init as repo:
252 252 commit = repo[commit_id]
253 253 blob_obj = commit.tree[path]
254 254
255 255 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
256 256 raise exceptions.LookupException()(
257 257 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
258 258
259 259 return BytesEnvelope(blob_obj.data)
260 260
261 261 @reraise_safe_exceptions
262 262 def get_node_size(self, wire, commit_id, path):
263 263 repo_init = self._factory.repo_libgit2(wire)
264 264 with repo_init as repo:
265 265 commit = repo[commit_id]
266 266 blob_obj = commit.tree[path]
267 267
268 268 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
269 269 raise exceptions.LookupException()(
270 270 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
271 271
272 272 return blob_obj.size
273 273
274 274 @reraise_safe_exceptions
275 275 def get_node_flags(self, wire, commit_id, path):
276 276 repo_init = self._factory.repo_libgit2(wire)
277 277 with repo_init as repo:
278 278 commit = repo[commit_id]
279 279 blob_obj = commit.tree[path]
280 280
281 281 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
282 282 raise exceptions.LookupException()(
283 283 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
284 284
285 285 return blob_obj.filemode
286 286
287 287 @reraise_safe_exceptions
288 288 def get_node_is_binary(self, wire, commit_id, path):
289 289 repo_init = self._factory.repo_libgit2(wire)
290 290 with repo_init as repo:
291 291 commit = repo[commit_id]
292 292 blob_obj = commit.tree[path]
293 293
294 294 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
295 295 raise exceptions.LookupException()(
296 296 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
297 297
298 298 return blob_obj.is_binary
299 299
300 300 @reraise_safe_exceptions
301 301 def blob_as_pretty_string(self, wire, sha):
302 302 repo_init = self._factory.repo_libgit2(wire)
303 303 with repo_init as repo:
304 304 blob_obj = repo[sha]
305 305 return BytesEnvelope(blob_obj.data)
306 306
307 307 @reraise_safe_exceptions
308 308 def blob_raw_length(self, wire, sha):
309 309 cache_on, context_uid, repo_id = self._cache_on(wire)
310 310 region = self._region(wire)
311 311
312 312 @region.conditional_cache_on_arguments(condition=cache_on)
313 313 def _blob_raw_length(_repo_id, _sha):
314 314
315 315 repo_init = self._factory.repo_libgit2(wire)
316 316 with repo_init as repo:
317 317 blob = repo[sha]
318 318 return blob.size
319 319
320 320 return _blob_raw_length(repo_id, sha)
321 321
322 322 def _parse_lfs_pointer(self, raw_content):
323 323 spec_string = b'version https://git-lfs.github.com/spec'
324 324 if raw_content and raw_content.startswith(spec_string):
325 325
326 326 pattern = re.compile(rb"""
327 327 (?:\n)?
328 328 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
329 329 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
330 330 ^size[ ](?P<oid_size>[0-9]+)\n
331 331 (?:\n)?
332 332 """, re.VERBOSE | re.MULTILINE)
333 333 match = pattern.match(raw_content)
334 334 if match:
335 335 return match.groupdict()
336 336
337 337 return {}
338 338
339 339 @reraise_safe_exceptions
340 340 def is_large_file(self, wire, commit_id):
341 341 cache_on, context_uid, repo_id = self._cache_on(wire)
342 342 region = self._region(wire)
343 343
344 344 @region.conditional_cache_on_arguments(condition=cache_on)
345 345 def _is_large_file(_repo_id, _sha):
346 346 repo_init = self._factory.repo_libgit2(wire)
347 347 with repo_init as repo:
348 348 blob = repo[commit_id]
349 349 if blob.is_binary:
350 350 return {}
351 351
352 352 return self._parse_lfs_pointer(blob.data)
353 353
354 354 return _is_large_file(repo_id, commit_id)
355 355
356 356 @reraise_safe_exceptions
357 357 def is_binary(self, wire, tree_id):
358 358 cache_on, context_uid, repo_id = self._cache_on(wire)
359 359 region = self._region(wire)
360 360
361 361 @region.conditional_cache_on_arguments(condition=cache_on)
362 362 def _is_binary(_repo_id, _tree_id):
363 363 repo_init = self._factory.repo_libgit2(wire)
364 364 with repo_init as repo:
365 365 blob_obj = repo[tree_id]
366 366 return blob_obj.is_binary
367 367
368 368 return _is_binary(repo_id, tree_id)
369 369
370 370 @reraise_safe_exceptions
371 371 def md5_hash(self, wire, commit_id, path):
372 372 cache_on, context_uid, repo_id = self._cache_on(wire)
373 373 region = self._region(wire)
374 374
375 375 @region.conditional_cache_on_arguments(condition=cache_on)
376 376 def _md5_hash(_repo_id, _commit_id, _path):
377 377 repo_init = self._factory.repo_libgit2(wire)
378 378 with repo_init as repo:
379 379 commit = repo[_commit_id]
380 380 blob_obj = commit.tree[_path]
381 381
382 382 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
383 383 raise exceptions.LookupException()(
384 384 f'Tree for commit_id:{_commit_id} is not a blob: {blob_obj.type_str}')
385 385
386 386 return ''
387 387
388 388 return _md5_hash(repo_id, commit_id, path)
389 389
390 390 @reraise_safe_exceptions
391 391 def in_largefiles_store(self, wire, oid):
392 392 conf = self._wire_to_config(wire)
393 393 repo_init = self._factory.repo_libgit2(wire)
394 394 with repo_init as repo:
395 395 repo_name = repo.path
396 396
397 397 store_location = conf.get('vcs_git_lfs_store_location')
398 398 if store_location:
399 399
400 400 store = LFSOidStore(
401 401 oid=oid, repo=repo_name, store_location=store_location)
402 402 return store.has_oid()
403 403
404 404 return False
405 405
406 406 @reraise_safe_exceptions
407 407 def store_path(self, wire, oid):
408 408 conf = self._wire_to_config(wire)
409 409 repo_init = self._factory.repo_libgit2(wire)
410 410 with repo_init as repo:
411 411 repo_name = repo.path
412 412
413 413 store_location = conf.get('vcs_git_lfs_store_location')
414 414 if store_location:
415 415 store = LFSOidStore(
416 416 oid=oid, repo=repo_name, store_location=store_location)
417 417 return store.oid_path
418 418 raise ValueError(f'Unable to fetch oid with path {oid}')
419 419
420 420 @reraise_safe_exceptions
421 421 def bulk_request(self, wire, rev, pre_load):
422 422 cache_on, context_uid, repo_id = self._cache_on(wire)
423 423 region = self._region(wire)
424 424
425 425 @region.conditional_cache_on_arguments(condition=cache_on)
426 426 def _bulk_request(_repo_id, _rev, _pre_load):
427 427 result = {}
428 428 for attr in pre_load:
429 429 try:
430 430 method = self._bulk_methods[attr]
431 431 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
432 432 args = [wire, rev]
433 433 result[attr] = method(*args)
434 434 except KeyError as e:
435 435 raise exceptions.VcsException(e)(f"Unknown bulk attribute: {attr}")
436 436 return result
437 437
438 438 return _bulk_request(repo_id, rev, sorted(pre_load))
439 439
440 440 @reraise_safe_exceptions
441 441 def bulk_file_request(self, wire, commit_id, path, pre_load):
442 442 cache_on, context_uid, repo_id = self._cache_on(wire)
443 443 region = self._region(wire)
444 444
445 445 @region.conditional_cache_on_arguments(condition=cache_on)
446 446 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
447 447 result = {}
448 448 for attr in pre_load:
449 449 try:
450 450 method = self._bulk_file_methods[attr]
451 451 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
452 452 result[attr] = method(wire, _commit_id, _path)
453 453 except KeyError as e:
454 454 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
455 455 return result
456 456
457 457 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
458 458
459 459 def _build_opener(self, url: str):
460 460 handlers = []
461 461 url_obj = url_parser(safe_bytes(url))
462 462 authinfo = url_obj.authinfo()[1]
463 463
464 def _convert_to_strings(data):
465 if isinstance(data, bytes):
466 return safe_str(data)
467 elif isinstance(data, tuple):
468 return tuple(_convert_to_strings(item) for item in data)
469 else:
470 return data
471
472 464 if authinfo:
473 465 # create a password manager
474 466 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
475 passmgr.add_password(*_convert_to_strings(authinfo))
467 passmgr.add_password(*convert_to_str(authinfo))
476 468
477 469 handlers.extend((httpbasicauthhandler(passmgr),
478 470 httpdigestauthhandler(passmgr)))
479 471
480 472 return urllib.request.build_opener(*handlers)
481 473
482 474 @reraise_safe_exceptions
483 475 def check_url(self, url, config):
484 476 url_obj = url_parser(safe_bytes(url))
485 477
486 478 test_uri = safe_str(url_obj.authinfo()[0])
487 479 obfuscated_uri = get_obfuscated_url(url_obj)
488 480
489 481 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
490 482
491 483 if not test_uri.endswith('info/refs'):
492 484 test_uri = test_uri.rstrip('/') + '/info/refs'
493 485
494 486 o = self._build_opener(url=url)
495 487 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
496 488
497 489 q = {"service": 'git-upload-pack'}
498 490 qs = f'?{urllib.parse.urlencode(q)}'
499 491 cu = f"{test_uri}{qs}"
500 492
501 493 try:
502 494 req = urllib.request.Request(cu, None, {})
503 495 log.debug("Trying to open URL %s", obfuscated_uri)
504 496 resp = o.open(req)
505 497 if resp.code != 200:
506 498 raise exceptions.URLError()('Return Code is not 200')
507 499 except Exception as e:
508 500 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
509 501 # means it cannot be cloned
510 502 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
511 503
512 504 # now detect if it's proper git repo
513 505 gitdata: bytes = resp.read()
514 506
515 507 if b'service=git-upload-pack' in gitdata:
516 508 pass
517 509 elif re.findall(br'[0-9a-fA-F]{40}\s+refs', gitdata):
518 510 # old style git can return some other format!
519 511 pass
520 512 else:
521 513 e = None
522 514 raise exceptions.URLError(e)(
523 515 f"url [{obfuscated_uri}] does not look like an hg repo org_exc: {e}")
524 516
525 517 return True
526 518
527 519 @reraise_safe_exceptions
528 520 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
529 521 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
530 522 remote_refs = self.pull(wire, url, apply_refs=False)
531 523 repo = self._factory.repo(wire)
532 524 if isinstance(valid_refs, list):
533 525 valid_refs = tuple(valid_refs)
534 526
535 527 for k in remote_refs:
536 528 # only parse heads/tags and skip so called deferred tags
537 529 if k.startswith(valid_refs) and not k.endswith(deferred):
538 530 repo[k] = remote_refs[k]
539 531
540 532 if update_after_clone:
541 533 # we want to checkout HEAD
542 534 repo["HEAD"] = remote_refs["HEAD"]
543 535 index.build_index_from_tree(repo.path, repo.index_path(),
544 536 repo.object_store, repo["HEAD"].tree)
545 537
546 538 @reraise_safe_exceptions
547 539 def branch(self, wire, commit_id):
548 540 cache_on, context_uid, repo_id = self._cache_on(wire)
549 541 region = self._region(wire)
550 542
551 543 @region.conditional_cache_on_arguments(condition=cache_on)
552 544 def _branch(_context_uid, _repo_id, _commit_id):
553 545 regex = re.compile('^refs/heads')
554 546
555 547 def filter_with(ref):
556 548 return regex.match(ref[0]) and ref[1] == _commit_id
557 549
558 550 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
559 551 return [x[0].split('refs/heads/')[-1] for x in branches]
560 552
561 553 return _branch(context_uid, repo_id, commit_id)
562 554
563 555 @reraise_safe_exceptions
564 556 def commit_branches(self, wire, commit_id):
565 557 cache_on, context_uid, repo_id = self._cache_on(wire)
566 558 region = self._region(wire)
567 559
568 560 @region.conditional_cache_on_arguments(condition=cache_on)
569 561 def _commit_branches(_context_uid, _repo_id, _commit_id):
570 562 repo_init = self._factory.repo_libgit2(wire)
571 563 with repo_init as repo:
572 564 branches = [x for x in repo.branches.with_commit(_commit_id)]
573 565 return branches
574 566
575 567 return _commit_branches(context_uid, repo_id, commit_id)
576 568
577 569 @reraise_safe_exceptions
578 570 def add_object(self, wire, content):
579 571 repo_init = self._factory.repo_libgit2(wire)
580 572 with repo_init as repo:
581 573 blob = objects.Blob()
582 574 blob.set_raw_string(content)
583 575 repo.object_store.add_object(blob)
584 576 return blob.id
585 577
586 578 @reraise_safe_exceptions
587 579 def create_commit(self, wire, author, committer, message, branch, new_tree_id,
588 580 date_args: list[int, int] = None,
589 581 parents: list | None = None):
590 582
591 583 repo_init = self._factory.repo_libgit2(wire)
592 584 with repo_init as repo:
593 585
594 586 if date_args:
595 587 current_time, offset = date_args
596 588
597 589 kw = {
598 590 'time': current_time,
599 591 'offset': offset
600 592 }
601 593 author = create_signature_from_string(author, **kw)
602 594 committer = create_signature_from_string(committer, **kw)
603 595
604 596 tree = new_tree_id
605 597 if isinstance(tree, (bytes, str)):
606 598 # validate this tree is in the repo...
607 599 tree = repo[safe_str(tree)].id
608 600
609 601 if parents:
610 602 # run via sha's and validate them in repo
611 603 parents = [repo[c].id for c in parents]
612 604 else:
613 605 parents = []
614 606 # ensure we COMMIT on top of given branch head
615 607 # check if this repo has ANY branches, otherwise it's a new branch case we need to make
616 608 if branch in repo.branches.local:
617 609 parents += [repo.branches[branch].target]
618 610 elif [x for x in repo.branches.local]:
619 611 parents += [repo.head.target]
620 612 #else:
621 613 # in case we want to commit on new branch we create it on top of HEAD
622 614 #repo.branches.local.create(branch, repo.revparse_single('HEAD'))
623 615
624 616 # # Create a new commit
625 617 commit_oid = repo.create_commit(
626 618 f'refs/heads/{branch}', # the name of the reference to update
627 619 author, # the author of the commit
628 620 committer, # the committer of the commit
629 621 message, # the commit message
630 622 tree, # the tree produced by the index
631 623 parents # list of parents for the new commit, usually just one,
632 624 )
633 625
634 626 new_commit_id = safe_str(commit_oid)
635 627
636 628 return new_commit_id
637 629
638 630 @reraise_safe_exceptions
639 631 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
640 632
641 633 def mode2pygit(mode):
642 634 """
643 635 git only supports two filemode 644 and 755
644 636
645 637 0o100755 -> 33261
646 638 0o100644 -> 33188
647 639 """
648 640 return {
649 641 0o100644: pygit2.GIT_FILEMODE_BLOB,
650 642 0o100755: pygit2.GIT_FILEMODE_BLOB_EXECUTABLE,
651 643 0o120000: pygit2.GIT_FILEMODE_LINK
652 644 }.get(mode) or pygit2.GIT_FILEMODE_BLOB
653 645
654 646 repo_init = self._factory.repo_libgit2(wire)
655 647 with repo_init as repo:
656 648 repo_index = repo.index
657 649
658 650 commit_parents = None
659 651 if commit_tree and commit_data['parents']:
660 652 commit_parents = commit_data['parents']
661 653 parent_commit = repo[commit_parents[0]]
662 654 repo_index.read_tree(parent_commit.tree)
663 655
664 656 for pathspec in updated:
665 657 blob_id = repo.create_blob(pathspec['content'])
666 658 ie = pygit2.IndexEntry(pathspec['path'], blob_id, mode2pygit(pathspec['mode']))
667 659 repo_index.add(ie)
668 660
669 661 for pathspec in removed:
670 662 repo_index.remove(pathspec)
671 663
672 664 # Write changes to the index
673 665 repo_index.write()
674 666
675 667 # Create a tree from the updated index
676 668 written_commit_tree = repo_index.write_tree()
677 669
678 670 new_tree_id = written_commit_tree
679 671
680 672 author = commit_data['author']
681 673 committer = commit_data['committer']
682 674 message = commit_data['message']
683 675
684 676 date_args = [int(commit_data['commit_time']), int(commit_data['commit_timezone'])]
685 677
686 678 new_commit_id = self.create_commit(wire, author, committer, message, branch,
687 679 new_tree_id, date_args=date_args, parents=commit_parents)
688 680
689 681 # libgit2, ensure the branch is there and exists
690 682 self.create_branch(wire, branch, new_commit_id)
691 683
692 684 # libgit2, set new ref to this created commit
693 685 self.set_refs(wire, f'refs/heads/{branch}', new_commit_id)
694 686
695 687 return new_commit_id
696 688
697 689 @reraise_safe_exceptions
698 690 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
699 691 if url != 'default' and '://' not in url:
700 692 client = LocalGitClient(url)
701 693 else:
702 694 url_obj = url_parser(safe_bytes(url))
703 695 o = self._build_opener(url)
704 696 url = url_obj.authinfo()[0]
705 697 client = HttpGitClient(base_url=url, opener=o)
706 698 repo = self._factory.repo(wire)
707 699
708 700 determine_wants = repo.object_store.determine_wants_all
709 701
710 702 if refs:
711 703 refs: list[bytes] = [ascii_bytes(x) for x in refs]
712 704
713 705 def determine_wants_requested(_remote_refs):
714 706 determined = []
715 707 for ref_name, ref_hash in _remote_refs.items():
716 708 bytes_ref_name = safe_bytes(ref_name)
717 709
718 710 if bytes_ref_name in refs:
719 711 bytes_ref_hash = safe_bytes(ref_hash)
720 712 determined.append(bytes_ref_hash)
721 713 return determined
722 714
723 715 # swap with our custom requested wants
724 716 determine_wants = determine_wants_requested
725 717
726 718 try:
727 719 remote_refs = client.fetch(
728 720 path=url, target=repo, determine_wants=determine_wants)
729 721
730 722 except NotGitRepository as e:
731 723 log.warning(
732 724 'Trying to fetch from "%s" failed, not a Git repository.', url)
733 725 # Exception can contain unicode which we convert
734 726 raise exceptions.AbortException(e)(repr(e))
735 727
736 728 # mikhail: client.fetch() returns all the remote refs, but fetches only
737 729 # refs filtered by `determine_wants` function. We need to filter result
738 730 # as well
739 731 if refs:
740 732 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
741 733
742 734 if apply_refs:
743 735 # TODO: johbo: Needs proper test coverage with a git repository
744 736 # that contains a tag object, so that we would end up with
745 737 # a peeled ref at this point.
746 738 for k in remote_refs:
747 739 if k.endswith(PEELED_REF_MARKER):
748 740 log.debug("Skipping peeled reference %s", k)
749 741 continue
750 742 repo[k] = remote_refs[k]
751 743
752 744 if refs and not update_after:
753 745 # update to ref
754 746 # mikhail: explicitly set the head to the last ref.
755 747 update_to_ref = refs[-1]
756 748 if isinstance(update_after, str):
757 749 update_to_ref = update_after
758 750
759 751 repo[HEAD_MARKER] = remote_refs[update_to_ref]
760 752
761 753 if update_after:
762 754 # we want to check out HEAD
763 755 repo[HEAD_MARKER] = remote_refs[HEAD_MARKER]
764 756 index.build_index_from_tree(repo.path, repo.index_path(),
765 757 repo.object_store, repo[HEAD_MARKER].tree)
766 758
767 759 if isinstance(remote_refs, FetchPackResult):
768 760 return remote_refs.refs
769 761 return remote_refs
770 762
771 763 @reraise_safe_exceptions
772 764 def sync_fetch(self, wire, url, refs=None, all_refs=False, **kwargs):
773 765 self._factory.repo(wire)
774 766 if refs and not isinstance(refs, (list, tuple)):
775 767 refs = [refs]
776 768
777 769 config = self._wire_to_config(wire)
778 770 # get all remote refs we'll use to fetch later
779 771 cmd = ['ls-remote']
780 772 if not all_refs:
781 773 cmd += ['--heads', '--tags']
782 774 cmd += [url]
783 775 output, __ = self.run_git_command(
784 776 wire, cmd, fail_on_stderr=False,
785 777 _copts=self._remote_conf(config),
786 778 extra_env={'GIT_TERMINAL_PROMPT': '0'})
787 779
788 780 remote_refs = collections.OrderedDict()
789 781 fetch_refs = []
790 782
791 783 for ref_line in output.splitlines():
792 784 sha, ref = ref_line.split(b'\t')
793 785 sha = sha.strip()
794 786 if ref in remote_refs:
795 787 # duplicate, skip
796 788 continue
797 789 if ref.endswith(PEELED_REF_MARKER):
798 790 log.debug("Skipping peeled reference %s", ref)
799 791 continue
800 792 # don't sync HEAD
801 793 if ref in [HEAD_MARKER]:
802 794 continue
803 795
804 796 remote_refs[ref] = sha
805 797
806 798 if refs and sha in refs:
807 799 # we filter fetch using our specified refs
808 800 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
809 801 elif not refs:
810 802 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
811 803 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
812 804
813 805 if fetch_refs:
814 806 for chunk in more_itertools.chunked(fetch_refs, 128):
815 807 fetch_refs_chunks = list(chunk)
816 808 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
817 809 self.run_git_command(
818 810 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
819 811 fail_on_stderr=False,
820 812 _copts=self._remote_conf(config),
821 813 extra_env={'GIT_TERMINAL_PROMPT': '0'})
822 814 if kwargs.get('sync_large_objects'):
823 815 self.run_git_command(
824 816 wire, ['lfs', 'fetch', url, '--all'],
825 817 fail_on_stderr=False,
826 818 _copts=self._remote_conf(config),
827 819 )
828 820
829 821 return remote_refs
830 822
831 823 @reraise_safe_exceptions
832 824 def sync_push(self, wire, url, refs=None, **kwargs):
833 825 if not self.check_url(url, wire):
834 826 return
835 827 config = self._wire_to_config(wire)
836 828 self._factory.repo(wire)
837 829 self.run_git_command(
838 830 wire, ['push', url, '--mirror'], fail_on_stderr=False,
839 831 _copts=self._remote_conf(config),
840 832 extra_env={'GIT_TERMINAL_PROMPT': '0'})
841 833 if kwargs.get('sync_large_objects'):
842 834 self.run_git_command(
843 835 wire, ['lfs', 'push', url, '--all'],
844 836 fail_on_stderr=False,
845 837 _copts=self._remote_conf(config),
846 838 )
847 839
848 840 @reraise_safe_exceptions
849 841 def get_remote_refs(self, wire, url):
850 842 repo = Repo(url)
851 843 return repo.get_refs()
852 844
853 845 @reraise_safe_exceptions
854 846 def get_description(self, wire):
855 847 repo = self._factory.repo(wire)
856 848 return repo.get_description()
857 849
858 850 @reraise_safe_exceptions
859 851 def get_missing_revs(self, wire, rev1, rev2, other_repo_path):
860 852 origin_repo_path = wire['path']
861 853 repo = self._factory.repo(wire)
862 854 # fetch from other_repo_path to our origin repo
863 855 LocalGitClient(thin_packs=False).fetch(other_repo_path, repo)
864 856
865 857 wire_remote = wire.copy()
866 858 wire_remote['path'] = other_repo_path
867 859 repo_remote = self._factory.repo(wire_remote)
868 860
869 861 # fetch from origin_repo_path to our remote repo
870 862 LocalGitClient(thin_packs=False).fetch(origin_repo_path, repo_remote)
871 863
872 864 revs = [
873 865 x.commit.id
874 866 for x in repo_remote.get_walker(include=[safe_bytes(rev2)], exclude=[safe_bytes(rev1)])]
875 867 return revs
876 868
877 869 @reraise_safe_exceptions
878 870 def get_object(self, wire, sha, maybe_unreachable=False):
879 871 cache_on, context_uid, repo_id = self._cache_on(wire)
880 872 region = self._region(wire)
881 873
882 874 @region.conditional_cache_on_arguments(condition=cache_on)
883 875 def _get_object(_context_uid, _repo_id, _sha):
884 876 repo_init = self._factory.repo_libgit2(wire)
885 877 with repo_init as repo:
886 878
887 879 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
888 880 try:
889 881 commit = repo.revparse_single(sha)
890 882 except KeyError:
891 883 # NOTE(marcink): KeyError doesn't give us any meaningful information
892 884 # here, we instead give something more explicit
893 885 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
894 886 raise exceptions.LookupException(e)(missing_commit_err)
895 887 except ValueError as e:
896 888 raise exceptions.LookupException(e)(missing_commit_err)
897 889
898 890 is_tag = False
899 891 if isinstance(commit, pygit2.Tag):
900 892 commit = repo.get(commit.target)
901 893 is_tag = True
902 894
903 895 check_dangling = True
904 896 if is_tag:
905 897 check_dangling = False
906 898
907 899 if check_dangling and maybe_unreachable:
908 900 check_dangling = False
909 901
910 902 # we used a reference and it parsed means we're not having a dangling commit
911 903 if sha != commit.hex:
912 904 check_dangling = False
913 905
914 906 if check_dangling:
915 907 # check for dangling commit
916 908 for branch in repo.branches.with_commit(commit.hex):
917 909 if branch:
918 910 break
919 911 else:
920 912 # NOTE(marcink): Empty error doesn't give us any meaningful information
921 913 # here, we instead give something more explicit
922 914 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
923 915 raise exceptions.LookupException(e)(missing_commit_err)
924 916
925 917 commit_id = commit.hex
926 918 type_str = commit.type_str
927 919
928 920 return {
929 921 'id': commit_id,
930 922 'type': type_str,
931 923 'commit_id': commit_id,
932 924 'idx': 0
933 925 }
934 926
935 927 return _get_object(context_uid, repo_id, sha)
936 928
937 929 @reraise_safe_exceptions
938 930 def get_refs(self, wire):
939 931 cache_on, context_uid, repo_id = self._cache_on(wire)
940 932 region = self._region(wire)
941 933
942 934 @region.conditional_cache_on_arguments(condition=cache_on)
943 935 def _get_refs(_context_uid, _repo_id):
944 936
945 937 repo_init = self._factory.repo_libgit2(wire)
946 938 with repo_init as repo:
947 939 regex = re.compile('^refs/(heads|tags)/')
948 940 return {x.name: x.target.hex for x in
949 941 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
950 942
951 943 return _get_refs(context_uid, repo_id)
952 944
953 945 @reraise_safe_exceptions
954 946 def get_branch_pointers(self, wire):
955 947 cache_on, context_uid, repo_id = self._cache_on(wire)
956 948 region = self._region(wire)
957 949
958 950 @region.conditional_cache_on_arguments(condition=cache_on)
959 951 def _get_branch_pointers(_context_uid, _repo_id):
960 952
961 953 repo_init = self._factory.repo_libgit2(wire)
962 954 regex = re.compile('^refs/heads')
963 955 with repo_init as repo:
964 956 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
965 957 return {x.target.hex: x.shorthand for x in branches}
966 958
967 959 return _get_branch_pointers(context_uid, repo_id)
968 960
969 961 @reraise_safe_exceptions
970 962 def head(self, wire, show_exc=True):
971 963 cache_on, context_uid, repo_id = self._cache_on(wire)
972 964 region = self._region(wire)
973 965
974 966 @region.conditional_cache_on_arguments(condition=cache_on)
975 967 def _head(_context_uid, _repo_id, _show_exc):
976 968 repo_init = self._factory.repo_libgit2(wire)
977 969 with repo_init as repo:
978 970 try:
979 971 return repo.head.peel().hex
980 972 except Exception:
981 973 if show_exc:
982 974 raise
983 975 return _head(context_uid, repo_id, show_exc)
984 976
985 977 @reraise_safe_exceptions
986 978 def init(self, wire):
987 979 repo_path = safe_str(wire['path'])
988 980 os.makedirs(repo_path, mode=0o755)
989 981 pygit2.init_repository(repo_path, bare=False)
990 982
991 983 @reraise_safe_exceptions
992 984 def init_bare(self, wire):
993 985 repo_path = safe_str(wire['path'])
994 986 os.makedirs(repo_path, mode=0o755)
995 987 pygit2.init_repository(repo_path, bare=True)
996 988
997 989 @reraise_safe_exceptions
998 990 def revision(self, wire, rev):
999 991
1000 992 cache_on, context_uid, repo_id = self._cache_on(wire)
1001 993 region = self._region(wire)
1002 994
1003 995 @region.conditional_cache_on_arguments(condition=cache_on)
1004 996 def _revision(_context_uid, _repo_id, _rev):
1005 997 repo_init = self._factory.repo_libgit2(wire)
1006 998 with repo_init as repo:
1007 999 commit = repo[rev]
1008 1000 obj_data = {
1009 1001 'id': commit.id.hex,
1010 1002 }
1011 1003 # tree objects itself don't have tree_id attribute
1012 1004 if hasattr(commit, 'tree_id'):
1013 1005 obj_data['tree'] = commit.tree_id.hex
1014 1006
1015 1007 return obj_data
1016 1008 return _revision(context_uid, repo_id, rev)
1017 1009
1018 1010 @reraise_safe_exceptions
1019 1011 def date(self, wire, commit_id):
1020 1012 cache_on, context_uid, repo_id = self._cache_on(wire)
1021 1013 region = self._region(wire)
1022 1014
1023 1015 @region.conditional_cache_on_arguments(condition=cache_on)
1024 1016 def _date(_repo_id, _commit_id):
1025 1017 repo_init = self._factory.repo_libgit2(wire)
1026 1018 with repo_init as repo:
1027 1019 commit = repo[commit_id]
1028 1020
1029 1021 if hasattr(commit, 'commit_time'):
1030 1022 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
1031 1023 else:
1032 1024 commit = commit.get_object()
1033 1025 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
1034 1026
1035 1027 # TODO(marcink): check dulwich difference of offset vs timezone
1036 1028 return [commit_time, commit_time_offset]
1037 1029 return _date(repo_id, commit_id)
1038 1030
1039 1031 @reraise_safe_exceptions
1040 1032 def author(self, wire, commit_id):
1041 1033 cache_on, context_uid, repo_id = self._cache_on(wire)
1042 1034 region = self._region(wire)
1043 1035
1044 1036 @region.conditional_cache_on_arguments(condition=cache_on)
1045 1037 def _author(_repo_id, _commit_id):
1046 1038 repo_init = self._factory.repo_libgit2(wire)
1047 1039 with repo_init as repo:
1048 1040 commit = repo[commit_id]
1049 1041
1050 1042 if hasattr(commit, 'author'):
1051 1043 author = commit.author
1052 1044 else:
1053 1045 author = commit.get_object().author
1054 1046
1055 1047 if author.email:
1056 1048 return f"{author.name} <{author.email}>"
1057 1049
1058 1050 try:
1059 1051 return f"{author.name}"
1060 1052 except Exception:
1061 1053 return f"{safe_str(author.raw_name)}"
1062 1054
1063 1055 return _author(repo_id, commit_id)
1064 1056
1065 1057 @reraise_safe_exceptions
1066 1058 def message(self, wire, commit_id):
1067 1059 cache_on, context_uid, repo_id = self._cache_on(wire)
1068 1060 region = self._region(wire)
1069 1061
1070 1062 @region.conditional_cache_on_arguments(condition=cache_on)
1071 1063 def _message(_repo_id, _commit_id):
1072 1064 repo_init = self._factory.repo_libgit2(wire)
1073 1065 with repo_init as repo:
1074 1066 commit = repo[commit_id]
1075 1067 return commit.message
1076 1068 return _message(repo_id, commit_id)
1077 1069
1078 1070 @reraise_safe_exceptions
1079 1071 def parents(self, wire, commit_id):
1080 1072 cache_on, context_uid, repo_id = self._cache_on(wire)
1081 1073 region = self._region(wire)
1082 1074
1083 1075 @region.conditional_cache_on_arguments(condition=cache_on)
1084 1076 def _parents(_repo_id, _commit_id):
1085 1077 repo_init = self._factory.repo_libgit2(wire)
1086 1078 with repo_init as repo:
1087 1079 commit = repo[commit_id]
1088 1080 if hasattr(commit, 'parent_ids'):
1089 1081 parent_ids = commit.parent_ids
1090 1082 else:
1091 1083 parent_ids = commit.get_object().parent_ids
1092 1084
1093 1085 return [x.hex for x in parent_ids]
1094 1086 return _parents(repo_id, commit_id)
1095 1087
1096 1088 @reraise_safe_exceptions
1097 1089 def children(self, wire, commit_id):
1098 1090 cache_on, context_uid, repo_id = self._cache_on(wire)
1099 1091 region = self._region(wire)
1100 1092
1101 1093 head = self.head(wire)
1102 1094
1103 1095 @region.conditional_cache_on_arguments(condition=cache_on)
1104 1096 def _children(_repo_id, _commit_id):
1105 1097
1106 1098 output, __ = self.run_git_command(
1107 1099 wire, ['rev-list', '--all', '--children', f'{commit_id}^..{head}'])
1108 1100
1109 1101 child_ids = []
1110 1102 pat = re.compile(fr'^{commit_id}')
1111 1103 for line in output.splitlines():
1112 1104 line = safe_str(line)
1113 1105 if pat.match(line):
1114 1106 found_ids = line.split(' ')[1:]
1115 1107 child_ids.extend(found_ids)
1116 1108 break
1117 1109
1118 1110 return child_ids
1119 1111 return _children(repo_id, commit_id)
1120 1112
1121 1113 @reraise_safe_exceptions
1122 1114 def set_refs(self, wire, key, value):
1123 1115 repo_init = self._factory.repo_libgit2(wire)
1124 1116 with repo_init as repo:
1125 1117 repo.references.create(key, value, force=True)
1126 1118
1127 1119 @reraise_safe_exceptions
1128 1120 def update_refs(self, wire, key, value):
1129 1121 repo_init = self._factory.repo_libgit2(wire)
1130 1122 with repo_init as repo:
1131 1123 if key not in repo.references:
1132 1124 raise ValueError(f'Reference {key} not found in the repository')
1133 1125 repo.references.create(key, value, force=True)
1134 1126
1135 1127 @reraise_safe_exceptions
1136 1128 def create_branch(self, wire, branch_name, commit_id, force=False):
1137 1129 repo_init = self._factory.repo_libgit2(wire)
1138 1130 with repo_init as repo:
1139 1131 if commit_id:
1140 1132 commit = repo[commit_id]
1141 1133 else:
1142 1134 # if commit is not given just use the HEAD
1143 1135 commit = repo.head()
1144 1136
1145 1137 if force:
1146 1138 repo.branches.local.create(branch_name, commit, force=force)
1147 1139 elif not repo.branches.get(branch_name):
1148 1140 # create only if that branch isn't existing
1149 1141 repo.branches.local.create(branch_name, commit, force=force)
1150 1142
1151 1143 @reraise_safe_exceptions
1152 1144 def remove_ref(self, wire, key):
1153 1145 repo_init = self._factory.repo_libgit2(wire)
1154 1146 with repo_init as repo:
1155 1147 repo.references.delete(key)
1156 1148
1157 1149 @reraise_safe_exceptions
1158 1150 def tag_remove(self, wire, tag_name):
1159 1151 repo_init = self._factory.repo_libgit2(wire)
1160 1152 with repo_init as repo:
1161 1153 key = f'refs/tags/{tag_name}'
1162 1154 repo.references.delete(key)
1163 1155
1164 1156 @reraise_safe_exceptions
1165 1157 def tree_changes(self, wire, source_id, target_id):
1166 1158 repo = self._factory.repo(wire)
1167 1159 # source can be empty
1168 1160 source_id = safe_bytes(source_id if source_id else b'')
1169 1161 target_id = safe_bytes(target_id)
1170 1162
1171 1163 source = repo[source_id].tree if source_id else None
1172 1164 target = repo[target_id].tree
1173 1165 result = repo.object_store.tree_changes(source, target)
1174 1166
1175 1167 added = set()
1176 1168 modified = set()
1177 1169 deleted = set()
1178 1170 for (old_path, new_path), (_, _), (_, _) in list(result):
1179 1171 if new_path and old_path:
1180 1172 modified.add(new_path)
1181 1173 elif new_path and not old_path:
1182 1174 added.add(new_path)
1183 1175 elif not new_path and old_path:
1184 1176 deleted.add(old_path)
1185 1177
1186 1178 return list(added), list(modified), list(deleted)
1187 1179
1188 1180 @reraise_safe_exceptions
1189 1181 def tree_and_type_for_path(self, wire, commit_id, path):
1190 1182
1191 1183 cache_on, context_uid, repo_id = self._cache_on(wire)
1192 1184 region = self._region(wire)
1193 1185
1194 1186 @region.conditional_cache_on_arguments(condition=cache_on)
1195 1187 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1196 1188 repo_init = self._factory.repo_libgit2(wire)
1197 1189
1198 1190 with repo_init as repo:
1199 1191 commit = repo[commit_id]
1200 1192 try:
1201 1193 tree = commit.tree[path]
1202 1194 except KeyError:
1203 1195 return None, None, None
1204 1196
1205 1197 return tree.id.hex, tree.type_str, tree.filemode
1206 1198 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1207 1199
1208 1200 @reraise_safe_exceptions
1209 1201 def tree_items(self, wire, tree_id):
1210 1202 cache_on, context_uid, repo_id = self._cache_on(wire)
1211 1203 region = self._region(wire)
1212 1204
1213 1205 @region.conditional_cache_on_arguments(condition=cache_on)
1214 1206 def _tree_items(_repo_id, _tree_id):
1215 1207
1216 1208 repo_init = self._factory.repo_libgit2(wire)
1217 1209 with repo_init as repo:
1218 1210 try:
1219 1211 tree = repo[tree_id]
1220 1212 except KeyError:
1221 1213 raise ObjectMissing(f'No tree with id: {tree_id}')
1222 1214
1223 1215 result = []
1224 1216 for item in tree:
1225 1217 item_sha = item.hex
1226 1218 item_mode = item.filemode
1227 1219 item_type = item.type_str
1228 1220
1229 1221 if item_type == 'commit':
1230 1222 # NOTE(marcink): submodules we translate to 'link' for backward compat
1231 1223 item_type = 'link'
1232 1224
1233 1225 result.append((item.name, item_mode, item_sha, item_type))
1234 1226 return result
1235 1227 return _tree_items(repo_id, tree_id)
1236 1228
1237 1229 @reraise_safe_exceptions
1238 1230 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1239 1231 """
1240 1232 Old version that uses subprocess to call diff
1241 1233 """
1242 1234
1243 1235 flags = [
1244 1236 f'-U{context}', '--patch',
1245 1237 '--binary',
1246 1238 '--find-renames',
1247 1239 '--no-indent-heuristic',
1248 1240 # '--indent-heuristic',
1249 1241 #'--full-index',
1250 1242 #'--abbrev=40'
1251 1243 ]
1252 1244
1253 1245 if opt_ignorews:
1254 1246 flags.append('--ignore-all-space')
1255 1247
1256 1248 if commit_id_1 == self.EMPTY_COMMIT:
1257 1249 cmd = ['show'] + flags + [commit_id_2]
1258 1250 else:
1259 1251 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1260 1252
1261 1253 if file_filter:
1262 1254 cmd.extend(['--', file_filter])
1263 1255
1264 1256 diff, __ = self.run_git_command(wire, cmd)
1265 1257 # If we used 'show' command, strip first few lines (until actual diff
1266 1258 # starts)
1267 1259 if commit_id_1 == self.EMPTY_COMMIT:
1268 1260 lines = diff.splitlines()
1269 1261 x = 0
1270 1262 for line in lines:
1271 1263 if line.startswith(b'diff'):
1272 1264 break
1273 1265 x += 1
1274 1266 # Append new line just like 'diff' command do
1275 1267 diff = '\n'.join(lines[x:]) + '\n'
1276 1268 return diff
1277 1269
1278 1270 @reraise_safe_exceptions
1279 1271 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1280 1272 repo_init = self._factory.repo_libgit2(wire)
1281 1273
1282 1274 with repo_init as repo:
1283 1275 swap = True
1284 1276 flags = 0
1285 1277 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1286 1278
1287 1279 if opt_ignorews:
1288 1280 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1289 1281
1290 1282 if commit_id_1 == self.EMPTY_COMMIT:
1291 1283 comm1 = repo[commit_id_2]
1292 1284 diff_obj = comm1.tree.diff_to_tree(
1293 1285 flags=flags, context_lines=context, swap=swap)
1294 1286
1295 1287 else:
1296 1288 comm1 = repo[commit_id_2]
1297 1289 comm2 = repo[commit_id_1]
1298 1290 diff_obj = comm1.tree.diff_to_tree(
1299 1291 comm2.tree, flags=flags, context_lines=context, swap=swap)
1300 1292 similar_flags = 0
1301 1293 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1302 1294 diff_obj.find_similar(flags=similar_flags)
1303 1295
1304 1296 if file_filter:
1305 1297 for p in diff_obj:
1306 1298 if p.delta.old_file.path == file_filter:
1307 1299 return BytesEnvelope(p.data) or BytesEnvelope(b'')
1308 1300 # fo matching path == no diff
1309 1301 return BytesEnvelope(b'')
1310 1302
1311 1303 return BytesEnvelope(safe_bytes(diff_obj.patch)) or BytesEnvelope(b'')
1312 1304
1313 1305 @reraise_safe_exceptions
1314 1306 def node_history(self, wire, commit_id, path, limit):
1315 1307 cache_on, context_uid, repo_id = self._cache_on(wire)
1316 1308 region = self._region(wire)
1317 1309
1318 1310 @region.conditional_cache_on_arguments(condition=cache_on)
1319 1311 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1320 1312 # optimize for n==1, rev-list is much faster for that use-case
1321 1313 if limit == 1:
1322 1314 cmd = ['rev-list', '-1', commit_id, '--', path]
1323 1315 else:
1324 1316 cmd = ['log']
1325 1317 if limit:
1326 1318 cmd.extend(['-n', str(safe_int(limit, 0))])
1327 1319 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1328 1320
1329 1321 output, __ = self.run_git_command(wire, cmd)
1330 1322 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1331 1323
1332 1324 return [x for x in commit_ids]
1333 1325 return _node_history(context_uid, repo_id, commit_id, path, limit)
1334 1326
1335 1327 @reraise_safe_exceptions
1336 1328 def node_annotate_legacy(self, wire, commit_id, path):
1337 1329 # note: replaced by pygit2 implementation
1338 1330 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1339 1331 # -l ==> outputs long shas (and we need all 40 characters)
1340 1332 # --root ==> doesn't put '^' character for boundaries
1341 1333 # -r commit_id ==> blames for the given commit
1342 1334 output, __ = self.run_git_command(wire, cmd)
1343 1335
1344 1336 result = []
1345 1337 for i, blame_line in enumerate(output.splitlines()[:-1]):
1346 1338 line_no = i + 1
1347 1339 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1348 1340 result.append((line_no, blame_commit_id, line))
1349 1341
1350 1342 return result
1351 1343
1352 1344 @reraise_safe_exceptions
1353 1345 def node_annotate(self, wire, commit_id, path):
1354 1346
1355 1347 result_libgit = []
1356 1348 repo_init = self._factory.repo_libgit2(wire)
1357 1349 with repo_init as repo:
1358 1350 commit = repo[commit_id]
1359 1351 blame_obj = repo.blame(path, newest_commit=commit_id)
1360 1352 for i, line in enumerate(commit.tree[path].data.splitlines()):
1361 1353 line_no = i + 1
1362 1354 hunk = blame_obj.for_line(line_no)
1363 1355 blame_commit_id = hunk.final_commit_id.hex
1364 1356
1365 1357 result_libgit.append((line_no, blame_commit_id, line))
1366 1358
1367 1359 return BinaryEnvelope(result_libgit)
1368 1360
1369 1361 @reraise_safe_exceptions
1370 1362 def update_server_info(self, wire):
1371 1363 repo = self._factory.repo(wire)
1372 1364 update_server_info(repo)
1373 1365
1374 1366 @reraise_safe_exceptions
1375 1367 def get_all_commit_ids(self, wire):
1376 1368
1377 1369 cache_on, context_uid, repo_id = self._cache_on(wire)
1378 1370 region = self._region(wire)
1379 1371
1380 1372 @region.conditional_cache_on_arguments(condition=cache_on)
1381 1373 def _get_all_commit_ids(_context_uid, _repo_id):
1382 1374
1383 1375 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1384 1376 try:
1385 1377 output, __ = self.run_git_command(wire, cmd)
1386 1378 return output.splitlines()
1387 1379 except Exception:
1388 1380 # Can be raised for empty repositories
1389 1381 return []
1390 1382
1391 1383 @region.conditional_cache_on_arguments(condition=cache_on)
1392 1384 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1393 1385 repo_init = self._factory.repo_libgit2(wire)
1394 1386 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1395 1387 results = []
1396 1388 with repo_init as repo:
1397 1389 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1398 1390 results.append(commit.id.hex)
1399 1391
1400 1392 return _get_all_commit_ids(context_uid, repo_id)
1401 1393
1402 1394 @reraise_safe_exceptions
1403 1395 def run_git_command(self, wire, cmd, **opts):
1404 1396 path = wire.get('path', None)
1405 1397 debug_mode = rhodecode.ConfigGet().get_bool('debug')
1406 1398
1407 1399 if path and os.path.isdir(path):
1408 1400 opts['cwd'] = path
1409 1401
1410 1402 if '_bare' in opts:
1411 1403 _copts = []
1412 1404 del opts['_bare']
1413 1405 else:
1414 1406 _copts = ['-c', 'core.quotepath=false', '-c', 'advice.diverging=false']
1415 1407 safe_call = False
1416 1408 if '_safe' in opts:
1417 1409 # no exc on failure
1418 1410 del opts['_safe']
1419 1411 safe_call = True
1420 1412
1421 1413 if '_copts' in opts:
1422 1414 _copts.extend(opts['_copts'] or [])
1423 1415 del opts['_copts']
1424 1416
1425 1417 gitenv = os.environ.copy()
1426 1418 gitenv.update(opts.pop('extra_env', {}))
1427 1419 # need to clean fix GIT_DIR !
1428 1420 if 'GIT_DIR' in gitenv:
1429 1421 del gitenv['GIT_DIR']
1430 1422 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1431 1423 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1432 1424
1433 1425 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1434 1426 _opts = {'env': gitenv, 'shell': False}
1435 1427
1436 1428 proc = None
1437 1429 try:
1438 1430 _opts.update(opts)
1439 1431 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1440 1432
1441 1433 return b''.join(proc), b''.join(proc.stderr)
1442 1434 except OSError as err:
1443 1435 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
1444 1436 call_opts = {}
1445 1437 if debug_mode:
1446 1438 call_opts = _opts
1447 1439
1448 1440 tb_err = ("Couldn't run git command ({}).\n"
1449 1441 "Original error was:{}\n"
1450 1442 "Call options:{}\n"
1451 1443 .format(cmd, err, call_opts))
1452 1444 log.exception(tb_err)
1453 1445 if safe_call:
1454 1446 return '', err
1455 1447 else:
1456 1448 raise exceptions.VcsException()(tb_err)
1457 1449 finally:
1458 1450 if proc:
1459 1451 proc.close()
1460 1452
1461 1453 @reraise_safe_exceptions
1462 1454 def install_hooks(self, wire, force=False):
1463 1455 from vcsserver.hook_utils import install_git_hooks
1464 1456 bare = self.bare(wire)
1465 1457 path = wire['path']
1466 1458 binary_dir = settings.BINARY_DIR
1467 1459 if binary_dir:
1468 1460 os.path.join(binary_dir, 'python3')
1469 1461 return install_git_hooks(path, bare, force_create=force)
1470 1462
1471 1463 @reraise_safe_exceptions
1472 1464 def get_hooks_info(self, wire):
1473 1465 from vcsserver.hook_utils import (
1474 1466 get_git_pre_hook_version, get_git_post_hook_version)
1475 1467 bare = self.bare(wire)
1476 1468 path = wire['path']
1477 1469 return {
1478 1470 'pre_version': get_git_pre_hook_version(path, bare),
1479 1471 'post_version': get_git_post_hook_version(path, bare),
1480 1472 }
1481 1473
1482 1474 @reraise_safe_exceptions
1483 1475 def set_head_ref(self, wire, head_name):
1484 1476 log.debug('Setting refs/head to `%s`', head_name)
1485 1477 repo_init = self._factory.repo_libgit2(wire)
1486 1478 with repo_init as repo:
1487 1479 repo.set_head(f'refs/heads/{head_name}')
1488 1480
1489 1481 return [head_name] + [f'set HEAD to refs/heads/{head_name}']
1490 1482
1491 1483 @reraise_safe_exceptions
1492 1484 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1493 1485 archive_dir_name, commit_id, cache_config):
1494 1486
1495 1487 def file_walker(_commit_id, path):
1496 1488 repo_init = self._factory.repo_libgit2(wire)
1497 1489
1498 1490 with repo_init as repo:
1499 1491 commit = repo[commit_id]
1500 1492
1501 1493 if path in ['', '/']:
1502 1494 tree = commit.tree
1503 1495 else:
1504 1496 tree = commit.tree[path.rstrip('/')]
1505 1497 tree_id = tree.id.hex
1506 1498 try:
1507 1499 tree = repo[tree_id]
1508 1500 except KeyError:
1509 1501 raise ObjectMissing(f'No tree with id: {tree_id}')
1510 1502
1511 1503 index = LibGit2Index.Index()
1512 1504 index.read_tree(tree)
1513 1505 file_iter = index
1514 1506
1515 1507 for file_node in file_iter:
1516 1508 file_path = file_node.path
1517 1509 mode = file_node.mode
1518 1510 is_link = stat.S_ISLNK(mode)
1519 1511 if mode == pygit2.GIT_FILEMODE_COMMIT:
1520 1512 log.debug('Skipping path %s as a commit node', file_path)
1521 1513 continue
1522 1514 yield ArchiveNode(file_path, mode, is_link, repo[file_node.hex].read_raw)
1523 1515
1524 1516 return store_archive_in_cache(
1525 1517 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
@@ -1,1213 +1,1213 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17 18 import binascii
18 19 import io
19 20 import logging
20 21 import stat
21 22 import sys
22 23 import urllib.request
23 24 import urllib.parse
24 25 import hashlib
25 26
26 from hgext import largefiles, rebase, purge
27 from hgext import largefiles, rebase
27 28
28 29 from mercurial import commands
29 30 from mercurial import unionrepo
30 31 from mercurial import verify
31 32 from mercurial import repair
32 33 from mercurial.error import AmbiguousPrefixLookupError
33 34
34 35 import vcsserver
35 36 from vcsserver import exceptions
36 37 from vcsserver.base import (
37 38 RepoFactory,
38 39 obfuscate_qs,
39 40 raise_from_original,
40 41 store_archive_in_cache,
41 42 ArchiveNode,
42 43 BytesEnvelope,
43 44 BinaryEnvelope,
44 45 )
45 46 from vcsserver.hgcompat import (
46 47 archival,
47 48 bin,
48 49 clone,
49 50 config as hgconfig,
50 51 diffopts,
51 52 hex,
52 53 get_ctx,
53 54 hg_url as url_parser,
54 55 httpbasicauthhandler,
55 56 httpdigestauthhandler,
56 57 makepeer,
57 58 instance,
58 59 match,
59 60 memctx,
60 61 exchange,
61 62 memfilectx,
62 63 nullrev,
63 64 hg_merge,
64 65 patch,
65 66 peer,
66 67 revrange,
67 68 ui,
68 69 hg_tag,
69 70 Abort,
70 71 LookupError,
71 72 RepoError,
72 73 RepoLookupError,
73 74 InterventionRequired,
74 75 RequirementError,
75 76 alwaysmatcher,
76 77 patternmatcher,
77 hgutil,
78 78 hgext_strip,
79 79 )
80 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes
80 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes, convert_to_str
81 81 from vcsserver.vcs_base import RemoteBase
82 82 from vcsserver.config import hooks as hooks_config
83 83 from vcsserver.lib.exc_tracking import format_exc
84 84
85 85 log = logging.getLogger(__name__)
86 86
87 87
88 88 def make_ui_from_config(repo_config):
89 89
90 90 class LoggingUI(ui.ui):
91 91
92 92 def status(self, *msg, **opts):
93 93 str_msg = map(safe_str, msg)
94 94 log.info(' '.join(str_msg).rstrip('\n'))
95 95 #super(LoggingUI, self).status(*msg, **opts)
96 96
97 97 def warn(self, *msg, **opts):
98 98 str_msg = map(safe_str, msg)
99 99 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
100 100 #super(LoggingUI, self).warn(*msg, **opts)
101 101
102 102 def error(self, *msg, **opts):
103 103 str_msg = map(safe_str, msg)
104 104 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
105 105 #super(LoggingUI, self).error(*msg, **opts)
106 106
107 107 def note(self, *msg, **opts):
108 108 str_msg = map(safe_str, msg)
109 109 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
110 110 #super(LoggingUI, self).note(*msg, **opts)
111 111
112 112 def debug(self, *msg, **opts):
113 113 str_msg = map(safe_str, msg)
114 114 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
115 115 #super(LoggingUI, self).debug(*msg, **opts)
116 116
117 117 baseui = LoggingUI()
118 118
119 119 # clean the baseui object
120 120 baseui._ocfg = hgconfig.config()
121 121 baseui._ucfg = hgconfig.config()
122 122 baseui._tcfg = hgconfig.config()
123 123
124 124 for section, option, value in repo_config:
125 125 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
126 126
127 127 # make our hgweb quiet so it doesn't print output
128 128 baseui.setconfig(b'ui', b'quiet', b'true')
129 129
130 130 baseui.setconfig(b'ui', b'paginate', b'never')
131 131 # for better Error reporting of Mercurial
132 132 baseui.setconfig(b'ui', b'message-output', b'stderr')
133 133
134 134 # force mercurial to only use 1 thread, otherwise it may try to set a
135 135 # signal in a non-main thread, thus generating a ValueError.
136 136 baseui.setconfig(b'worker', b'numcpus', 1)
137 137
138 138 # If there is no config for the largefiles extension, we explicitly disable
139 139 # it here. This overrides settings from repositories hgrc file. Recent
140 140 # mercurial versions enable largefiles in hgrc on clone from largefile
141 141 # repo.
142 142 if not baseui.hasconfig(b'extensions', b'largefiles'):
143 143 log.debug('Explicitly disable largefiles extension for repo.')
144 144 baseui.setconfig(b'extensions', b'largefiles', b'!')
145 145
146 146 return baseui
147 147
148 148
149 149 def reraise_safe_exceptions(func):
150 150 """Decorator for converting mercurial exceptions to something neutral."""
151 151
152 152 def wrapper(*args, **kwargs):
153 153 try:
154 154 return func(*args, **kwargs)
155 155 except (Abort, InterventionRequired) as e:
156 156 raise_from_original(exceptions.AbortException(e), e)
157 157 except RepoLookupError as e:
158 158 raise_from_original(exceptions.LookupException(e), e)
159 159 except RequirementError as e:
160 160 raise_from_original(exceptions.RequirementException(e), e)
161 161 except RepoError as e:
162 162 raise_from_original(exceptions.VcsException(e), e)
163 163 except LookupError as e:
164 164 raise_from_original(exceptions.LookupException(e), e)
165 165 except Exception as e:
166 166 if not hasattr(e, '_vcs_kind'):
167 167 log.exception("Unhandled exception in hg remote call")
168 168 raise_from_original(exceptions.UnhandledException(e), e)
169 169
170 170 raise
171 171 return wrapper
172 172
173 173
174 174 class MercurialFactory(RepoFactory):
175 175 repo_type = 'hg'
176 176
177 177 def _create_config(self, config, hooks=True):
178 178 if not hooks:
179 179
180 180 hooks_to_clean = {
181 181
182 182 hooks_config.HOOK_REPO_SIZE,
183 183 hooks_config.HOOK_PRE_PULL,
184 184 hooks_config.HOOK_PULL,
185 185
186 186 hooks_config.HOOK_PRE_PUSH,
187 187 # TODO: what about PRETXT, this was disabled in pre 5.0.0
188 188 hooks_config.HOOK_PRETX_PUSH,
189 189
190 190 }
191 191 new_config = []
192 192 for section, option, value in config:
193 193 if section == 'hooks' and option in hooks_to_clean:
194 194 continue
195 195 new_config.append((section, option, value))
196 196 config = new_config
197 197
198 198 baseui = make_ui_from_config(config)
199 199 return baseui
200 200
201 201 def _create_repo(self, wire, create):
202 202 baseui = self._create_config(wire["config"])
203 203 repo = instance(baseui, safe_bytes(wire["path"]), create)
204 204 log.debug('repository created: got HG object: %s', repo)
205 205 return repo
206 206
207 207 def repo(self, wire, create=False):
208 208 """
209 209 Get a repository instance for the given path.
210 210 """
211 211 return self._create_repo(wire, create)
212 212
213 213
214 214 def patch_ui_message_output(baseui):
215 215 baseui.setconfig(b'ui', b'quiet', b'false')
216 216 output = io.BytesIO()
217 217
218 218 def write(data, **unused_kwargs):
219 219 output.write(data)
220 220
221 221 baseui.status = write
222 222 baseui.write = write
223 223 baseui.warn = write
224 224 baseui.debug = write
225 225
226 226 return baseui, output
227 227
228 228
229 229 def get_obfuscated_url(url_obj):
230 230 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
231 231 url_obj.query = obfuscate_qs(url_obj.query)
232 232 obfuscated_uri = str(url_obj)
233 233 return obfuscated_uri
234 234
235 235
236 236 def normalize_url_for_hg(url: str):
237 237 _proto = None
238 238
239 239 if '+' in url[:url.find('://')]:
240 240 _proto = url[0:url.find('+')]
241 241 url = url[url.find('+') + 1:]
242 242 return url, _proto
243 243
244 244
245 245 class HgRemote(RemoteBase):
246 246
247 247 def __init__(self, factory):
248 248 self._factory = factory
249 249 self._bulk_methods = {
250 250 "affected_files": self.ctx_files,
251 251 "author": self.ctx_user,
252 252 "branch": self.ctx_branch,
253 253 "children": self.ctx_children,
254 254 "date": self.ctx_date,
255 255 "message": self.ctx_description,
256 256 "parents": self.ctx_parents,
257 257 "status": self.ctx_status,
258 258 "obsolete": self.ctx_obsolete,
259 259 "phase": self.ctx_phase,
260 260 "hidden": self.ctx_hidden,
261 261 "_file_paths": self.ctx_list,
262 262 }
263 263 self._bulk_file_methods = {
264 264 "size": self.fctx_size,
265 265 "data": self.fctx_node_data,
266 266 "flags": self.fctx_flags,
267 267 "is_binary": self.is_binary,
268 268 "md5": self.md5_hash,
269 269 }
270 270
271 271 def _get_ctx(self, repo, ref):
272 272 return get_ctx(repo, ref)
273 273
274 274 @reraise_safe_exceptions
275 275 def discover_hg_version(self):
276 276 from mercurial import util
277 277 return safe_str(util.version())
278 278
279 279 @reraise_safe_exceptions
280 280 def is_empty(self, wire):
281 281 repo = self._factory.repo(wire)
282 282
283 283 try:
284 284 return len(repo) == 0
285 285 except Exception:
286 286 log.exception("failed to read object_store")
287 287 return False
288 288
289 289 @reraise_safe_exceptions
290 290 def bookmarks(self, wire):
291 291 cache_on, context_uid, repo_id = self._cache_on(wire)
292 292 region = self._region(wire)
293 293
294 294 @region.conditional_cache_on_arguments(condition=cache_on)
295 295 def _bookmarks(_context_uid, _repo_id):
296 296 repo = self._factory.repo(wire)
297 297 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo._bookmarks.items()}
298 298
299 299 return _bookmarks(context_uid, repo_id)
300 300
301 301 @reraise_safe_exceptions
302 302 def branches(self, wire, normal, closed):
303 303 cache_on, context_uid, repo_id = self._cache_on(wire)
304 304 region = self._region(wire)
305 305
306 306 @region.conditional_cache_on_arguments(condition=cache_on)
307 307 def _branches(_context_uid, _repo_id, _normal, _closed):
308 308 repo = self._factory.repo(wire)
309 309 iter_branches = repo.branchmap().iterbranches()
310 310 bt = {}
311 311 for branch_name, _heads, tip_node, is_closed in iter_branches:
312 312 if normal and not is_closed:
313 313 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
314 314 if closed and is_closed:
315 315 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
316 316
317 317 return bt
318 318
319 319 return _branches(context_uid, repo_id, normal, closed)
320 320
321 321 @reraise_safe_exceptions
322 322 def bulk_request(self, wire, commit_id, pre_load):
323 323 cache_on, context_uid, repo_id = self._cache_on(wire)
324 324 region = self._region(wire)
325 325
326 326 @region.conditional_cache_on_arguments(condition=cache_on)
327 327 def _bulk_request(_repo_id, _commit_id, _pre_load):
328 328 result = {}
329 329 for attr in pre_load:
330 330 try:
331 331 method = self._bulk_methods[attr]
332 332 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
333 333 result[attr] = method(wire, commit_id)
334 334 except KeyError as e:
335 335 raise exceptions.VcsException(e)(
336 336 f'Unknown bulk attribute: "{attr}"')
337 337 return result
338 338
339 339 return _bulk_request(repo_id, commit_id, sorted(pre_load))
340 340
341 341 @reraise_safe_exceptions
342 342 def ctx_branch(self, wire, commit_id):
343 343 cache_on, context_uid, repo_id = self._cache_on(wire)
344 344 region = self._region(wire)
345 345
346 346 @region.conditional_cache_on_arguments(condition=cache_on)
347 347 def _ctx_branch(_repo_id, _commit_id):
348 348 repo = self._factory.repo(wire)
349 349 ctx = self._get_ctx(repo, commit_id)
350 350 return ctx.branch()
351 351 return _ctx_branch(repo_id, commit_id)
352 352
353 353 @reraise_safe_exceptions
354 354 def ctx_date(self, wire, commit_id):
355 355 cache_on, context_uid, repo_id = self._cache_on(wire)
356 356 region = self._region(wire)
357 357
358 358 @region.conditional_cache_on_arguments(condition=cache_on)
359 359 def _ctx_date(_repo_id, _commit_id):
360 360 repo = self._factory.repo(wire)
361 361 ctx = self._get_ctx(repo, commit_id)
362 362 return ctx.date()
363 363 return _ctx_date(repo_id, commit_id)
364 364
365 365 @reraise_safe_exceptions
366 366 def ctx_description(self, wire, revision):
367 367 repo = self._factory.repo(wire)
368 368 ctx = self._get_ctx(repo, revision)
369 369 return ctx.description()
370 370
371 371 @reraise_safe_exceptions
372 372 def ctx_files(self, wire, commit_id):
373 373 cache_on, context_uid, repo_id = self._cache_on(wire)
374 374 region = self._region(wire)
375 375
376 376 @region.conditional_cache_on_arguments(condition=cache_on)
377 377 def _ctx_files(_repo_id, _commit_id):
378 378 repo = self._factory.repo(wire)
379 379 ctx = self._get_ctx(repo, commit_id)
380 380 return ctx.files()
381 381
382 382 return _ctx_files(repo_id, commit_id)
383 383
384 384 @reraise_safe_exceptions
385 385 def ctx_list(self, path, revision):
386 386 repo = self._factory.repo(path)
387 387 ctx = self._get_ctx(repo, revision)
388 388 return list(ctx)
389 389
390 390 @reraise_safe_exceptions
391 391 def ctx_parents(self, wire, commit_id):
392 392 cache_on, context_uid, repo_id = self._cache_on(wire)
393 393 region = self._region(wire)
394 394
395 395 @region.conditional_cache_on_arguments(condition=cache_on)
396 396 def _ctx_parents(_repo_id, _commit_id):
397 397 repo = self._factory.repo(wire)
398 398 ctx = self._get_ctx(repo, commit_id)
399 399 return [parent.hex() for parent in ctx.parents()
400 400 if not (parent.hidden() or parent.obsolete())]
401 401
402 402 return _ctx_parents(repo_id, commit_id)
403 403
404 404 @reraise_safe_exceptions
405 405 def ctx_children(self, wire, commit_id):
406 406 cache_on, context_uid, repo_id = self._cache_on(wire)
407 407 region = self._region(wire)
408 408
409 409 @region.conditional_cache_on_arguments(condition=cache_on)
410 410 def _ctx_children(_repo_id, _commit_id):
411 411 repo = self._factory.repo(wire)
412 412 ctx = self._get_ctx(repo, commit_id)
413 413 return [child.hex() for child in ctx.children()
414 414 if not (child.hidden() or child.obsolete())]
415 415
416 416 return _ctx_children(repo_id, commit_id)
417 417
418 418 @reraise_safe_exceptions
419 419 def ctx_phase(self, wire, commit_id):
420 420 cache_on, context_uid, repo_id = self._cache_on(wire)
421 421 region = self._region(wire)
422 422
423 423 @region.conditional_cache_on_arguments(condition=cache_on)
424 424 def _ctx_phase(_context_uid, _repo_id, _commit_id):
425 425 repo = self._factory.repo(wire)
426 426 ctx = self._get_ctx(repo, commit_id)
427 427 # public=0, draft=1, secret=3
428 428 return ctx.phase()
429 429 return _ctx_phase(context_uid, repo_id, commit_id)
430 430
431 431 @reraise_safe_exceptions
432 432 def ctx_obsolete(self, wire, commit_id):
433 433 cache_on, context_uid, repo_id = self._cache_on(wire)
434 434 region = self._region(wire)
435 435
436 436 @region.conditional_cache_on_arguments(condition=cache_on)
437 437 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
438 438 repo = self._factory.repo(wire)
439 439 ctx = self._get_ctx(repo, commit_id)
440 440 return ctx.obsolete()
441 441 return _ctx_obsolete(context_uid, repo_id, commit_id)
442 442
443 443 @reraise_safe_exceptions
444 444 def ctx_hidden(self, wire, commit_id):
445 445 cache_on, context_uid, repo_id = self._cache_on(wire)
446 446 region = self._region(wire)
447 447
448 448 @region.conditional_cache_on_arguments(condition=cache_on)
449 449 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
450 450 repo = self._factory.repo(wire)
451 451 ctx = self._get_ctx(repo, commit_id)
452 452 return ctx.hidden()
453 453 return _ctx_hidden(context_uid, repo_id, commit_id)
454 454
455 455 @reraise_safe_exceptions
456 456 def ctx_substate(self, wire, revision):
457 457 repo = self._factory.repo(wire)
458 458 ctx = self._get_ctx(repo, revision)
459 459 return ctx.substate
460 460
461 461 @reraise_safe_exceptions
462 462 def ctx_status(self, wire, revision):
463 463 repo = self._factory.repo(wire)
464 464 ctx = self._get_ctx(repo, revision)
465 465 status = repo[ctx.p1().node()].status(other=ctx.node())
466 466 # object of status (odd, custom named tuple in mercurial) is not
467 467 # correctly serializable, we make it a list, as the underling
468 468 # API expects this to be a list
469 469 return list(status)
470 470
471 471 @reraise_safe_exceptions
472 472 def ctx_user(self, wire, revision):
473 473 repo = self._factory.repo(wire)
474 474 ctx = self._get_ctx(repo, revision)
475 475 return ctx.user()
476 476
477 477 @reraise_safe_exceptions
478 478 def check_url(self, url, config):
479 479 url, _proto = normalize_url_for_hg(url)
480 480 url_obj = url_parser(safe_bytes(url))
481 481
482 482 test_uri = safe_str(url_obj.authinfo()[0])
483 483 authinfo = url_obj.authinfo()[1]
484 484 obfuscated_uri = get_obfuscated_url(url_obj)
485 485 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
486 486
487 487 handlers = []
488 488 if authinfo:
489 489 # create a password manager
490 490 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
491 passmgr.add_password(*authinfo)
491 passmgr.add_password(*convert_to_str(authinfo))
492 492
493 493 handlers.extend((httpbasicauthhandler(passmgr),
494 494 httpdigestauthhandler(passmgr)))
495 495
496 496 o = urllib.request.build_opener(*handlers)
497 497 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
498 498 ('Accept', 'application/mercurial-0.1')]
499 499
500 500 q = {"cmd": 'between'}
501 501 q.update({'pairs': "{}-{}".format('0' * 40, '0' * 40)})
502 502 qs = f'?{urllib.parse.urlencode(q)}'
503 503 cu = f"{test_uri}{qs}"
504 504
505 505 try:
506 506 req = urllib.request.Request(cu, None, {})
507 507 log.debug("Trying to open URL %s", obfuscated_uri)
508 508 resp = o.open(req)
509 509 if resp.code != 200:
510 510 raise exceptions.URLError()('Return Code is not 200')
511 511 except Exception as e:
512 512 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
513 513 # means it cannot be cloned
514 514 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
515 515
516 516 # now check if it's a proper hg repo, but don't do it for svn
517 517 try:
518 518 if _proto == 'svn':
519 519 pass
520 520 else:
521 521 # check for pure hg repos
522 522 log.debug(
523 523 "Verifying if URL is a Mercurial repository: %s", obfuscated_uri)
524 524 ui = make_ui_from_config(config)
525 525 peer_checker = makepeer(ui, safe_bytes(url))
526 526 peer_checker.lookup(b'tip')
527 527 except Exception as e:
528 528 log.warning("URL is not a valid Mercurial repository: %s",
529 529 obfuscated_uri)
530 530 raise exceptions.URLError(e)(
531 531 f"url [{obfuscated_uri}] does not look like an hg repo org_exc: {e}")
532 532
533 533 log.info("URL is a valid Mercurial repository: %s", obfuscated_uri)
534 534 return True
535 535
536 536 @reraise_safe_exceptions
537 537 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
538 538 repo = self._factory.repo(wire)
539 539
540 540 if file_filter:
541 541 # unpack the file-filter
542 542 repo_path, node_path = file_filter
543 543 match_filter = match(safe_bytes(repo_path), b'', [safe_bytes(node_path)])
544 544 else:
545 545 match_filter = file_filter
546 546 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
547 547
548 548 try:
549 549 diff_iter = patch.diff(
550 550 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts)
551 551 return BytesEnvelope(b"".join(diff_iter))
552 552 except RepoLookupError as e:
553 553 raise exceptions.LookupException(e)()
554 554
555 555 @reraise_safe_exceptions
556 556 def node_history(self, wire, revision, path, limit):
557 557 cache_on, context_uid, repo_id = self._cache_on(wire)
558 558 region = self._region(wire)
559 559
560 560 @region.conditional_cache_on_arguments(condition=cache_on)
561 561 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
562 562 repo = self._factory.repo(wire)
563 563
564 564 ctx = self._get_ctx(repo, revision)
565 565 fctx = ctx.filectx(safe_bytes(path))
566 566
567 567 def history_iter():
568 568 limit_rev = fctx.rev()
569 569
570 570 for fctx_candidate in reversed(list(fctx.filelog())):
571 571 f_obj = fctx.filectx(fctx_candidate)
572 572
573 573 # NOTE: This can be problematic...we can hide ONLY history node resulting in empty history
574 574 _ctx = f_obj.changectx()
575 575 if _ctx.hidden() or _ctx.obsolete():
576 576 continue
577 577
578 578 if limit_rev >= f_obj.rev():
579 579 yield f_obj
580 580
581 581 history = []
582 582 for cnt, obj in enumerate(history_iter()):
583 583 if limit and cnt >= limit:
584 584 break
585 585 history.append(hex(obj.node()))
586 586
587 587 return [x for x in history]
588 588 return _node_history(context_uid, repo_id, revision, path, limit)
589 589
590 590 @reraise_safe_exceptions
591 591 def node_history_until(self, wire, revision, path, limit):
592 592 cache_on, context_uid, repo_id = self._cache_on(wire)
593 593 region = self._region(wire)
594 594
595 595 @region.conditional_cache_on_arguments(condition=cache_on)
596 596 def _node_history_until(_context_uid, _repo_id):
597 597 repo = self._factory.repo(wire)
598 598 ctx = self._get_ctx(repo, revision)
599 599 fctx = ctx.filectx(safe_bytes(path))
600 600
601 601 file_log = list(fctx.filelog())
602 602 if limit:
603 603 # Limit to the last n items
604 604 file_log = file_log[-limit:]
605 605
606 606 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
607 607 return _node_history_until(context_uid, repo_id, revision, path, limit)
608 608
609 609 @reraise_safe_exceptions
610 610 def bulk_file_request(self, wire, commit_id, path, pre_load):
611 611 cache_on, context_uid, repo_id = self._cache_on(wire)
612 612 region = self._region(wire)
613 613
614 614 @region.conditional_cache_on_arguments(condition=cache_on)
615 615 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
616 616 result = {}
617 617 for attr in pre_load:
618 618 try:
619 619 method = self._bulk_file_methods[attr]
620 620 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
621 621 result[attr] = method(wire, _commit_id, _path)
622 622 except KeyError as e:
623 623 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
624 624 return result
625 625
626 626 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
627 627
628 628 @reraise_safe_exceptions
629 629 def fctx_annotate(self, wire, revision, path):
630 630 repo = self._factory.repo(wire)
631 631 ctx = self._get_ctx(repo, revision)
632 632 fctx = ctx.filectx(safe_bytes(path))
633 633
634 634 result = []
635 635 for i, annotate_obj in enumerate(fctx.annotate(), 1):
636 636 ln_no = i
637 637 sha = hex(annotate_obj.fctx.node())
638 638 content = annotate_obj.text
639 639 result.append((ln_no, ascii_str(sha), content))
640 640 return BinaryEnvelope(result)
641 641
642 642 @reraise_safe_exceptions
643 643 def fctx_node_data(self, wire, revision, path):
644 644 repo = self._factory.repo(wire)
645 645 ctx = self._get_ctx(repo, revision)
646 646 fctx = ctx.filectx(safe_bytes(path))
647 647 return BytesEnvelope(fctx.data())
648 648
649 649 @reraise_safe_exceptions
650 650 def fctx_flags(self, wire, commit_id, path):
651 651 cache_on, context_uid, repo_id = self._cache_on(wire)
652 652 region = self._region(wire)
653 653
654 654 @region.conditional_cache_on_arguments(condition=cache_on)
655 655 def _fctx_flags(_repo_id, _commit_id, _path):
656 656 repo = self._factory.repo(wire)
657 657 ctx = self._get_ctx(repo, commit_id)
658 658 fctx = ctx.filectx(safe_bytes(path))
659 659 return fctx.flags()
660 660
661 661 return _fctx_flags(repo_id, commit_id, path)
662 662
663 663 @reraise_safe_exceptions
664 664 def fctx_size(self, wire, commit_id, path):
665 665 cache_on, context_uid, repo_id = self._cache_on(wire)
666 666 region = self._region(wire)
667 667
668 668 @region.conditional_cache_on_arguments(condition=cache_on)
669 669 def _fctx_size(_repo_id, _revision, _path):
670 670 repo = self._factory.repo(wire)
671 671 ctx = self._get_ctx(repo, commit_id)
672 672 fctx = ctx.filectx(safe_bytes(path))
673 673 return fctx.size()
674 674 return _fctx_size(repo_id, commit_id, path)
675 675
676 676 @reraise_safe_exceptions
677 677 def get_all_commit_ids(self, wire, name):
678 678 cache_on, context_uid, repo_id = self._cache_on(wire)
679 679 region = self._region(wire)
680 680
681 681 @region.conditional_cache_on_arguments(condition=cache_on)
682 682 def _get_all_commit_ids(_context_uid, _repo_id, _name):
683 683 repo = self._factory.repo(wire)
684 684 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
685 685 return revs
686 686 return _get_all_commit_ids(context_uid, repo_id, name)
687 687
688 688 @reraise_safe_exceptions
689 689 def get_config_value(self, wire, section, name, untrusted=False):
690 690 repo = self._factory.repo(wire)
691 691 return repo.ui.config(ascii_bytes(section), ascii_bytes(name), untrusted=untrusted)
692 692
693 693 @reraise_safe_exceptions
694 694 def is_large_file(self, wire, commit_id, path):
695 695 cache_on, context_uid, repo_id = self._cache_on(wire)
696 696 region = self._region(wire)
697 697
698 698 @region.conditional_cache_on_arguments(condition=cache_on)
699 699 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
700 700 return largefiles.lfutil.isstandin(safe_bytes(path))
701 701
702 702 return _is_large_file(context_uid, repo_id, commit_id, path)
703 703
704 704 @reraise_safe_exceptions
705 705 def is_binary(self, wire, revision, path):
706 706 cache_on, context_uid, repo_id = self._cache_on(wire)
707 707 region = self._region(wire)
708 708
709 709 @region.conditional_cache_on_arguments(condition=cache_on)
710 710 def _is_binary(_repo_id, _sha, _path):
711 711 repo = self._factory.repo(wire)
712 712 ctx = self._get_ctx(repo, revision)
713 713 fctx = ctx.filectx(safe_bytes(path))
714 714 return fctx.isbinary()
715 715
716 716 return _is_binary(repo_id, revision, path)
717 717
718 718 @reraise_safe_exceptions
719 719 def md5_hash(self, wire, revision, path):
720 720 cache_on, context_uid, repo_id = self._cache_on(wire)
721 721 region = self._region(wire)
722 722
723 723 @region.conditional_cache_on_arguments(condition=cache_on)
724 724 def _md5_hash(_repo_id, _sha, _path):
725 725 repo = self._factory.repo(wire)
726 726 ctx = self._get_ctx(repo, revision)
727 727 fctx = ctx.filectx(safe_bytes(path))
728 728 return hashlib.md5(fctx.data()).hexdigest()
729 729
730 730 return _md5_hash(repo_id, revision, path)
731 731
732 732 @reraise_safe_exceptions
733 733 def in_largefiles_store(self, wire, sha):
734 734 repo = self._factory.repo(wire)
735 735 return largefiles.lfutil.instore(repo, sha)
736 736
737 737 @reraise_safe_exceptions
738 738 def in_user_cache(self, wire, sha):
739 739 repo = self._factory.repo(wire)
740 740 return largefiles.lfutil.inusercache(repo.ui, sha)
741 741
742 742 @reraise_safe_exceptions
743 743 def store_path(self, wire, sha):
744 744 repo = self._factory.repo(wire)
745 745 return largefiles.lfutil.storepath(repo, sha)
746 746
747 747 @reraise_safe_exceptions
748 748 def link(self, wire, sha, path):
749 749 repo = self._factory.repo(wire)
750 750 largefiles.lfutil.link(
751 751 largefiles.lfutil.usercachepath(repo.ui, sha), path)
752 752
753 753 @reraise_safe_exceptions
754 754 def localrepository(self, wire, create=False):
755 755 self._factory.repo(wire, create=create)
756 756
757 757 @reraise_safe_exceptions
758 758 def lookup(self, wire, revision, both):
759 759 cache_on, context_uid, repo_id = self._cache_on(wire)
760 760 region = self._region(wire)
761 761
762 762 @region.conditional_cache_on_arguments(condition=cache_on)
763 763 def _lookup(_context_uid, _repo_id, _revision, _both):
764 764 repo = self._factory.repo(wire)
765 765 rev = _revision
766 766 if isinstance(rev, int):
767 767 # NOTE(marcink):
768 768 # since Mercurial doesn't support negative indexes properly
769 769 # we need to shift accordingly by one to get proper index, e.g
770 770 # repo[-1] => repo[-2]
771 771 # repo[0] => repo[-1]
772 772 if rev <= 0:
773 773 rev = rev + -1
774 774 try:
775 775 ctx = self._get_ctx(repo, rev)
776 776 except AmbiguousPrefixLookupError:
777 777 e = RepoLookupError(rev)
778 778 e._org_exc_tb = format_exc(sys.exc_info())
779 779 raise exceptions.LookupException(e)(rev)
780 780 except (TypeError, RepoLookupError, binascii.Error) as e:
781 781 e._org_exc_tb = format_exc(sys.exc_info())
782 782 raise exceptions.LookupException(e)(rev)
783 783 except LookupError as e:
784 784 e._org_exc_tb = format_exc(sys.exc_info())
785 785 raise exceptions.LookupException(e)(e.name)
786 786
787 787 if not both:
788 788 return ctx.hex()
789 789
790 790 ctx = repo[ctx.hex()]
791 791 return ctx.hex(), ctx.rev()
792 792
793 793 return _lookup(context_uid, repo_id, revision, both)
794 794
795 795 @reraise_safe_exceptions
796 796 def sync_push(self, wire, url):
797 797 if not self.check_url(url, wire['config']):
798 798 return
799 799
800 800 repo = self._factory.repo(wire)
801 801
802 802 # Disable any prompts for this repo
803 803 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
804 804
805 805 bookmarks = list(dict(repo._bookmarks).keys())
806 806 remote = peer(repo, {}, safe_bytes(url))
807 807 # Disable any prompts for this remote
808 808 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
809 809
810 810 return exchange.push(
811 811 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
812 812
813 813 @reraise_safe_exceptions
814 814 def revision(self, wire, rev):
815 815 repo = self._factory.repo(wire)
816 816 ctx = self._get_ctx(repo, rev)
817 817 return ctx.rev()
818 818
819 819 @reraise_safe_exceptions
820 820 def rev_range(self, wire, commit_filter):
821 821 cache_on, context_uid, repo_id = self._cache_on(wire)
822 822 region = self._region(wire)
823 823
824 824 @region.conditional_cache_on_arguments(condition=cache_on)
825 825 def _rev_range(_context_uid, _repo_id, _filter):
826 826 repo = self._factory.repo(wire)
827 827 revisions = [
828 828 ascii_str(repo[rev].hex())
829 829 for rev in revrange(repo, list(map(ascii_bytes, commit_filter)))
830 830 ]
831 831 return revisions
832 832
833 833 return _rev_range(context_uid, repo_id, sorted(commit_filter))
834 834
835 835 @reraise_safe_exceptions
836 836 def rev_range_hash(self, wire, node):
837 837 repo = self._factory.repo(wire)
838 838
839 839 def get_revs(repo, rev_opt):
840 840 if rev_opt:
841 841 revs = revrange(repo, rev_opt)
842 842 if len(revs) == 0:
843 843 return (nullrev, nullrev)
844 844 return max(revs), min(revs)
845 845 else:
846 846 return len(repo) - 1, 0
847 847
848 848 stop, start = get_revs(repo, [node + ':'])
849 849 revs = [ascii_str(repo[r].hex()) for r in range(start, stop + 1)]
850 850 return revs
851 851
852 852 @reraise_safe_exceptions
853 853 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
854 854 org_path = safe_bytes(wire["path"])
855 855 other_path = safe_bytes(kwargs.pop('other_path', ''))
856 856
857 857 # case when we want to compare two independent repositories
858 858 if other_path and other_path != wire["path"]:
859 859 baseui = self._factory._create_config(wire["config"])
860 860 repo = unionrepo.makeunionrepository(baseui, other_path, org_path)
861 861 else:
862 862 repo = self._factory.repo(wire)
863 863 return list(repo.revs(rev_spec, *args))
864 864
865 865 @reraise_safe_exceptions
866 866 def verify(self, wire,):
867 867 repo = self._factory.repo(wire)
868 868 baseui = self._factory._create_config(wire['config'])
869 869
870 870 baseui, output = patch_ui_message_output(baseui)
871 871
872 872 repo.ui = baseui
873 873 verify.verify(repo)
874 874 return output.getvalue()
875 875
876 876 @reraise_safe_exceptions
877 877 def hg_update_cache(self, wire,):
878 878 repo = self._factory.repo(wire)
879 879 baseui = self._factory._create_config(wire['config'])
880 880 baseui, output = patch_ui_message_output(baseui)
881 881
882 882 repo.ui = baseui
883 883 with repo.wlock(), repo.lock():
884 884 repo.updatecaches(full=True)
885 885
886 886 return output.getvalue()
887 887
888 888 @reraise_safe_exceptions
889 889 def hg_rebuild_fn_cache(self, wire,):
890 890 repo = self._factory.repo(wire)
891 891 baseui = self._factory._create_config(wire['config'])
892 892 baseui, output = patch_ui_message_output(baseui)
893 893
894 894 repo.ui = baseui
895 895
896 896 repair.rebuildfncache(baseui, repo)
897 897
898 898 return output.getvalue()
899 899
900 900 @reraise_safe_exceptions
901 901 def tags(self, wire):
902 902 cache_on, context_uid, repo_id = self._cache_on(wire)
903 903 region = self._region(wire)
904 904
905 905 @region.conditional_cache_on_arguments(condition=cache_on)
906 906 def _tags(_context_uid, _repo_id):
907 907 repo = self._factory.repo(wire)
908 908 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo.tags().items()}
909 909
910 910 return _tags(context_uid, repo_id)
911 911
912 912 @reraise_safe_exceptions
913 913 def update(self, wire, node='', clean=False):
914 914 repo = self._factory.repo(wire)
915 915 baseui = self._factory._create_config(wire['config'])
916 916 node = safe_bytes(node)
917 917
918 918 commands.update(baseui, repo, node=node, clean=clean)
919 919
920 920 @reraise_safe_exceptions
921 921 def identify(self, wire):
922 922 repo = self._factory.repo(wire)
923 923 baseui = self._factory._create_config(wire['config'])
924 924 output = io.BytesIO()
925 925 baseui.write = output.write
926 926 # This is required to get a full node id
927 927 baseui.debugflag = True
928 928 commands.identify(baseui, repo, id=True)
929 929
930 930 return output.getvalue()
931 931
932 932 @reraise_safe_exceptions
933 933 def heads(self, wire, branch=None):
934 934 repo = self._factory.repo(wire)
935 935 baseui = self._factory._create_config(wire['config'])
936 936 output = io.BytesIO()
937 937
938 938 def write(data, **unused_kwargs):
939 939 output.write(data)
940 940
941 941 baseui.write = write
942 942 if branch:
943 943 args = [safe_bytes(branch)]
944 944 else:
945 945 args = []
946 946 commands.heads(baseui, repo, template=b'{node} ', *args)
947 947
948 948 return output.getvalue()
949 949
950 950 @reraise_safe_exceptions
951 951 def ancestor(self, wire, revision1, revision2):
952 952 repo = self._factory.repo(wire)
953 953 changelog = repo.changelog
954 954 lookup = repo.lookup
955 955 a = changelog.ancestor(lookup(safe_bytes(revision1)), lookup(safe_bytes(revision2)))
956 956 return hex(a)
957 957
958 958 @reraise_safe_exceptions
959 959 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
960 960 baseui = self._factory._create_config(wire["config"], hooks=hooks)
961 961 clone(baseui, safe_bytes(source), safe_bytes(dest), noupdate=not update_after_clone)
962 962
963 963 @reraise_safe_exceptions
964 964 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
965 965
966 966 repo = self._factory.repo(wire)
967 967 baseui = self._factory._create_config(wire['config'])
968 968 publishing = baseui.configbool(b'phases', b'publish')
969 969
970 970 def _filectxfn(_repo, ctx, path: bytes):
971 971 """
972 972 Marks given path as added/changed/removed in a given _repo. This is
973 973 for internal mercurial commit function.
974 974 """
975 975
976 976 # check if this path is removed
977 977 if safe_str(path) in removed:
978 978 # returning None is a way to mark node for removal
979 979 return None
980 980
981 981 # check if this path is added
982 982 for node in updated:
983 983 if safe_bytes(node['path']) == path:
984 984 return memfilectx(
985 985 _repo,
986 986 changectx=ctx,
987 987 path=safe_bytes(node['path']),
988 988 data=safe_bytes(node['content']),
989 989 islink=False,
990 990 isexec=bool(node['mode'] & stat.S_IXUSR),
991 991 copysource=False)
992 992 abort_exc = exceptions.AbortException()
993 993 raise abort_exc(f"Given path haven't been marked as added, changed or removed ({path})")
994 994
995 995 if publishing:
996 996 new_commit_phase = b'public'
997 997 else:
998 998 new_commit_phase = b'draft'
999 999 with repo.ui.configoverride({(b'phases', b'new-commit'): new_commit_phase}):
1000 1000 kwargs = {safe_bytes(k): safe_bytes(v) for k, v in extra.items()}
1001 1001 commit_ctx = memctx(
1002 1002 repo=repo,
1003 1003 parents=parents,
1004 1004 text=safe_bytes(message),
1005 1005 files=[safe_bytes(x) for x in files],
1006 1006 filectxfn=_filectxfn,
1007 1007 user=safe_bytes(user),
1008 1008 date=(commit_time, commit_timezone),
1009 1009 extra=kwargs)
1010 1010
1011 1011 n = repo.commitctx(commit_ctx)
1012 1012 new_id = hex(n)
1013 1013
1014 1014 return new_id
1015 1015
1016 1016 @reraise_safe_exceptions
1017 1017 def pull(self, wire, url, commit_ids=None):
1018 1018 repo = self._factory.repo(wire)
1019 1019 # Disable any prompts for this repo
1020 1020 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
1021 1021
1022 1022 remote = peer(repo, {}, safe_bytes(url))
1023 1023 # Disable any prompts for this remote
1024 1024 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
1025 1025
1026 1026 if commit_ids:
1027 1027 commit_ids = [bin(commit_id) for commit_id in commit_ids]
1028 1028
1029 1029 return exchange.pull(
1030 1030 repo, remote, heads=commit_ids, force=None).cgresult
1031 1031
1032 1032 @reraise_safe_exceptions
1033 1033 def pull_cmd(self, wire, source, bookmark='', branch='', revision='', hooks=True):
1034 1034 repo = self._factory.repo(wire)
1035 1035 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1036 1036
1037 1037 source = safe_bytes(source)
1038 1038
1039 1039 # Mercurial internally has a lot of logic that checks ONLY if
1040 1040 # option is defined, we just pass those if they are defined then
1041 1041 opts = {}
1042 1042
1043 1043 if bookmark:
1044 1044 opts['bookmark'] = [safe_bytes(x) for x in bookmark] \
1045 1045 if isinstance(bookmark, list) else safe_bytes(bookmark)
1046 1046
1047 1047 if branch:
1048 1048 opts['branch'] = [safe_bytes(x) for x in branch] \
1049 1049 if isinstance(branch, list) else safe_bytes(branch)
1050 1050
1051 1051 if revision:
1052 1052 opts['rev'] = [safe_bytes(x) for x in revision] \
1053 1053 if isinstance(revision, list) else safe_bytes(revision)
1054 1054
1055 1055 commands.pull(baseui, repo, source, **opts)
1056 1056
1057 1057 @reraise_safe_exceptions
1058 1058 def push(self, wire, revisions, dest_path, hooks: bool = True, push_branches: bool = False):
1059 1059 repo = self._factory.repo(wire)
1060 1060 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1061 1061
1062 1062 revisions = [safe_bytes(x) for x in revisions] \
1063 1063 if isinstance(revisions, list) else safe_bytes(revisions)
1064 1064
1065 1065 commands.push(baseui, repo, safe_bytes(dest_path),
1066 1066 rev=revisions,
1067 1067 new_branch=push_branches)
1068 1068
1069 1069 @reraise_safe_exceptions
1070 1070 def strip(self, wire, revision, update, backup):
1071 1071 repo = self._factory.repo(wire)
1072 1072 ctx = self._get_ctx(repo, revision)
1073 1073 hgext_strip.strip(
1074 1074 repo.baseui, repo, ctx.node(), update=update, backup=backup)
1075 1075
1076 1076 @reraise_safe_exceptions
1077 1077 def get_unresolved_files(self, wire):
1078 1078 repo = self._factory.repo(wire)
1079 1079
1080 1080 log.debug('Calculating unresolved files for repo: %s', repo)
1081 1081 output = io.BytesIO()
1082 1082
1083 1083 def write(data, **unused_kwargs):
1084 1084 output.write(data)
1085 1085
1086 1086 baseui = self._factory._create_config(wire['config'])
1087 1087 baseui.write = write
1088 1088
1089 1089 commands.resolve(baseui, repo, list=True)
1090 1090 unresolved = output.getvalue().splitlines(0)
1091 1091 return unresolved
1092 1092
1093 1093 @reraise_safe_exceptions
1094 1094 def merge(self, wire, revision):
1095 1095 repo = self._factory.repo(wire)
1096 1096 baseui = self._factory._create_config(wire['config'])
1097 1097 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1098 1098
1099 1099 # In case of sub repositories are used mercurial prompts the user in
1100 1100 # case of merge conflicts or different sub repository sources. By
1101 1101 # setting the interactive flag to `False` mercurial doesn't prompt the
1102 1102 # used but instead uses a default value.
1103 1103 repo.ui.setconfig(b'ui', b'interactive', False)
1104 1104 commands.merge(baseui, repo, rev=safe_bytes(revision))
1105 1105
1106 1106 @reraise_safe_exceptions
1107 1107 def merge_state(self, wire):
1108 1108 repo = self._factory.repo(wire)
1109 1109 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1110 1110
1111 1111 # In case of sub repositories are used mercurial prompts the user in
1112 1112 # case of merge conflicts or different sub repository sources. By
1113 1113 # setting the interactive flag to `False` mercurial doesn't prompt the
1114 1114 # used but instead uses a default value.
1115 1115 repo.ui.setconfig(b'ui', b'interactive', False)
1116 1116 ms = hg_merge.mergestate(repo)
1117 1117 return [x for x in ms.unresolved()]
1118 1118
1119 1119 @reraise_safe_exceptions
1120 1120 def commit(self, wire, message, username, close_branch=False):
1121 1121 repo = self._factory.repo(wire)
1122 1122 baseui = self._factory._create_config(wire['config'])
1123 1123 repo.ui.setconfig(b'ui', b'username', safe_bytes(username))
1124 1124 commands.commit(baseui, repo, message=safe_bytes(message), close_branch=close_branch)
1125 1125
1126 1126 @reraise_safe_exceptions
1127 1127 def rebase(self, wire, source='', dest='', abort=False):
1128 1128
1129 1129 repo = self._factory.repo(wire)
1130 1130 baseui = self._factory._create_config(wire['config'])
1131 1131 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1132 1132 # In case of sub repositories are used mercurial prompts the user in
1133 1133 # case of merge conflicts or different sub repository sources. By
1134 1134 # setting the interactive flag to `False` mercurial doesn't prompt the
1135 1135 # used but instead uses a default value.
1136 1136 repo.ui.setconfig(b'ui', b'interactive', False)
1137 1137
1138 1138 rebase_kws = dict(
1139 1139 keep=not abort,
1140 1140 abort=abort
1141 1141 )
1142 1142
1143 1143 if source:
1144 1144 source = repo[source]
1145 1145 rebase_kws['base'] = [source.hex()]
1146 1146 if dest:
1147 1147 dest = repo[dest]
1148 1148 rebase_kws['dest'] = dest.hex()
1149 1149
1150 1150 rebase.rebase(baseui, repo, **rebase_kws)
1151 1151
1152 1152 @reraise_safe_exceptions
1153 1153 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
1154 1154 repo = self._factory.repo(wire)
1155 1155 ctx = self._get_ctx(repo, revision)
1156 1156 node = ctx.node()
1157 1157
1158 1158 date = (tag_time, tag_timezone)
1159 1159 try:
1160 1160 hg_tag.tag(repo, safe_bytes(name), node, safe_bytes(message), local, safe_bytes(user), date)
1161 1161 except Abort as e:
1162 1162 log.exception("Tag operation aborted")
1163 1163 # Exception can contain unicode which we convert
1164 1164 raise exceptions.AbortException(e)(repr(e))
1165 1165
1166 1166 @reraise_safe_exceptions
1167 1167 def bookmark(self, wire, bookmark, revision=''):
1168 1168 repo = self._factory.repo(wire)
1169 1169 baseui = self._factory._create_config(wire['config'])
1170 1170 revision = revision or ''
1171 1171 commands.bookmark(baseui, repo, safe_bytes(bookmark), rev=safe_bytes(revision), force=True)
1172 1172
1173 1173 @reraise_safe_exceptions
1174 1174 def install_hooks(self, wire, force=False):
1175 1175 # we don't need any special hooks for Mercurial
1176 1176 pass
1177 1177
1178 1178 @reraise_safe_exceptions
1179 1179 def get_hooks_info(self, wire):
1180 1180 return {
1181 1181 'pre_version': vcsserver.get_version(),
1182 1182 'post_version': vcsserver.get_version(),
1183 1183 }
1184 1184
1185 1185 @reraise_safe_exceptions
1186 1186 def set_head_ref(self, wire, head_name):
1187 1187 pass
1188 1188
1189 1189 @reraise_safe_exceptions
1190 1190 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1191 1191 archive_dir_name, commit_id, cache_config):
1192 1192
1193 1193 def file_walker(_commit_id, path):
1194 1194 repo = self._factory.repo(wire)
1195 1195 ctx = repo[_commit_id]
1196 1196 is_root = path in ['', '/']
1197 1197 if is_root:
1198 1198 matcher = alwaysmatcher(badfn=None)
1199 1199 else:
1200 1200 matcher = patternmatcher('', [(b'glob', safe_bytes(path)+b'/**', b'')], badfn=None)
1201 1201 file_iter = ctx.manifest().walk(matcher)
1202 1202
1203 1203 for fn in file_iter:
1204 1204 file_path = fn
1205 1205 flags = ctx.flags(fn)
1206 1206 mode = b'x' in flags and 0o755 or 0o644
1207 1207 is_link = b'l' in flags
1208 1208
1209 1209 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1210 1210
1211 1211 return store_archive_in_cache(
1212 1212 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
1213 1213
@@ -1,133 +1,144 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import typing
19 19 import base64
20 20 import logging
21 21
22 22
23 23 log = logging.getLogger(__name__)
24 24
25 25
26 26 def safe_int(val, default=None) -> int:
27 27 """
28 28 Returns int() of val if val is not convertable to int use default
29 29 instead
30 30
31 31 :param val:
32 32 :param default:
33 33 """
34 34
35 35 try:
36 36 val = int(val)
37 37 except (ValueError, TypeError):
38 38 val = default
39 39
40 40 return val
41 41
42 42
43 43 def base64_to_str(text) -> str:
44 44 return safe_str(base64.encodebytes(safe_bytes(text))).strip()
45 45
46 46
47 47 def get_default_encodings() -> list[str]:
48 48 return ['utf8']
49 49
50 50
51 51 def safe_str(str_, to_encoding=None) -> str:
52 52 """
53 53 safe str function. Does few trick to turn unicode_ into string
54 54
55 55 :param str_: str to encode
56 56 :param to_encoding: encode to this type UTF8 default
57 57 """
58 58 if isinstance(str_, str):
59 59 return str_
60 60
61 61 # if it's bytes cast to str
62 62 if not isinstance(str_, bytes):
63 63 return str(str_)
64 64
65 65 to_encoding = to_encoding or get_default_encodings()
66 66 if not isinstance(to_encoding, (list, tuple)):
67 67 to_encoding = [to_encoding]
68 68
69 69 for enc in to_encoding:
70 70 try:
71 71 return str(str_, enc)
72 72 except UnicodeDecodeError:
73 73 pass
74 74
75 75 return str(str_, to_encoding[0], 'replace')
76 76
77 77
78 78 def safe_bytes(str_, from_encoding=None) -> bytes:
79 79 """
80 80 safe bytes function. Does few trick to turn str_ into bytes string:
81 81
82 82 :param str_: string to decode
83 83 :param from_encoding: encode from this type UTF8 default
84 84 """
85 85 if isinstance(str_, bytes):
86 86 return str_
87 87
88 88 if not isinstance(str_, str):
89 89 raise ValueError(f'safe_bytes cannot convert other types than str: got: {type(str_)}')
90 90
91 91 from_encoding = from_encoding or get_default_encodings()
92 92 if not isinstance(from_encoding, (list, tuple)):
93 93 from_encoding = [from_encoding]
94 94
95 95 for enc in from_encoding:
96 96 try:
97 97 return str_.encode(enc)
98 98 except UnicodeDecodeError:
99 99 pass
100 100
101 101 return str_.encode(from_encoding[0], 'replace')
102 102
103 103
104 104 def ascii_bytes(str_, allow_bytes=False) -> bytes:
105 105 """
106 106 Simple conversion from str to bytes, with assumption that str_ is pure ASCII.
107 107 Fails with UnicodeError on invalid input.
108 108 This should be used where encoding and "safe" ambiguity should be avoided.
109 109 Where strings already have been encoded in other ways but still are unicode
110 110 string - for example to hex, base64, json, urlencoding, or are known to be
111 111 identifiers.
112 112 """
113 113 if allow_bytes and isinstance(str_, bytes):
114 114 return str_
115 115
116 116 if not isinstance(str_, str):
117 117 raise ValueError(f'ascii_bytes cannot convert other types than str: got: {type(str_)}')
118 118 return str_.encode('ascii')
119 119
120 120
121 121 def ascii_str(str_) -> str:
122 122 """
123 123 Simple conversion from bytes to str, with assumption that str_ is pure ASCII.
124 124 Fails with UnicodeError on invalid input.
125 125 This should be used where encoding and "safe" ambiguity should be avoided.
126 126 Where strings are encoded but also in other ways are known to be ASCII, and
127 127 where a unicode string is wanted without caring about encoding. For example
128 128 to hex, base64, urlencoding, or are known to be identifiers.
129 129 """
130 130
131 131 if not isinstance(str_, bytes):
132 132 raise ValueError(f'ascii_str cannot convert other types than bytes: got: {type(str_)}')
133 133 return str_.decode('ascii')
134
135
136 def convert_to_str(data):
137 if isinstance(data, bytes):
138 return safe_str(data)
139 elif isinstance(data, tuple):
140 return tuple(convert_to_str(item) for item in data)
141 elif isinstance(data, list):
142 return list(convert_to_str(item) for item in data)
143 else:
144 return data
@@ -1,53 +1,69 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import pytest
19 from vcsserver.str_utils import ascii_bytes, ascii_str
19 from vcsserver.str_utils import ascii_bytes, ascii_str, convert_to_str
20 20
21 21
22 22 @pytest.mark.parametrize('given, expected', [
23 23 ('a', b'a'),
24 24 ('a', b'a'),
25 25 ])
26 26 def test_ascii_bytes(given, expected):
27 27 assert ascii_bytes(given) == expected
28 28
29 29
30 30 @pytest.mark.parametrize('given', [
31 31 'å',
32 32 'å'.encode('utf8')
33 33 ])
34 34 def test_ascii_bytes_raises(given):
35 35 with pytest.raises(ValueError):
36 36 ascii_bytes(given)
37 37
38 38
39 39 @pytest.mark.parametrize('given, expected', [
40 40 (b'a', 'a'),
41 41 ])
42 42 def test_ascii_str(given, expected):
43 43 assert ascii_str(given) == expected
44 44
45 45
46 46 @pytest.mark.parametrize('given', [
47 47 'a',
48 48 'å'.encode('utf8'),
49 49 'å'
50 50 ])
51 51 def test_ascii_str_raises(given):
52 52 with pytest.raises(ValueError):
53 53 ascii_str(given)
54
55
56 @pytest.mark.parametrize('given, expected', [
57 ('a', 'a'),
58 (b'a', 'a'),
59 # tuple
60 (('a', b'b', b'c'), ('a', 'b', 'c')),
61 # nested tuple
62 (('a', b'b', (b'd', b'e')), ('a', 'b', ('d', 'e'))),
63 # list
64 (['a', b'b', b'c'], ['a', 'b', 'c']),
65 # mixed
66 (['a', b'b', b'c', (b'b1', b'b2')], ['a', 'b', 'c', ('b1', 'b2')])
67 ])
68 def test_convert_to_str(given, expected):
69 assert convert_to_str(given) == expected
General Comments 0
You need to be logged in to leave comments. Login now