diff --git a/rhodecode/api/tests/test_create_pull_request.py b/rhodecode/api/tests/test_create_pull_request.py --- a/rhodecode/api/tests/test_create_pull_request.py +++ b/rhodecode/api/tests/test_create_pull_request.py @@ -56,6 +56,25 @@ class TestCreatePullRequestApi(object): assert_error(id_, expected, given=response.body) @pytest.mark.backends("git", "hg") + @pytest.mark.parametrize('source_ref', [ + 'bookmarg:default:initial' + ]) + def test_create_with_wrong_refs_data(self, backend, source_ref): + + data = self._prepare_data(backend) + data['source_ref'] = source_ref + + id_, params = build_data( + self.apikey_regular, 'create_pull_request', **data) + + response = api_call(self.app, params) + + expected = "Ref `{}` type is not allowed. " \ + "Only:['bookmark', 'book', 'tag', 'branch'] " \ + "are possible.".format(source_ref) + assert_error(id_, expected, given=response.body) + + @pytest.mark.backends("git", "hg") def test_create_with_correct_data(self, backend): data = self._prepare_data(backend) RepoModel().revoke_user_permission( diff --git a/rhodecode/api/tests/test_utils.py b/rhodecode/api/tests/test_utils.py --- a/rhodecode/api/tests/test_utils.py +++ b/rhodecode/api/tests/test_utils.py @@ -84,11 +84,11 @@ class TestResolveRefOrError(object): def test_non_supported_refs(self): repo = Mock() - ref = 'ancestor:ref' + ref = 'bookmark:ref' with pytest.raises(JSONRPCError) as excinfo: utils.resolve_ref_or_error(ref, repo) expected_message = ( - 'The specified value:ancestor:`ref` does not exist, or is not allowed.') + 'The specified value:bookmark:`ref` does not exist, or is not allowed.') assert excinfo.value.message == expected_message def test_branch_is_not_found(self): diff --git a/rhodecode/api/utils.py b/rhodecode/api/utils.py --- a/rhodecode/api/utils.py +++ b/rhodecode/api/utils.py @@ -388,7 +388,19 @@ def get_commit_or_error(ref, repo): raise JSONRPCError('Ref `{ref}` does not exist'.format(ref=ref)) -def resolve_ref_or_error(ref, repo): +def _get_ref_hash(repo, type_, name): + vcs_repo = repo.scm_instance() + if type_ in ['branch'] and vcs_repo.alias in ('hg', 'git'): + return vcs_repo.branches[name] + elif type_ in ['bookmark', 'book'] and vcs_repo.alias == 'hg': + return vcs_repo.bookmarks[name] + else: + raise ValueError() + + +def resolve_ref_or_error(ref, repo, allowed_ref_types=None): + allowed_ref_types = allowed_ref_types or ['bookmark', 'book', 'tag', 'branch'] + def _parse_ref(type_, name, hash_=None): return type_, name, hash_ @@ -399,6 +411,12 @@ def resolve_ref_or_error(ref, repo): 'Ref `{ref}` given in a wrong format. Please check the API' ' documentation for more details'.format(ref=ref)) + if ref_type not in allowed_ref_types: + raise JSONRPCError( + 'Ref `{ref}` type is not allowed. ' + 'Only:{allowed_refs} are possible.'.format( + ref=ref, allowed_refs=allowed_ref_types)) + try: ref_hash = ref_hash or _get_ref_hash(repo, ref_type, ref_name) except (KeyError, ValueError): @@ -429,13 +447,3 @@ def _get_commit_dict( "raw_diff": raw_diff, "stats": stats } - - -def _get_ref_hash(repo, type_, name): - vcs_repo = repo.scm_instance() - if type_ == 'branch' and vcs_repo.alias in ('hg', 'git'): - return vcs_repo.branches[name] - elif type_ == 'bookmark' and vcs_repo.alias == 'hg': - return vcs_repo.bookmarks[name] - else: - raise ValueError() diff --git a/rhodecode/model/pull_request.py b/rhodecode/model/pull_request.py --- a/rhodecode/model/pull_request.py +++ b/rhodecode/model/pull_request.py @@ -129,6 +129,8 @@ class PullRequestModel(BaseModel): 'This pull request cannot be updated because the source ' 'reference is missing.'), } + REF_TYPES = ['bookmark', 'book', 'tag', 'branch'] + UPDATABLE_REF_TYPES = ['bookmark', 'book', 'branch'] def __get_pull_request(self, pull_request): return self._get_instance(( @@ -671,7 +673,7 @@ class PullRequestModel(BaseModel): def has_valid_update_type(self, pull_request): source_ref_type = pull_request.source_ref_parts.type - return source_ref_type in ['book', 'branch', 'tag'] + return source_ref_type in self.REF_TYPES def update_commits(self, pull_request): """ @@ -751,7 +753,7 @@ class PullRequestModel(BaseModel): pull_request_version = pull_request try: - if target_ref_type in ('tag', 'branch', 'book'): + if target_ref_type in self.REF_TYPES: target_commit = target_repo.get_commit(target_ref_name) else: target_commit = target_repo.get_commit(target_ref_id) @@ -1326,7 +1328,7 @@ class PullRequestModel(BaseModel): return merge_state def _refresh_reference(self, reference, vcs_repository): - if reference.type in ('branch', 'book'): + if reference.type in self.UPDATABLE_REF_TYPES: name_or_id = reference.name else: name_or_id = reference.commit_id