##// END OF EJS Templates
hooks: added few python3 related fixes to handle bytes vs str on Mercurial hooks
super-admin -
r1108:62b12ad9 python3
parent child Browse files
Show More
@@ -25,6 +25,8 b' import collections'
25 import importlib
25 import importlib
26 import base64
26 import base64
27 import msgpack
27 import msgpack
28 import dataclasses
29 import pygit2
28
30
29 from http.client import HTTPConnection
31 from http.client import HTTPConnection
30
32
@@ -34,7 +36,8 b' import mercurial.node'
34
36
35 from vcsserver.lib.rc_json import json
37 from vcsserver.lib.rc_json import json
36 from vcsserver import exceptions, subprocessio, settings
38 from vcsserver import exceptions, subprocessio, settings
37 from vcsserver.str_utils import safe_bytes
39 from vcsserver.str_utils import ascii_str, safe_str
40 from vcsserver.remote.git import Repository
38
41
39 log = logging.getLogger(__name__)
42 log = logging.getLogger(__name__)
40
43
@@ -104,7 +107,7 b' class HgMessageWriter(RemoteMessageWrite'
104 def __init__(self, ui):
107 def __init__(self, ui):
105 self.ui = ui
108 self.ui = ui
106
109
107 def write(self, message):
110 def write(self, message: str):
108 # TODO: Check why the quiet flag is set by default.
111 # TODO: Check why the quiet flag is set by default.
109 old = self.ui.quiet
112 old = self.ui.quiet
110 self.ui.quiet = False
113 self.ui.quiet = False
@@ -118,8 +121,8 b' class GitMessageWriter(RemoteMessageWrit'
118 def __init__(self, stdout=None):
121 def __init__(self, stdout=None):
119 self.stdout = stdout or sys.stdout
122 self.stdout = stdout or sys.stdout
120
123
121 def write(self, message):
124 def write(self, message: str):
122 self.stdout.write(safe_bytes(message))
125 self.stdout.write(message)
123
126
124
127
125 class SvnMessageWriter(RemoteMessageWriter):
128 class SvnMessageWriter(RemoteMessageWriter):
@@ -147,8 +150,9 b' def _handle_exception(result):'
147 elif exception_class == 'RepositoryError':
150 elif exception_class == 'RepositoryError':
148 raise exceptions.VcsException()(*result['exception_args'])
151 raise exceptions.VcsException()(*result['exception_args'])
149 elif exception_class:
152 elif exception_class:
150 raise Exception('Got remote exception "%s" with args "%s"' %
153 raise Exception(
151 (exception_class, result['exception_args']))
154 f"""Got remote exception "{exception_class}" with args "{result['exception_args']}" """
155 )
152
156
153
157
154 def _get_hooks_client(extras):
158 def _get_hooks_client(extras):
@@ -167,7 +171,6 b' def _call_hook(hook_name, extras, writer'
167 log.debug('Hooks, using client:%s', hooks_client)
171 log.debug('Hooks, using client:%s', hooks_client)
168 result = hooks_client(hook_name, extras)
172 result = hooks_client(hook_name, extras)
169 log.debug('Hooks got result: %s', result)
173 log.debug('Hooks got result: %s', result)
170
171 _handle_exception(result)
174 _handle_exception(result)
172 writer.write(result['output'])
175 writer.write(result['output'])
173
176
@@ -198,8 +201,8 b' def _rev_range_hash(repo, node, check_he'
198 for rev in range(start, end):
201 for rev in range(start, end):
199 revs.append(rev)
202 revs.append(rev)
200 ctx = get_ctx(repo, rev)
203 ctx = get_ctx(repo, rev)
201 commit_id = mercurial.node.hex(ctx.node())
204 commit_id = ascii_str(mercurial.node.hex(ctx.node()))
202 branch = ctx.branch()
205 branch = safe_str(ctx.branch())
203 commits.append((commit_id, branch))
206 commits.append((commit_id, branch))
204
207
205 parent_heads = []
208 parent_heads = []
@@ -223,8 +226,8 b' def _check_heads(repo, start, end, commi'
223 for p in parents:
226 for p in parents:
224 branch = get_ctx(repo, p).branch()
227 branch = get_ctx(repo, p).branch()
225 # The heads descending from that parent, on the same branch
228 # The heads descending from that parent, on the same branch
226 parent_heads = set([p])
229 parent_heads = {p}
227 reachable = set([p])
230 reachable = {p}
228 for x in range(p + 1, end):
231 for x in range(p + 1, end):
229 if get_ctx(repo, x).branch() != branch:
232 if get_ctx(repo, x).branch() != branch:
230 continue
233 continue
@@ -301,14 +304,16 b' def pre_push(ui, repo, node=None, **kwar'
301 detect_force_push = extras.get('detect_force_push')
304 detect_force_push = extras.get('detect_force_push')
302
305
303 rev_data = []
306 rev_data = []
304 if node and kwargs.get('hooktype') == 'pretxnchangegroup':
307 hook_type: str = safe_str(kwargs.get('hooktype'))
308
309 if node and hook_type == 'pretxnchangegroup':
305 branches = collections.defaultdict(list)
310 branches = collections.defaultdict(list)
306 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
311 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
307 for commit_id, branch in commits:
312 for commit_id, branch in commits:
308 branches[branch].append(commit_id)
313 branches[branch].append(commit_id)
309
314
310 for branch, commits in branches.items():
315 for branch, commits in branches.items():
311 old_rev = kwargs.get('node_last') or commits[0]
316 old_rev = ascii_str(kwargs.get('node_last')) or commits[0]
312 rev_data.append({
317 rev_data.append({
313 'total_commits': len(commits),
318 'total_commits': len(commits),
314 'old_rev': old_rev,
319 'old_rev': old_rev,
@@ -325,10 +330,10 b' def pre_push(ui, repo, node=None, **kwar'
325 extras.get('repo_store', ''), extras.get('repository', ''))
330 extras.get('repo_store', ''), extras.get('repository', ''))
326 push_ref['hg_env'] = _get_hg_env(
331 push_ref['hg_env'] = _get_hg_env(
327 old_rev=push_ref['old_rev'],
332 old_rev=push_ref['old_rev'],
328 new_rev=push_ref['new_rev'], txnid=kwargs.get('txnid'),
333 new_rev=push_ref['new_rev'], txnid=ascii_str(kwargs.get('txnid')),
329 repo_path=repo_path)
334 repo_path=repo_path)
330
335
331 extras['hook_type'] = kwargs.get('hooktype', 'pre_push')
336 extras['hook_type'] = hook_type or 'pre_push'
332 extras['commit_ids'] = rev_data
337 extras['commit_ids'] = rev_data
333
338
334 return _call_hook('pre_push', extras, HgMessageWriter(ui))
339 return _call_hook('pre_push', extras, HgMessageWriter(ui))
@@ -369,6 +374,7 b' def post_push(ui, repo, node, **kwargs):'
369 branches = []
374 branches = []
370 bookmarks = []
375 bookmarks = []
371 tags = []
376 tags = []
377 hook_type: str = safe_str(kwargs.get('hooktype'))
372
378
373 commits, _heads = _rev_range_hash(repo, node)
379 commits, _heads = _rev_range_hash(repo, node)
374 for commit_id, branch in commits:
380 for commit_id, branch in commits:
@@ -376,11 +382,12 b' def post_push(ui, repo, node, **kwargs):'
376 if branch not in branches:
382 if branch not in branches:
377 branches.append(branch)
383 branches.append(branch)
378
384
379 if hasattr(ui, '_rc_pushkey_branches'):
385 if hasattr(ui, '_rc_pushkey_bookmarks'):
380 bookmarks = ui._rc_pushkey_branches
386 bookmarks = ui._rc_pushkey_bookmarks
381
387
382 extras['hook_type'] = kwargs.get('hooktype', 'post_push')
388 extras['hook_type'] = hook_type or 'post_push'
383 extras['commit_ids'] = commit_ids
389 extras['commit_ids'] = commit_ids
390
384 extras['new_refs'] = {
391 extras['new_refs'] = {
385 'branches': branches,
392 'branches': branches,
386 'bookmarks': bookmarks,
393 'bookmarks': bookmarks,
@@ -401,9 +408,10 b' def post_push_ssh(ui, repo, node, **kwar'
401
408
402 def key_push(ui, repo, **kwargs):
409 def key_push(ui, repo, **kwargs):
403 from vcsserver.hgcompat import get_ctx
410 from vcsserver.hgcompat import get_ctx
404 if kwargs['new'] != '0' and kwargs['namespace'] == 'bookmarks':
411
412 if kwargs['new'] != b'0' and kwargs['namespace'] == b'bookmarks':
405 # store new bookmarks in our UI object propagated later to post_push
413 # store new bookmarks in our UI object propagated later to post_push
406 ui._rc_pushkey_branches = get_ctx(repo, kwargs['key']).bookmarks()
414 ui._rc_pushkey_bookmarks = get_ctx(repo, kwargs['key']).bookmarks()
407 return
415 return
408
416
409
417
@@ -432,10 +440,13 b' def handle_git_post_receive(unused_repo_'
432 pass
440 pass
433
441
434
442
435 HookResponse = collections.namedtuple('HookResponse', ('status', 'output'))
443 @dataclasses.dataclass
444 class HookResponse:
445 status: int
446 output: str
436
447
437
448
438 def git_pre_pull(extras):
449 def git_pre_pull(extras) -> HookResponse:
439 """
450 """
440 Pre pull hook.
451 Pre pull hook.
441
452
@@ -449,19 +460,19 b' def git_pre_pull(extras):'
449 if 'pull' not in extras['hooks']:
460 if 'pull' not in extras['hooks']:
450 return HookResponse(0, '')
461 return HookResponse(0, '')
451
462
452 stdout = io.BytesIO()
463 stdout = io.StringIO()
453 try:
464 try:
454 status = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
465 status_code = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
455
466
456 except Exception as error:
467 except Exception as error:
457 log.exception('Failed to call pre_pull hook')
468 log.exception('Failed to call pre_pull hook')
458 status = 128
469 status_code = 128
459 stdout.write(safe_bytes(f'ERROR: {error}\n'))
470 stdout.write(f'ERROR: {error}\n')
460
471
461 return HookResponse(status, stdout.getvalue())
472 return HookResponse(status_code, stdout.getvalue())
462
473
463
474
464 def git_post_pull(extras):
475 def git_post_pull(extras) -> HookResponse:
465 """
476 """
466 Post pull hook.
477 Post pull hook.
467
478
@@ -474,12 +485,12 b' def git_post_pull(extras):'
474 if 'pull' not in extras['hooks']:
485 if 'pull' not in extras['hooks']:
475 return HookResponse(0, '')
486 return HookResponse(0, '')
476
487
477 stdout = io.BytesIO()
488 stdout = io.StringIO()
478 try:
489 try:
479 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
490 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
480 except Exception as error:
491 except Exception as error:
481 status = 128
492 status = 128
482 stdout.write(safe_bytes(f'ERROR: {error}\n'))
493 stdout.write(f'ERROR: {error}\n')
483
494
484 return HookResponse(status, stdout.getvalue())
495 return HookResponse(status, stdout.getvalue())
485
496
@@ -504,15 +515,11 b' def _parse_git_ref_lines(revision_lines)'
504 return rev_data
515 return rev_data
505
516
506
517
507 def git_pre_receive(unused_repo_path, revision_lines, env):
518 def git_pre_receive(unused_repo_path, revision_lines, env) -> int:
508 """
519 """
509 Pre push hook.
520 Pre push hook.
510
521
511 :param extras: dictionary containing the keys defined in simplevcs
512 :type extras: dict
513
514 :return: status code of the hook. 0 for success.
522 :return: status code of the hook. 0 for success.
515 :rtype: int
516 """
523 """
517 extras = json.loads(env['RC_SCM_DATA'])
524 extras = json.loads(env['RC_SCM_DATA'])
518 rev_data = _parse_git_ref_lines(revision_lines)
525 rev_data = _parse_git_ref_lines(revision_lines)
@@ -545,18 +552,18 b' def git_pre_receive(unused_repo_path, re'
545
552
546 extras['hook_type'] = 'pre_receive'
553 extras['hook_type'] = 'pre_receive'
547 extras['commit_ids'] = rev_data
554 extras['commit_ids'] = rev_data
548 return _call_hook('pre_push', extras, GitMessageWriter())
555
556 stdout = sys.stdout
557 status_code = _call_hook('pre_push', extras, GitMessageWriter(stdout))
558
559 return status_code
549
560
550
561
551 def git_post_receive(unused_repo_path, revision_lines, env):
562 def git_post_receive(unused_repo_path, revision_lines, env) -> int:
552 """
563 """
553 Post push hook.
564 Post push hook.
554
565
555 :param extras: dictionary containing the keys defined in simplevcs
556 :type extras: dict
557
558 :return: status code of the hook. 0 for success.
566 :return: status code of the hook. 0 for success.
559 :rtype: int
560 """
567 """
561 extras = json.loads(env['RC_SCM_DATA'])
568 extras = json.loads(env['RC_SCM_DATA'])
562 if 'push' not in extras['hooks']:
569 if 'push' not in extras['hooks']:
@@ -576,26 +583,28 b' def git_post_receive(unused_repo_path, r'
576 type_ = push_ref['type']
583 type_ = push_ref['type']
577
584
578 if type_ == 'heads':
585 if type_ == 'heads':
586 # starting new branch case
579 if push_ref['old_rev'] == empty_commit_id:
587 if push_ref['old_rev'] == empty_commit_id:
580 # starting new branch case
588 push_ref_name = push_ref['name']
581 if push_ref['name'] not in branches:
589
582 branches.append(push_ref['name'])
590 if push_ref_name not in branches:
591 branches.append(push_ref_name)
583
592
584 # Fix up head revision if needed
593 need_head_set = ''
585 cmd = [settings.GIT_EXECUTABLE, 'show', 'HEAD']
594 with Repository(os.getcwd()) as repo:
586 try:
595 try:
587 subprocessio.run_command(cmd, env=os.environ.copy())
596 repo.head
588 except Exception:
597 except pygit2.GitError:
589 push_ref_name = push_ref['name']
598 need_head_set = f'refs/heads/{push_ref_name}'
590 cmd = [settings.GIT_EXECUTABLE, 'symbolic-ref', '"HEAD"', f'"refs/heads/{push_ref_name}"']
591 print(f"Setting default branch to {push_ref_name}")
592 subprocessio.run_command(cmd, env=os.environ.copy())
593
599
594 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref',
600 if need_head_set:
595 '--format=%(refname)', 'refs/heads/*']
601 repo.set_head(need_head_set)
602 print(f"Setting default branch to {push_ref_name}")
603
604 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref', '--format=%(refname)', 'refs/heads/*']
596 stdout, stderr = subprocessio.run_command(
605 stdout, stderr = subprocessio.run_command(
597 cmd, env=os.environ.copy())
606 cmd, env=os.environ.copy())
598 heads = stdout
607 heads = safe_str(stdout)
599 heads = heads.replace(push_ref['ref'], '')
608 heads = heads.replace(push_ref['ref'], '')
600 heads = ' '.join(head for head
609 heads = ' '.join(head for head
601 in heads.splitlines() if head) or '.'
610 in heads.splitlines() if head) or '.'
@@ -604,9 +613,10 b' def git_post_receive(unused_repo_path, r'
604 '--not', heads]
613 '--not', heads]
605 stdout, stderr = subprocessio.run_command(
614 stdout, stderr = subprocessio.run_command(
606 cmd, env=os.environ.copy())
615 cmd, env=os.environ.copy())
607 git_revs.extend(stdout.splitlines())
616 git_revs.extend(list(map(ascii_str, stdout.splitlines())))
617
618 # delete branch case
608 elif push_ref['new_rev'] == empty_commit_id:
619 elif push_ref['new_rev'] == empty_commit_id:
609 # delete branch case
610 git_revs.append('delete_branch=>%s' % push_ref['name'])
620 git_revs.append('delete_branch=>%s' % push_ref['name'])
611 else:
621 else:
612 if push_ref['name'] not in branches:
622 if push_ref['name'] not in branches:
@@ -617,7 +627,25 b' def git_post_receive(unused_repo_path, r'
617 '--reverse', '--pretty=format:%H']
627 '--reverse', '--pretty=format:%H']
618 stdout, stderr = subprocessio.run_command(
628 stdout, stderr = subprocessio.run_command(
619 cmd, env=os.environ.copy())
629 cmd, env=os.environ.copy())
620 git_revs.extend(stdout.splitlines())
630 # we get bytes from stdout, we need str to be consistent
631 log_revs = list(map(ascii_str, stdout.splitlines()))
632 git_revs.extend(log_revs)
633
634 # Pure pygit2 impl. but still 2-3x slower :/
635 # results = []
636 #
637 # with Repository(os.getcwd()) as repo:
638 # repo_new_rev = repo[push_ref['new_rev']]
639 # repo_old_rev = repo[push_ref['old_rev']]
640 # walker = repo.walk(repo_new_rev.id, pygit2.GIT_SORT_TOPOLOGICAL)
641 #
642 # for commit in walker:
643 # if commit.id == repo_old_rev.id:
644 # break
645 # results.append(commit.id.hex)
646 # # reverse the order, can't use GIT_SORT_REVERSE
647 # log_revs = results[::-1]
648
621 elif type_ == 'tags':
649 elif type_ == 'tags':
622 if push_ref['name'] not in tags:
650 if push_ref['name'] not in tags:
623 tags.append(push_ref['name'])
651 tags.append(push_ref['name'])
@@ -631,13 +659,16 b' def git_post_receive(unused_repo_path, r'
631 'tags': tags,
659 'tags': tags,
632 }
660 }
633
661
662 stdout = sys.stdout
663
634 if 'repo_size' in extras['hooks']:
664 if 'repo_size' in extras['hooks']:
635 try:
665 try:
636 _call_hook('repo_size', extras, GitMessageWriter())
666 _call_hook('repo_size', extras, GitMessageWriter(stdout))
637 except Exception:
667 except Exception:
638 pass
668 pass
639
669
640 return _call_hook('post_push', extras, GitMessageWriter())
670 status_code = _call_hook('post_push', extras, GitMessageWriter(stdout))
671 return status_code
641
672
642
673
643 def _get_extras_from_txn_id(path, txn_id):
674 def _get_extras_from_txn_id(path, txn_id):
@@ -336,8 +336,9 b' class GitRepository(object):'
336 pre_pull_messages = ''
336 pre_pull_messages = ''
337 # Upload-pack == clone
337 # Upload-pack == clone
338 if git_command == 'git-upload-pack':
338 if git_command == 'git-upload-pack':
339 status, pre_pull_messages = hooks.git_pre_pull(self.extras)
339 hook_response = hooks.git_pre_pull(self.extras)
340 if status != 0:
340 if hook_response.status != 0:
341 pre_pull_messages = hook_response.output
341 resp.app_iter = self._build_failed_pre_pull_response(
342 resp.app_iter = self._build_failed_pre_pull_response(
342 capabilities, pre_pull_messages)
343 capabilities, pre_pull_messages)
343 return resp
344 return resp
@@ -385,8 +386,8 b' class GitRepository(object):'
385
386
386 # Upload-pack == clone
387 # Upload-pack == clone
387 if git_command == 'git-upload-pack':
388 if git_command == 'git-upload-pack':
388 unused_status, post_pull_messages = hooks.git_post_pull(self.extras)
389 hook_response = hooks.git_post_pull(self.extras)
389
390 post_pull_messages = hook_response.output
390 resp.app_iter = self._build_post_pull_response(out, capabilities, pre_pull_messages, post_pull_messages)
391 resp.app_iter = self._build_post_pull_response(out, capabilities, pre_pull_messages, post_pull_messages)
391 else:
392 else:
392 resp.app_iter = out
393 resp.app_iter = out
@@ -22,8 +22,9 b' import posixpath as vcspath'
22 import re
22 import re
23 import stat
23 import stat
24 import traceback
24 import traceback
25 import urllib.request, urllib.parse, urllib.error
25 import urllib.request
26 import urllib.request, urllib.error, urllib.parse
26 import urllib.parse
27 import urllib.error
27 from functools import wraps
28 from functools import wraps
28
29
29 import more_itertools
30 import more_itertools
@@ -40,7 +41,7 b' from dulwich.repo import Repo as Dulwich'
40 from dulwich.server import update_server_info
41 from dulwich.server import update_server_info
41
42
42 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_str, ascii_bytes
44 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes
44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, archive_repo, BinaryEnvelope
45 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, archive_repo, BinaryEnvelope
45 from vcsserver.hgcompat import (
46 from vcsserver.hgcompat import (
46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
@@ -69,7 +70,7 b' def reraise_safe_exceptions(func):'
69 except (HangupException, UnexpectedCommandError) as e:
70 except (HangupException, UnexpectedCommandError) as e:
70 exc = exceptions.VcsException(org_exc=e)
71 exc = exceptions.VcsException(org_exc=e)
71 raise exc(safe_str(e))
72 raise exc(safe_str(e))
72 except Exception as e:
73 except Exception:
73 # NOTE(marcink): because of how dulwich handles some exceptions
74 # NOTE(marcink): because of how dulwich handles some exceptions
74 # (KeyError on empty repos), we cannot track this and catch all
75 # (KeyError on empty repos), we cannot track this and catch all
75 # exceptions, it's an exceptions from other handlers
76 # exceptions, it's an exceptions from other handlers
@@ -107,7 +108,7 b' class GitFactory(RepoFactory):'
107
108
108 def _create_repo(self, wire, create, use_libgit2=False):
109 def _create_repo(self, wire, create, use_libgit2=False):
109 if use_libgit2:
110 if use_libgit2:
110 return Repository(safe_bytes(wire['path']))
111 repo = Repository(safe_bytes(wire['path']))
111 else:
112 else:
112 # dulwich mode
113 # dulwich mode
113 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
114 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
General Comments 0
You need to be logged in to leave comments. Login now