##// END OF EJS Templates
feat(celery-hooks): added all needed changes to support new celery backend, removed DummyHooksCallbackDaemon, updated tests. Fixes: RCCE-55
ilin.s -
r5298:25044729 default
parent child Browse files
Show More
@@ -1,262 +1,263 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import re
21 21 import logging
22 22 import datetime
23 23 import configparser
24 24 from sqlalchemy import Table
25 25
26 26 from rhodecode.lib.utils2 import AttributeDict
27 27 from rhodecode.model.scm import ScmModel
28 28
29 29 from .hg import MercurialServer
30 30 from .git import GitServer
31 31 from .svn import SubversionServer
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 class SshWrapper(object):
36 36 hg_cmd_pat = re.compile(r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$')
37 37 git_cmd_pat = re.compile(r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$')
38 38 svn_cmd_pat = re.compile(r'^svnserve -t')
39 39
40 40 def __init__(self, command, connection_info, mode,
41 41 user, user_id, key_id: int, shell, ini_path: str, env):
42 42 self.command = command
43 43 self.connection_info = connection_info
44 44 self.mode = mode
45 45 self.username = user
46 46 self.user_id = user_id
47 47 self.key_id = key_id
48 48 self.shell = shell
49 49 self.ini_path = ini_path
50 50 self.env = env
51 51
52 52 self.config = self.parse_config(ini_path)
53 53 self.server_impl = None
54 54
55 55 def parse_config(self, config_path):
56 56 parser = configparser.ConfigParser()
57 57 parser.read(config_path)
58 58 return parser
59 59
60 60 def update_key_access_time(self, key_id):
61 61 from rhodecode.model.meta import raw_query_executor, Base
62 62
63 63 table = Table('user_ssh_keys', Base.metadata, autoload=False)
64 64 atime = datetime.datetime.utcnow()
65 65 stmt = (
66 66 table.update()
67 67 .where(table.c.ssh_key_id == key_id)
68 68 .values(accessed_on=atime)
69 69 # no MySQL Support for .returning :((
70 70 #.returning(table.c.accessed_on, table.c.ssh_key_fingerprint)
71 71 )
72 72
73 73 res_count = None
74 74 with raw_query_executor() as session:
75 75 result = session.execute(stmt)
76 76 if result.rowcount:
77 77 res_count = result.rowcount
78 78
79 79 if res_count:
80 80 log.debug('Update key id:`%s` access time', key_id)
81 81
82 82 def get_user(self, user_id):
83 83 user = AttributeDict()
84 84 # lazy load db imports
85 85 from rhodecode.model.db import User
86 86 dbuser = User.get(user_id)
87 87 if not dbuser:
88 88 return None
89 89 user.user_id = dbuser.user_id
90 90 user.username = dbuser.username
91 91 user.auth_user = dbuser.AuthUser()
92 92 return user
93 93
94 94 def get_connection_info(self):
95 95 """
96 96 connection_info
97 97
98 98 Identifies the client and server ends of the connection.
99 99 The variable contains four space-separated values: client IP address,
100 100 client port number, server IP address, and server port number.
101 101 """
102 102 conn = dict(
103 103 client_ip=None,
104 104 client_port=None,
105 105 server_ip=None,
106 106 server_port=None,
107 107 )
108 108
109 109 info = self.connection_info.split(' ')
110 110 if len(info) == 4:
111 111 conn['client_ip'] = info[0]
112 112 conn['client_port'] = info[1]
113 113 conn['server_ip'] = info[2]
114 114 conn['server_port'] = info[3]
115 115
116 116 return conn
117 117
118 118 def maybe_translate_repo_uid(self, repo_name):
119 119 _org_name = repo_name
120 120 if _org_name.startswith('_'):
121 121 # remove format of _ID/subrepo
122 122 _org_name = _org_name.split('/', 1)[0]
123 123
124 124 if repo_name.startswith('_'):
125 125 from rhodecode.model.repo import RepoModel
126 126 org_repo_name = repo_name
127 127 log.debug('translating UID repo %s', org_repo_name)
128 128 by_id_match = RepoModel().get_repo_by_id(repo_name)
129 129 if by_id_match:
130 130 repo_name = by_id_match.repo_name
131 131 log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
132 132
133 133 return repo_name, _org_name
134 134
135 135 def get_repo_details(self, mode):
136 136 vcs_type = mode if mode in ['svn', 'hg', 'git'] else None
137 137 repo_name = None
138 138
139 139 hg_match = self.hg_cmd_pat.match(self.command)
140 140 if hg_match is not None:
141 141 vcs_type = 'hg'
142 142 repo_id = hg_match.group(1).strip('/')
143 143 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
144 144 return vcs_type, repo_name, mode
145 145
146 146 git_match = self.git_cmd_pat.match(self.command)
147 147 if git_match is not None:
148 148 mode = git_match.group(1)
149 149 vcs_type = 'git'
150 150 repo_id = git_match.group(2).strip('/')
151 151 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
152 152 return vcs_type, repo_name, mode
153 153
154 154 svn_match = self.svn_cmd_pat.match(self.command)
155 155 if svn_match is not None:
156 156 vcs_type = 'svn'
157 157 # Repo name should be extracted from the input stream, we're unable to
158 158 # extract it at this point in execution
159 159 return vcs_type, repo_name, mode
160 160
161 161 return vcs_type, repo_name, mode
162 162
163 163 def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
164 164 store = ScmModel().repos_path
165 165
166 166 check_branch_perms = False
167 167 detect_force_push = False
168 168
169 169 if branch_permissions:
170 170 check_branch_perms = True
171 171 detect_force_push = True
172 172
173 173 log.debug(
174 174 'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
175 175 vcs, mode, repo, check_branch_perms)
176 176
177 177 # detect if we have to check branch permissions
178 178 extras = {
179 179 'detect_force_push': detect_force_push,
180 180 'check_branch_perms': check_branch_perms,
181 'config': self.ini_path
181 182 }
182 183
183 184 if vcs == 'hg':
184 185 server = MercurialServer(
185 186 store=store, ini_path=self.ini_path,
186 187 repo_name=repo, user=user,
187 188 user_permissions=permissions, config=self.config, env=self.env)
188 189 self.server_impl = server
189 190 return server.run(tunnel_extras=extras)
190 191
191 192 elif vcs == 'git':
192 193 server = GitServer(
193 194 store=store, ini_path=self.ini_path,
194 195 repo_name=repo, repo_mode=mode, user=user,
195 196 user_permissions=permissions, config=self.config, env=self.env)
196 197 self.server_impl = server
197 198 return server.run(tunnel_extras=extras)
198 199
199 200 elif vcs == 'svn':
200 201 server = SubversionServer(
201 202 store=store, ini_path=self.ini_path,
202 203 repo_name=None, user=user,
203 204 user_permissions=permissions, config=self.config, env=self.env)
204 205 self.server_impl = server
205 206 return server.run(tunnel_extras=extras)
206 207
207 208 else:
208 209 raise Exception(f'Unrecognised VCS: {vcs}')
209 210
210 211 def wrap(self):
211 212 mode = self.mode
212 213 username = self.username
213 214 user_id = self.user_id
214 215 key_id = self.key_id
215 216 shell = self.shell
216 217
217 218 scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
218 219
219 220 log.debug(
220 221 'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
221 222 'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
222 223 mode, username, user_id, shell, self.command,
223 224 scm_detected, scm_mode, scm_repo)
224 225
225 226 log.debug('SSH Connection info %s', self.get_connection_info())
226 227
227 228 # update last access time for this key
228 229 if key_id:
229 230 self.update_key_access_time(key_id)
230 231
231 232 if shell and self.command is None:
232 233 log.info('Dropping to shell, no command given and shell is allowed')
233 234 os.execl('/bin/bash', '-l')
234 235 exit_code = 1
235 236
236 237 elif scm_detected:
237 238 user = self.get_user(user_id)
238 239 if not user:
239 240 log.warning('User with id %s not found', user_id)
240 241 exit_code = -1
241 242 return exit_code
242 243
243 244 auth_user = user.auth_user
244 245 permissions = auth_user.permissions['repositories']
245 246 repo_branch_permissions = auth_user.get_branch_permissions(scm_repo)
246 247 try:
247 248 exit_code, is_updated = self.serve(
248 249 scm_detected, scm_repo, scm_mode, user, permissions,
249 250 repo_branch_permissions)
250 251 except Exception:
251 252 log.exception('Error occurred during execution of SshWrapper')
252 253 exit_code = -1
253 254
254 255 elif self.command is None and shell is False:
255 256 log.error('No Command given.')
256 257 exit_code = -1
257 258
258 259 else:
259 260 log.error('Unhandled Command: "%s" Aborting.', self.command)
260 261 exit_code = -1
261 262
262 263 return exit_code
@@ -1,161 +1,160 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import sys
21 21 import logging
22 22
23 23 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
24 24 from rhodecode.lib.ext_json import sjson as json
25 25 from rhodecode.lib.vcs.conf import settings as vcs_settings
26 26 from rhodecode.model.scm import ScmModel
27 27
28 28 log = logging.getLogger(__name__)
29 29
30 30
31 31 class VcsServer(object):
32 32 repo_user_agent = None # set in child classes
33 33 _path = None # set executable path for hg/git/svn binary
34 34 backend = None # set in child classes
35 35 tunnel = None # subprocess handling tunnel
36 36 write_perms = ['repository.admin', 'repository.write']
37 37 read_perms = ['repository.read', 'repository.admin', 'repository.write']
38 38
39 39 def __init__(self, user, user_permissions, config, env):
40 40 self.user = user
41 41 self.user_permissions = user_permissions
42 42 self.config = config
43 43 self.env = env
44 44 self.stdin = sys.stdin
45 45
46 46 self.repo_name = None
47 47 self.repo_mode = None
48 48 self.store = ''
49 49 self.ini_path = ''
50 50
51 51 def _invalidate_cache(self, repo_name):
52 52 """
53 53 Set's cache for this repository for invalidation on next access
54 54
55 55 :param repo_name: full repo name, also a cache key
56 56 """
57 57 ScmModel().mark_for_invalidation(repo_name)
58 58
59 59 def has_write_perm(self):
60 60 permission = self.user_permissions.get(self.repo_name)
61 61 if permission in ['repository.write', 'repository.admin']:
62 62 return True
63 63
64 64 return False
65 65
66 66 def _check_permissions(self, action):
67 67 permission = self.user_permissions.get(self.repo_name)
68 68 log.debug('permission for %s on %s are: %s',
69 69 self.user, self.repo_name, permission)
70 70
71 71 if not permission:
72 72 log.error('user `%s` permissions to repo:%s are empty. Forbidding access.',
73 73 self.user, self.repo_name)
74 74 return -2
75 75
76 76 if action == 'pull':
77 77 if permission in self.read_perms:
78 78 log.info(
79 79 'READ Permissions for User "%s" detected to repo "%s"!',
80 80 self.user, self.repo_name)
81 81 return 0
82 82 else:
83 83 if permission in self.write_perms:
84 84 log.info(
85 85 'WRITE, or Higher Permissions for User "%s" detected to repo "%s"!',
86 86 self.user, self.repo_name)
87 87 return 0
88 88
89 89 log.error('Cannot properly fetch or verify user `%s` permissions. '
90 90 'Permissions: %s, vcs action: %s',
91 91 self.user, permission, action)
92 92 return -2
93 93
94 94 def update_environment(self, action, extras=None):
95 95
96 96 scm_data = {
97 97 'ip': os.environ['SSH_CLIENT'].split()[0],
98 98 'username': self.user.username,
99 99 'user_id': self.user.user_id,
100 100 'action': action,
101 101 'repository': self.repo_name,
102 102 'scm': self.backend,
103 103 'config': self.ini_path,
104 104 'repo_store': self.store,
105 105 'make_lock': None,
106 106 'locked_by': [None, None],
107 107 'server_url': None,
108 108 'user_agent': f'{self.repo_user_agent}/ssh-user-agent',
109 109 'hooks': ['push', 'pull'],
110 110 'hooks_module': 'rhodecode.lib.hooks_daemon',
111 111 'is_shadow_repo': False,
112 112 'detect_force_push': False,
113 113 'check_branch_perms': False,
114 114
115 115 'SSH': True,
116 116 'SSH_PERMISSIONS': self.user_permissions.get(self.repo_name),
117 117 }
118 118 if extras:
119 119 scm_data.update(extras)
120 120 os.putenv("RC_SCM_DATA", json.dumps(scm_data))
121 121
122 122 def get_root_store(self):
123 123 root_store = self.store
124 124 if not root_store.endswith('/'):
125 125 # always append trailing slash
126 126 root_store = root_store + '/'
127 127 return root_store
128 128
129 129 def _handle_tunnel(self, extras):
130 130 # pre-auth
131 131 action = 'pull'
132 132 exit_code = self._check_permissions(action)
133 133 if exit_code:
134 134 return exit_code, False
135 135
136 136 req = self.env['request']
137 137 server_url = req.host_url + req.script_name
138 138 extras['server_url'] = server_url
139 139
140 140 log.debug('Using %s binaries from path %s', self.backend, self._path)
141 141 exit_code = self.tunnel.run(extras)
142 142
143 143 return exit_code, action == "push"
144 144
145 145 def run(self, tunnel_extras=None):
146 146 tunnel_extras = tunnel_extras or {}
147 147 extras = {}
148 148 extras.update(tunnel_extras)
149 149
150 150 callback_daemon, extras = prepare_callback_daemon(
151 151 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
152 host=vcs_settings.HOOKS_HOST,
153 use_direct_calls=False)
152 host=vcs_settings.HOOKS_HOST)
154 153
155 154 with callback_daemon:
156 155 try:
157 156 return self._handle_tunnel(extras)
158 157 finally:
159 158 log.debug('Running cleanup with cache invalidation')
160 159 if self.repo_name:
161 160 self._invalidate_cache(self.repo_name)
@@ -1,118 +1,117 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import platform
21 21
22 22 from rhodecode.model import init_model
23 23
24 24
25 25 def configure_vcs(config):
26 26 """
27 27 Patch VCS config with some RhodeCode specific stuff
28 28 """
29 29 from rhodecode.lib.vcs import conf
30 30 import rhodecode.lib.vcs.conf.settings
31 31
32 32 conf.settings.BACKENDS = {
33 33 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
34 34 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
35 35 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
36 36 }
37 37
38 38 conf.settings.HOOKS_PROTOCOL = config['vcs.hooks.protocol']
39 39 conf.settings.HOOKS_HOST = config['vcs.hooks.host']
40 conf.settings.HOOKS_DIRECT_CALLS = config['vcs.hooks.direct_calls']
41 40 conf.settings.DEFAULT_ENCODINGS = config['default_encoding']
42 41 conf.settings.ALIASES[:] = config['vcs.backends']
43 42 conf.settings.SVN_COMPATIBLE_VERSION = config['vcs.svn.compatible_version']
44 43
45 44
46 45 def initialize_database(config):
47 46 from rhodecode.lib.utils2 import engine_from_config, get_encryption_key
48 47 engine = engine_from_config(config, 'sqlalchemy.db1.')
49 48 init_model(engine, encryption_key=get_encryption_key(config))
50 49
51 50
52 51 def initialize_test_environment(settings, test_env=None):
53 52 if test_env is None:
54 53 test_env = not int(os.environ.get('RC_NO_TMP_PATH', 0))
55 54
56 55 from rhodecode.lib.utils import (
57 56 create_test_directory, create_test_database, create_test_repositories,
58 57 create_test_index)
59 58 from rhodecode.tests import TESTS_TMP_PATH
60 59 from rhodecode.lib.vcs.backends.hg import largefiles_store
61 60 from rhodecode.lib.vcs.backends.git import lfs_store
62 61
63 62 # test repos
64 63 if test_env:
65 64 create_test_directory(TESTS_TMP_PATH)
66 65 # large object stores
67 66 create_test_directory(largefiles_store(TESTS_TMP_PATH))
68 67 create_test_directory(lfs_store(TESTS_TMP_PATH))
69 68
70 69 create_test_database(TESTS_TMP_PATH, settings)
71 70 create_test_repositories(TESTS_TMP_PATH, settings)
72 71 create_test_index(TESTS_TMP_PATH, settings)
73 72
74 73
75 74 def get_vcs_server_protocol(config):
76 75 return config['vcs.server.protocol']
77 76
78 77
79 78 def set_instance_id(config):
80 79 """
81 80 Sets a dynamic generated config['instance_id'] if missing or '*'
82 81 E.g instance_id = *cluster-1 or instance_id = *
83 82 """
84 83
85 84 config['instance_id'] = config.get('instance_id') or ''
86 85 instance_id = config['instance_id']
87 86 if instance_id.startswith('*') or not instance_id:
88 87 prefix = instance_id.lstrip('*')
89 88 _platform_id = platform.uname()[1] or 'instance'
90 89 config['instance_id'] = '{prefix}uname:{platform}-pid:{pid}'.format(
91 90 prefix=prefix,
92 91 platform=_platform_id,
93 92 pid=os.getpid())
94 93
95 94
96 95 def get_default_user_id():
97 96 DEFAULT_USER = 'default'
98 97 from sqlalchemy import text
99 98 from rhodecode.model import meta
100 99
101 100 engine = meta.get_engine()
102 101 with meta.SA_Session(engine) as session:
103 102 result = session.execute(text("SELECT user_id from users where username = :uname"), {'uname': DEFAULT_USER})
104 103 user_id = result.first()[0]
105 104
106 105 return user_id
107 106
108 107
109 108 def get_default_base_path():
110 109 from sqlalchemy import text
111 110 from rhodecode.model import meta
112 111
113 112 engine = meta.get_engine()
114 113 with meta.SA_Session(engine) as session:
115 114 result = session.execute(text("SELECT ui_value from rhodecode_ui where ui_key = '/'"))
116 115 base_path = result.first()[0]
117 116
118 117 return base_path
@@ -1,412 +1,448 b''
1 1 # Copyright (C) 2012-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 """
20 20 RhodeCode task modules, containing all task that suppose to be run
21 21 by celery daemon
22 22 """
23 23
24 24 import os
25 25 import time
26 26
27 27 from pyramid_mailer.mailer import Mailer
28 28 from pyramid_mailer.message import Message
29 29 from email.utils import formatdate
30 30
31 31 import rhodecode
32 32 from rhodecode.lib import audit_logger
33 33 from rhodecode.lib.celerylib import get_logger, async_task, RequestContextTask, run_task
34 34 from rhodecode.lib import hooks_base
35 from rhodecode.lib.utils import adopt_for_celery
35 36 from rhodecode.lib.utils2 import safe_int, str2bool, aslist
36 37 from rhodecode.lib.statsd_client import StatsdClient
37 38 from rhodecode.model.db import (
38 39 true, null, Session, IntegrityError, Repository, RepoGroup, User)
39 40 from rhodecode.model.permission import PermissionModel
40 41
41 42
42 43 @async_task(ignore_result=True, base=RequestContextTask)
43 44 def send_email(recipients, subject, body='', html_body='', email_config=None,
44 45 extra_headers=None):
45 46 """
46 47 Sends an email with defined parameters from the .ini files.
47 48
48 49 :param recipients: list of recipients, it this is empty the defined email
49 50 address from field 'email_to' is used instead
50 51 :param subject: subject of the mail
51 52 :param body: body of the mail
52 53 :param html_body: html version of body
53 54 :param email_config: specify custom configuration for mailer
54 55 :param extra_headers: specify custom headers
55 56 """
56 57 log = get_logger(send_email)
57 58
58 59 email_config = email_config or rhodecode.CONFIG
59 60
60 61 mail_server = email_config.get('smtp_server') or None
61 62 if mail_server is None:
62 63 log.error("SMTP server information missing. Sending email failed. "
63 64 "Make sure that `smtp_server` variable is configured "
64 65 "inside the .ini file")
65 66 return False
66 67
67 68 subject = "%s %s" % (email_config.get('email_prefix', ''), subject)
68 69
69 70 if recipients:
70 71 if isinstance(recipients, str):
71 72 recipients = recipients.split(',')
72 73 else:
73 74 # if recipients are not defined we send to email_config + all admins
74 75 admins = []
75 76 for u in User.query().filter(User.admin == true()).all():
76 77 if u.email:
77 78 admins.append(u.email)
78 79 recipients = []
79 80 config_email = email_config.get('email_to')
80 81 if config_email:
81 82 recipients += [config_email]
82 83 recipients += admins
83 84
84 85 # translate our LEGACY config into the one that pyramid_mailer supports
85 86 email_conf = dict(
86 87 host=mail_server,
87 88 port=email_config.get('smtp_port', 25),
88 89 username=email_config.get('smtp_username'),
89 90 password=email_config.get('smtp_password'),
90 91
91 92 tls=str2bool(email_config.get('smtp_use_tls')),
92 93 ssl=str2bool(email_config.get('smtp_use_ssl')),
93 94
94 95 # SSL key file
95 96 # keyfile='',
96 97
97 98 # SSL certificate file
98 99 # certfile='',
99 100
100 101 # Location of maildir
101 102 # queue_path='',
102 103
103 104 default_sender=email_config.get('app_email_from', 'RhodeCode-noreply@rhodecode.com'),
104 105
105 106 debug=str2bool(email_config.get('smtp_debug')),
106 107 # /usr/sbin/sendmail Sendmail executable
107 108 # sendmail_app='',
108 109
109 110 # {sendmail_app} -t -i -f {sender} Template for sendmail execution
110 111 # sendmail_template='',
111 112 )
112 113
113 114 if extra_headers is None:
114 115 extra_headers = {}
115 116
116 117 extra_headers.setdefault('Date', formatdate(time.time()))
117 118
118 119 if 'thread_ids' in extra_headers:
119 120 thread_ids = extra_headers.pop('thread_ids')
120 121 extra_headers['References'] = ' '.join('<{}>'.format(t) for t in thread_ids)
121 122
122 123 try:
123 124 mailer = Mailer(**email_conf)
124 125
125 126 message = Message(subject=subject,
126 127 sender=email_conf['default_sender'],
127 128 recipients=recipients,
128 129 body=body, html=html_body,
129 130 extra_headers=extra_headers)
130 131 mailer.send_immediately(message)
131 132 statsd = StatsdClient.statsd
132 133 if statsd:
133 134 statsd.incr('rhodecode_email_sent_total')
134 135
135 136 except Exception:
136 137 log.exception('Mail sending failed')
137 138 return False
138 139 return True
139 140
140 141
141 142 @async_task(ignore_result=True, base=RequestContextTask)
142 143 def create_repo(form_data, cur_user):
143 144 from rhodecode.model.repo import RepoModel
144 145 from rhodecode.model.user import UserModel
145 146 from rhodecode.model.scm import ScmModel
146 147 from rhodecode.model.settings import SettingsModel
147 148
148 149 log = get_logger(create_repo)
149 150
150 151 cur_user = UserModel()._get_user(cur_user)
151 152 owner = cur_user
152 153
153 154 repo_name = form_data['repo_name']
154 155 repo_name_full = form_data['repo_name_full']
155 156 repo_type = form_data['repo_type']
156 157 description = form_data['repo_description']
157 158 private = form_data['repo_private']
158 159 clone_uri = form_data.get('clone_uri')
159 160 repo_group = safe_int(form_data['repo_group'])
160 161 copy_fork_permissions = form_data.get('copy_permissions')
161 162 copy_group_permissions = form_data.get('repo_copy_permissions')
162 163 fork_of = form_data.get('fork_parent_id')
163 164 state = form_data.get('repo_state', Repository.STATE_PENDING)
164 165
165 166 # repo creation defaults, private and repo_type are filled in form
166 167 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
167 168 enable_statistics = form_data.get(
168 169 'enable_statistics', defs.get('repo_enable_statistics'))
169 170 enable_locking = form_data.get(
170 171 'enable_locking', defs.get('repo_enable_locking'))
171 172 enable_downloads = form_data.get(
172 173 'enable_downloads', defs.get('repo_enable_downloads'))
173 174
174 175 # set landing rev based on default branches for SCM
175 176 landing_ref, _label = ScmModel.backend_landing_ref(repo_type)
176 177
177 178 try:
178 179 RepoModel()._create_repo(
179 180 repo_name=repo_name_full,
180 181 repo_type=repo_type,
181 182 description=description,
182 183 owner=owner,
183 184 private=private,
184 185 clone_uri=clone_uri,
185 186 repo_group=repo_group,
186 187 landing_rev=landing_ref,
187 188 fork_of=fork_of,
188 189 copy_fork_permissions=copy_fork_permissions,
189 190 copy_group_permissions=copy_group_permissions,
190 191 enable_statistics=enable_statistics,
191 192 enable_locking=enable_locking,
192 193 enable_downloads=enable_downloads,
193 194 state=state
194 195 )
195 196 Session().commit()
196 197
197 198 # now create this repo on Filesystem
198 199 RepoModel()._create_filesystem_repo(
199 200 repo_name=repo_name,
200 201 repo_type=repo_type,
201 202 repo_group=RepoModel()._get_repo_group(repo_group),
202 203 clone_uri=clone_uri,
203 204 )
204 205 repo = Repository.get_by_repo_name(repo_name_full)
205 206 hooks_base.create_repository(created_by=owner.username, **repo.get_dict())
206 207
207 208 # update repo commit caches initially
208 209 repo.update_commit_cache()
209 210
210 211 # set new created state
211 212 repo.set_state(Repository.STATE_CREATED)
212 213 repo_id = repo.repo_id
213 214 repo_data = repo.get_api_data()
214 215
215 216 audit_logger.store(
216 217 'repo.create', action_data={'data': repo_data},
217 218 user=cur_user,
218 219 repo=audit_logger.RepoWrap(repo_name=repo_name, repo_id=repo_id))
219 220
220 221 Session().commit()
221 222
222 223 PermissionModel().trigger_permission_flush()
223 224
224 225 except Exception as e:
225 226 log.warning('Exception occurred when creating repository, '
226 227 'doing cleanup...', exc_info=True)
227 228 if isinstance(e, IntegrityError):
228 229 Session().rollback()
229 230
230 231 # rollback things manually !
231 232 repo = Repository.get_by_repo_name(repo_name_full)
232 233 if repo:
233 234 Repository.delete(repo.repo_id)
234 235 Session().commit()
235 236 RepoModel()._delete_filesystem_repo(repo)
236 237 log.info('Cleanup of repo %s finished', repo_name_full)
237 238 raise
238 239
239 240 return True
240 241
241 242
242 243 @async_task(ignore_result=True, base=RequestContextTask)
243 244 def create_repo_fork(form_data, cur_user):
244 245 """
245 246 Creates a fork of repository using internal VCS methods
246 247 """
247 248 from rhodecode.model.repo import RepoModel
248 249 from rhodecode.model.user import UserModel
249 250
250 251 log = get_logger(create_repo_fork)
251 252
252 253 cur_user = UserModel()._get_user(cur_user)
253 254 owner = cur_user
254 255
255 256 repo_name = form_data['repo_name'] # fork in this case
256 257 repo_name_full = form_data['repo_name_full']
257 258 repo_type = form_data['repo_type']
258 259 description = form_data['description']
259 260 private = form_data['private']
260 261 clone_uri = form_data.get('clone_uri')
261 262 repo_group = safe_int(form_data['repo_group'])
262 263 landing_ref = form_data['landing_rev']
263 264 copy_fork_permissions = form_data.get('copy_permissions')
264 265 fork_id = safe_int(form_data.get('fork_parent_id'))
265 266
266 267 try:
267 268 fork_of = RepoModel()._get_repo(fork_id)
268 269 RepoModel()._create_repo(
269 270 repo_name=repo_name_full,
270 271 repo_type=repo_type,
271 272 description=description,
272 273 owner=owner,
273 274 private=private,
274 275 clone_uri=clone_uri,
275 276 repo_group=repo_group,
276 277 landing_rev=landing_ref,
277 278 fork_of=fork_of,
278 279 copy_fork_permissions=copy_fork_permissions
279 280 )
280 281
281 282 Session().commit()
282 283
283 284 base_path = Repository.base_path()
284 285 source_repo_path = os.path.join(base_path, fork_of.repo_name)
285 286
286 287 # now create this repo on Filesystem
287 288 RepoModel()._create_filesystem_repo(
288 289 repo_name=repo_name,
289 290 repo_type=repo_type,
290 291 repo_group=RepoModel()._get_repo_group(repo_group),
291 292 clone_uri=source_repo_path,
292 293 )
293 294 repo = Repository.get_by_repo_name(repo_name_full)
294 295 hooks_base.create_repository(created_by=owner.username, **repo.get_dict())
295 296
296 297 # update repo commit caches initially
297 298 config = repo._config
298 299 config.set('extensions', 'largefiles', '')
299 300 repo.update_commit_cache(config=config)
300 301
301 302 # set new created state
302 303 repo.set_state(Repository.STATE_CREATED)
303 304
304 305 repo_id = repo.repo_id
305 306 repo_data = repo.get_api_data()
306 307 audit_logger.store(
307 308 'repo.fork', action_data={'data': repo_data},
308 309 user=cur_user,
309 310 repo=audit_logger.RepoWrap(repo_name=repo_name, repo_id=repo_id))
310 311
311 312 Session().commit()
312 313 except Exception as e:
313 314 log.warning('Exception occurred when forking repository, '
314 315 'doing cleanup...', exc_info=True)
315 316 if isinstance(e, IntegrityError):
316 317 Session().rollback()
317 318
318 319 # rollback things manually !
319 320 repo = Repository.get_by_repo_name(repo_name_full)
320 321 if repo:
321 322 Repository.delete(repo.repo_id)
322 323 Session().commit()
323 324 RepoModel()._delete_filesystem_repo(repo)
324 325 log.info('Cleanup of repo %s finished', repo_name_full)
325 326 raise
326 327
327 328 return True
328 329
329 330
330 331 @async_task(ignore_result=True, base=RequestContextTask)
331 332 def repo_maintenance(repoid):
332 333 from rhodecode.lib import repo_maintenance as repo_maintenance_lib
333 334 log = get_logger(repo_maintenance)
334 335 repo = Repository.get_by_id_or_repo_name(repoid)
335 336 if repo:
336 337 maintenance = repo_maintenance_lib.RepoMaintenance()
337 338 tasks = maintenance.get_tasks_for_repo(repo)
338 339 log.debug('Executing %s tasks on repo `%s`', tasks, repoid)
339 340 executed_types = maintenance.execute(repo)
340 341 log.debug('Got execution results %s', executed_types)
341 342 else:
342 343 log.debug('Repo `%s` not found or without a clone_url', repoid)
343 344
344 345
345 346 @async_task(ignore_result=True, base=RequestContextTask)
346 347 def check_for_update(send_email_notification=True, email_recipients=None):
347 348 from rhodecode.model.update import UpdateModel
348 349 from rhodecode.model.notification import EmailNotificationModel
349 350
350 351 log = get_logger(check_for_update)
351 352 update_url = UpdateModel().get_update_url()
352 353 cur_ver = rhodecode.__version__
353 354
354 355 try:
355 356 data = UpdateModel().get_update_data(update_url)
356 357
357 358 current_ver = UpdateModel().get_stored_version(fallback=cur_ver)
358 359 latest_ver = data['versions'][0]['version']
359 360 UpdateModel().store_version(latest_ver)
360 361
361 362 if send_email_notification:
362 363 log.debug('Send email notification is enabled. '
363 364 'Current RhodeCode version: %s, latest known: %s', current_ver, latest_ver)
364 365 if UpdateModel().is_outdated(current_ver, latest_ver):
365 366
366 367 email_kwargs = {
367 368 'current_ver': current_ver,
368 369 'latest_ver': latest_ver,
369 370 }
370 371
371 372 (subject, email_body, email_body_plaintext) = EmailNotificationModel().render_email(
372 373 EmailNotificationModel.TYPE_UPDATE_AVAILABLE, **email_kwargs)
373 374
374 375 email_recipients = aslist(email_recipients, sep=',') or \
375 376 [user.email for user in User.get_all_super_admins()]
376 377 run_task(send_email, email_recipients, subject,
377 378 email_body_plaintext, email_body)
378 379
379 380 except Exception:
380 381 log.exception('Failed to check for update')
381 382 raise
382 383
383 384
384 385 def sync_last_update_for_objects(*args, **kwargs):
385 386 skip_repos = kwargs.get('skip_repos')
386 387 if not skip_repos:
387 388 repos = Repository.query() \
388 389 .order_by(Repository.group_id.asc())
389 390
390 391 for repo in repos:
391 392 repo.update_commit_cache()
392 393
393 394 skip_groups = kwargs.get('skip_groups')
394 395 if not skip_groups:
395 396 repo_groups = RepoGroup.query() \
396 397 .filter(RepoGroup.group_parent_id == null())
397 398
398 399 for root_gr in repo_groups:
399 400 for repo_gr in reversed(root_gr.recursive_groups()):
400 401 repo_gr.update_commit_cache()
401 402
402 403
403 404 @async_task(ignore_result=True, base=RequestContextTask)
404 405 def sync_last_update(*args, **kwargs):
405 406 sync_last_update_for_objects(*args, **kwargs)
406 407
407 408
408 409 @async_task(ignore_result=False)
409 410 def beat_check(*args, **kwargs):
410 411 log = get_logger(beat_check)
411 412 log.info('%r: Got args: %r and kwargs %r', beat_check, args, kwargs)
412 413 return time.time()
414
415
416 @async_task
417 @adopt_for_celery
418 def repo_size(extras):
419 from rhodecode.lib.hooks_base import repo_size
420 return repo_size(extras)
421
422
423 @async_task
424 @adopt_for_celery
425 def pre_pull(extras):
426 from rhodecode.lib.hooks_base import pre_pull
427 return pre_pull(extras)
428
429
430 @async_task
431 @adopt_for_celery
432 def post_pull(extras):
433 from rhodecode.lib.hooks_base import post_pull
434 return post_pull(extras)
435
436
437 @async_task
438 @adopt_for_celery
439 def pre_push(extras):
440 from rhodecode.lib.hooks_base import pre_push
441 return pre_push(extras)
442
443
444 @async_task
445 @adopt_for_celery
446 def post_push(extras):
447 from rhodecode.lib.hooks_base import post_push
448 return post_push(extras)
@@ -1,531 +1,534 b''
1 1 # Copyright (C) 2013-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19
20 20 """
21 21 Set of hooks run by RhodeCode Enterprise
22 22 """
23 23
24 24 import os
25 25 import logging
26 26
27 27 import rhodecode
28 28 from rhodecode import events
29 29 from rhodecode.lib import helpers as h
30 30 from rhodecode.lib import audit_logger
31 31 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
32 32 from rhodecode.lib.exceptions import (
33 33 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
34 34 from rhodecode.model.db import Repository, User
35 35 from rhodecode.lib.statsd_client import StatsdClient
36 36
37 37 log = logging.getLogger(__name__)
38 38
39 39
40 40 class HookResponse(object):
41 41 def __init__(self, status, output):
42 42 self.status = status
43 43 self.output = output
44 44
45 45 def __add__(self, other):
46 46 other_status = getattr(other, 'status', 0)
47 47 new_status = max(self.status, other_status)
48 48 other_output = getattr(other, 'output', '')
49 49 new_output = self.output + other_output
50 50
51 51 return HookResponse(new_status, new_output)
52 52
53 53 def __bool__(self):
54 54 return self.status == 0
55 55
56 def to_json(self):
57 return {'status': self.status, 'output': self.output}
58
56 59
57 60 def is_shadow_repo(extras):
58 61 """
59 62 Returns ``True`` if this is an action executed against a shadow repository.
60 63 """
61 64 return extras['is_shadow_repo']
62 65
63 66
64 67 def _get_scm_size(alias, root_path):
65 68
66 69 if not alias.startswith('.'):
67 70 alias += '.'
68 71
69 72 size_scm, size_root = 0, 0
70 73 for path, unused_dirs, files in os.walk(safe_str(root_path)):
71 74 if path.find(alias) != -1:
72 75 for f in files:
73 76 try:
74 77 size_scm += os.path.getsize(os.path.join(path, f))
75 78 except OSError:
76 79 pass
77 80 else:
78 81 for f in files:
79 82 try:
80 83 size_root += os.path.getsize(os.path.join(path, f))
81 84 except OSError:
82 85 pass
83 86
84 87 size_scm_f = h.format_byte_size_binary(size_scm)
85 88 size_root_f = h.format_byte_size_binary(size_root)
86 89 size_total_f = h.format_byte_size_binary(size_root + size_scm)
87 90
88 91 return size_scm_f, size_root_f, size_total_f
89 92
90 93
91 94 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
92 95 def repo_size(extras):
93 96 """Present size of repository after push."""
94 97 repo = Repository.get_by_repo_name(extras.repository)
95 98 vcs_part = f'.{repo.repo_type}'
96 99 size_vcs, size_root, size_total = _get_scm_size(vcs_part, repo.repo_full_path)
97 100 msg = (f'RhodeCode: `{repo.repo_name}` size summary {vcs_part}:{size_vcs} repo:{size_root} total:{size_total}\n')
98 101 return HookResponse(0, msg)
99 102
100 103
101 104 def pre_push(extras):
102 105 """
103 106 Hook executed before pushing code.
104 107
105 108 It bans pushing when the repository is locked.
106 109 """
107 110
108 111 user = User.get_by_username(extras.username)
109 112 output = ''
110 113 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
111 114 locked_by = User.get(extras.locked_by[0]).username
112 115 reason = extras.locked_by[2]
113 116 # this exception is interpreted in git/hg middlewares and based
114 117 # on that proper return code is server to client
115 118 _http_ret = HTTPLockedRC(
116 119 _locked_by_explanation(extras.repository, locked_by, reason))
117 120 if str(_http_ret.code).startswith('2'):
118 121 # 2xx Codes don't raise exceptions
119 122 output = _http_ret.title
120 123 else:
121 124 raise _http_ret
122 125
123 126 hook_response = ''
124 127 if not is_shadow_repo(extras):
125 128
126 129 if extras.commit_ids and extras.check_branch_perms:
127 130 auth_user = user.AuthUser()
128 131 repo = Repository.get_by_repo_name(extras.repository)
129 132 affected_branches = []
130 133 if repo.repo_type == 'hg':
131 134 for entry in extras.commit_ids:
132 135 if entry['type'] == 'branch':
133 136 is_forced = bool(entry['multiple_heads'])
134 137 affected_branches.append([entry['name'], is_forced])
135 138 elif repo.repo_type == 'git':
136 139 for entry in extras.commit_ids:
137 140 if entry['type'] == 'heads':
138 141 is_forced = bool(entry['pruned_sha'])
139 142 affected_branches.append([entry['name'], is_forced])
140 143
141 144 for branch_name, is_forced in affected_branches:
142 145
143 146 rule, branch_perm = auth_user.get_rule_and_branch_permission(
144 147 extras.repository, branch_name)
145 148 if not branch_perm:
146 149 # no branch permission found for this branch, just keep checking
147 150 continue
148 151
149 152 if branch_perm == 'branch.push_force':
150 153 continue
151 154 elif branch_perm == 'branch.push' and is_forced is False:
152 155 continue
153 156 elif branch_perm == 'branch.push' and is_forced is True:
154 157 halt_message = f'Branch `{branch_name}` changes rejected by rule {rule}. ' \
155 158 f'FORCE PUSH FORBIDDEN.'
156 159 else:
157 160 halt_message = f'Branch `{branch_name}` changes rejected by rule {rule}.'
158 161
159 162 if halt_message:
160 163 _http_ret = HTTPBranchProtected(halt_message)
161 164 raise _http_ret
162 165
163 166 # Propagate to external components. This is done after checking the
164 167 # lock, for consistent behavior.
165 168 hook_response = pre_push_extension(
166 169 repo_store_path=Repository.base_path(), **extras)
167 170 events.trigger(events.RepoPrePushEvent(
168 171 repo_name=extras.repository, extras=extras))
169 172
170 173 return HookResponse(0, output) + hook_response
171 174
172 175
173 176 def pre_pull(extras):
174 177 """
175 178 Hook executed before pulling the code.
176 179
177 180 It bans pulling when the repository is locked.
178 181 """
179 182
180 183 output = ''
181 184 if extras.locked_by[0]:
182 185 locked_by = User.get(extras.locked_by[0]).username
183 186 reason = extras.locked_by[2]
184 187 # this exception is interpreted in git/hg middlewares and based
185 188 # on that proper return code is server to client
186 189 _http_ret = HTTPLockedRC(
187 190 _locked_by_explanation(extras.repository, locked_by, reason))
188 191 if str(_http_ret.code).startswith('2'):
189 192 # 2xx Codes don't raise exceptions
190 193 output = _http_ret.title
191 194 else:
192 195 raise _http_ret
193 196
194 197 # Propagate to external components. This is done after checking the
195 198 # lock, for consistent behavior.
196 199 hook_response = ''
197 200 if not is_shadow_repo(extras):
198 201 extras.hook_type = extras.hook_type or 'pre_pull'
199 202 hook_response = pre_pull_extension(
200 203 repo_store_path=Repository.base_path(), **extras)
201 204 events.trigger(events.RepoPrePullEvent(
202 205 repo_name=extras.repository, extras=extras))
203 206
204 207 return HookResponse(0, output) + hook_response
205 208
206 209
207 210 def post_pull(extras):
208 211 """Hook executed after client pulls the code."""
209 212
210 213 audit_user = audit_logger.UserWrap(
211 214 username=extras.username,
212 215 ip_addr=extras.ip)
213 216 repo = audit_logger.RepoWrap(repo_name=extras.repository)
214 217 audit_logger.store(
215 218 'user.pull', action_data={'user_agent': extras.user_agent},
216 219 user=audit_user, repo=repo, commit=True)
217 220
218 221 statsd = StatsdClient.statsd
219 222 if statsd:
220 223 statsd.incr('rhodecode_pull_total', tags=[
221 224 f'user-agent:{user_agent_normalizer(extras.user_agent)}',
222 225 ])
223 226 output = ''
224 227 # make lock is a tri state False, True, None. We only make lock on True
225 228 if extras.make_lock is True and not is_shadow_repo(extras):
226 229 user = User.get_by_username(extras.username)
227 230 Repository.lock(Repository.get_by_repo_name(extras.repository),
228 231 user.user_id,
229 232 lock_reason=Repository.LOCK_PULL)
230 233 msg = 'Made lock on repo `{}`'.format(extras.repository)
231 234 output += msg
232 235
233 236 if extras.locked_by[0]:
234 237 locked_by = User.get(extras.locked_by[0]).username
235 238 reason = extras.locked_by[2]
236 239 _http_ret = HTTPLockedRC(
237 240 _locked_by_explanation(extras.repository, locked_by, reason))
238 241 if str(_http_ret.code).startswith('2'):
239 242 # 2xx Codes don't raise exceptions
240 243 output += _http_ret.title
241 244
242 245 # Propagate to external components.
243 246 hook_response = ''
244 247 if not is_shadow_repo(extras):
245 248 extras.hook_type = extras.hook_type or 'post_pull'
246 249 hook_response = post_pull_extension(
247 250 repo_store_path=Repository.base_path(), **extras)
248 251 events.trigger(events.RepoPullEvent(
249 252 repo_name=extras.repository, extras=extras))
250 253
251 254 return HookResponse(0, output) + hook_response
252 255
253 256
254 257 def post_push(extras):
255 258 """Hook executed after user pushes to the repository."""
256 259 commit_ids = extras.commit_ids
257 260
258 261 # log the push call
259 262 audit_user = audit_logger.UserWrap(
260 263 username=extras.username, ip_addr=extras.ip)
261 264 repo = audit_logger.RepoWrap(repo_name=extras.repository)
262 265 audit_logger.store(
263 266 'user.push', action_data={
264 267 'user_agent': extras.user_agent,
265 268 'commit_ids': commit_ids[:400]},
266 269 user=audit_user, repo=repo, commit=True)
267 270
268 271 statsd = StatsdClient.statsd
269 272 if statsd:
270 273 statsd.incr('rhodecode_push_total', tags=[
271 274 f'user-agent:{user_agent_normalizer(extras.user_agent)}',
272 275 ])
273 276
274 277 # Propagate to external components.
275 278 output = ''
276 279 # make lock is a tri state False, True, None. We only release lock on False
277 280 if extras.make_lock is False and not is_shadow_repo(extras):
278 281 Repository.unlock(Repository.get_by_repo_name(extras.repository))
279 282 msg = f'Released lock on repo `{extras.repository}`\n'
280 283 output += msg
281 284
282 285 if extras.locked_by[0]:
283 286 locked_by = User.get(extras.locked_by[0]).username
284 287 reason = extras.locked_by[2]
285 288 _http_ret = HTTPLockedRC(
286 289 _locked_by_explanation(extras.repository, locked_by, reason))
287 290 # TODO: johbo: if not?
288 291 if str(_http_ret.code).startswith('2'):
289 292 # 2xx Codes don't raise exceptions
290 293 output += _http_ret.title
291 294
292 295 if extras.new_refs:
293 296 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
294 297 safe_str(extras.server_url), safe_str(extras.repository))
295 298
296 299 for branch_name in extras.new_refs['branches']:
297 300 pr_link = tmpl.format(ref_type='branch', ref_name=safe_str(branch_name))
298 301 output += f'RhodeCode: open pull request link: {pr_link}\n'
299 302
300 303 for book_name in extras.new_refs['bookmarks']:
301 304 pr_link = tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name))
302 305 output += f'RhodeCode: open pull request link: {pr_link}\n'
303 306
304 307 hook_response = ''
305 308 if not is_shadow_repo(extras):
306 309 hook_response = post_push_extension(
307 310 repo_store_path=Repository.base_path(),
308 311 **extras)
309 312 events.trigger(events.RepoPushEvent(
310 313 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
311 314
312 315 output += 'RhodeCode: push completed\n'
313 316 return HookResponse(0, output) + hook_response
314 317
315 318
316 319 def _locked_by_explanation(repo_name, user_name, reason):
317 320 message = f'Repository `{repo_name}` locked by user `{user_name}`. Reason:`{reason}`'
318 321 return message
319 322
320 323
321 324 def check_allowed_create_user(user_dict, created_by, **kwargs):
322 325 # pre create hooks
323 326 if pre_create_user.is_active():
324 327 hook_result = pre_create_user(created_by=created_by, **user_dict)
325 328 allowed = hook_result.status == 0
326 329 if not allowed:
327 330 reason = hook_result.output
328 331 raise UserCreationError(reason)
329 332
330 333
331 334 class ExtensionCallback(object):
332 335 """
333 336 Forwards a given call to rcextensions, sanitizes keyword arguments.
334 337
335 338 Does check if there is an extension active for that hook. If it is
336 339 there, it will forward all `kwargs_keys` keyword arguments to the
337 340 extension callback.
338 341 """
339 342
340 343 def __init__(self, hook_name, kwargs_keys):
341 344 self._hook_name = hook_name
342 345 self._kwargs_keys = set(kwargs_keys)
343 346
344 347 def __call__(self, *args, **kwargs):
345 348 log.debug('Calling extension callback for `%s`', self._hook_name)
346 349 callback = self._get_callback()
347 350 if not callback:
348 351 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
349 352 return
350 353
351 354 kwargs_to_pass = {}
352 355 for key in self._kwargs_keys:
353 356 try:
354 357 kwargs_to_pass[key] = kwargs[key]
355 358 except KeyError:
356 359 log.error('Failed to fetch %s key from given kwargs. '
357 360 'Expected keys: %s', key, self._kwargs_keys)
358 361 raise
359 362
360 363 # backward compat for removed api_key for old hooks. This was it works
361 364 # with older rcextensions that require api_key present
362 365 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
363 366 kwargs_to_pass['api_key'] = '_DEPRECATED_'
364 367 return callback(**kwargs_to_pass)
365 368
366 369 def is_active(self):
367 370 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
368 371
369 372 def _get_callback(self):
370 373 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
371 374
372 375
373 376 pre_pull_extension = ExtensionCallback(
374 377 hook_name='PRE_PULL_HOOK',
375 378 kwargs_keys=(
376 379 'server_url', 'config', 'scm', 'username', 'ip', 'action',
377 380 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
378 381
379 382
380 383 post_pull_extension = ExtensionCallback(
381 384 hook_name='PULL_HOOK',
382 385 kwargs_keys=(
383 386 'server_url', 'config', 'scm', 'username', 'ip', 'action',
384 387 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
385 388
386 389
387 390 pre_push_extension = ExtensionCallback(
388 391 hook_name='PRE_PUSH_HOOK',
389 392 kwargs_keys=(
390 393 'server_url', 'config', 'scm', 'username', 'ip', 'action',
391 394 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
392 395
393 396
394 397 post_push_extension = ExtensionCallback(
395 398 hook_name='PUSH_HOOK',
396 399 kwargs_keys=(
397 400 'server_url', 'config', 'scm', 'username', 'ip', 'action',
398 401 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
399 402
400 403
401 404 pre_create_user = ExtensionCallback(
402 405 hook_name='PRE_CREATE_USER_HOOK',
403 406 kwargs_keys=(
404 407 'username', 'password', 'email', 'firstname', 'lastname', 'active',
405 408 'admin', 'created_by'))
406 409
407 410
408 411 create_pull_request = ExtensionCallback(
409 412 hook_name='CREATE_PULL_REQUEST',
410 413 kwargs_keys=(
411 414 'server_url', 'config', 'scm', 'username', 'ip', 'action',
412 415 'repository', 'pull_request_id', 'url', 'title', 'description',
413 416 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
414 417 'mergeable', 'source', 'target', 'author', 'reviewers'))
415 418
416 419
417 420 merge_pull_request = ExtensionCallback(
418 421 hook_name='MERGE_PULL_REQUEST',
419 422 kwargs_keys=(
420 423 'server_url', 'config', 'scm', 'username', 'ip', 'action',
421 424 'repository', 'pull_request_id', 'url', 'title', 'description',
422 425 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
423 426 'mergeable', 'source', 'target', 'author', 'reviewers'))
424 427
425 428
426 429 close_pull_request = ExtensionCallback(
427 430 hook_name='CLOSE_PULL_REQUEST',
428 431 kwargs_keys=(
429 432 'server_url', 'config', 'scm', 'username', 'ip', 'action',
430 433 'repository', 'pull_request_id', 'url', 'title', 'description',
431 434 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
432 435 'mergeable', 'source', 'target', 'author', 'reviewers'))
433 436
434 437
435 438 review_pull_request = ExtensionCallback(
436 439 hook_name='REVIEW_PULL_REQUEST',
437 440 kwargs_keys=(
438 441 'server_url', 'config', 'scm', 'username', 'ip', 'action',
439 442 'repository', 'pull_request_id', 'url', 'title', 'description',
440 443 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
441 444 'mergeable', 'source', 'target', 'author', 'reviewers'))
442 445
443 446
444 447 comment_pull_request = ExtensionCallback(
445 448 hook_name='COMMENT_PULL_REQUEST',
446 449 kwargs_keys=(
447 450 'server_url', 'config', 'scm', 'username', 'ip', 'action',
448 451 'repository', 'pull_request_id', 'url', 'title', 'description',
449 452 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
450 453 'mergeable', 'source', 'target', 'author', 'reviewers'))
451 454
452 455
453 456 comment_edit_pull_request = ExtensionCallback(
454 457 hook_name='COMMENT_EDIT_PULL_REQUEST',
455 458 kwargs_keys=(
456 459 'server_url', 'config', 'scm', 'username', 'ip', 'action',
457 460 'repository', 'pull_request_id', 'url', 'title', 'description',
458 461 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
459 462 'mergeable', 'source', 'target', 'author', 'reviewers'))
460 463
461 464
462 465 update_pull_request = ExtensionCallback(
463 466 hook_name='UPDATE_PULL_REQUEST',
464 467 kwargs_keys=(
465 468 'server_url', 'config', 'scm', 'username', 'ip', 'action',
466 469 'repository', 'pull_request_id', 'url', 'title', 'description',
467 470 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
468 471 'mergeable', 'source', 'target', 'author', 'reviewers'))
469 472
470 473
471 474 create_user = ExtensionCallback(
472 475 hook_name='CREATE_USER_HOOK',
473 476 kwargs_keys=(
474 477 'username', 'full_name_or_username', 'full_contact', 'user_id',
475 478 'name', 'firstname', 'short_contact', 'admin', 'lastname',
476 479 'ip_addresses', 'extern_type', 'extern_name',
477 480 'email', 'api_keys', 'last_login',
478 481 'full_name', 'active', 'password', 'emails',
479 482 'inherit_default_permissions', 'created_by', 'created_on'))
480 483
481 484
482 485 delete_user = ExtensionCallback(
483 486 hook_name='DELETE_USER_HOOK',
484 487 kwargs_keys=(
485 488 'username', 'full_name_or_username', 'full_contact', 'user_id',
486 489 'name', 'firstname', 'short_contact', 'admin', 'lastname',
487 490 'ip_addresses',
488 491 'email', 'last_login',
489 492 'full_name', 'active', 'password', 'emails',
490 493 'inherit_default_permissions', 'deleted_by'))
491 494
492 495
493 496 create_repository = ExtensionCallback(
494 497 hook_name='CREATE_REPO_HOOK',
495 498 kwargs_keys=(
496 499 'repo_name', 'repo_type', 'description', 'private', 'created_on',
497 500 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
498 501 'clone_uri', 'fork_id', 'group_id', 'created_by'))
499 502
500 503
501 504 delete_repository = ExtensionCallback(
502 505 hook_name='DELETE_REPO_HOOK',
503 506 kwargs_keys=(
504 507 'repo_name', 'repo_type', 'description', 'private', 'created_on',
505 508 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
506 509 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
507 510
508 511
509 512 comment_commit_repository = ExtensionCallback(
510 513 hook_name='COMMENT_COMMIT_REPO_HOOK',
511 514 kwargs_keys=(
512 515 'repo_name', 'repo_type', 'description', 'private', 'created_on',
513 516 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
514 517 'clone_uri', 'fork_id', 'group_id',
515 518 'repository', 'created_by', 'comment', 'commit'))
516 519
517 520 comment_edit_commit_repository = ExtensionCallback(
518 521 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
519 522 kwargs_keys=(
520 523 'repo_name', 'repo_type', 'description', 'private', 'created_on',
521 524 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
522 525 'clone_uri', 'fork_id', 'group_id',
523 526 'repository', 'created_by', 'comment', 'commit'))
524 527
525 528
526 529 create_repository_group = ExtensionCallback(
527 530 hook_name='CREATE_REPO_GROUP_HOOK',
528 531 kwargs_keys=(
529 532 'group_name', 'group_parent_id', 'group_description',
530 533 'group_id', 'user_id', 'created_by', 'created_on',
531 534 'enable_locking'))
@@ -1,436 +1,451 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import time
21 21 import logging
22 22 import tempfile
23 23 import traceback
24 24 import threading
25 25 import socket
26 26 import msgpack
27 27 import gevent
28 28
29 29 from http.server import BaseHTTPRequestHandler
30 30 from socketserver import TCPServer
31 31
32 32 import rhodecode
33 33 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
34 34 from rhodecode.model import meta
35 35 from rhodecode.lib import hooks_base
36 36 from rhodecode.lib.utils2 import AttributeDict
37 from rhodecode.lib.pyramid_utils import get_config
37 38 from rhodecode.lib.ext_json import json
38 39 from rhodecode.lib import rc_cache
39 40
40 41 log = logging.getLogger(__name__)
41 42
42 43
43 44 class HooksHttpHandler(BaseHTTPRequestHandler):
44 45
45 46 JSON_HOOKS_PROTO = 'json.v1'
46 47 MSGPACK_HOOKS_PROTO = 'msgpack.v1'
47 48 # starting with RhodeCode 5.0.0 MsgPack is the default, prior it used json
48 49 DEFAULT_HOOKS_PROTO = MSGPACK_HOOKS_PROTO
49 50
50 51 @classmethod
51 52 def serialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
52 53 if proto == cls.MSGPACK_HOOKS_PROTO:
53 54 return msgpack.packb(data)
54 55 return json.dumps(data)
55 56
56 57 @classmethod
57 58 def deserialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
58 59 if proto == cls.MSGPACK_HOOKS_PROTO:
59 60 return msgpack.unpackb(data)
60 61 return json.loads(data)
61 62
62 63 def do_POST(self):
63 64 hooks_proto, method, extras = self._read_request()
64 65 log.debug('Handling HooksHttpHandler %s with %s proto', method, hooks_proto)
65 66
66 67 txn_id = getattr(self.server, 'txn_id', None)
67 68 if txn_id:
68 69 log.debug('Computing TXN_ID based on `%s`:`%s`',
69 70 extras['repository'], extras['txn_id'])
70 71 computed_txn_id = rc_cache.utils.compute_key_from_params(
71 72 extras['repository'], extras['txn_id'])
72 73 if txn_id != computed_txn_id:
73 74 raise Exception(
74 75 'TXN ID fail: expected {} got {} instead'.format(
75 76 txn_id, computed_txn_id))
76 77
77 78 request = getattr(self.server, 'request', None)
78 79 try:
79 80 hooks = Hooks(request=request, log_prefix='HOOKS: {} '.format(self.server.server_address))
80 81 result = self._call_hook_method(hooks, method, extras)
81 82
82 83 except Exception as e:
83 84 exc_tb = traceback.format_exc()
84 85 result = {
85 86 'exception': e.__class__.__name__,
86 87 'exception_traceback': exc_tb,
87 88 'exception_args': e.args
88 89 }
89 90 self._write_response(hooks_proto, result)
90 91
91 92 def _read_request(self):
92 93 length = int(self.headers['Content-Length'])
93 94 # respect sent headers, fallback to OLD proto for compatability
94 95 hooks_proto = self.headers.get('rc-hooks-protocol') or self.JSON_HOOKS_PROTO
95 96 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
96 97 # support for new vcsserver msgpack based protocol hooks
97 98 body = self.rfile.read(length)
98 99 data = self.deserialize_data(body)
99 100 else:
100 101 body = self.rfile.read(length)
101 102 data = self.deserialize_data(body)
102 103
103 104 return hooks_proto, data['method'], data['extras']
104 105
105 106 def _write_response(self, hooks_proto, result):
106 107 self.send_response(200)
107 108 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
108 109 self.send_header("Content-type", "application/msgpack")
109 110 self.end_headers()
110 111 data = self.serialize_data(result)
111 112 self.wfile.write(data)
112 113 else:
113 114 self.send_header("Content-type", "text/json")
114 115 self.end_headers()
115 116 data = self.serialize_data(result)
116 117 self.wfile.write(data)
117 118
118 119 def _call_hook_method(self, hooks, method, extras):
119 120 try:
120 121 result = getattr(hooks, method)(extras)
121 122 finally:
122 123 meta.Session.remove()
123 124 return result
124 125
125 126 def log_message(self, format, *args):
126 127 """
127 128 This is an overridden method of BaseHTTPRequestHandler which logs using
128 129 logging library instead of writing directly to stderr.
129 130 """
130 131
131 132 message = format % args
132 133
133 134 log.debug(
134 135 "HOOKS: client=%s - - [%s] %s", self.client_address,
135 136 self.log_date_time_string(), message)
136 137
137 138
138 class DummyHooksCallbackDaemon(object):
139 hooks_uri = ''
140
139 class BaseHooksCallbackDaemon:
140 """
141 Basic context manager for actions that don't require some extra
142 """
141 143 def __init__(self):
142 144 self.hooks_module = Hooks.__module__
143 145
144 146 def __enter__(self):
145 147 log.debug('Running `%s` callback daemon', self.__class__.__name__)
146 148 return self
147 149
148 150 def __exit__(self, exc_type, exc_val, exc_tb):
149 151 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
150 152
151 153
154 class CeleryHooksCallbackDaemon(BaseHooksCallbackDaemon):
155 """
156 Context manger for achieving a compatibility with celery backend
157 """
158
159 def __init__(self, config):
160 self.task_queue = config.get('app:main', 'celery.broker_url')
161 self.task_backend = config.get('app:main', 'celery.result_backend')
162
163
152 164 class ThreadedHookCallbackDaemon(object):
153 165
154 166 _callback_thread = None
155 167 _daemon = None
156 168 _done = False
157 169 use_gevent = False
158 170
159 171 def __init__(self, txn_id=None, host=None, port=None):
160 172 self._prepare(txn_id=txn_id, host=host, port=port)
161 173 if self.use_gevent:
162 174 self._run_func = self._run_gevent
163 175 self._stop_func = self._stop_gevent
164 176 else:
165 177 self._run_func = self._run
166 178 self._stop_func = self._stop
167 179
168 180 def __enter__(self):
169 181 log.debug('Running `%s` callback daemon', self.__class__.__name__)
170 182 self._run_func()
171 183 return self
172 184
173 185 def __exit__(self, exc_type, exc_val, exc_tb):
174 186 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
175 187 self._stop_func()
176 188
177 189 def _prepare(self, txn_id=None, host=None, port=None):
178 190 raise NotImplementedError()
179 191
180 192 def _run(self):
181 193 raise NotImplementedError()
182 194
183 195 def _stop(self):
184 196 raise NotImplementedError()
185 197
186 198 def _run_gevent(self):
187 199 raise NotImplementedError()
188 200
189 201 def _stop_gevent(self):
190 202 raise NotImplementedError()
191 203
192 204
193 205 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
194 206 """
195 207 Context manager which will run a callback daemon in a background thread.
196 208 """
197 209
198 210 hooks_uri = None
199 211
200 212 # From Python docs: Polling reduces our responsiveness to a shutdown
201 213 # request and wastes cpu at all other times.
202 214 POLL_INTERVAL = 0.01
203 215
204 216 use_gevent = False
205 217
206 218 @property
207 219 def _hook_prefix(self):
208 220 return 'HOOKS: {} '.format(self.hooks_uri)
209 221
210 222 def get_hostname(self):
211 223 return socket.gethostname() or '127.0.0.1'
212 224
213 225 def get_available_port(self, min_port=20000, max_port=65535):
214 226 from rhodecode.lib.utils2 import get_available_port as _get_port
215 227 return _get_port(min_port, max_port)
216 228
217 229 def _prepare(self, txn_id=None, host=None, port=None):
218 230 from pyramid.threadlocal import get_current_request
219 231
220 232 if not host or host == "*":
221 233 host = self.get_hostname()
222 234 if not port:
223 235 port = self.get_available_port()
224 236
225 237 server_address = (host, port)
226 238 self.hooks_uri = '{}:{}'.format(host, port)
227 239 self.txn_id = txn_id
228 240 self._done = False
229 241
230 242 log.debug(
231 243 "%s Preparing HTTP callback daemon registering hook object: %s",
232 244 self._hook_prefix, HooksHttpHandler)
233 245
234 246 self._daemon = TCPServer(server_address, HooksHttpHandler)
235 247 # inject transaction_id for later verification
236 248 self._daemon.txn_id = self.txn_id
237 249
238 250 # pass the WEB app request into daemon
239 251 self._daemon.request = get_current_request()
240 252
241 253 def _run(self):
242 254 log.debug("Running thread-based loop of callback daemon in background")
243 255 callback_thread = threading.Thread(
244 256 target=self._daemon.serve_forever,
245 257 kwargs={'poll_interval': self.POLL_INTERVAL})
246 258 callback_thread.daemon = True
247 259 callback_thread.start()
248 260 self._callback_thread = callback_thread
249 261
250 262 def _run_gevent(self):
251 263 log.debug("Running gevent-based loop of callback daemon in background")
252 264 # create a new greenlet for the daemon's serve_forever method
253 265 callback_greenlet = gevent.spawn(
254 266 self._daemon.serve_forever,
255 267 poll_interval=self.POLL_INTERVAL)
256 268
257 269 # store reference to greenlet
258 270 self._callback_greenlet = callback_greenlet
259 271
260 272 # switch to this greenlet
261 273 gevent.sleep(0.01)
262 274
263 275 def _stop(self):
264 276 log.debug("Waiting for background thread to finish.")
265 277 self._daemon.shutdown()
266 278 self._callback_thread.join()
267 279 self._daemon = None
268 280 self._callback_thread = None
269 281 if self.txn_id:
270 282 txn_id_file = get_txn_id_data_path(self.txn_id)
271 283 log.debug('Cleaning up TXN ID %s', txn_id_file)
272 284 if os.path.isfile(txn_id_file):
273 285 os.remove(txn_id_file)
274 286
275 287 log.debug("Background thread done.")
276 288
277 289 def _stop_gevent(self):
278 290 log.debug("Waiting for background greenlet to finish.")
279 291
280 292 # if greenlet exists and is running
281 293 if self._callback_greenlet and not self._callback_greenlet.dead:
282 294 # shutdown daemon if it exists
283 295 if self._daemon:
284 296 self._daemon.shutdown()
285 297
286 298 # kill the greenlet
287 299 self._callback_greenlet.kill()
288 300
289 301 self._daemon = None
290 302 self._callback_greenlet = None
291 303
292 304 if self.txn_id:
293 305 txn_id_file = get_txn_id_data_path(self.txn_id)
294 306 log.debug('Cleaning up TXN ID %s', txn_id_file)
295 307 if os.path.isfile(txn_id_file):
296 308 os.remove(txn_id_file)
297 309
298 310 log.debug("Background greenlet done.")
299 311
300 312
301 313 def get_txn_id_data_path(txn_id):
302 314 import rhodecode
303 315
304 316 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
305 317 final_dir = os.path.join(root, 'svn_txn_id')
306 318
307 319 if not os.path.isdir(final_dir):
308 320 os.makedirs(final_dir)
309 321 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
310 322
311 323
312 324 def store_txn_id_data(txn_id, data_dict):
313 325 if not txn_id:
314 326 log.warning('Cannot store txn_id because it is empty')
315 327 return
316 328
317 329 path = get_txn_id_data_path(txn_id)
318 330 try:
319 331 with open(path, 'wb') as f:
320 332 f.write(json.dumps(data_dict))
321 333 except Exception:
322 334 log.exception('Failed to write txn_id metadata')
323 335
324 336
325 337 def get_txn_id_from_store(txn_id):
326 338 """
327 339 Reads txn_id from store and if present returns the data for callback manager
328 340 """
329 341 path = get_txn_id_data_path(txn_id)
330 342 try:
331 343 with open(path, 'rb') as f:
332 344 return json.loads(f.read())
333 345 except Exception:
334 346 return {}
335 347
336 348
337 def prepare_callback_daemon(extras, protocol, host, use_direct_calls, txn_id=None):
349 def prepare_callback_daemon(extras, protocol, host, txn_id=None):
338 350 txn_details = get_txn_id_from_store(txn_id)
339 351 port = txn_details.get('port', 0)
340 if use_direct_calls:
341 callback_daemon = DummyHooksCallbackDaemon()
342 extras['hooks_module'] = callback_daemon.hooks_module
343 else:
344 if protocol == 'http':
352 match protocol:
353 case 'http':
345 354 callback_daemon = HttpHooksCallbackDaemon(
346 355 txn_id=txn_id, host=host, port=port)
347 else:
356 case 'celery':
357 callback_daemon = CeleryHooksCallbackDaemon(get_config(extras['config']))
358 case 'local':
359 callback_daemon = BaseHooksCallbackDaemon()
360 case _:
348 361 log.error('Unsupported callback daemon protocol "%s"', protocol)
349 362 raise Exception('Unsupported callback daemon protocol.')
350 363
351 extras['hooks_uri'] = callback_daemon.hooks_uri
364 extras['hooks_uri'] = getattr(callback_daemon, 'hooks_uri', '')
365 extras['task_queue'] = getattr(callback_daemon, 'task_queue', '')
366 extras['task_backend'] = getattr(callback_daemon, 'task_backend', '')
352 367 extras['hooks_protocol'] = protocol
353 368 extras['time'] = time.time()
354 369
355 370 # register txn_id
356 371 extras['txn_id'] = txn_id
357 log.debug('Prepared a callback daemon: %s at url `%s`',
358 callback_daemon.__class__.__name__, callback_daemon.hooks_uri)
372 log.debug('Prepared a callback daemon: %s',
373 callback_daemon.__class__.__name__)
359 374 return callback_daemon, extras
360 375
361 376
362 377 class Hooks(object):
363 378 """
364 379 Exposes the hooks for remote call backs
365 380 """
366 381 def __init__(self, request=None, log_prefix=''):
367 382 self.log_prefix = log_prefix
368 383 self.request = request
369 384
370 385 def repo_size(self, extras):
371 386 log.debug("%sCalled repo_size of %s object", self.log_prefix, self)
372 387 return self._call_hook(hooks_base.repo_size, extras)
373 388
374 389 def pre_pull(self, extras):
375 390 log.debug("%sCalled pre_pull of %s object", self.log_prefix, self)
376 391 return self._call_hook(hooks_base.pre_pull, extras)
377 392
378 393 def post_pull(self, extras):
379 394 log.debug("%sCalled post_pull of %s object", self.log_prefix, self)
380 395 return self._call_hook(hooks_base.post_pull, extras)
381 396
382 397 def pre_push(self, extras):
383 398 log.debug("%sCalled pre_push of %s object", self.log_prefix, self)
384 399 return self._call_hook(hooks_base.pre_push, extras)
385 400
386 401 def post_push(self, extras):
387 402 log.debug("%sCalled post_push of %s object", self.log_prefix, self)
388 403 return self._call_hook(hooks_base.post_push, extras)
389 404
390 405 def _call_hook(self, hook, extras):
391 406 extras = AttributeDict(extras)
392 407 server_url = extras['server_url']
393 408
394 409 extras.request = self.request
395 410
396 411 try:
397 412 result = hook(extras)
398 413 if result is None:
399 414 raise Exception(
400 415 'Failed to obtain hook result from func: {}'.format(hook))
401 416 except HTTPBranchProtected as handled_error:
402 417 # Those special cases doesn't need error reporting. It's a case of
403 418 # locked repo or protected branch
404 419 result = AttributeDict({
405 420 'status': handled_error.code,
406 421 'output': handled_error.explanation
407 422 })
408 423 except (HTTPLockedRC, Exception) as error:
409 424 # locked needs different handling since we need to also
410 425 # handle PULL operations
411 426 exc_tb = ''
412 427 if not isinstance(error, HTTPLockedRC):
413 428 exc_tb = traceback.format_exc()
414 429 log.exception('%sException when handling hook %s', self.log_prefix, hook)
415 430 error_args = error.args
416 431 return {
417 432 'status': 128,
418 433 'output': '',
419 434 'exception': type(error).__name__,
420 435 'exception_traceback': exc_tb,
421 436 'exception_args': error_args,
422 437 }
423 438 finally:
424 439 meta.Session.remove()
425 440
426 441 log.debug('%sGot hook call response %s', self.log_prefix, result)
427 442 return {
428 443 'status': result.status,
429 444 'output': result.output,
430 445 }
431 446
432 447 def __enter__(self):
433 448 return self
434 449
435 450 def __exit__(self, exc_type, exc_val, exc_tb):
436 451 pass
@@ -1,701 +1,701 b''
1 1
2 2
3 3 # Copyright (C) 2014-2023 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 23 It's implemented with basic auth function
24 24 """
25 25
26 26 import os
27 27 import re
28 28 import io
29 29 import logging
30 30 import importlib
31 31 from functools import wraps
32 32 from lxml import etree
33 33
34 34 import time
35 35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36 36
37 37 from pyramid.httpexceptions import (
38 38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 39 from zope.cachedescriptors.property import Lazy as LazyProperty
40 40
41 41 import rhodecode
42 42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 43 from rhodecode.lib import rc_cache
44 44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 45 from rhodecode.lib.base import (
46 46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
49 49 from rhodecode.lib.middleware import appenlight
50 50 from rhodecode.lib.middleware.utils import scm_app_http
51 51 from rhodecode.lib.str_utils import safe_bytes
52 52 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
53 53 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool
54 54 from rhodecode.lib.vcs.conf import settings as vcs_settings
55 55 from rhodecode.lib.vcs.backends import base
56 56
57 57 from rhodecode.model import meta
58 58 from rhodecode.model.db import User, Repository, PullRequest
59 59 from rhodecode.model.scm import ScmModel
60 60 from rhodecode.model.pull_request import PullRequestModel
61 61 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
62 62
63 63 log = logging.getLogger(__name__)
64 64
65 65
66 66 def extract_svn_txn_id(acl_repo_name, data: bytes):
67 67 """
68 68 Helper method for extraction of svn txn_id from submitted XML data during
69 69 POST operations
70 70 """
71 71
72 72 try:
73 73 root = etree.fromstring(data)
74 74 pat = re.compile(r'/txn/(?P<txn_id>.*)')
75 75 for el in root:
76 76 if el.tag == '{DAV:}source':
77 77 for sub_el in el:
78 78 if sub_el.tag == '{DAV:}href':
79 79 match = pat.search(sub_el.text)
80 80 if match:
81 81 svn_tx_id = match.groupdict()['txn_id']
82 82 txn_id = rc_cache.utils.compute_key_from_params(
83 83 acl_repo_name, svn_tx_id)
84 84 return txn_id
85 85 except Exception:
86 86 log.exception('Failed to extract txn_id')
87 87
88 88
89 89 def initialize_generator(factory):
90 90 """
91 91 Initializes the returned generator by draining its first element.
92 92
93 93 This can be used to give a generator an initializer, which is the code
94 94 up to the first yield statement. This decorator enforces that the first
95 95 produced element has the value ``"__init__"`` to make its special
96 96 purpose very explicit in the using code.
97 97 """
98 98
99 99 @wraps(factory)
100 100 def wrapper(*args, **kwargs):
101 101 gen = factory(*args, **kwargs)
102 102 try:
103 103 init = next(gen)
104 104 except StopIteration:
105 105 raise ValueError('Generator must yield at least one element.')
106 106 if init != "__init__":
107 107 raise ValueError('First yielded element must be "__init__".')
108 108 return gen
109 109 return wrapper
110 110
111 111
112 112 class SimpleVCS(object):
113 113 """Common functionality for SCM HTTP handlers."""
114 114
115 115 SCM = 'unknown'
116 116
117 117 acl_repo_name = None
118 118 url_repo_name = None
119 119 vcs_repo_name = None
120 120 rc_extras = {}
121 121
122 122 # We have to handle requests to shadow repositories different than requests
123 123 # to normal repositories. Therefore we have to distinguish them. To do this
124 124 # we use this regex which will match only on URLs pointing to shadow
125 125 # repositories.
126 126 shadow_repo_re = re.compile(
127 127 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
128 128 '(?P<target>{slug_pat})/' # target repo
129 129 'pull-request/(?P<pr_id>\\d+)/' # pull request
130 130 'repository$' # shadow repo
131 131 .format(slug_pat=SLUG_RE.pattern))
132 132
133 133 def __init__(self, config, registry):
134 134 self.registry = registry
135 135 self.config = config
136 136 # re-populated by specialized middleware
137 137 self.repo_vcs_config = base.Config()
138 138
139 139 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
140 140 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
141 141
142 142 # authenticate this VCS request using authfunc
143 143 auth_ret_code_detection = \
144 144 str2bool(self.config.get('auth_ret_code_detection', False))
145 145 self.authenticate = BasicAuth(
146 146 '', authenticate, registry, config.get('auth_ret_code'),
147 147 auth_ret_code_detection, rc_realm=realm)
148 148 self.ip_addr = '0.0.0.0'
149 149
150 150 @LazyProperty
151 151 def global_vcs_config(self):
152 152 try:
153 153 return VcsSettingsModel().get_ui_settings_as_config_obj()
154 154 except Exception:
155 155 return base.Config()
156 156
157 157 @property
158 158 def base_path(self):
159 159 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
160 160
161 161 if not settings_path:
162 162 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
163 163
164 164 if not settings_path:
165 165 # try, maybe we passed in explicitly as config option
166 166 settings_path = self.config.get('base_path')
167 167
168 168 if not settings_path:
169 169 raise ValueError('FATAL: base_path is empty')
170 170 return settings_path
171 171
172 172 def set_repo_names(self, environ):
173 173 """
174 174 This will populate the attributes acl_repo_name, url_repo_name,
175 175 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
176 176 shadow) repositories all names are equal. In case of requests to a
177 177 shadow repository the acl-name points to the target repo of the pull
178 178 request and the vcs-name points to the shadow repo file system path.
179 179 The url-name is always the URL used by the vcs client program.
180 180
181 181 Example in case of a shadow repo:
182 182 acl_repo_name = RepoGroup/MyRepo
183 183 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
184 184 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
185 185 """
186 186 # First we set the repo name from URL for all attributes. This is the
187 187 # default if handling normal (non shadow) repo requests.
188 188 self.url_repo_name = self._get_repository_name(environ)
189 189 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
190 190 self.is_shadow_repo = False
191 191
192 192 # Check if this is a request to a shadow repository.
193 193 match = self.shadow_repo_re.match(self.url_repo_name)
194 194 if match:
195 195 match_dict = match.groupdict()
196 196
197 197 # Build acl repo name from regex match.
198 198 acl_repo_name = safe_str('{groups}{target}'.format(
199 199 groups=match_dict['groups'] or '',
200 200 target=match_dict['target']))
201 201
202 202 # Retrieve pull request instance by ID from regex match.
203 203 pull_request = PullRequest.get(match_dict['pr_id'])
204 204
205 205 # Only proceed if we got a pull request and if acl repo name from
206 206 # URL equals the target repo name of the pull request.
207 207 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
208 208
209 209 # Get file system path to shadow repository.
210 210 workspace_id = PullRequestModel()._workspace_id(pull_request)
211 211 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
212 212
213 213 # Store names for later usage.
214 214 self.vcs_repo_name = vcs_repo_name
215 215 self.acl_repo_name = acl_repo_name
216 216 self.is_shadow_repo = True
217 217
218 218 log.debug('Setting all VCS repository names: %s', {
219 219 'acl_repo_name': self.acl_repo_name,
220 220 'url_repo_name': self.url_repo_name,
221 221 'vcs_repo_name': self.vcs_repo_name,
222 222 })
223 223
224 224 @property
225 225 def scm_app(self):
226 226 custom_implementation = self.config['vcs.scm_app_implementation']
227 227 if custom_implementation == 'http':
228 228 log.debug('Using HTTP implementation of scm app.')
229 229 scm_app_impl = scm_app_http
230 230 else:
231 231 log.debug('Using custom implementation of scm_app: "{}"'.format(
232 232 custom_implementation))
233 233 scm_app_impl = importlib.import_module(custom_implementation)
234 234 return scm_app_impl
235 235
236 236 def _get_by_id(self, repo_name):
237 237 """
238 238 Gets a special pattern _<ID> from clone url and tries to replace it
239 239 with a repository_name for support of _<ID> non changeable urls
240 240 """
241 241
242 242 data = repo_name.split('/')
243 243 if len(data) >= 2:
244 244 from rhodecode.model.repo import RepoModel
245 245 by_id_match = RepoModel().get_repo_by_id(repo_name)
246 246 if by_id_match:
247 247 data[1] = by_id_match.repo_name
248 248
249 249 # Because PEP-3333-WSGI uses bytes-tunneled-in-latin-1 as PATH_INFO
250 250 # and we use this data
251 251 maybe_new_path = '/'.join(data)
252 252 return safe_bytes(maybe_new_path).decode('latin1')
253 253
254 254 def _invalidate_cache(self, repo_name):
255 255 """
256 256 Set's cache for this repository for invalidation on next access
257 257
258 258 :param repo_name: full repo name, also a cache key
259 259 """
260 260 ScmModel().mark_for_invalidation(repo_name)
261 261
262 262 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
263 263 db_repo = Repository.get_by_repo_name(repo_name)
264 264 if not db_repo:
265 265 log.debug('Repository `%s` not found inside the database.',
266 266 repo_name)
267 267 return False
268 268
269 269 if db_repo.repo_type != scm_type:
270 270 log.warning(
271 271 'Repository `%s` have incorrect scm_type, expected %s got %s',
272 272 repo_name, db_repo.repo_type, scm_type)
273 273 return False
274 274
275 275 config = db_repo._config
276 276 config.set('extensions', 'largefiles', '')
277 277 return is_valid_repo(
278 278 repo_name, base_path,
279 279 explicit_scm=scm_type, expect_scm=scm_type, config=config)
280 280
281 281 def valid_and_active_user(self, user):
282 282 """
283 283 Checks if that user is not empty, and if it's actually object it checks
284 284 if he's active.
285 285
286 286 :param user: user object or None
287 287 :return: boolean
288 288 """
289 289 if user is None:
290 290 return False
291 291
292 292 elif user.active:
293 293 return True
294 294
295 295 return False
296 296
297 297 @property
298 298 def is_shadow_repo_dir(self):
299 299 return os.path.isdir(self.vcs_repo_name)
300 300
301 301 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
302 302 plugin_id='', plugin_cache_active=False, cache_ttl=0):
303 303 """
304 304 Checks permissions using action (push/pull) user and repository
305 305 name. If plugin_cache and ttl is set it will use the plugin which
306 306 authenticated the user to store the cached permissions result for N
307 307 amount of seconds as in cache_ttl
308 308
309 309 :param action: push or pull action
310 310 :param user: user instance
311 311 :param repo_name: repository name
312 312 """
313 313
314 314 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
315 315 plugin_id, plugin_cache_active, cache_ttl)
316 316
317 317 user_id = user.user_id
318 318 cache_namespace_uid = f'cache_user_auth.{rc_cache.PERMISSIONS_CACHE_VER}.{user_id}'
319 319 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
320 320
321 321 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
322 322 expiration_time=cache_ttl,
323 323 condition=plugin_cache_active)
324 324 def compute_perm_vcs(
325 325 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
326 326
327 327 log.debug('auth: calculating permission access now...')
328 328 # check IP
329 329 inherit = user.inherit_default_permissions
330 330 ip_allowed = AuthUser.check_ip_allowed(
331 331 user_id, ip_addr, inherit_from_default=inherit)
332 332 if ip_allowed:
333 333 log.info('Access for IP:%s allowed', ip_addr)
334 334 else:
335 335 return False
336 336
337 337 if action == 'push':
338 338 perms = ('repository.write', 'repository.admin')
339 339 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
340 340 return False
341 341
342 342 else:
343 343 # any other action need at least read permission
344 344 perms = (
345 345 'repository.read', 'repository.write', 'repository.admin')
346 346 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
347 347 return False
348 348
349 349 return True
350 350
351 351 start = time.time()
352 352 log.debug('Running plugin `%s` permissions check', plugin_id)
353 353
354 354 # for environ based auth, password can be empty, but then the validation is
355 355 # on the server that fills in the env data needed for authentication
356 356 perm_result = compute_perm_vcs(
357 357 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
358 358
359 359 auth_time = time.time() - start
360 360 log.debug('Permissions for plugin `%s` completed in %.4fs, '
361 361 'expiration time of fetched cache %.1fs.',
362 362 plugin_id, auth_time, cache_ttl)
363 363
364 364 return perm_result
365 365
366 366 def _get_http_scheme(self, environ):
367 367 try:
368 368 return environ['wsgi.url_scheme']
369 369 except Exception:
370 370 log.exception('Failed to read http scheme')
371 371 return 'http'
372 372
373 373 def _check_ssl(self, environ, start_response):
374 374 """
375 375 Checks the SSL check flag and returns False if SSL is not present
376 376 and required True otherwise
377 377 """
378 378 org_proto = environ['wsgi._org_proto']
379 379 # check if we have SSL required ! if not it's a bad request !
380 380 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
381 381 if require_ssl and org_proto == 'http':
382 382 log.debug(
383 383 'Bad request: detected protocol is `%s` and '
384 384 'SSL/HTTPS is required.', org_proto)
385 385 return False
386 386 return True
387 387
388 388 def _get_default_cache_ttl(self):
389 389 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
390 390 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
391 391 plugin_settings = plugin.get_settings()
392 392 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
393 393 plugin_settings) or (False, 0)
394 394 return plugin_cache_active, cache_ttl
395 395
396 396 def __call__(self, environ, start_response):
397 397 try:
398 398 return self._handle_request(environ, start_response)
399 399 except Exception:
400 400 log.exception("Exception while handling request")
401 401 appenlight.track_exception(environ)
402 402 return HTTPInternalServerError()(environ, start_response)
403 403 finally:
404 404 meta.Session.remove()
405 405
406 406 def _handle_request(self, environ, start_response):
407 407 if not self._check_ssl(environ, start_response):
408 408 reason = ('SSL required, while RhodeCode was unable '
409 409 'to detect this as SSL request')
410 410 log.debug('User not allowed to proceed, %s', reason)
411 411 return HTTPNotAcceptable(reason)(environ, start_response)
412 412
413 413 if not self.url_repo_name:
414 414 log.warning('Repository name is empty: %s', self.url_repo_name)
415 415 # failed to get repo name, we fail now
416 416 return HTTPNotFound()(environ, start_response)
417 417 log.debug('Extracted repo name is %s', self.url_repo_name)
418 418
419 419 ip_addr = get_ip_addr(environ)
420 420 user_agent = get_user_agent(environ)
421 421 username = None
422 422
423 423 # skip passing error to error controller
424 424 environ['pylons.status_code_redirect'] = True
425 425
426 426 # ======================================================================
427 427 # GET ACTION PULL or PUSH
428 428 # ======================================================================
429 429 action = self._get_action(environ)
430 430
431 431 # ======================================================================
432 432 # Check if this is a request to a shadow repository of a pull request.
433 433 # In this case only pull action is allowed.
434 434 # ======================================================================
435 435 if self.is_shadow_repo and action != 'pull':
436 436 reason = 'Only pull action is allowed for shadow repositories.'
437 437 log.debug('User not allowed to proceed, %s', reason)
438 438 return HTTPNotAcceptable(reason)(environ, start_response)
439 439
440 440 # Check if the shadow repo actually exists, in case someone refers
441 441 # to it, and it has been deleted because of successful merge.
442 442 if self.is_shadow_repo and not self.is_shadow_repo_dir:
443 443 log.debug(
444 444 'Shadow repo detected, and shadow repo dir `%s` is missing',
445 445 self.is_shadow_repo_dir)
446 446 return HTTPNotFound()(environ, start_response)
447 447
448 448 # ======================================================================
449 449 # CHECK ANONYMOUS PERMISSION
450 450 # ======================================================================
451 451 detect_force_push = False
452 452 check_branch_perms = False
453 453 if action in ['pull', 'push']:
454 454 user_obj = anonymous_user = User.get_default_user()
455 455 auth_user = user_obj.AuthUser()
456 456 username = anonymous_user.username
457 457 if anonymous_user.active:
458 458 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
459 459 # ONLY check permissions if the user is activated
460 460 anonymous_perm = self._check_permission(
461 461 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
462 462 plugin_id='anonymous_access',
463 463 plugin_cache_active=plugin_cache_active,
464 464 cache_ttl=cache_ttl,
465 465 )
466 466 else:
467 467 anonymous_perm = False
468 468
469 469 if not anonymous_user.active or not anonymous_perm:
470 470 if not anonymous_user.active:
471 471 log.debug('Anonymous access is disabled, running '
472 472 'authentication')
473 473
474 474 if not anonymous_perm:
475 475 log.debug('Not enough credentials to access repo: `%s` '
476 476 'repository as anonymous user', self.acl_repo_name)
477 477
478 478
479 479 username = None
480 480 # ==============================================================
481 481 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
482 482 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
483 483 # ==============================================================
484 484
485 485 # try to auth based on environ, container auth methods
486 486 log.debug('Running PRE-AUTH for container|headers based authentication')
487 487
488 488 # headers auth, by just reading special headers and bypass the auth with user/passwd
489 489 pre_auth = authenticate(
490 490 '', '', environ, VCS_TYPE, registry=self.registry,
491 491 acl_repo_name=self.acl_repo_name)
492 492
493 493 if pre_auth and pre_auth.get('username'):
494 494 username = pre_auth['username']
495 495 log.debug('PRE-AUTH got `%s` as username', username)
496 496 if pre_auth:
497 497 log.debug('PRE-AUTH successful from %s',
498 498 pre_auth.get('auth_data', {}).get('_plugin'))
499 499
500 500 # If not authenticated by the container, running basic auth
501 501 # before inject the calling repo_name for special scope checks
502 502 self.authenticate.acl_repo_name = self.acl_repo_name
503 503
504 504 plugin_cache_active, cache_ttl = False, 0
505 505 plugin = None
506 506
507 507 # regular auth chain
508 508 if not username:
509 509 self.authenticate.realm = self.authenticate.get_rc_realm()
510 510
511 511 try:
512 512 auth_result = self.authenticate(environ)
513 513 except (UserCreationError, NotAllowedToCreateUserError) as e:
514 514 log.error(e)
515 515 reason = safe_str(e)
516 516 return HTTPNotAcceptable(reason)(environ, start_response)
517 517
518 518 if isinstance(auth_result, dict):
519 519 AUTH_TYPE.update(environ, 'basic')
520 520 REMOTE_USER.update(environ, auth_result['username'])
521 521 username = auth_result['username']
522 522 plugin = auth_result.get('auth_data', {}).get('_plugin')
523 523 log.info(
524 524 'MAIN-AUTH successful for user `%s` from %s plugin',
525 525 username, plugin)
526 526
527 527 plugin_cache_active, cache_ttl = auth_result.get(
528 528 'auth_data', {}).get('_ttl_cache') or (False, 0)
529 529 else:
530 530 return auth_result.wsgi_application(environ, start_response)
531 531
532 532 # ==============================================================
533 533 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
534 534 # ==============================================================
535 535 user = User.get_by_username(username)
536 536 if not self.valid_and_active_user(user):
537 537 return HTTPForbidden()(environ, start_response)
538 538 username = user.username
539 539 user_id = user.user_id
540 540
541 541 # check user attributes for password change flag
542 542 user_obj = user
543 543 auth_user = user_obj.AuthUser()
544 544 if user_obj and user_obj.username != User.DEFAULT_USER and \
545 545 user_obj.user_data.get('force_password_change'):
546 546 reason = 'password change required'
547 547 log.debug('User not allowed to authenticate, %s', reason)
548 548 return HTTPNotAcceptable(reason)(environ, start_response)
549 549
550 550 # check permissions for this repository
551 551 perm = self._check_permission(
552 552 action, user, auth_user, self.acl_repo_name, ip_addr,
553 553 plugin, plugin_cache_active, cache_ttl)
554 554 if not perm:
555 555 return HTTPForbidden()(environ, start_response)
556 556 environ['rc_auth_user_id'] = str(user_id)
557 557
558 558 if action == 'push':
559 559 perms = auth_user.get_branch_permissions(self.acl_repo_name)
560 560 if perms:
561 561 check_branch_perms = True
562 562 detect_force_push = True
563 563
564 564 # extras are injected into UI object and later available
565 565 # in hooks executed by RhodeCode
566 566 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
567 567
568 568 extras = vcs_operation_context(
569 569 environ, repo_name=self.acl_repo_name, username=username,
570 570 action=action, scm=self.SCM, check_locking=check_locking,
571 571 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
572 572 detect_force_push=detect_force_push
573 573 )
574 574
575 575 # ======================================================================
576 576 # REQUEST HANDLING
577 577 # ======================================================================
578 578 repo_path = os.path.join(
579 579 safe_str(self.base_path), safe_str(self.vcs_repo_name))
580 580 log.debug('Repository path is %s', repo_path)
581 581
582 582 fix_PATH()
583 583
584 584 log.info(
585 585 '%s action on %s repo "%s" by "%s" from %s %s',
586 586 action, self.SCM, safe_str(self.url_repo_name),
587 587 safe_str(username), ip_addr, user_agent)
588 588
589 589 return self._generate_vcs_response(
590 590 environ, start_response, repo_path, extras, action)
591 591
592 592 @initialize_generator
593 593 def _generate_vcs_response(
594 594 self, environ, start_response, repo_path, extras, action):
595 595 """
596 596 Returns a generator for the response content.
597 597
598 598 This method is implemented as a generator, so that it can trigger
599 599 the cache validation after all content sent back to the client. It
600 600 also handles the locking exceptions which will be triggered when
601 601 the first chunk is produced by the underlying WSGI application.
602 602 """
603 603
604 604 txn_id = ''
605 605 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
606 606 # case for SVN, we want to re-use the callback daemon port
607 607 # so we use the txn_id, for this we peek the body, and still save
608 608 # it as wsgi.input
609 609
610 610 stream = environ['wsgi.input']
611 611
612 612 if isinstance(stream, io.BytesIO):
613 613 data: bytes = stream.getvalue()
614 614 elif hasattr(stream, 'buf'): # most likely gunicorn.http.body.Body
615 615 data: bytes = stream.buf.getvalue()
616 616 else:
617 617 # fallback to the crudest way, copy the iterator
618 618 data = safe_bytes(stream.read())
619 619 environ['wsgi.input'] = io.BytesIO(data)
620 620
621 621 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
622 622
623 623 callback_daemon, extras = self._prepare_callback_daemon(
624 624 extras, environ, action, txn_id=txn_id)
625 625 log.debug('HOOKS extras is %s', extras)
626 626
627 627 http_scheme = self._get_http_scheme(environ)
628 628
629 629 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
630 630 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
631 631 with callback_daemon:
632 632 app.rc_extras = extras
633 633
634 634 try:
635 635 response = app(environ, start_response)
636 636 finally:
637 637 # This statement works together with the decorator
638 638 # "initialize_generator" above. The decorator ensures that
639 639 # we hit the first yield statement before the generator is
640 640 # returned back to the WSGI server. This is needed to
641 641 # ensure that the call to "app" above triggers the
642 642 # needed callback to "start_response" before the
643 643 # generator is actually used.
644 644 yield "__init__"
645 645
646 646 # iter content
647 647 for chunk in response:
648 648 yield chunk
649 649
650 650 try:
651 651 # invalidate cache on push
652 652 if action == 'push':
653 653 self._invalidate_cache(self.url_repo_name)
654 654 finally:
655 655 meta.Session.remove()
656 656
657 657 def _get_repository_name(self, environ):
658 658 """Get repository name out of the environmnent
659 659
660 660 :param environ: WSGI environment
661 661 """
662 662 raise NotImplementedError()
663 663
664 664 def _get_action(self, environ):
665 665 """Map request commands into a pull or push command.
666 666
667 667 :param environ: WSGI environment
668 668 """
669 669 raise NotImplementedError()
670 670
671 671 def _create_wsgi_app(self, repo_path, repo_name, config):
672 672 """Return the WSGI app that will finally handle the request."""
673 673 raise NotImplementedError()
674 674
675 675 def _create_config(self, extras, repo_name, scheme='http'):
676 676 """Create a safe config representation."""
677 677 raise NotImplementedError()
678 678
679 679 def _should_use_callback_daemon(self, extras, environ, action):
680 680 if extras.get('is_shadow_repo'):
681 681 # we don't want to execute hooks, and callback daemon for shadow repos
682 682 return False
683 683 return True
684 684
685 685 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
686 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
686 protocol = vcs_settings.HOOKS_PROTOCOL
687 687 if not self._should_use_callback_daemon(extras, environ, action):
688 688 # disable callback daemon for actions that don't require it
689 direct_calls = True
689 protocol = 'local'
690 690
691 691 return prepare_callback_daemon(
692 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
693 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
692 extras, protocol=protocol,
693 host=vcs_settings.HOOKS_HOST, txn_id=txn_id)
694 694
695 695
696 696 def _should_check_locking(query_string):
697 697 # this is kind of hacky, but due to how mercurial handles client-server
698 698 # server see all operation on commit; bookmarks, phases and
699 699 # obsolescence marker in different transaction, we don't want to check
700 700 # locking on those
701 701 return query_string not in ['cmd=listkeys']
@@ -1,808 +1,823 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 """
20 20 Utilities library for RhodeCode
21 21 """
22 22
23 23 import datetime
24 24 import decorator
25 25 import logging
26 26 import os
27 27 import re
28 28 import sys
29 29 import shutil
30 30 import socket
31 31 import tempfile
32 32 import traceback
33 33 import tarfile
34 34 import warnings
35 from functools import wraps
35 36 from os.path import join as jn
36 37
37 38 import paste
38 39 import pkg_resources
39 40 from webhelpers2.text import collapse, strip_tags, convert_accented_entities, convert_misc_entities
40 41
41 42 from mako import exceptions
42 43
43 44 from rhodecode.lib.hash_utils import sha256_safe, md5, sha1
45 from rhodecode.lib.type_utils import AttributeDict
44 46 from rhodecode.lib.str_utils import safe_bytes, safe_str
45 47 from rhodecode.lib.vcs.backends.base import Config
46 48 from rhodecode.lib.vcs.exceptions import VCSError
47 49 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
48 50 from rhodecode.lib.ext_json import sjson as json
49 51 from rhodecode.model import meta
50 52 from rhodecode.model.db import (
51 53 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
52 54 from rhodecode.model.meta import Session
53 55
54 56
55 57 log = logging.getLogger(__name__)
56 58
57 59 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
58 60
59 61 # String which contains characters that are not allowed in slug names for
60 62 # repositories or repository groups. It is properly escaped to use it in
61 63 # regular expressions.
62 64 SLUG_BAD_CHARS = re.escape(r'`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
63 65
64 66 # Regex that matches forbidden characters in repo/group slugs.
65 67 SLUG_BAD_CHAR_RE = re.compile(r'[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
66 68
67 69 # Regex that matches allowed characters in repo/group slugs.
68 70 SLUG_GOOD_CHAR_RE = re.compile(r'[^{}]'.format(SLUG_BAD_CHARS))
69 71
70 72 # Regex that matches whole repo/group slugs.
71 73 SLUG_RE = re.compile(r'[^{}]+'.format(SLUG_BAD_CHARS))
72 74
73 75 _license_cache = None
74 76
75 77
78 def adopt_for_celery(func):
79 """
80 Decorator designed to adopt hooks (from rhodecode.lib.hooks_base)
81 for further usage as a celery tasks.
82 """
83 @wraps(func)
84 def wrapper(extras):
85 extras = AttributeDict(extras)
86 # HooksResponse implements to_json method which must be used there.
87 return func(extras).to_json()
88 return wrapper
89
90
76 91 def repo_name_slug(value):
77 92 """
78 93 Return slug of name of repository
79 94 This function is called on each creation/modification
80 95 of repository to prevent bad names in repo
81 96 """
82 97
83 98 replacement_char = '-'
84 99
85 100 slug = strip_tags(value)
86 101 slug = convert_accented_entities(slug)
87 102 slug = convert_misc_entities(slug)
88 103
89 104 slug = SLUG_BAD_CHAR_RE.sub('', slug)
90 105 slug = re.sub(r'[\s]+', '-', slug)
91 106 slug = collapse(slug, replacement_char)
92 107
93 108 return slug
94 109
95 110
96 111 #==============================================================================
97 112 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
98 113 #==============================================================================
99 114 def get_repo_slug(request):
100 115 _repo = ''
101 116
102 117 if hasattr(request, 'db_repo_name'):
103 118 # if our requests has set db reference use it for name, this
104 119 # translates the example.com/_<id> into proper repo names
105 120 _repo = request.db_repo_name
106 121 elif getattr(request, 'matchdict', None):
107 122 # pyramid
108 123 _repo = request.matchdict.get('repo_name')
109 124
110 125 if _repo:
111 126 _repo = _repo.rstrip('/')
112 127 return _repo
113 128
114 129
115 130 def get_repo_group_slug(request):
116 131 _group = ''
117 132 if hasattr(request, 'db_repo_group'):
118 133 # if our requests has set db reference use it for name, this
119 134 # translates the example.com/_<id> into proper repo group names
120 135 _group = request.db_repo_group.group_name
121 136 elif getattr(request, 'matchdict', None):
122 137 # pyramid
123 138 _group = request.matchdict.get('repo_group_name')
124 139
125 140 if _group:
126 141 _group = _group.rstrip('/')
127 142 return _group
128 143
129 144
130 145 def get_user_group_slug(request):
131 146 _user_group = ''
132 147
133 148 if hasattr(request, 'db_user_group'):
134 149 _user_group = request.db_user_group.users_group_name
135 150 elif getattr(request, 'matchdict', None):
136 151 # pyramid
137 152 _user_group = request.matchdict.get('user_group_id')
138 153 _user_group_name = request.matchdict.get('user_group_name')
139 154 try:
140 155 if _user_group:
141 156 _user_group = UserGroup.get(_user_group)
142 157 elif _user_group_name:
143 158 _user_group = UserGroup.get_by_group_name(_user_group_name)
144 159
145 160 if _user_group:
146 161 _user_group = _user_group.users_group_name
147 162 except Exception:
148 163 log.exception('Failed to get user group by id and name')
149 164 # catch all failures here
150 165 return None
151 166
152 167 return _user_group
153 168
154 169
155 170 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
156 171 """
157 172 Scans given path for repos and return (name,(type,path)) tuple
158 173
159 174 :param path: path to scan for repositories
160 175 :param recursive: recursive search and return names with subdirs in front
161 176 """
162 177
163 178 # remove ending slash for better results
164 179 path = path.rstrip(os.sep)
165 180 log.debug('now scanning in %s location recursive:%s...', path, recursive)
166 181
167 182 def _get_repos(p):
168 183 dirpaths = get_dirpaths(p)
169 184 if not _is_dir_writable(p):
170 185 log.warning('repo path without write access: %s', p)
171 186
172 187 for dirpath in dirpaths:
173 188 if os.path.isfile(os.path.join(p, dirpath)):
174 189 continue
175 190 cur_path = os.path.join(p, dirpath)
176 191
177 192 # skip removed repos
178 193 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
179 194 continue
180 195
181 196 #skip .<somethin> dirs
182 197 if dirpath.startswith('.'):
183 198 continue
184 199
185 200 try:
186 201 scm_info = get_scm(cur_path)
187 202 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
188 203 except VCSError:
189 204 if not recursive:
190 205 continue
191 206 #check if this dir containts other repos for recursive scan
192 207 rec_path = os.path.join(p, dirpath)
193 208 if os.path.isdir(rec_path):
194 209 yield from _get_repos(rec_path)
195 210
196 211 return _get_repos(path)
197 212
198 213
199 214 def get_dirpaths(p: str) -> list:
200 215 try:
201 216 # OS-independable way of checking if we have at least read-only
202 217 # access or not.
203 218 dirpaths = os.listdir(p)
204 219 except OSError:
205 220 log.warning('ignoring repo path without read access: %s', p)
206 221 return []
207 222
208 223 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
209 224 # decode paths and suddenly returns unicode objects itself. The items it
210 225 # cannot decode are returned as strings and cause issues.
211 226 #
212 227 # Those paths are ignored here until a solid solution for path handling has
213 228 # been built.
214 229 expected_type = type(p)
215 230
216 231 def _has_correct_type(item):
217 232 if type(item) is not expected_type:
218 233 log.error(
219 234 "Ignoring path %s since it cannot be decoded into str.",
220 235 # Using "repr" to make sure that we see the byte value in case
221 236 # of support.
222 237 repr(item))
223 238 return False
224 239 return True
225 240
226 241 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
227 242
228 243 return dirpaths
229 244
230 245
231 246 def _is_dir_writable(path):
232 247 """
233 248 Probe if `path` is writable.
234 249
235 250 Due to trouble on Cygwin / Windows, this is actually probing if it is
236 251 possible to create a file inside of `path`, stat does not produce reliable
237 252 results in this case.
238 253 """
239 254 try:
240 255 with tempfile.TemporaryFile(dir=path):
241 256 pass
242 257 except OSError:
243 258 return False
244 259 return True
245 260
246 261
247 262 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
248 263 """
249 264 Returns True if given path is a valid repository False otherwise.
250 265 If expect_scm param is given also, compare if given scm is the same
251 266 as expected from scm parameter. If explicit_scm is given don't try to
252 267 detect the scm, just use the given one to check if repo is valid
253 268
254 269 :param repo_name:
255 270 :param base_path:
256 271 :param expect_scm:
257 272 :param explicit_scm:
258 273 :param config:
259 274
260 275 :return True: if given path is a valid repository
261 276 """
262 277 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
263 278 log.debug('Checking if `%s` is a valid path for repository. '
264 279 'Explicit type: %s', repo_name, explicit_scm)
265 280
266 281 try:
267 282 if explicit_scm:
268 283 detected_scms = [get_scm_backend(explicit_scm)(
269 284 full_path, config=config).alias]
270 285 else:
271 286 detected_scms = get_scm(full_path)
272 287
273 288 if expect_scm:
274 289 return detected_scms[0] == expect_scm
275 290 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
276 291 return True
277 292 except VCSError:
278 293 log.debug('path: %s is not a valid repo !', full_path)
279 294 return False
280 295
281 296
282 297 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
283 298 """
284 299 Returns True if a given path is a repository group, False otherwise
285 300
286 301 :param repo_group_name:
287 302 :param base_path:
288 303 """
289 304 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
290 305 log.debug('Checking if `%s` is a valid path for repository group',
291 306 repo_group_name)
292 307
293 308 # check if it's not a repo
294 309 if is_valid_repo(repo_group_name, base_path):
295 310 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
296 311 return False
297 312
298 313 try:
299 314 # we need to check bare git repos at higher level
300 315 # since we might match branches/hooks/info/objects or possible
301 316 # other things inside bare git repo
302 317 maybe_repo = os.path.dirname(full_path)
303 318 if maybe_repo == base_path:
304 319 # skip root level repo check; we know root location CANNOT BE a repo group
305 320 return False
306 321
307 322 scm_ = get_scm(maybe_repo)
308 323 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
309 324 return False
310 325 except VCSError:
311 326 pass
312 327
313 328 # check if it's a valid path
314 329 if skip_path_check or os.path.isdir(full_path):
315 330 log.debug('path: %s is a valid repo group !', full_path)
316 331 return True
317 332
318 333 log.debug('path: %s is not a valid repo group !', full_path)
319 334 return False
320 335
321 336
322 337 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
323 338 while True:
324 339 ok = input(prompt)
325 340 if ok.lower() in ('y', 'ye', 'yes'):
326 341 return True
327 342 if ok.lower() in ('n', 'no', 'nop', 'nope'):
328 343 return False
329 344 retries = retries - 1
330 345 if retries < 0:
331 346 raise OSError
332 347 print(complaint)
333 348
334 349 # propagated from mercurial documentation
335 350 ui_sections = [
336 351 'alias', 'auth',
337 352 'decode/encode', 'defaults',
338 353 'diff', 'email',
339 354 'extensions', 'format',
340 355 'merge-patterns', 'merge-tools',
341 356 'hooks', 'http_proxy',
342 357 'smtp', 'patch',
343 358 'paths', 'profiling',
344 359 'server', 'trusted',
345 360 'ui', 'web', ]
346 361
347 362
348 363 def config_data_from_db(clear_session=True, repo=None):
349 364 """
350 365 Read the configuration data from the database and return configuration
351 366 tuples.
352 367 """
353 368 from rhodecode.model.settings import VcsSettingsModel
354 369
355 370 config = []
356 371
357 372 sa = meta.Session()
358 373 settings_model = VcsSettingsModel(repo=repo, sa=sa)
359 374
360 375 ui_settings = settings_model.get_ui_settings()
361 376
362 377 ui_data = []
363 378 for setting in ui_settings:
364 379 if setting.active:
365 380 ui_data.append((setting.section, setting.key, setting.value))
366 381 config.append((
367 382 safe_str(setting.section), safe_str(setting.key),
368 383 safe_str(setting.value)))
369 384 if setting.key == 'push_ssl':
370 385 # force set push_ssl requirement to False, rhodecode
371 386 # handles that
372 387 config.append((
373 388 safe_str(setting.section), safe_str(setting.key), False))
374 389 log.debug(
375 390 'settings ui from db@repo[%s]: %s',
376 391 repo,
377 392 ','.join(['[{}] {}={}'.format(*s) for s in ui_data]))
378 393 if clear_session:
379 394 meta.Session.remove()
380 395
381 396 # TODO: mikhail: probably it makes no sense to re-read hooks information.
382 397 # It's already there and activated/deactivated
383 398 skip_entries = []
384 399 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
385 400 if 'pull' not in enabled_hook_classes:
386 401 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
387 402 if 'push' not in enabled_hook_classes:
388 403 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
389 404 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
390 405 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
391 406
392 407 config = [entry for entry in config if entry[:2] not in skip_entries]
393 408
394 409 return config
395 410
396 411
397 412 def make_db_config(clear_session=True, repo=None):
398 413 """
399 414 Create a :class:`Config` instance based on the values in the database.
400 415 """
401 416 config = Config()
402 417 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
403 418 for section, option, value in config_data:
404 419 config.set(section, option, value)
405 420 return config
406 421
407 422
408 423 def get_enabled_hook_classes(ui_settings):
409 424 """
410 425 Return the enabled hook classes.
411 426
412 427 :param ui_settings: List of ui_settings as returned
413 428 by :meth:`VcsSettingsModel.get_ui_settings`
414 429
415 430 :return: a list with the enabled hook classes. The order is not guaranteed.
416 431 :rtype: list
417 432 """
418 433 enabled_hooks = []
419 434 active_hook_keys = [
420 435 key for section, key, value, active in ui_settings
421 436 if section == 'hooks' and active]
422 437
423 438 hook_names = {
424 439 RhodeCodeUi.HOOK_PUSH: 'push',
425 440 RhodeCodeUi.HOOK_PULL: 'pull',
426 441 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
427 442 }
428 443
429 444 for key in active_hook_keys:
430 445 hook = hook_names.get(key)
431 446 if hook:
432 447 enabled_hooks.append(hook)
433 448
434 449 return enabled_hooks
435 450
436 451
437 452 def set_rhodecode_config(config):
438 453 """
439 454 Updates pyramid config with new settings from database
440 455
441 456 :param config:
442 457 """
443 458 from rhodecode.model.settings import SettingsModel
444 459 app_settings = SettingsModel().get_all_settings()
445 460
446 461 for k, v in list(app_settings.items()):
447 462 config[k] = v
448 463
449 464
450 465 def get_rhodecode_realm():
451 466 """
452 467 Return the rhodecode realm from database.
453 468 """
454 469 from rhodecode.model.settings import SettingsModel
455 470 realm = SettingsModel().get_setting_by_name('realm')
456 471 return safe_str(realm.app_settings_value)
457 472
458 473
459 474 def get_rhodecode_base_path():
460 475 """
461 476 Returns the base path. The base path is the filesystem path which points
462 477 to the repository store.
463 478 """
464 479
465 480 import rhodecode
466 481 return rhodecode.CONFIG['default_base_path']
467 482
468 483
469 484 def map_groups(path):
470 485 """
471 486 Given a full path to a repository, create all nested groups that this
472 487 repo is inside. This function creates parent-child relationships between
473 488 groups and creates default perms for all new groups.
474 489
475 490 :param paths: full path to repository
476 491 """
477 492 from rhodecode.model.repo_group import RepoGroupModel
478 493 sa = meta.Session()
479 494 groups = path.split(Repository.NAME_SEP)
480 495 parent = None
481 496 group = None
482 497
483 498 # last element is repo in nested groups structure
484 499 groups = groups[:-1]
485 500 rgm = RepoGroupModel(sa)
486 501 owner = User.get_first_super_admin()
487 502 for lvl, group_name in enumerate(groups):
488 503 group_name = '/'.join(groups[:lvl] + [group_name])
489 504 group = RepoGroup.get_by_group_name(group_name)
490 505 desc = '%s group' % group_name
491 506
492 507 # skip folders that are now removed repos
493 508 if REMOVED_REPO_PAT.match(group_name):
494 509 break
495 510
496 511 if group is None:
497 512 log.debug('creating group level: %s group_name: %s',
498 513 lvl, group_name)
499 514 group = RepoGroup(group_name, parent)
500 515 group.group_description = desc
501 516 group.user = owner
502 517 sa.add(group)
503 518 perm_obj = rgm._create_default_perms(group)
504 519 sa.add(perm_obj)
505 520 sa.flush()
506 521
507 522 parent = group
508 523 return group
509 524
510 525
511 526 def repo2db_mapper(initial_repo_list, remove_obsolete=False, force_hooks_rebuild=False):
512 527 """
513 528 maps all repos given in initial_repo_list, non existing repositories
514 529 are created, if remove_obsolete is True it also checks for db entries
515 530 that are not in initial_repo_list and removes them.
516 531
517 532 :param initial_repo_list: list of repositories found by scanning methods
518 533 :param remove_obsolete: check for obsolete entries in database
519 534 """
520 535 from rhodecode.model.repo import RepoModel
521 536 from rhodecode.model.repo_group import RepoGroupModel
522 537 from rhodecode.model.settings import SettingsModel
523 538
524 539 sa = meta.Session()
525 540 repo_model = RepoModel()
526 541 user = User.get_first_super_admin()
527 542 added = []
528 543
529 544 # creation defaults
530 545 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
531 546 enable_statistics = defs.get('repo_enable_statistics')
532 547 enable_locking = defs.get('repo_enable_locking')
533 548 enable_downloads = defs.get('repo_enable_downloads')
534 549 private = defs.get('repo_private')
535 550
536 551 for name, repo in list(initial_repo_list.items()):
537 552 group = map_groups(name)
538 553 str_name = safe_str(name)
539 554 db_repo = repo_model.get_by_repo_name(str_name)
540 555
541 556 # found repo that is on filesystem not in RhodeCode database
542 557 if not db_repo:
543 558 log.info('repository `%s` not found in the database, creating now', name)
544 559 added.append(name)
545 560 desc = (repo.description
546 561 if repo.description != 'unknown'
547 562 else '%s repository' % name)
548 563
549 564 db_repo = repo_model._create_repo(
550 565 repo_name=name,
551 566 repo_type=repo.alias,
552 567 description=desc,
553 568 repo_group=getattr(group, 'group_id', None),
554 569 owner=user,
555 570 enable_locking=enable_locking,
556 571 enable_downloads=enable_downloads,
557 572 enable_statistics=enable_statistics,
558 573 private=private,
559 574 state=Repository.STATE_CREATED
560 575 )
561 576 sa.commit()
562 577 # we added that repo just now, and make sure we updated server info
563 578 if db_repo.repo_type == 'git':
564 579 git_repo = db_repo.scm_instance()
565 580 # update repository server-info
566 581 log.debug('Running update server info')
567 582 git_repo._update_server_info(force=True)
568 583
569 584 db_repo.update_commit_cache()
570 585
571 586 config = db_repo._config
572 587 config.set('extensions', 'largefiles', '')
573 588 repo = db_repo.scm_instance(config=config)
574 589 repo.install_hooks(force=force_hooks_rebuild)
575 590
576 591 removed = []
577 592 if remove_obsolete:
578 593 # remove from database those repositories that are not in the filesystem
579 594 for repo in sa.query(Repository).all():
580 595 if repo.repo_name not in list(initial_repo_list.keys()):
581 596 log.debug("Removing non-existing repository found in db `%s`",
582 597 repo.repo_name)
583 598 try:
584 599 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
585 600 sa.commit()
586 601 removed.append(repo.repo_name)
587 602 except Exception:
588 603 # don't hold further removals on error
589 604 log.error(traceback.format_exc())
590 605 sa.rollback()
591 606
592 607 def splitter(full_repo_name):
593 608 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
594 609 gr_name = None
595 610 if len(_parts) == 2:
596 611 gr_name = _parts[0]
597 612 return gr_name
598 613
599 614 initial_repo_group_list = [splitter(x) for x in
600 615 list(initial_repo_list.keys()) if splitter(x)]
601 616
602 617 # remove from database those repository groups that are not in the
603 618 # filesystem due to parent child relationships we need to delete them
604 619 # in a specific order of most nested first
605 620 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
606 621 def nested_sort(gr):
607 622 return len(gr.split('/'))
608 623 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
609 624 if group_name not in initial_repo_group_list:
610 625 repo_group = RepoGroup.get_by_group_name(group_name)
611 626 if (repo_group.children.all() or
612 627 not RepoGroupModel().check_exist_filesystem(
613 628 group_name=group_name, exc_on_failure=False)):
614 629 continue
615 630
616 631 log.info(
617 632 'Removing non-existing repository group found in db `%s`',
618 633 group_name)
619 634 try:
620 635 RepoGroupModel(sa).delete(group_name, fs_remove=False)
621 636 sa.commit()
622 637 removed.append(group_name)
623 638 except Exception:
624 639 # don't hold further removals on error
625 640 log.exception(
626 641 'Unable to remove repository group `%s`',
627 642 group_name)
628 643 sa.rollback()
629 644 raise
630 645
631 646 return added, removed
632 647
633 648
634 649 def load_rcextensions(root_path):
635 650 import rhodecode
636 651 from rhodecode.config import conf
637 652
638 653 path = os.path.join(root_path)
639 654 sys.path.append(path)
640 655
641 656 try:
642 657 rcextensions = __import__('rcextensions')
643 658 except ImportError:
644 659 if os.path.isdir(os.path.join(path, 'rcextensions')):
645 660 log.warning('Unable to load rcextensions from %s', path)
646 661 rcextensions = None
647 662
648 663 if rcextensions:
649 664 log.info('Loaded rcextensions from %s...', rcextensions)
650 665 rhodecode.EXTENSIONS = rcextensions
651 666
652 667 # Additional mappings that are not present in the pygments lexers
653 668 conf.LANGUAGES_EXTENSIONS_MAP.update(
654 669 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
655 670
656 671
657 672 def get_custom_lexer(extension):
658 673 """
659 674 returns a custom lexer if it is defined in rcextensions module, or None
660 675 if there's no custom lexer defined
661 676 """
662 677 import rhodecode
663 678 from pygments import lexers
664 679
665 680 # custom override made by RhodeCode
666 681 if extension in ['mako']:
667 682 return lexers.get_lexer_by_name('html+mako')
668 683
669 684 # check if we didn't define this extension as other lexer
670 685 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
671 686 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
672 687 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
673 688 return lexers.get_lexer_by_name(_lexer_name)
674 689
675 690
676 691 #==============================================================================
677 692 # TEST FUNCTIONS AND CREATORS
678 693 #==============================================================================
679 694 def create_test_index(repo_location, config):
680 695 """
681 696 Makes default test index.
682 697 """
683 698 try:
684 699 import rc_testdata
685 700 except ImportError:
686 701 raise ImportError('Failed to import rc_testdata, '
687 702 'please make sure this package is installed from requirements_test.txt')
688 703 rc_testdata.extract_search_index(
689 704 'vcs_search_index', os.path.dirname(config['search.location']))
690 705
691 706
692 707 def create_test_directory(test_path):
693 708 """
694 709 Create test directory if it doesn't exist.
695 710 """
696 711 if not os.path.isdir(test_path):
697 712 log.debug('Creating testdir %s', test_path)
698 713 os.makedirs(test_path)
699 714
700 715
701 716 def create_test_database(test_path, config):
702 717 """
703 718 Makes a fresh database.
704 719 """
705 720 from rhodecode.lib.db_manage import DbManage
706 721 from rhodecode.lib.utils2 import get_encryption_key
707 722
708 723 # PART ONE create db
709 724 dbconf = config['sqlalchemy.db1.url']
710 725 enc_key = get_encryption_key(config)
711 726
712 727 log.debug('making test db %s', dbconf)
713 728
714 729 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
715 730 tests=True, cli_args={'force_ask': True}, enc_key=enc_key)
716 731 dbmanage.create_tables(override=True)
717 732 dbmanage.set_db_version()
718 733 # for tests dynamically set new root paths based on generated content
719 734 dbmanage.create_settings(dbmanage.config_prompt(test_path))
720 735 dbmanage.create_default_user()
721 736 dbmanage.create_test_admin_and_users()
722 737 dbmanage.create_permissions()
723 738 dbmanage.populate_default_permissions()
724 739 Session().commit()
725 740
726 741
727 742 def create_test_repositories(test_path, config):
728 743 """
729 744 Creates test repositories in the temporary directory. Repositories are
730 745 extracted from archives within the rc_testdata package.
731 746 """
732 747 import rc_testdata
733 748 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
734 749
735 750 log.debug('making test vcs repositories')
736 751
737 752 idx_path = config['search.location']
738 753 data_path = config['cache_dir']
739 754
740 755 # clean index and data
741 756 if idx_path and os.path.exists(idx_path):
742 757 log.debug('remove %s', idx_path)
743 758 shutil.rmtree(idx_path)
744 759
745 760 if data_path and os.path.exists(data_path):
746 761 log.debug('remove %s', data_path)
747 762 shutil.rmtree(data_path)
748 763
749 764 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
750 765 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
751 766
752 767 # Note: Subversion is in the process of being integrated with the system,
753 768 # until we have a properly packed version of the test svn repository, this
754 769 # tries to copy over the repo from a package "rc_testdata"
755 770 svn_repo_path = rc_testdata.get_svn_repo_archive()
756 771 with tarfile.open(svn_repo_path) as tar:
757 772 tar.extractall(jn(test_path, SVN_REPO))
758 773
759 774
760 775 def password_changed(auth_user, session):
761 776 # Never report password change in case of default user or anonymous user.
762 777 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
763 778 return False
764 779
765 780 password_hash = md5(safe_bytes(auth_user.password)) if auth_user.password else None
766 781 rhodecode_user = session.get('rhodecode_user', {})
767 782 session_password_hash = rhodecode_user.get('password', '')
768 783 return password_hash != session_password_hash
769 784
770 785
771 786 def read_opensource_licenses():
772 787 global _license_cache
773 788
774 789 if not _license_cache:
775 790 licenses = pkg_resources.resource_string(
776 791 'rhodecode', 'config/licenses.json')
777 792 _license_cache = json.loads(licenses)
778 793
779 794 return _license_cache
780 795
781 796
782 797 def generate_platform_uuid():
783 798 """
784 799 Generates platform UUID based on it's name
785 800 """
786 801 import platform
787 802
788 803 try:
789 804 uuid_list = [platform.platform()]
790 805 return sha256_safe(':'.join(uuid_list))
791 806 except Exception as e:
792 807 log.error('Failed to generate host uuid: %s', e)
793 808 return 'UNDEFINED'
794 809
795 810
796 811 def send_test_email(recipients, email_body='TEST EMAIL'):
797 812 """
798 813 Simple code for generating test emails.
799 814 Usage::
800 815
801 816 from rhodecode.lib import utils
802 817 utils.send_test_email()
803 818 """
804 819 from rhodecode.lib.celerylib import tasks, run_task
805 820
806 821 email_body = email_body_plaintext = email_body
807 822 subject = f'SUBJECT FROM: {socket.gethostname()}'
808 823 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
@@ -1,74 +1,73 b''
1 1 # Copyright (C) 2014-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 """
20 20 Internal settings for vcs-lib
21 21 """
22 22
23 23 # list of default encoding used in safe_str methods
24 24 DEFAULT_ENCODINGS = ['utf8']
25 25
26 26
27 27 # Compatibility version when creating SVN repositories. None means newest.
28 28 # Other available options are: pre-1.4-compatible, pre-1.5-compatible,
29 29 # pre-1.6-compatible, pre-1.8-compatible
30 30 SVN_COMPATIBLE_VERSION = None
31 31
32 32 ALIASES = ['hg', 'git', 'svn']
33 33
34 34 BACKENDS = {
35 35 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
36 36 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
37 37 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
38 38 }
39 39
40 40
41 41 ARCHIVE_SPECS = [
42 42 ('tbz2', 'application/x-bzip2', '.tbz2'),
43 43 ('tbz2', 'application/x-bzip2', '.tar.bz2'),
44 44
45 45 ('tgz', 'application/x-gzip', '.tgz'),
46 46 ('tgz', 'application/x-gzip', '.tar.gz'),
47 47
48 48 ('zip', 'application/zip', '.zip'),
49 49 ]
50 50
51 51 HOOKS_PROTOCOL = None
52 HOOKS_DIRECT_CALLS = False
53 52 HOOKS_HOST = '127.0.0.1'
54 53
55 54
56 55 MERGE_MESSAGE_TMPL = (
57 56 'Merge pull request !{pr_id} from {source_repo} {source_ref_name}\n\n '
58 57 '{pr_title}')
59 58 MERGE_DRY_RUN_MESSAGE = 'dry_run_merge_message_from_rhodecode'
60 59 MERGE_DRY_RUN_USER = 'Dry-Run User'
61 60 MERGE_DRY_RUN_EMAIL = 'dry-run-merge@rhodecode.com'
62 61
63 62
64 63 def available_aliases():
65 64 """
66 65 Mercurial is required for the system to work, so in case vcs.backends does
67 66 not include it, we make sure it will be available internally
68 67 TODO: anderson: refactor vcs.backends so it won't be necessary, VCS server
69 68 should be responsible to dictate available backends.
70 69 """
71 70 aliases = ALIASES[:]
72 71 if 'hg' not in aliases:
73 72 aliases += ['hg']
74 73 return aliases
@@ -1,2390 +1,2389 b''
1 1 # Copyright (C) 2012-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19
20 20 """
21 21 pull request model for RhodeCode
22 22 """
23 23
24 24 import logging
25 25 import os
26 26
27 27 import datetime
28 28 import urllib.request
29 29 import urllib.parse
30 30 import urllib.error
31 31 import collections
32 32
33 33 import dataclasses as dataclasses
34 34 from pyramid.threadlocal import get_current_request
35 35
36 36 from rhodecode.lib.vcs.nodes import FileNode
37 37 from rhodecode.translation import lazy_ugettext
38 38 from rhodecode.lib import helpers as h, hooks_utils, diffs
39 39 from rhodecode.lib import audit_logger
40 40 from collections import OrderedDict
41 41 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
42 42 from rhodecode.lib.ext_json import sjson as json
43 43 from rhodecode.lib.markup_renderer import (
44 44 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
45 45 from rhodecode.lib.hash_utils import md5_safe
46 46 from rhodecode.lib.str_utils import safe_str
47 47 from rhodecode.lib.utils2 import AttributeDict, get_current_rhodecode_user
48 48 from rhodecode.lib.vcs.backends.base import (
49 49 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason,
50 50 TargetRefMissing, SourceRefMissing)
51 51 from rhodecode.lib.vcs.conf import settings as vcs_settings
52 52 from rhodecode.lib.vcs.exceptions import (
53 53 CommitDoesNotExistError, EmptyRepositoryError)
54 54 from rhodecode.model import BaseModel
55 55 from rhodecode.model.changeset_status import ChangesetStatusModel
56 56 from rhodecode.model.comment import CommentsModel
57 57 from rhodecode.model.db import (
58 58 aliased, null, lazyload, and_, or_, select, func, String, cast, PullRequest, PullRequestReviewers, ChangesetStatus,
59 59 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule, User)
60 60 from rhodecode.model.meta import Session
61 61 from rhodecode.model.notification import NotificationModel, \
62 62 EmailNotificationModel
63 63 from rhodecode.model.scm import ScmModel
64 64 from rhodecode.model.settings import VcsSettingsModel
65 65
66 66
67 67 log = logging.getLogger(__name__)
68 68
69 69
70 70 # Data structure to hold the response data when updating commits during a pull
71 71 # request update.
72 72 class UpdateResponse(object):
73 73
74 74 def __init__(self, executed, reason, new, old, common_ancestor_id,
75 75 commit_changes, source_changed, target_changed):
76 76
77 77 self.executed = executed
78 78 self.reason = reason
79 79 self.new = new
80 80 self.old = old
81 81 self.common_ancestor_id = common_ancestor_id
82 82 self.changes = commit_changes
83 83 self.source_changed = source_changed
84 84 self.target_changed = target_changed
85 85
86 86
87 87 def get_diff_info(
88 88 source_repo, source_ref, target_repo, target_ref, get_authors=False,
89 89 get_commit_authors=True):
90 90 """
91 91 Calculates detailed diff information for usage in preview of creation of a pull-request.
92 92 This is also used for default reviewers logic
93 93 """
94 94
95 95 source_scm = source_repo.scm_instance()
96 96 target_scm = target_repo.scm_instance()
97 97
98 98 ancestor_id = target_scm.get_common_ancestor(target_ref, source_ref, source_scm)
99 99 if not ancestor_id:
100 100 raise ValueError(
101 101 'cannot calculate diff info without a common ancestor. '
102 102 'Make sure both repositories are related, and have a common forking commit.')
103 103
104 104 # case here is that want a simple diff without incoming commits,
105 105 # previewing what will be merged based only on commits in the source.
106 106 log.debug('Using ancestor %s as source_ref instead of %s',
107 107 ancestor_id, source_ref)
108 108
109 109 # source of changes now is the common ancestor
110 110 source_commit = source_scm.get_commit(commit_id=ancestor_id)
111 111 # target commit becomes the source ref as it is the last commit
112 112 # for diff generation this logic gives proper diff
113 113 target_commit = source_scm.get_commit(commit_id=source_ref)
114 114
115 115 vcs_diff = \
116 116 source_scm.get_diff(commit1=source_commit, commit2=target_commit,
117 117 ignore_whitespace=False, context=3)
118 118
119 119 diff_processor = diffs.DiffProcessor(vcs_diff, diff_format='newdiff',
120 120 diff_limit=0, file_limit=0, show_full_diff=True)
121 121
122 122 _parsed = diff_processor.prepare()
123 123
124 124 all_files = []
125 125 all_files_changes = []
126 126 changed_lines = {}
127 127 stats = [0, 0]
128 128 for f in _parsed:
129 129 all_files.append(f['filename'])
130 130 all_files_changes.append({
131 131 'filename': f['filename'],
132 132 'stats': f['stats']
133 133 })
134 134 stats[0] += f['stats']['added']
135 135 stats[1] += f['stats']['deleted']
136 136
137 137 changed_lines[f['filename']] = []
138 138 if len(f['chunks']) < 2:
139 139 continue
140 140 # first line is "context" information
141 141 for chunks in f['chunks'][1:]:
142 142 for chunk in chunks['lines']:
143 143 if chunk['action'] not in ('del', 'mod'):
144 144 continue
145 145 changed_lines[f['filename']].append(chunk['old_lineno'])
146 146
147 147 commit_authors = []
148 148 user_counts = {}
149 149 email_counts = {}
150 150 author_counts = {}
151 151 _commit_cache = {}
152 152
153 153 commits = []
154 154 if get_commit_authors:
155 155 log.debug('Obtaining commit authors from set of commits')
156 156 _compare_data = target_scm.compare(
157 157 target_ref, source_ref, source_scm, merge=True,
158 158 pre_load=["author", "date", "message"]
159 159 )
160 160
161 161 for commit in _compare_data:
162 162 # NOTE(marcink): we serialize here, so we don't produce more vcsserver calls on data returned
163 163 # at this function which is later called via JSON serialization
164 164 serialized_commit = dict(
165 165 author=commit.author,
166 166 date=commit.date,
167 167 message=commit.message,
168 168 commit_id=commit.raw_id,
169 169 raw_id=commit.raw_id
170 170 )
171 171 commits.append(serialized_commit)
172 172 user = User.get_from_cs_author(serialized_commit['author'])
173 173 if user and user not in commit_authors:
174 174 commit_authors.append(user)
175 175
176 176 # lines
177 177 if get_authors:
178 178 log.debug('Calculating authors of changed files')
179 179 target_commit = source_repo.get_commit(ancestor_id)
180 180
181 181 for fname, lines in changed_lines.items():
182 182
183 183 try:
184 184 node = target_commit.get_node(fname, pre_load=["is_binary"])
185 185 except Exception:
186 186 log.exception("Failed to load node with path %s", fname)
187 187 continue
188 188
189 189 if not isinstance(node, FileNode):
190 190 continue
191 191
192 192 # NOTE(marcink): for binary node we don't do annotation, just use last author
193 193 if node.is_binary:
194 194 author = node.last_commit.author
195 195 email = node.last_commit.author_email
196 196
197 197 user = User.get_from_cs_author(author)
198 198 if user:
199 199 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
200 200 author_counts[author] = author_counts.get(author, 0) + 1
201 201 email_counts[email] = email_counts.get(email, 0) + 1
202 202
203 203 continue
204 204
205 205 for annotation in node.annotate:
206 206 line_no, commit_id, get_commit_func, line_text = annotation
207 207 if line_no in lines:
208 208 if commit_id not in _commit_cache:
209 209 _commit_cache[commit_id] = get_commit_func()
210 210 commit = _commit_cache[commit_id]
211 211 author = commit.author
212 212 email = commit.author_email
213 213 user = User.get_from_cs_author(author)
214 214 if user:
215 215 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
216 216 author_counts[author] = author_counts.get(author, 0) + 1
217 217 email_counts[email] = email_counts.get(email, 0) + 1
218 218
219 219 log.debug('Default reviewers processing finished')
220 220
221 221 return {
222 222 'commits': commits,
223 223 'files': all_files_changes,
224 224 'stats': stats,
225 225 'ancestor': ancestor_id,
226 226 # original authors of modified files
227 227 'original_authors': {
228 228 'users': user_counts,
229 229 'authors': author_counts,
230 230 'emails': email_counts,
231 231 },
232 232 'commit_authors': commit_authors
233 233 }
234 234
235 235
236 236 class PullRequestModel(BaseModel):
237 237
238 238 cls = PullRequest
239 239
240 240 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
241 241
242 242 UPDATE_STATUS_MESSAGES = {
243 243 UpdateFailureReason.NONE: lazy_ugettext(
244 244 'Pull request update successful.'),
245 245 UpdateFailureReason.UNKNOWN: lazy_ugettext(
246 246 'Pull request update failed because of an unknown error.'),
247 247 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
248 248 'No update needed because the source and target have not changed.'),
249 249 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
250 250 'Pull request cannot be updated because the reference type is '
251 251 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
252 252 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
253 253 'This pull request cannot be updated because the target '
254 254 'reference is missing.'),
255 255 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
256 256 'This pull request cannot be updated because the source '
257 257 'reference is missing.'),
258 258 }
259 259 REF_TYPES = ['bookmark', 'book', 'tag', 'branch']
260 260 UPDATABLE_REF_TYPES = ['bookmark', 'book', 'branch']
261 261
262 262 def __get_pull_request(self, pull_request):
263 263 return self._get_instance((
264 264 PullRequest, PullRequestVersion), pull_request)
265 265
266 266 def _check_perms(self, perms, pull_request, user, api=False):
267 267 if not api:
268 268 return h.HasRepoPermissionAny(*perms)(
269 269 user=user, repo_name=pull_request.target_repo.repo_name)
270 270 else:
271 271 return h.HasRepoPermissionAnyApi(*perms)(
272 272 user=user, repo_name=pull_request.target_repo.repo_name)
273 273
274 274 def check_user_read(self, pull_request, user, api=False):
275 275 _perms = ('repository.admin', 'repository.write', 'repository.read',)
276 276 return self._check_perms(_perms, pull_request, user, api)
277 277
278 278 def check_user_merge(self, pull_request, user, api=False):
279 279 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
280 280 return self._check_perms(_perms, pull_request, user, api)
281 281
282 282 def check_user_update(self, pull_request, user, api=False):
283 283 owner = user.user_id == pull_request.user_id
284 284 return self.check_user_merge(pull_request, user, api) or owner
285 285
286 286 def check_user_delete(self, pull_request, user):
287 287 owner = user.user_id == pull_request.user_id
288 288 _perms = ('repository.admin',)
289 289 return self._check_perms(_perms, pull_request, user) or owner
290 290
291 291 def is_user_reviewer(self, pull_request, user):
292 292 return user.user_id in [
293 293 x.user_id for x in
294 294 pull_request.get_pull_request_reviewers(PullRequestReviewers.ROLE_REVIEWER)
295 295 if x.user
296 296 ]
297 297
298 298 def check_user_change_status(self, pull_request, user, api=False):
299 299 return self.check_user_update(pull_request, user, api) \
300 300 or self.is_user_reviewer(pull_request, user)
301 301
302 302 def check_user_comment(self, pull_request, user):
303 303 owner = user.user_id == pull_request.user_id
304 304 return self.check_user_read(pull_request, user) or owner
305 305
306 306 def get(self, pull_request):
307 307 return self.__get_pull_request(pull_request)
308 308
309 309 def _prepare_get_all_query(self, repo_name, search_q=None, source=False,
310 310 statuses=None, opened_by=None, order_by=None,
311 311 order_dir='desc', only_created=False):
312 312 repo = None
313 313 if repo_name:
314 314 repo = self._get_repo(repo_name)
315 315
316 316 q = PullRequest.query()
317 317
318 318 if search_q:
319 319 like_expression = u'%{}%'.format(safe_str(search_q))
320 320 q = q.join(User, User.user_id == PullRequest.user_id)
321 321 q = q.filter(or_(
322 322 cast(PullRequest.pull_request_id, String).ilike(like_expression),
323 323 User.username.ilike(like_expression),
324 324 PullRequest.title.ilike(like_expression),
325 325 PullRequest.description.ilike(like_expression),
326 326 ))
327 327
328 328 # source or target
329 329 if repo and source:
330 330 q = q.filter(PullRequest.source_repo == repo)
331 331 elif repo:
332 332 q = q.filter(PullRequest.target_repo == repo)
333 333
334 334 # closed,opened
335 335 if statuses:
336 336 q = q.filter(PullRequest.status.in_(statuses))
337 337
338 338 # opened by filter
339 339 if opened_by:
340 340 q = q.filter(PullRequest.user_id.in_(opened_by))
341 341
342 342 # only get those that are in "created" state
343 343 if only_created:
344 344 q = q.filter(PullRequest.pull_request_state == PullRequest.STATE_CREATED)
345 345
346 346 order_map = {
347 347 'name_raw': PullRequest.pull_request_id,
348 348 'id': PullRequest.pull_request_id,
349 349 'title': PullRequest.title,
350 350 'updated_on_raw': PullRequest.updated_on,
351 351 'target_repo': PullRequest.target_repo_id
352 352 }
353 353 if order_by and order_by in order_map:
354 354 if order_dir == 'asc':
355 355 q = q.order_by(order_map[order_by].asc())
356 356 else:
357 357 q = q.order_by(order_map[order_by].desc())
358 358
359 359 return q
360 360
361 361 def count_all(self, repo_name, search_q=None, source=False, statuses=None,
362 362 opened_by=None):
363 363 """
364 364 Count the number of pull requests for a specific repository.
365 365
366 366 :param repo_name: target or source repo
367 367 :param search_q: filter by text
368 368 :param source: boolean flag to specify if repo_name refers to source
369 369 :param statuses: list of pull request statuses
370 370 :param opened_by: author user of the pull request
371 371 :returns: int number of pull requests
372 372 """
373 373 q = self._prepare_get_all_query(
374 374 repo_name, search_q=search_q, source=source, statuses=statuses,
375 375 opened_by=opened_by)
376 376
377 377 return q.count()
378 378
379 379 def get_all(self, repo_name, search_q=None, source=False, statuses=None,
380 380 opened_by=None, offset=0, length=None, order_by=None, order_dir='desc'):
381 381 """
382 382 Get all pull requests for a specific repository.
383 383
384 384 :param repo_name: target or source repo
385 385 :param search_q: filter by text
386 386 :param source: boolean flag to specify if repo_name refers to source
387 387 :param statuses: list of pull request statuses
388 388 :param opened_by: author user of the pull request
389 389 :param offset: pagination offset
390 390 :param length: length of returned list
391 391 :param order_by: order of the returned list
392 392 :param order_dir: 'asc' or 'desc' ordering direction
393 393 :returns: list of pull requests
394 394 """
395 395 q = self._prepare_get_all_query(
396 396 repo_name, search_q=search_q, source=source, statuses=statuses,
397 397 opened_by=opened_by, order_by=order_by, order_dir=order_dir)
398 398
399 399 if length:
400 400 pull_requests = q.limit(length).offset(offset).all()
401 401 else:
402 402 pull_requests = q.all()
403 403
404 404 return pull_requests
405 405
406 406 def count_awaiting_review(self, repo_name, search_q=None, statuses=None):
407 407 """
408 408 Count the number of pull requests for a specific repository that are
409 409 awaiting review.
410 410
411 411 :param repo_name: target or source repo
412 412 :param search_q: filter by text
413 413 :param statuses: list of pull request statuses
414 414 :returns: int number of pull requests
415 415 """
416 416 pull_requests = self.get_awaiting_review(
417 417 repo_name, search_q=search_q, statuses=statuses)
418 418
419 419 return len(pull_requests)
420 420
421 421 def get_awaiting_review(self, repo_name, search_q=None, statuses=None,
422 422 offset=0, length=None, order_by=None, order_dir='desc'):
423 423 """
424 424 Get all pull requests for a specific repository that are awaiting
425 425 review.
426 426
427 427 :param repo_name: target or source repo
428 428 :param search_q: filter by text
429 429 :param statuses: list of pull request statuses
430 430 :param offset: pagination offset
431 431 :param length: length of returned list
432 432 :param order_by: order of the returned list
433 433 :param order_dir: 'asc' or 'desc' ordering direction
434 434 :returns: list of pull requests
435 435 """
436 436 pull_requests = self.get_all(
437 437 repo_name, search_q=search_q, statuses=statuses,
438 438 order_by=order_by, order_dir=order_dir)
439 439
440 440 _filtered_pull_requests = []
441 441 for pr in pull_requests:
442 442 status = pr.calculated_review_status()
443 443 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
444 444 ChangesetStatus.STATUS_UNDER_REVIEW]:
445 445 _filtered_pull_requests.append(pr)
446 446 if length:
447 447 return _filtered_pull_requests[offset:offset+length]
448 448 else:
449 449 return _filtered_pull_requests
450 450
451 451 def _prepare_awaiting_my_review_review_query(
452 452 self, repo_name, user_id, search_q=None, statuses=None,
453 453 order_by=None, order_dir='desc'):
454 454
455 455 for_review_statuses = [
456 456 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
457 457 ]
458 458
459 459 pull_request_alias = aliased(PullRequest)
460 460 status_alias = aliased(ChangesetStatus)
461 461 reviewers_alias = aliased(PullRequestReviewers)
462 462 repo_alias = aliased(Repository)
463 463
464 464 last_ver_subq = Session()\
465 465 .query(func.min(ChangesetStatus.version)) \
466 466 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
467 467 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
468 468 .subquery()
469 469
470 470 q = Session().query(pull_request_alias) \
471 471 .options(lazyload(pull_request_alias.author)) \
472 472 .join(reviewers_alias,
473 473 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
474 474 .join(repo_alias,
475 475 repo_alias.repo_id == pull_request_alias.target_repo_id) \
476 476 .outerjoin(status_alias,
477 477 and_(status_alias.user_id == reviewers_alias.user_id,
478 478 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
479 479 .filter(or_(status_alias.version == null(),
480 480 status_alias.version == last_ver_subq)) \
481 481 .filter(reviewers_alias.user_id == user_id) \
482 482 .filter(repo_alias.repo_name == repo_name) \
483 483 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
484 484 .group_by(pull_request_alias)
485 485
486 486 # closed,opened
487 487 if statuses:
488 488 q = q.filter(pull_request_alias.status.in_(statuses))
489 489
490 490 if search_q:
491 491 like_expression = u'%{}%'.format(safe_str(search_q))
492 492 q = q.join(User, User.user_id == pull_request_alias.user_id)
493 493 q = q.filter(or_(
494 494 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
495 495 User.username.ilike(like_expression),
496 496 pull_request_alias.title.ilike(like_expression),
497 497 pull_request_alias.description.ilike(like_expression),
498 498 ))
499 499
500 500 order_map = {
501 501 'name_raw': pull_request_alias.pull_request_id,
502 502 'title': pull_request_alias.title,
503 503 'updated_on_raw': pull_request_alias.updated_on,
504 504 'target_repo': pull_request_alias.target_repo_id
505 505 }
506 506 if order_by and order_by in order_map:
507 507 if order_dir == 'asc':
508 508 q = q.order_by(order_map[order_by].asc())
509 509 else:
510 510 q = q.order_by(order_map[order_by].desc())
511 511
512 512 return q
513 513
514 514 def count_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None):
515 515 """
516 516 Count the number of pull requests for a specific repository that are
517 517 awaiting review from a specific user.
518 518
519 519 :param repo_name: target or source repo
520 520 :param user_id: reviewer user of the pull request
521 521 :param search_q: filter by text
522 522 :param statuses: list of pull request statuses
523 523 :returns: int number of pull requests
524 524 """
525 525 q = self._prepare_awaiting_my_review_review_query(
526 526 repo_name, user_id, search_q=search_q, statuses=statuses)
527 527 return q.count()
528 528
529 529 def get_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None,
530 530 offset=0, length=None, order_by=None, order_dir='desc'):
531 531 """
532 532 Get all pull requests for a specific repository that are awaiting
533 533 review from a specific user.
534 534
535 535 :param repo_name: target or source repo
536 536 :param user_id: reviewer user of the pull request
537 537 :param search_q: filter by text
538 538 :param statuses: list of pull request statuses
539 539 :param offset: pagination offset
540 540 :param length: length of returned list
541 541 :param order_by: order of the returned list
542 542 :param order_dir: 'asc' or 'desc' ordering direction
543 543 :returns: list of pull requests
544 544 """
545 545
546 546 q = self._prepare_awaiting_my_review_review_query(
547 547 repo_name, user_id, search_q=search_q, statuses=statuses,
548 548 order_by=order_by, order_dir=order_dir)
549 549
550 550 if length:
551 551 pull_requests = q.limit(length).offset(offset).all()
552 552 else:
553 553 pull_requests = q.all()
554 554
555 555 return pull_requests
556 556
557 557 def _prepare_im_participating_query(self, user_id=None, statuses=None, query='',
558 558 order_by=None, order_dir='desc'):
559 559 """
560 560 return a query of pull-requests user is an creator, or he's added as a reviewer
561 561 """
562 562 q = PullRequest.query()
563 563 if user_id:
564 564
565 565 base_query = select(PullRequestReviewers)\
566 566 .where(PullRequestReviewers.user_id == user_id)\
567 567 .with_only_columns(PullRequestReviewers.pull_request_id)
568 568
569 569 user_filter = or_(
570 570 PullRequest.user_id == user_id,
571 571 PullRequest.pull_request_id.in_(base_query)
572 572 )
573 573 q = PullRequest.query().filter(user_filter)
574 574
575 575 # closed,opened
576 576 if statuses:
577 577 q = q.filter(PullRequest.status.in_(statuses))
578 578
579 579 if query:
580 580 like_expression = u'%{}%'.format(safe_str(query))
581 581 q = q.join(User, User.user_id == PullRequest.user_id)
582 582 q = q.filter(or_(
583 583 cast(PullRequest.pull_request_id, String).ilike(like_expression),
584 584 User.username.ilike(like_expression),
585 585 PullRequest.title.ilike(like_expression),
586 586 PullRequest.description.ilike(like_expression),
587 587 ))
588 588
589 589 order_map = {
590 590 'name_raw': PullRequest.pull_request_id,
591 591 'title': PullRequest.title,
592 592 'updated_on_raw': PullRequest.updated_on,
593 593 'target_repo': PullRequest.target_repo_id
594 594 }
595 595 if order_by and order_by in order_map:
596 596 if order_dir == 'asc':
597 597 q = q.order_by(order_map[order_by].asc())
598 598 else:
599 599 q = q.order_by(order_map[order_by].desc())
600 600
601 601 return q
602 602
603 603 def count_im_participating_in(self, user_id=None, statuses=None, query=''):
604 604 q = self._prepare_im_participating_query(user_id, statuses=statuses, query=query)
605 605 return q.count()
606 606
607 607 def get_im_participating_in(
608 608 self, user_id=None, statuses=None, query='', offset=0,
609 609 length=None, order_by=None, order_dir='desc'):
610 610 """
611 611 Get all Pull requests that i'm participating in as a reviewer, or i have opened
612 612 """
613 613
614 614 q = self._prepare_im_participating_query(
615 615 user_id, statuses=statuses, query=query, order_by=order_by,
616 616 order_dir=order_dir)
617 617
618 618 if length:
619 619 pull_requests = q.limit(length).offset(offset).all()
620 620 else:
621 621 pull_requests = q.all()
622 622
623 623 return pull_requests
624 624
625 625 def _prepare_participating_in_for_review_query(
626 626 self, user_id, statuses=None, query='', order_by=None, order_dir='desc'):
627 627
628 628 for_review_statuses = [
629 629 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
630 630 ]
631 631
632 632 pull_request_alias = aliased(PullRequest)
633 633 status_alias = aliased(ChangesetStatus)
634 634 reviewers_alias = aliased(PullRequestReviewers)
635 635
636 636 last_ver_subq = Session()\
637 637 .query(func.min(ChangesetStatus.version)) \
638 638 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
639 639 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
640 640 .subquery()
641 641
642 642 q = Session().query(pull_request_alias) \
643 643 .options(lazyload(pull_request_alias.author)) \
644 644 .join(reviewers_alias,
645 645 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
646 646 .outerjoin(status_alias,
647 647 and_(status_alias.user_id == reviewers_alias.user_id,
648 648 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
649 649 .filter(or_(status_alias.version == null(),
650 650 status_alias.version == last_ver_subq)) \
651 651 .filter(reviewers_alias.user_id == user_id) \
652 652 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
653 653 .group_by(pull_request_alias)
654 654
655 655 # closed,opened
656 656 if statuses:
657 657 q = q.filter(pull_request_alias.status.in_(statuses))
658 658
659 659 if query:
660 660 like_expression = u'%{}%'.format(safe_str(query))
661 661 q = q.join(User, User.user_id == pull_request_alias.user_id)
662 662 q = q.filter(or_(
663 663 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
664 664 User.username.ilike(like_expression),
665 665 pull_request_alias.title.ilike(like_expression),
666 666 pull_request_alias.description.ilike(like_expression),
667 667 ))
668 668
669 669 order_map = {
670 670 'name_raw': pull_request_alias.pull_request_id,
671 671 'title': pull_request_alias.title,
672 672 'updated_on_raw': pull_request_alias.updated_on,
673 673 'target_repo': pull_request_alias.target_repo_id
674 674 }
675 675 if order_by and order_by in order_map:
676 676 if order_dir == 'asc':
677 677 q = q.order_by(order_map[order_by].asc())
678 678 else:
679 679 q = q.order_by(order_map[order_by].desc())
680 680
681 681 return q
682 682
683 683 def count_im_participating_in_for_review(self, user_id, statuses=None, query=''):
684 684 q = self._prepare_participating_in_for_review_query(user_id, statuses=statuses, query=query)
685 685 return q.count()
686 686
687 687 def get_im_participating_in_for_review(
688 688 self, user_id, statuses=None, query='', offset=0,
689 689 length=None, order_by=None, order_dir='desc'):
690 690 """
691 691 Get all Pull requests that needs user approval or rejection
692 692 """
693 693
694 694 q = self._prepare_participating_in_for_review_query(
695 695 user_id, statuses=statuses, query=query, order_by=order_by,
696 696 order_dir=order_dir)
697 697
698 698 if length:
699 699 pull_requests = q.limit(length).offset(offset).all()
700 700 else:
701 701 pull_requests = q.all()
702 702
703 703 return pull_requests
704 704
705 705 def get_versions(self, pull_request):
706 706 """
707 707 returns version of pull request sorted by ID descending
708 708 """
709 709 return PullRequestVersion.query()\
710 710 .filter(PullRequestVersion.pull_request == pull_request)\
711 711 .order_by(PullRequestVersion.pull_request_version_id.asc())\
712 712 .all()
713 713
714 714 def get_pr_version(self, pull_request_id, version=None):
715 715 at_version = None
716 716
717 717 if version and version == 'latest':
718 718 pull_request_ver = PullRequest.get(pull_request_id)
719 719 pull_request_obj = pull_request_ver
720 720 _org_pull_request_obj = pull_request_obj
721 721 at_version = 'latest'
722 722 elif version:
723 723 pull_request_ver = PullRequestVersion.get_or_404(version)
724 724 pull_request_obj = pull_request_ver
725 725 _org_pull_request_obj = pull_request_ver.pull_request
726 726 at_version = pull_request_ver.pull_request_version_id
727 727 else:
728 728 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
729 729 pull_request_id)
730 730
731 731 pull_request_display_obj = PullRequest.get_pr_display_object(
732 732 pull_request_obj, _org_pull_request_obj)
733 733
734 734 return _org_pull_request_obj, pull_request_obj, \
735 735 pull_request_display_obj, at_version
736 736
737 737 def pr_commits_versions(self, versions):
738 738 """
739 739 Maps the pull-request commits into all known PR versions. This way we can obtain
740 740 each pr version the commit was introduced in.
741 741 """
742 742 commit_versions = collections.defaultdict(list)
743 743 num_versions = [x.pull_request_version_id for x in versions]
744 744 for ver in versions:
745 745 for commit_id in ver.revisions:
746 746 ver_idx = ChangesetComment.get_index_from_version(
747 747 ver.pull_request_version_id, num_versions=num_versions)
748 748 commit_versions[commit_id].append(ver_idx)
749 749 return commit_versions
750 750
751 751 def create(self, created_by, source_repo, source_ref, target_repo,
752 752 target_ref, revisions, reviewers, observers, title, description=None,
753 753 common_ancestor_id=None,
754 754 description_renderer=None,
755 755 reviewer_data=None, translator=None, auth_user=None):
756 756 translator = translator or get_current_request().translate
757 757
758 758 created_by_user = self._get_user(created_by)
759 759 auth_user = auth_user or created_by_user.AuthUser()
760 760 source_repo = self._get_repo(source_repo)
761 761 target_repo = self._get_repo(target_repo)
762 762
763 763 pull_request = PullRequest()
764 764 pull_request.source_repo = source_repo
765 765 pull_request.source_ref = source_ref
766 766 pull_request.target_repo = target_repo
767 767 pull_request.target_ref = target_ref
768 768 pull_request.revisions = revisions
769 769 pull_request.title = title
770 770 pull_request.description = description
771 771 pull_request.description_renderer = description_renderer
772 772 pull_request.author = created_by_user
773 773 pull_request.reviewer_data = reviewer_data
774 774 pull_request.pull_request_state = pull_request.STATE_CREATING
775 775 pull_request.common_ancestor_id = common_ancestor_id
776 776
777 777 Session().add(pull_request)
778 778 Session().flush()
779 779
780 780 reviewer_ids = set()
781 781 # members / reviewers
782 782 for reviewer_object in reviewers:
783 783 user_id, reasons, mandatory, role, rules = reviewer_object
784 784 user = self._get_user(user_id)
785 785
786 786 # skip duplicates
787 787 if user.user_id in reviewer_ids:
788 788 continue
789 789
790 790 reviewer_ids.add(user.user_id)
791 791
792 792 reviewer = PullRequestReviewers()
793 793 reviewer.user = user
794 794 reviewer.pull_request = pull_request
795 795 reviewer.reasons = reasons
796 796 reviewer.mandatory = mandatory
797 797 reviewer.role = role
798 798
799 799 # NOTE(marcink): pick only first rule for now
800 800 rule_id = list(rules)[0] if rules else None
801 801 rule = RepoReviewRule.get(rule_id) if rule_id else None
802 802 if rule:
803 803 review_group = rule.user_group_vote_rule(user_id)
804 804 # we check if this particular reviewer is member of a voting group
805 805 if review_group:
806 806 # NOTE(marcink):
807 807 # can be that user is member of more but we pick the first same,
808 808 # same as default reviewers algo
809 809 review_group = review_group[0]
810 810
811 811 rule_data = {
812 812 'rule_name':
813 813 rule.review_rule_name,
814 814 'rule_user_group_entry_id':
815 815 review_group.repo_review_rule_users_group_id,
816 816 'rule_user_group_name':
817 817 review_group.users_group.users_group_name,
818 818 'rule_user_group_members':
819 819 [x.user.username for x in review_group.users_group.members],
820 820 'rule_user_group_members_id':
821 821 [x.user.user_id for x in review_group.users_group.members],
822 822 }
823 823 # e.g {'vote_rule': -1, 'mandatory': True}
824 824 rule_data.update(review_group.rule_data())
825 825
826 826 reviewer.rule_data = rule_data
827 827
828 828 Session().add(reviewer)
829 829 Session().flush()
830 830
831 831 for observer_object in observers:
832 832 user_id, reasons, mandatory, role, rules = observer_object
833 833 user = self._get_user(user_id)
834 834
835 835 # skip duplicates from reviewers
836 836 if user.user_id in reviewer_ids:
837 837 continue
838 838
839 839 #reviewer_ids.add(user.user_id)
840 840
841 841 observer = PullRequestReviewers()
842 842 observer.user = user
843 843 observer.pull_request = pull_request
844 844 observer.reasons = reasons
845 845 observer.mandatory = mandatory
846 846 observer.role = role
847 847
848 848 # NOTE(marcink): pick only first rule for now
849 849 rule_id = list(rules)[0] if rules else None
850 850 rule = RepoReviewRule.get(rule_id) if rule_id else None
851 851 if rule:
852 852 # TODO(marcink): do we need this for observers ??
853 853 pass
854 854
855 855 Session().add(observer)
856 856 Session().flush()
857 857
858 858 # Set approval status to "Under Review" for all commits which are
859 859 # part of this pull request.
860 860 ChangesetStatusModel().set_status(
861 861 repo=target_repo,
862 862 status=ChangesetStatus.STATUS_UNDER_REVIEW,
863 863 user=created_by_user,
864 864 pull_request=pull_request
865 865 )
866 866 # we commit early at this point. This has to do with a fact
867 867 # that before queries do some row-locking. And because of that
868 868 # we need to commit and finish transaction before below validate call
869 869 # that for large repos could be long resulting in long row locks
870 870 Session().commit()
871 871
872 872 # prepare workspace, and run initial merge simulation. Set state during that
873 873 # operation
874 874 pull_request = PullRequest.get(pull_request.pull_request_id)
875 875
876 876 # set as merging, for merge simulation, and if finished to created so we mark
877 877 # simulation is working fine
878 878 with pull_request.set_state(PullRequest.STATE_MERGING,
879 879 final_state=PullRequest.STATE_CREATED) as state_obj:
880 880 MergeCheck.validate(
881 881 pull_request, auth_user=auth_user, translator=translator)
882 882
883 883 self.notify_reviewers(pull_request, reviewer_ids, created_by_user)
884 884 self.trigger_pull_request_hook(pull_request, created_by_user, 'create')
885 885
886 886 creation_data = pull_request.get_api_data(with_merge_state=False)
887 887 self._log_audit_action(
888 888 'repo.pull_request.create', {'data': creation_data},
889 889 auth_user, pull_request)
890 890
891 891 return pull_request
892 892
893 893 def trigger_pull_request_hook(self, pull_request, user, action, data=None):
894 894 pull_request = self.__get_pull_request(pull_request)
895 895 target_scm = pull_request.target_repo.scm_instance()
896 896 if action == 'create':
897 897 trigger_hook = hooks_utils.trigger_create_pull_request_hook
898 898 elif action == 'merge':
899 899 trigger_hook = hooks_utils.trigger_merge_pull_request_hook
900 900 elif action == 'close':
901 901 trigger_hook = hooks_utils.trigger_close_pull_request_hook
902 902 elif action == 'review_status_change':
903 903 trigger_hook = hooks_utils.trigger_review_pull_request_hook
904 904 elif action == 'update':
905 905 trigger_hook = hooks_utils.trigger_update_pull_request_hook
906 906 elif action == 'comment':
907 907 trigger_hook = hooks_utils.trigger_comment_pull_request_hook
908 908 elif action == 'comment_edit':
909 909 trigger_hook = hooks_utils.trigger_comment_pull_request_edit_hook
910 910 else:
911 911 return
912 912
913 913 log.debug('Handling pull_request %s trigger_pull_request_hook with action %s and hook: %s',
914 914 pull_request, action, trigger_hook)
915 915 trigger_hook(
916 916 username=user.username,
917 917 repo_name=pull_request.target_repo.repo_name,
918 918 repo_type=target_scm.alias,
919 919 pull_request=pull_request,
920 920 data=data)
921 921
922 922 def _get_commit_ids(self, pull_request):
923 923 """
924 924 Return the commit ids of the merged pull request.
925 925
926 926 This method is not dealing correctly yet with the lack of autoupdates
927 927 nor with the implicit target updates.
928 928 For example: if a commit in the source repo is already in the target it
929 929 will be reported anyways.
930 930 """
931 931 merge_rev = pull_request.merge_rev
932 932 if merge_rev is None:
933 933 raise ValueError('This pull request was not merged yet')
934 934
935 935 commit_ids = list(pull_request.revisions)
936 936 if merge_rev not in commit_ids:
937 937 commit_ids.append(merge_rev)
938 938
939 939 return commit_ids
940 940
941 941 def merge_repo(self, pull_request, user, extras):
942 942 repo_type = pull_request.source_repo.repo_type
943 943 log.debug("Merging pull request %s", pull_request)
944 944
945 945 extras['user_agent'] = '{}/internal-merge'.format(repo_type)
946 946 merge_state = self._merge_pull_request(pull_request, user, extras)
947 947 if merge_state.executed:
948 948 log.debug("Merge was successful, updating the pull request comments.")
949 949 self._comment_and_close_pr(pull_request, user, merge_state)
950 950
951 951 self._log_audit_action(
952 952 'repo.pull_request.merge',
953 953 {'merge_state': merge_state.__dict__},
954 954 user, pull_request)
955 955
956 956 else:
957 957 log.warning("Merge failed, not updating the pull request.")
958 958 return merge_state
959 959
960 960 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
961 961 target_vcs = pull_request.target_repo.scm_instance()
962 962 source_vcs = pull_request.source_repo.scm_instance()
963 963
964 964 message = safe_str(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
965 965 pr_id=pull_request.pull_request_id,
966 966 pr_title=pull_request.title,
967 967 pr_desc=pull_request.description,
968 968 source_repo=source_vcs.name,
969 969 source_ref_name=pull_request.source_ref_parts.name,
970 970 target_repo=target_vcs.name,
971 971 target_ref_name=pull_request.target_ref_parts.name,
972 972 )
973 973
974 974 workspace_id = self._workspace_id(pull_request)
975 975 repo_id = pull_request.target_repo.repo_id
976 976 use_rebase = self._use_rebase_for_merging(pull_request)
977 977 close_branch = self._close_branch_before_merging(pull_request)
978 978 user_name = self._user_name_for_merging(pull_request, user)
979 979
980 980 target_ref = self._refresh_reference(
981 981 pull_request.target_ref_parts, target_vcs)
982 982
983 983 callback_daemon, extras = prepare_callback_daemon(
984 984 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
985 host=vcs_settings.HOOKS_HOST,
986 use_direct_calls=vcs_settings.HOOKS_DIRECT_CALLS)
985 host=vcs_settings.HOOKS_HOST)
987 986
988 987 with callback_daemon:
989 988 # TODO: johbo: Implement a clean way to run a config_override
990 989 # for a single call.
991 990 target_vcs.config.set(
992 991 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
993 992
994 993 merge_state = target_vcs.merge(
995 994 repo_id, workspace_id, target_ref, source_vcs,
996 995 pull_request.source_ref_parts,
997 996 user_name=user_name, user_email=user.email,
998 997 message=message, use_rebase=use_rebase,
999 998 close_branch=close_branch)
1000 999
1001 1000 return merge_state
1002 1001
1003 1002 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
1004 1003 pull_request.merge_rev = merge_state.merge_ref.commit_id
1005 1004 pull_request.updated_on = datetime.datetime.now()
1006 1005 close_msg = close_msg or 'Pull request merged and closed'
1007 1006
1008 1007 CommentsModel().create(
1009 1008 text=safe_str(close_msg),
1010 1009 repo=pull_request.target_repo.repo_id,
1011 1010 user=user.user_id,
1012 1011 pull_request=pull_request.pull_request_id,
1013 1012 f_path=None,
1014 1013 line_no=None,
1015 1014 closing_pr=True
1016 1015 )
1017 1016
1018 1017 Session().add(pull_request)
1019 1018 Session().flush()
1020 1019 # TODO: paris: replace invalidation with less radical solution
1021 1020 ScmModel().mark_for_invalidation(
1022 1021 pull_request.target_repo.repo_name)
1023 1022 self.trigger_pull_request_hook(pull_request, user, 'merge')
1024 1023
1025 1024 def has_valid_update_type(self, pull_request):
1026 1025 source_ref_type = pull_request.source_ref_parts.type
1027 1026 return source_ref_type in self.REF_TYPES
1028 1027
1029 1028 def get_flow_commits(self, pull_request):
1030 1029
1031 1030 # source repo
1032 1031 source_ref_name = pull_request.source_ref_parts.name
1033 1032 source_ref_type = pull_request.source_ref_parts.type
1034 1033 source_ref_id = pull_request.source_ref_parts.commit_id
1035 1034 source_repo = pull_request.source_repo.scm_instance()
1036 1035
1037 1036 try:
1038 1037 if source_ref_type in self.REF_TYPES:
1039 1038 source_commit = source_repo.get_commit(
1040 1039 source_ref_name, reference_obj=pull_request.source_ref_parts)
1041 1040 else:
1042 1041 source_commit = source_repo.get_commit(source_ref_id)
1043 1042 except CommitDoesNotExistError:
1044 1043 raise SourceRefMissing()
1045 1044
1046 1045 # target repo
1047 1046 target_ref_name = pull_request.target_ref_parts.name
1048 1047 target_ref_type = pull_request.target_ref_parts.type
1049 1048 target_ref_id = pull_request.target_ref_parts.commit_id
1050 1049 target_repo = pull_request.target_repo.scm_instance()
1051 1050
1052 1051 try:
1053 1052 if target_ref_type in self.REF_TYPES:
1054 1053 target_commit = target_repo.get_commit(
1055 1054 target_ref_name, reference_obj=pull_request.target_ref_parts)
1056 1055 else:
1057 1056 target_commit = target_repo.get_commit(target_ref_id)
1058 1057 except CommitDoesNotExistError:
1059 1058 raise TargetRefMissing()
1060 1059
1061 1060 return source_commit, target_commit
1062 1061
1063 1062 def update_commits(self, pull_request, updating_user):
1064 1063 """
1065 1064 Get the updated list of commits for the pull request
1066 1065 and return the new pull request version and the list
1067 1066 of commits processed by this update action
1068 1067
1069 1068 updating_user is the user_object who triggered the update
1070 1069 """
1071 1070 pull_request = self.__get_pull_request(pull_request)
1072 1071 source_ref_type = pull_request.source_ref_parts.type
1073 1072 source_ref_name = pull_request.source_ref_parts.name
1074 1073 source_ref_id = pull_request.source_ref_parts.commit_id
1075 1074
1076 1075 target_ref_type = pull_request.target_ref_parts.type
1077 1076 target_ref_name = pull_request.target_ref_parts.name
1078 1077 target_ref_id = pull_request.target_ref_parts.commit_id
1079 1078
1080 1079 if not self.has_valid_update_type(pull_request):
1081 1080 log.debug("Skipping update of pull request %s due to ref type: %s",
1082 1081 pull_request, source_ref_type)
1083 1082 return UpdateResponse(
1084 1083 executed=False,
1085 1084 reason=UpdateFailureReason.WRONG_REF_TYPE,
1086 1085 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1087 1086 source_changed=False, target_changed=False)
1088 1087
1089 1088 try:
1090 1089 source_commit, target_commit = self.get_flow_commits(pull_request)
1091 1090 except SourceRefMissing:
1092 1091 return UpdateResponse(
1093 1092 executed=False,
1094 1093 reason=UpdateFailureReason.MISSING_SOURCE_REF,
1095 1094 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1096 1095 source_changed=False, target_changed=False)
1097 1096 except TargetRefMissing:
1098 1097 return UpdateResponse(
1099 1098 executed=False,
1100 1099 reason=UpdateFailureReason.MISSING_TARGET_REF,
1101 1100 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1102 1101 source_changed=False, target_changed=False)
1103 1102
1104 1103 source_changed = source_ref_id != source_commit.raw_id
1105 1104 target_changed = target_ref_id != target_commit.raw_id
1106 1105
1107 1106 if not (source_changed or target_changed):
1108 1107 log.debug("Nothing changed in pull request %s", pull_request)
1109 1108 return UpdateResponse(
1110 1109 executed=False,
1111 1110 reason=UpdateFailureReason.NO_CHANGE,
1112 1111 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1113 1112 source_changed=target_changed, target_changed=source_changed)
1114 1113
1115 1114 change_in_found = 'target repo' if target_changed else 'source repo'
1116 1115 log.debug('Updating pull request because of change in %s detected',
1117 1116 change_in_found)
1118 1117
1119 1118 # Finally there is a need for an update, in case of source change
1120 1119 # we create a new version, else just an update
1121 1120 if source_changed:
1122 1121 pull_request_version = self._create_version_from_snapshot(pull_request)
1123 1122 self._link_comments_to_version(pull_request_version)
1124 1123 else:
1125 1124 try:
1126 1125 ver = pull_request.versions[-1]
1127 1126 except IndexError:
1128 1127 ver = None
1129 1128
1130 1129 pull_request.pull_request_version_id = \
1131 1130 ver.pull_request_version_id if ver else None
1132 1131 pull_request_version = pull_request
1133 1132
1134 1133 source_repo = pull_request.source_repo.scm_instance()
1135 1134 target_repo = pull_request.target_repo.scm_instance()
1136 1135
1137 1136 # re-compute commit ids
1138 1137 old_commit_ids = pull_request.revisions
1139 1138 pre_load = ["author", "date", "message", "branch"]
1140 1139 commit_ranges = target_repo.compare(
1141 1140 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
1142 1141 pre_load=pre_load)
1143 1142
1144 1143 target_ref = target_commit.raw_id
1145 1144 source_ref = source_commit.raw_id
1146 1145 ancestor_commit_id = target_repo.get_common_ancestor(
1147 1146 target_ref, source_ref, source_repo)
1148 1147
1149 1148 if not ancestor_commit_id:
1150 1149 raise ValueError(
1151 1150 'cannot calculate diff info without a common ancestor. '
1152 1151 'Make sure both repositories are related, and have a common forking commit.')
1153 1152
1154 1153 pull_request.common_ancestor_id = ancestor_commit_id
1155 1154
1156 1155 pull_request.source_ref = f'{source_ref_type}:{source_ref_name}:{source_commit.raw_id}'
1157 1156 pull_request.target_ref = f'{target_ref_type}:{target_ref_name}:{ancestor_commit_id}'
1158 1157
1159 1158 pull_request.revisions = [
1160 1159 commit.raw_id for commit in reversed(commit_ranges)]
1161 1160 pull_request.updated_on = datetime.datetime.now()
1162 1161 Session().add(pull_request)
1163 1162 new_commit_ids = pull_request.revisions
1164 1163
1165 1164 old_diff_data, new_diff_data = self._generate_update_diffs(
1166 1165 pull_request, pull_request_version)
1167 1166
1168 1167 # calculate commit and file changes
1169 1168 commit_changes = self._calculate_commit_id_changes(
1170 1169 old_commit_ids, new_commit_ids)
1171 1170 file_changes = self._calculate_file_changes(
1172 1171 old_diff_data, new_diff_data)
1173 1172
1174 1173 # set comments as outdated if DIFFS changed
1175 1174 CommentsModel().outdate_comments(
1176 1175 pull_request, old_diff_data=old_diff_data,
1177 1176 new_diff_data=new_diff_data)
1178 1177
1179 1178 valid_commit_changes = (commit_changes.added or commit_changes.removed)
1180 1179 file_node_changes = (
1181 1180 file_changes.added or file_changes.modified or file_changes.removed)
1182 1181 pr_has_changes = valid_commit_changes or file_node_changes
1183 1182
1184 1183 # Add an automatic comment to the pull request, in case
1185 1184 # anything has changed
1186 1185 if pr_has_changes:
1187 1186 update_comment = CommentsModel().create(
1188 1187 text=self._render_update_message(ancestor_commit_id, commit_changes, file_changes),
1189 1188 repo=pull_request.target_repo,
1190 1189 user=pull_request.author,
1191 1190 pull_request=pull_request,
1192 1191 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
1193 1192
1194 1193 # Update status to "Under Review" for added commits
1195 1194 for commit_id in commit_changes.added:
1196 1195 ChangesetStatusModel().set_status(
1197 1196 repo=pull_request.source_repo,
1198 1197 status=ChangesetStatus.STATUS_UNDER_REVIEW,
1199 1198 comment=update_comment,
1200 1199 user=pull_request.author,
1201 1200 pull_request=pull_request,
1202 1201 revision=commit_id)
1203 1202
1204 1203 # initial commit
1205 1204 Session().commit()
1206 1205
1207 1206 if pr_has_changes:
1208 1207 # send update email to users
1209 1208 try:
1210 1209 self.notify_users(pull_request=pull_request, updating_user=updating_user,
1211 1210 ancestor_commit_id=ancestor_commit_id,
1212 1211 commit_changes=commit_changes,
1213 1212 file_changes=file_changes)
1214 1213 Session().commit()
1215 1214 except Exception:
1216 1215 log.exception('Failed to send email notification to users')
1217 1216 Session().rollback()
1218 1217
1219 1218 log.debug(
1220 1219 'Updated pull request %s, added_ids: %s, common_ids: %s, '
1221 1220 'removed_ids: %s', pull_request.pull_request_id,
1222 1221 commit_changes.added, commit_changes.common, commit_changes.removed)
1223 1222 log.debug(
1224 1223 'Updated pull request with the following file changes: %s',
1225 1224 file_changes)
1226 1225
1227 1226 log.info(
1228 1227 "Updated pull request %s from commit %s to commit %s, "
1229 1228 "stored new version %s of this pull request.",
1230 1229 pull_request.pull_request_id, source_ref_id,
1231 1230 pull_request.source_ref_parts.commit_id,
1232 1231 pull_request_version.pull_request_version_id)
1233 1232
1234 1233 self.trigger_pull_request_hook(pull_request, pull_request.author, 'update')
1235 1234
1236 1235 return UpdateResponse(
1237 1236 executed=True, reason=UpdateFailureReason.NONE,
1238 1237 old=pull_request, new=pull_request_version,
1239 1238 common_ancestor_id=ancestor_commit_id, commit_changes=commit_changes,
1240 1239 source_changed=source_changed, target_changed=target_changed)
1241 1240
1242 1241 def _create_version_from_snapshot(self, pull_request):
1243 1242 version = PullRequestVersion()
1244 1243 version.title = pull_request.title
1245 1244 version.description = pull_request.description
1246 1245 version.status = pull_request.status
1247 1246 version.pull_request_state = pull_request.pull_request_state
1248 1247 version.created_on = datetime.datetime.now()
1249 1248 version.updated_on = pull_request.updated_on
1250 1249 version.user_id = pull_request.user_id
1251 1250 version.source_repo = pull_request.source_repo
1252 1251 version.source_ref = pull_request.source_ref
1253 1252 version.target_repo = pull_request.target_repo
1254 1253 version.target_ref = pull_request.target_ref
1255 1254
1256 1255 version._last_merge_source_rev = pull_request._last_merge_source_rev
1257 1256 version._last_merge_target_rev = pull_request._last_merge_target_rev
1258 1257 version.last_merge_status = pull_request.last_merge_status
1259 1258 version.last_merge_metadata = pull_request.last_merge_metadata
1260 1259 version.shadow_merge_ref = pull_request.shadow_merge_ref
1261 1260 version.merge_rev = pull_request.merge_rev
1262 1261 version.reviewer_data = pull_request.reviewer_data
1263 1262
1264 1263 version.revisions = pull_request.revisions
1265 1264 version.common_ancestor_id = pull_request.common_ancestor_id
1266 1265 version.pull_request = pull_request
1267 1266 Session().add(version)
1268 1267 Session().flush()
1269 1268
1270 1269 return version
1271 1270
1272 1271 def _generate_update_diffs(self, pull_request, pull_request_version):
1273 1272
1274 1273 diff_context = (
1275 1274 self.DIFF_CONTEXT +
1276 1275 CommentsModel.needed_extra_diff_context())
1277 1276 hide_whitespace_changes = False
1278 1277 source_repo = pull_request_version.source_repo
1279 1278 source_ref_id = pull_request_version.source_ref_parts.commit_id
1280 1279 target_ref_id = pull_request_version.target_ref_parts.commit_id
1281 1280 old_diff = self._get_diff_from_pr_or_version(
1282 1281 source_repo, source_ref_id, target_ref_id,
1283 1282 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1284 1283
1285 1284 source_repo = pull_request.source_repo
1286 1285 source_ref_id = pull_request.source_ref_parts.commit_id
1287 1286 target_ref_id = pull_request.target_ref_parts.commit_id
1288 1287
1289 1288 new_diff = self._get_diff_from_pr_or_version(
1290 1289 source_repo, source_ref_id, target_ref_id,
1291 1290 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1292 1291
1293 1292 # NOTE: this was using diff_format='gitdiff'
1294 1293 old_diff_data = diffs.DiffProcessor(old_diff, diff_format='newdiff')
1295 1294 old_diff_data.prepare()
1296 1295 new_diff_data = diffs.DiffProcessor(new_diff, diff_format='newdiff')
1297 1296 new_diff_data.prepare()
1298 1297
1299 1298 return old_diff_data, new_diff_data
1300 1299
1301 1300 def _link_comments_to_version(self, pull_request_version):
1302 1301 """
1303 1302 Link all unlinked comments of this pull request to the given version.
1304 1303
1305 1304 :param pull_request_version: The `PullRequestVersion` to which
1306 1305 the comments shall be linked.
1307 1306
1308 1307 """
1309 1308 pull_request = pull_request_version.pull_request
1310 1309 comments = ChangesetComment.query()\
1311 1310 .filter(
1312 1311 # TODO: johbo: Should we query for the repo at all here?
1313 1312 # Pending decision on how comments of PRs are to be related
1314 1313 # to either the source repo, the target repo or no repo at all.
1315 1314 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
1316 1315 ChangesetComment.pull_request == pull_request,
1317 1316 ChangesetComment.pull_request_version == null())\
1318 1317 .order_by(ChangesetComment.comment_id.asc())
1319 1318
1320 1319 # TODO: johbo: Find out why this breaks if it is done in a bulk
1321 1320 # operation.
1322 1321 for comment in comments:
1323 1322 comment.pull_request_version_id = (
1324 1323 pull_request_version.pull_request_version_id)
1325 1324 Session().add(comment)
1326 1325
1327 1326 def _calculate_commit_id_changes(self, old_ids, new_ids):
1328 1327 added = [x for x in new_ids if x not in old_ids]
1329 1328 common = [x for x in new_ids if x in old_ids]
1330 1329 removed = [x for x in old_ids if x not in new_ids]
1331 1330 total = new_ids
1332 1331 return ChangeTuple(added, common, removed, total)
1333 1332
1334 1333 def _calculate_file_changes(self, old_diff_data, new_diff_data):
1335 1334
1336 1335 old_files = OrderedDict()
1337 1336 for diff_data in old_diff_data.parsed_diff:
1338 1337 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
1339 1338
1340 1339 added_files = []
1341 1340 modified_files = []
1342 1341 removed_files = []
1343 1342 for diff_data in new_diff_data.parsed_diff:
1344 1343 new_filename = diff_data['filename']
1345 1344 new_hash = md5_safe(diff_data['raw_diff'])
1346 1345
1347 1346 old_hash = old_files.get(new_filename)
1348 1347 if not old_hash:
1349 1348 # file is not present in old diff, we have to figure out from parsed diff
1350 1349 # operation ADD/REMOVE
1351 1350 operations_dict = diff_data['stats']['ops']
1352 1351 if diffs.DEL_FILENODE in operations_dict:
1353 1352 removed_files.append(new_filename)
1354 1353 else:
1355 1354 added_files.append(new_filename)
1356 1355 else:
1357 1356 if new_hash != old_hash:
1358 1357 modified_files.append(new_filename)
1359 1358 # now remove a file from old, since we have seen it already
1360 1359 del old_files[new_filename]
1361 1360
1362 1361 # removed files is when there are present in old, but not in NEW,
1363 1362 # since we remove old files that are present in new diff, left-overs
1364 1363 # if any should be the removed files
1365 1364 removed_files.extend(old_files.keys())
1366 1365
1367 1366 return FileChangeTuple(added_files, modified_files, removed_files)
1368 1367
1369 1368 def _render_update_message(self, ancestor_commit_id, changes, file_changes):
1370 1369 """
1371 1370 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
1372 1371 so it's always looking the same disregarding on which default
1373 1372 renderer system is using.
1374 1373
1375 1374 :param ancestor_commit_id: ancestor raw_id
1376 1375 :param changes: changes named tuple
1377 1376 :param file_changes: file changes named tuple
1378 1377
1379 1378 """
1380 1379 new_status = ChangesetStatus.get_status_lbl(
1381 1380 ChangesetStatus.STATUS_UNDER_REVIEW)
1382 1381
1383 1382 changed_files = (
1384 1383 file_changes.added + file_changes.modified + file_changes.removed)
1385 1384
1386 1385 params = {
1387 1386 'under_review_label': new_status,
1388 1387 'added_commits': changes.added,
1389 1388 'removed_commits': changes.removed,
1390 1389 'changed_files': changed_files,
1391 1390 'added_files': file_changes.added,
1392 1391 'modified_files': file_changes.modified,
1393 1392 'removed_files': file_changes.removed,
1394 1393 'ancestor_commit_id': ancestor_commit_id
1395 1394 }
1396 1395 renderer = RstTemplateRenderer()
1397 1396 return renderer.render('pull_request_update.mako', **params)
1398 1397
1399 1398 def edit(self, pull_request, title, description, description_renderer, user):
1400 1399 pull_request = self.__get_pull_request(pull_request)
1401 1400 old_data = pull_request.get_api_data(with_merge_state=False)
1402 1401 if pull_request.is_closed():
1403 1402 raise ValueError('This pull request is closed')
1404 1403 if title:
1405 1404 pull_request.title = title
1406 1405 pull_request.description = description
1407 1406 pull_request.updated_on = datetime.datetime.now()
1408 1407 pull_request.description_renderer = description_renderer
1409 1408 Session().add(pull_request)
1410 1409 self._log_audit_action(
1411 1410 'repo.pull_request.edit', {'old_data': old_data},
1412 1411 user, pull_request)
1413 1412
1414 1413 def update_reviewers(self, pull_request, reviewer_data, user):
1415 1414 """
1416 1415 Update the reviewers in the pull request
1417 1416
1418 1417 :param pull_request: the pr to update
1419 1418 :param reviewer_data: list of tuples
1420 1419 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1421 1420 :param user: current use who triggers this action
1422 1421 """
1423 1422
1424 1423 pull_request = self.__get_pull_request(pull_request)
1425 1424 if pull_request.is_closed():
1426 1425 raise ValueError('This pull request is closed')
1427 1426
1428 1427 reviewers = {}
1429 1428 for user_id, reasons, mandatory, role, rules in reviewer_data:
1430 1429 if isinstance(user_id, (int, str)):
1431 1430 user_id = self._get_user(user_id).user_id
1432 1431 reviewers[user_id] = {
1433 1432 'reasons': reasons, 'mandatory': mandatory, 'role': role}
1434 1433
1435 1434 reviewers_ids = set(reviewers.keys())
1436 1435 current_reviewers = PullRequestReviewers.get_pull_request_reviewers(
1437 1436 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_REVIEWER)
1438 1437
1439 1438 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1440 1439
1441 1440 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1442 1441 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1443 1442
1444 1443 log.debug("Adding %s reviewers", ids_to_add)
1445 1444 log.debug("Removing %s reviewers", ids_to_remove)
1446 1445 changed = False
1447 1446 added_audit_reviewers = []
1448 1447 removed_audit_reviewers = []
1449 1448
1450 1449 for uid in ids_to_add:
1451 1450 changed = True
1452 1451 _usr = self._get_user(uid)
1453 1452 reviewer = PullRequestReviewers()
1454 1453 reviewer.user = _usr
1455 1454 reviewer.pull_request = pull_request
1456 1455 reviewer.reasons = reviewers[uid]['reasons']
1457 1456 # NOTE(marcink): mandatory shouldn't be changed now
1458 1457 # reviewer.mandatory = reviewers[uid]['reasons']
1459 1458 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1460 1459 reviewer.role = PullRequestReviewers.ROLE_REVIEWER
1461 1460 Session().add(reviewer)
1462 1461 added_audit_reviewers.append(reviewer.get_dict())
1463 1462
1464 1463 for uid in ids_to_remove:
1465 1464 changed = True
1466 1465 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1467 1466 # This is an edge case that handles previous state of having the same reviewer twice.
1468 1467 # this CAN happen due to the lack of DB checks
1469 1468 reviewers = PullRequestReviewers.query()\
1470 1469 .filter(PullRequestReviewers.user_id == uid,
1471 1470 PullRequestReviewers.role == PullRequestReviewers.ROLE_REVIEWER,
1472 1471 PullRequestReviewers.pull_request == pull_request)\
1473 1472 .all()
1474 1473
1475 1474 for obj in reviewers:
1476 1475 added_audit_reviewers.append(obj.get_dict())
1477 1476 Session().delete(obj)
1478 1477
1479 1478 if changed:
1480 1479 Session().expire_all()
1481 1480 pull_request.updated_on = datetime.datetime.now()
1482 1481 Session().add(pull_request)
1483 1482
1484 1483 # finally store audit logs
1485 1484 for user_data in added_audit_reviewers:
1486 1485 self._log_audit_action(
1487 1486 'repo.pull_request.reviewer.add', {'data': user_data},
1488 1487 user, pull_request)
1489 1488 for user_data in removed_audit_reviewers:
1490 1489 self._log_audit_action(
1491 1490 'repo.pull_request.reviewer.delete', {'old_data': user_data},
1492 1491 user, pull_request)
1493 1492
1494 1493 self.notify_reviewers(pull_request, ids_to_add, user)
1495 1494 return ids_to_add, ids_to_remove
1496 1495
1497 1496 def update_observers(self, pull_request, observer_data, user):
1498 1497 """
1499 1498 Update the observers in the pull request
1500 1499
1501 1500 :param pull_request: the pr to update
1502 1501 :param observer_data: list of tuples
1503 1502 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1504 1503 :param user: current use who triggers this action
1505 1504 """
1506 1505 pull_request = self.__get_pull_request(pull_request)
1507 1506 if pull_request.is_closed():
1508 1507 raise ValueError('This pull request is closed')
1509 1508
1510 1509 observers = {}
1511 1510 for user_id, reasons, mandatory, role, rules in observer_data:
1512 1511 if isinstance(user_id, (int, str)):
1513 1512 user_id = self._get_user(user_id).user_id
1514 1513 observers[user_id] = {
1515 1514 'reasons': reasons, 'observers': mandatory, 'role': role}
1516 1515
1517 1516 observers_ids = set(observers.keys())
1518 1517 current_observers = PullRequestReviewers.get_pull_request_reviewers(
1519 1518 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_OBSERVER)
1520 1519
1521 1520 current_observers_ids = set([x.user.user_id for x in current_observers])
1522 1521
1523 1522 ids_to_add = observers_ids.difference(current_observers_ids)
1524 1523 ids_to_remove = current_observers_ids.difference(observers_ids)
1525 1524
1526 1525 log.debug("Adding %s observer", ids_to_add)
1527 1526 log.debug("Removing %s observer", ids_to_remove)
1528 1527 changed = False
1529 1528 added_audit_observers = []
1530 1529 removed_audit_observers = []
1531 1530
1532 1531 for uid in ids_to_add:
1533 1532 changed = True
1534 1533 _usr = self._get_user(uid)
1535 1534 observer = PullRequestReviewers()
1536 1535 observer.user = _usr
1537 1536 observer.pull_request = pull_request
1538 1537 observer.reasons = observers[uid]['reasons']
1539 1538 # NOTE(marcink): mandatory shouldn't be changed now
1540 1539 # observer.mandatory = observer[uid]['reasons']
1541 1540
1542 1541 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1543 1542 observer.role = PullRequestReviewers.ROLE_OBSERVER
1544 1543 Session().add(observer)
1545 1544 added_audit_observers.append(observer.get_dict())
1546 1545
1547 1546 for uid in ids_to_remove:
1548 1547 changed = True
1549 1548 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1550 1549 # This is an edge case that handles previous state of having the same reviewer twice.
1551 1550 # this CAN happen due to the lack of DB checks
1552 1551 observers = PullRequestReviewers.query()\
1553 1552 .filter(PullRequestReviewers.user_id == uid,
1554 1553 PullRequestReviewers.role == PullRequestReviewers.ROLE_OBSERVER,
1555 1554 PullRequestReviewers.pull_request == pull_request)\
1556 1555 .all()
1557 1556
1558 1557 for obj in observers:
1559 1558 added_audit_observers.append(obj.get_dict())
1560 1559 Session().delete(obj)
1561 1560
1562 1561 if changed:
1563 1562 Session().expire_all()
1564 1563 pull_request.updated_on = datetime.datetime.now()
1565 1564 Session().add(pull_request)
1566 1565
1567 1566 # finally store audit logs
1568 1567 for user_data in added_audit_observers:
1569 1568 self._log_audit_action(
1570 1569 'repo.pull_request.observer.add', {'data': user_data},
1571 1570 user, pull_request)
1572 1571 for user_data in removed_audit_observers:
1573 1572 self._log_audit_action(
1574 1573 'repo.pull_request.observer.delete', {'old_data': user_data},
1575 1574 user, pull_request)
1576 1575
1577 1576 self.notify_observers(pull_request, ids_to_add, user)
1578 1577 return ids_to_add, ids_to_remove
1579 1578
1580 1579 def get_url(self, pull_request, request=None, permalink=False):
1581 1580 if not request:
1582 1581 request = get_current_request()
1583 1582
1584 1583 if permalink:
1585 1584 return request.route_url(
1586 1585 'pull_requests_global',
1587 1586 pull_request_id=pull_request.pull_request_id,)
1588 1587 else:
1589 1588 return request.route_url('pullrequest_show',
1590 1589 repo_name=safe_str(pull_request.target_repo.repo_name),
1591 1590 pull_request_id=pull_request.pull_request_id,)
1592 1591
1593 1592 def get_shadow_clone_url(self, pull_request, request=None):
1594 1593 """
1595 1594 Returns qualified url pointing to the shadow repository. If this pull
1596 1595 request is closed there is no shadow repository and ``None`` will be
1597 1596 returned.
1598 1597 """
1599 1598 if pull_request.is_closed():
1600 1599 return None
1601 1600 else:
1602 1601 pr_url = urllib.parse.unquote(self.get_url(pull_request, request=request))
1603 1602 return safe_str('{pr_url}/repository'.format(pr_url=pr_url))
1604 1603
1605 1604 def _notify_reviewers(self, pull_request, user_ids, role, user):
1606 1605 # notification to reviewers/observers
1607 1606 if not user_ids:
1608 1607 return
1609 1608
1610 1609 log.debug('Notify following %s users about pull-request %s', role, user_ids)
1611 1610
1612 1611 pull_request_obj = pull_request
1613 1612 # get the current participants of this pull request
1614 1613 recipients = user_ids
1615 1614 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1616 1615
1617 1616 pr_source_repo = pull_request_obj.source_repo
1618 1617 pr_target_repo = pull_request_obj.target_repo
1619 1618
1620 1619 pr_url = h.route_url('pullrequest_show',
1621 1620 repo_name=pr_target_repo.repo_name,
1622 1621 pull_request_id=pull_request_obj.pull_request_id,)
1623 1622
1624 1623 # set some variables for email notification
1625 1624 pr_target_repo_url = h.route_url(
1626 1625 'repo_summary', repo_name=pr_target_repo.repo_name)
1627 1626
1628 1627 pr_source_repo_url = h.route_url(
1629 1628 'repo_summary', repo_name=pr_source_repo.repo_name)
1630 1629
1631 1630 # pull request specifics
1632 1631 pull_request_commits = [
1633 1632 (x.raw_id, x.message)
1634 1633 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1635 1634
1636 1635 current_rhodecode_user = user
1637 1636 kwargs = {
1638 1637 'user': current_rhodecode_user,
1639 1638 'pull_request_author': pull_request.author,
1640 1639 'pull_request': pull_request_obj,
1641 1640 'pull_request_commits': pull_request_commits,
1642 1641
1643 1642 'pull_request_target_repo': pr_target_repo,
1644 1643 'pull_request_target_repo_url': pr_target_repo_url,
1645 1644
1646 1645 'pull_request_source_repo': pr_source_repo,
1647 1646 'pull_request_source_repo_url': pr_source_repo_url,
1648 1647
1649 1648 'pull_request_url': pr_url,
1650 1649 'thread_ids': [pr_url],
1651 1650 'user_role': role
1652 1651 }
1653 1652
1654 1653 # create notification objects, and emails
1655 1654 NotificationModel().create(
1656 1655 created_by=current_rhodecode_user,
1657 1656 notification_subject='', # Filled in based on the notification_type
1658 1657 notification_body='', # Filled in based on the notification_type
1659 1658 notification_type=notification_type,
1660 1659 recipients=recipients,
1661 1660 email_kwargs=kwargs,
1662 1661 )
1663 1662
1664 1663 def notify_reviewers(self, pull_request, reviewers_ids, user):
1665 1664 return self._notify_reviewers(pull_request, reviewers_ids,
1666 1665 PullRequestReviewers.ROLE_REVIEWER, user)
1667 1666
1668 1667 def notify_observers(self, pull_request, observers_ids, user):
1669 1668 return self._notify_reviewers(pull_request, observers_ids,
1670 1669 PullRequestReviewers.ROLE_OBSERVER, user)
1671 1670
1672 1671 def notify_users(self, pull_request, updating_user, ancestor_commit_id,
1673 1672 commit_changes, file_changes):
1674 1673
1675 1674 updating_user_id = updating_user.user_id
1676 1675 reviewers = set([x.user.user_id for x in pull_request.get_pull_request_reviewers()])
1677 1676 # NOTE(marcink): send notification to all other users except to
1678 1677 # person who updated the PR
1679 1678 recipients = reviewers.difference(set([updating_user_id]))
1680 1679
1681 1680 log.debug('Notify following recipients about pull-request update %s', recipients)
1682 1681
1683 1682 pull_request_obj = pull_request
1684 1683
1685 1684 # send email about the update
1686 1685 changed_files = (
1687 1686 file_changes.added + file_changes.modified + file_changes.removed)
1688 1687
1689 1688 pr_source_repo = pull_request_obj.source_repo
1690 1689 pr_target_repo = pull_request_obj.target_repo
1691 1690
1692 1691 pr_url = h.route_url('pullrequest_show',
1693 1692 repo_name=pr_target_repo.repo_name,
1694 1693 pull_request_id=pull_request_obj.pull_request_id,)
1695 1694
1696 1695 # set some variables for email notification
1697 1696 pr_target_repo_url = h.route_url(
1698 1697 'repo_summary', repo_name=pr_target_repo.repo_name)
1699 1698
1700 1699 pr_source_repo_url = h.route_url(
1701 1700 'repo_summary', repo_name=pr_source_repo.repo_name)
1702 1701
1703 1702 email_kwargs = {
1704 1703 'date': datetime.datetime.now(),
1705 1704 'updating_user': updating_user,
1706 1705
1707 1706 'pull_request': pull_request_obj,
1708 1707
1709 1708 'pull_request_target_repo': pr_target_repo,
1710 1709 'pull_request_target_repo_url': pr_target_repo_url,
1711 1710
1712 1711 'pull_request_source_repo': pr_source_repo,
1713 1712 'pull_request_source_repo_url': pr_source_repo_url,
1714 1713
1715 1714 'pull_request_url': pr_url,
1716 1715
1717 1716 'ancestor_commit_id': ancestor_commit_id,
1718 1717 'added_commits': commit_changes.added,
1719 1718 'removed_commits': commit_changes.removed,
1720 1719 'changed_files': changed_files,
1721 1720 'added_files': file_changes.added,
1722 1721 'modified_files': file_changes.modified,
1723 1722 'removed_files': file_changes.removed,
1724 1723 'thread_ids': [pr_url],
1725 1724 }
1726 1725
1727 1726 # create notification objects, and emails
1728 1727 NotificationModel().create(
1729 1728 created_by=updating_user,
1730 1729 notification_subject='', # Filled in based on the notification_type
1731 1730 notification_body='', # Filled in based on the notification_type
1732 1731 notification_type=EmailNotificationModel.TYPE_PULL_REQUEST_UPDATE,
1733 1732 recipients=recipients,
1734 1733 email_kwargs=email_kwargs,
1735 1734 )
1736 1735
1737 1736 def delete(self, pull_request, user=None):
1738 1737 if not user:
1739 1738 user = getattr(get_current_rhodecode_user(), 'username', None)
1740 1739
1741 1740 pull_request = self.__get_pull_request(pull_request)
1742 1741 old_data = pull_request.get_api_data(with_merge_state=False)
1743 1742 self._cleanup_merge_workspace(pull_request)
1744 1743 self._log_audit_action(
1745 1744 'repo.pull_request.delete', {'old_data': old_data},
1746 1745 user, pull_request)
1747 1746 Session().delete(pull_request)
1748 1747
1749 1748 def close_pull_request(self, pull_request, user):
1750 1749 pull_request = self.__get_pull_request(pull_request)
1751 1750 self._cleanup_merge_workspace(pull_request)
1752 1751 pull_request.status = PullRequest.STATUS_CLOSED
1753 1752 pull_request.updated_on = datetime.datetime.now()
1754 1753 Session().add(pull_request)
1755 1754 self.trigger_pull_request_hook(pull_request, pull_request.author, 'close')
1756 1755
1757 1756 pr_data = pull_request.get_api_data(with_merge_state=False)
1758 1757 self._log_audit_action(
1759 1758 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1760 1759
1761 1760 def close_pull_request_with_comment(
1762 1761 self, pull_request, user, repo, message=None, auth_user=None):
1763 1762
1764 1763 pull_request_review_status = pull_request.calculated_review_status()
1765 1764
1766 1765 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1767 1766 # approved only if we have voting consent
1768 1767 status = ChangesetStatus.STATUS_APPROVED
1769 1768 else:
1770 1769 status = ChangesetStatus.STATUS_REJECTED
1771 1770 status_lbl = ChangesetStatus.get_status_lbl(status)
1772 1771
1773 1772 default_message = (
1774 1773 'Closing with status change {transition_icon} {status}.'
1775 1774 ).format(transition_icon='>', status=status_lbl)
1776 1775 text = message or default_message
1777 1776
1778 1777 # create a comment, and link it to new status
1779 1778 comment = CommentsModel().create(
1780 1779 text=text,
1781 1780 repo=repo.repo_id,
1782 1781 user=user.user_id,
1783 1782 pull_request=pull_request.pull_request_id,
1784 1783 status_change=status_lbl,
1785 1784 status_change_type=status,
1786 1785 closing_pr=True,
1787 1786 auth_user=auth_user,
1788 1787 )
1789 1788
1790 1789 # calculate old status before we change it
1791 1790 old_calculated_status = pull_request.calculated_review_status()
1792 1791 ChangesetStatusModel().set_status(
1793 1792 repo.repo_id,
1794 1793 status,
1795 1794 user.user_id,
1796 1795 comment=comment,
1797 1796 pull_request=pull_request.pull_request_id
1798 1797 )
1799 1798
1800 1799 Session().flush()
1801 1800
1802 1801 self.trigger_pull_request_hook(pull_request, user, 'comment',
1803 1802 data={'comment': comment})
1804 1803
1805 1804 # we now calculate the status of pull request again, and based on that
1806 1805 # calculation trigger status change. This might happen in cases
1807 1806 # that non-reviewer admin closes a pr, which means his vote doesn't
1808 1807 # change the status, while if he's a reviewer this might change it.
1809 1808 calculated_status = pull_request.calculated_review_status()
1810 1809 if old_calculated_status != calculated_status:
1811 1810 self.trigger_pull_request_hook(pull_request, user, 'review_status_change',
1812 1811 data={'status': calculated_status})
1813 1812
1814 1813 # finally close the PR
1815 1814 PullRequestModel().close_pull_request(pull_request.pull_request_id, user)
1816 1815
1817 1816 return comment, status
1818 1817
1819 1818 def merge_status(self, pull_request, translator=None, force_shadow_repo_refresh=False):
1820 1819 _ = translator or get_current_request().translate
1821 1820
1822 1821 if not self._is_merge_enabled(pull_request):
1823 1822 return None, False, _('Server-side pull request merging is disabled.')
1824 1823
1825 1824 if pull_request.is_closed():
1826 1825 return None, False, _('This pull request is closed.')
1827 1826
1828 1827 merge_possible, msg = self._check_repo_requirements(
1829 1828 target=pull_request.target_repo, source=pull_request.source_repo,
1830 1829 translator=_)
1831 1830 if not merge_possible:
1832 1831 return None, merge_possible, msg
1833 1832
1834 1833 try:
1835 1834 merge_response = self._try_merge(
1836 1835 pull_request, force_shadow_repo_refresh=force_shadow_repo_refresh)
1837 1836 log.debug("Merge response: %s", merge_response)
1838 1837 return merge_response, merge_response.possible, merge_response.merge_status_message
1839 1838 except NotImplementedError:
1840 1839 return None, False, _('Pull request merging is not supported.')
1841 1840
1842 1841 def _check_repo_requirements(self, target, source, translator):
1843 1842 """
1844 1843 Check if `target` and `source` have compatible requirements.
1845 1844
1846 1845 Currently this is just checking for largefiles.
1847 1846 """
1848 1847 _ = translator
1849 1848 target_has_largefiles = self._has_largefiles(target)
1850 1849 source_has_largefiles = self._has_largefiles(source)
1851 1850 merge_possible = True
1852 1851 message = u''
1853 1852
1854 1853 if target_has_largefiles != source_has_largefiles:
1855 1854 merge_possible = False
1856 1855 if source_has_largefiles:
1857 1856 message = _(
1858 1857 'Target repository large files support is disabled.')
1859 1858 else:
1860 1859 message = _(
1861 1860 'Source repository large files support is disabled.')
1862 1861
1863 1862 return merge_possible, message
1864 1863
1865 1864 def _has_largefiles(self, repo):
1866 1865 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1867 1866 'extensions', 'largefiles')
1868 1867 return largefiles_ui and largefiles_ui[0].active
1869 1868
1870 1869 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1871 1870 """
1872 1871 Try to merge the pull request and return the merge status.
1873 1872 """
1874 1873 log.debug(
1875 1874 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1876 1875 pull_request.pull_request_id, force_shadow_repo_refresh)
1877 1876 target_vcs = pull_request.target_repo.scm_instance()
1878 1877 # Refresh the target reference.
1879 1878 try:
1880 1879 target_ref = self._refresh_reference(
1881 1880 pull_request.target_ref_parts, target_vcs)
1882 1881 except CommitDoesNotExistError:
1883 1882 merge_state = MergeResponse(
1884 1883 False, False, None, MergeFailureReason.MISSING_TARGET_REF,
1885 1884 metadata={'target_ref': pull_request.target_ref_parts})
1886 1885 return merge_state
1887 1886
1888 1887 target_locked = pull_request.target_repo.locked
1889 1888 if target_locked and target_locked[0]:
1890 1889 locked_by = 'user:{}'.format(target_locked[0])
1891 1890 log.debug("The target repository is locked by %s.", locked_by)
1892 1891 merge_state = MergeResponse(
1893 1892 False, False, None, MergeFailureReason.TARGET_IS_LOCKED,
1894 1893 metadata={'locked_by': locked_by})
1895 1894 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1896 1895 pull_request, target_ref):
1897 1896 log.debug("Refreshing the merge status of the repository.")
1898 1897 merge_state = self._refresh_merge_state(
1899 1898 pull_request, target_vcs, target_ref)
1900 1899 else:
1901 1900 possible = pull_request.last_merge_status == MergeFailureReason.NONE
1902 1901 metadata = {
1903 1902 'unresolved_files': '',
1904 1903 'target_ref': pull_request.target_ref_parts,
1905 1904 'source_ref': pull_request.source_ref_parts,
1906 1905 }
1907 1906 if pull_request.last_merge_metadata:
1908 1907 metadata.update(pull_request.last_merge_metadata_parsed)
1909 1908
1910 1909 if not possible and target_ref.type == 'branch':
1911 1910 # NOTE(marcink): case for mercurial multiple heads on branch
1912 1911 heads = target_vcs._heads(target_ref.name)
1913 1912 if len(heads) != 1:
1914 1913 heads = '\n,'.join(target_vcs._heads(target_ref.name))
1915 1914 metadata.update({
1916 1915 'heads': heads
1917 1916 })
1918 1917
1919 1918 merge_state = MergeResponse(
1920 1919 possible, False, None, pull_request.last_merge_status, metadata=metadata)
1921 1920
1922 1921 return merge_state
1923 1922
1924 1923 def _refresh_reference(self, reference, vcs_repository):
1925 1924 if reference.type in self.UPDATABLE_REF_TYPES:
1926 1925 name_or_id = reference.name
1927 1926 else:
1928 1927 name_or_id = reference.commit_id
1929 1928
1930 1929 refreshed_commit = vcs_repository.get_commit(name_or_id)
1931 1930 refreshed_reference = Reference(
1932 1931 reference.type, reference.name, refreshed_commit.raw_id)
1933 1932 return refreshed_reference
1934 1933
1935 1934 def _needs_merge_state_refresh(self, pull_request, target_reference):
1936 1935 return not(
1937 1936 pull_request.revisions and
1938 1937 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1939 1938 target_reference.commit_id == pull_request._last_merge_target_rev)
1940 1939
1941 1940 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1942 1941 workspace_id = self._workspace_id(pull_request)
1943 1942 source_vcs = pull_request.source_repo.scm_instance()
1944 1943 repo_id = pull_request.target_repo.repo_id
1945 1944 use_rebase = self._use_rebase_for_merging(pull_request)
1946 1945 close_branch = self._close_branch_before_merging(pull_request)
1947 1946 merge_state = target_vcs.merge(
1948 1947 repo_id, workspace_id,
1949 1948 target_reference, source_vcs, pull_request.source_ref_parts,
1950 1949 dry_run=True, use_rebase=use_rebase,
1951 1950 close_branch=close_branch)
1952 1951
1953 1952 # Do not store the response if there was an unknown error.
1954 1953 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1955 1954 pull_request._last_merge_source_rev = \
1956 1955 pull_request.source_ref_parts.commit_id
1957 1956 pull_request._last_merge_target_rev = target_reference.commit_id
1958 1957 pull_request.last_merge_status = merge_state.failure_reason
1959 1958 pull_request.last_merge_metadata = merge_state.metadata
1960 1959
1961 1960 pull_request.shadow_merge_ref = merge_state.merge_ref
1962 1961 Session().add(pull_request)
1963 1962 Session().commit()
1964 1963
1965 1964 return merge_state
1966 1965
1967 1966 def _workspace_id(self, pull_request):
1968 1967 workspace_id = 'pr-%s' % pull_request.pull_request_id
1969 1968 return workspace_id
1970 1969
1971 1970 def generate_repo_data(self, repo, commit_id=None, branch=None,
1972 1971 bookmark=None, translator=None):
1973 1972 from rhodecode.model.repo import RepoModel
1974 1973
1975 1974 all_refs, selected_ref = \
1976 1975 self._get_repo_pullrequest_sources(
1977 1976 repo.scm_instance(), commit_id=commit_id,
1978 1977 branch=branch, bookmark=bookmark, translator=translator)
1979 1978
1980 1979 refs_select2 = []
1981 1980 for element in all_refs:
1982 1981 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1983 1982 refs_select2.append({'text': element[1], 'children': children})
1984 1983
1985 1984 return {
1986 1985 'user': {
1987 1986 'user_id': repo.user.user_id,
1988 1987 'username': repo.user.username,
1989 1988 'firstname': repo.user.first_name,
1990 1989 'lastname': repo.user.last_name,
1991 1990 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1992 1991 },
1993 1992 'name': repo.repo_name,
1994 1993 'link': RepoModel().get_url(repo),
1995 1994 'description': h.chop_at_smart(repo.description_safe, '\n'),
1996 1995 'refs': {
1997 1996 'all_refs': all_refs,
1998 1997 'selected_ref': selected_ref,
1999 1998 'select2_refs': refs_select2
2000 1999 }
2001 2000 }
2002 2001
2003 2002 def generate_pullrequest_title(self, source, source_ref, target):
2004 2003 return u'{source}#{at_ref} to {target}'.format(
2005 2004 source=source,
2006 2005 at_ref=source_ref,
2007 2006 target=target,
2008 2007 )
2009 2008
2010 2009 def _cleanup_merge_workspace(self, pull_request):
2011 2010 # Merging related cleanup
2012 2011 repo_id = pull_request.target_repo.repo_id
2013 2012 target_scm = pull_request.target_repo.scm_instance()
2014 2013 workspace_id = self._workspace_id(pull_request)
2015 2014
2016 2015 try:
2017 2016 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
2018 2017 except NotImplementedError:
2019 2018 pass
2020 2019
2021 2020 def _get_repo_pullrequest_sources(
2022 2021 self, repo, commit_id=None, branch=None, bookmark=None,
2023 2022 translator=None):
2024 2023 """
2025 2024 Return a structure with repo's interesting commits, suitable for
2026 2025 the selectors in pullrequest controller
2027 2026
2028 2027 :param commit_id: a commit that must be in the list somehow
2029 2028 and selected by default
2030 2029 :param branch: a branch that must be in the list and selected
2031 2030 by default - even if closed
2032 2031 :param bookmark: a bookmark that must be in the list and selected
2033 2032 """
2034 2033 _ = translator or get_current_request().translate
2035 2034
2036 2035 commit_id = safe_str(commit_id) if commit_id else None
2037 2036 branch = safe_str(branch) if branch else None
2038 2037 bookmark = safe_str(bookmark) if bookmark else None
2039 2038
2040 2039 selected = None
2041 2040
2042 2041 # order matters: first source that has commit_id in it will be selected
2043 2042 sources = []
2044 2043 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
2045 2044 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
2046 2045
2047 2046 if commit_id:
2048 2047 ref_commit = (h.short_id(commit_id), commit_id)
2049 2048 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
2050 2049
2051 2050 sources.append(
2052 2051 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
2053 2052 )
2054 2053
2055 2054 groups = []
2056 2055
2057 2056 for group_key, ref_list, group_name, match in sources:
2058 2057 group_refs = []
2059 2058 for ref_name, ref_id in ref_list:
2060 2059 ref_key = u'{}:{}:{}'.format(group_key, ref_name, ref_id)
2061 2060 group_refs.append((ref_key, ref_name))
2062 2061
2063 2062 if not selected:
2064 2063 if set([commit_id, match]) & set([ref_id, ref_name]):
2065 2064 selected = ref_key
2066 2065
2067 2066 if group_refs:
2068 2067 groups.append((group_refs, group_name))
2069 2068
2070 2069 if not selected:
2071 2070 ref = commit_id or branch or bookmark
2072 2071 if ref:
2073 2072 raise CommitDoesNotExistError(
2074 2073 u'No commit refs could be found matching: {}'.format(ref))
2075 2074 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
2076 2075 selected = u'branch:{}:{}'.format(
2077 2076 safe_str(repo.DEFAULT_BRANCH_NAME),
2078 2077 safe_str(repo.branches[repo.DEFAULT_BRANCH_NAME])
2079 2078 )
2080 2079 elif repo.commit_ids:
2081 2080 # make the user select in this case
2082 2081 selected = None
2083 2082 else:
2084 2083 raise EmptyRepositoryError()
2085 2084 return groups, selected
2086 2085
2087 2086 def get_diff(self, source_repo, source_ref_id, target_ref_id,
2088 2087 hide_whitespace_changes, diff_context):
2089 2088
2090 2089 return self._get_diff_from_pr_or_version(
2091 2090 source_repo, source_ref_id, target_ref_id,
2092 2091 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
2093 2092
2094 2093 def _get_diff_from_pr_or_version(
2095 2094 self, source_repo, source_ref_id, target_ref_id,
2096 2095 hide_whitespace_changes, diff_context):
2097 2096
2098 2097 target_commit = source_repo.get_commit(
2099 2098 commit_id=safe_str(target_ref_id))
2100 2099 source_commit = source_repo.get_commit(
2101 2100 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
2102 2101 if isinstance(source_repo, Repository):
2103 2102 vcs_repo = source_repo.scm_instance()
2104 2103 else:
2105 2104 vcs_repo = source_repo
2106 2105
2107 2106 # TODO: johbo: In the context of an update, we cannot reach
2108 2107 # the old commit anymore with our normal mechanisms. It needs
2109 2108 # some sort of special support in the vcs layer to avoid this
2110 2109 # workaround.
2111 2110 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
2112 2111 vcs_repo.alias == 'git'):
2113 2112 source_commit.raw_id = safe_str(source_ref_id)
2114 2113
2115 2114 log.debug('calculating diff between '
2116 2115 'source_ref:%s and target_ref:%s for repo `%s`',
2117 2116 target_ref_id, source_ref_id,
2118 2117 safe_str(vcs_repo.path))
2119 2118
2120 2119 vcs_diff = vcs_repo.get_diff(
2121 2120 commit1=target_commit, commit2=source_commit,
2122 2121 ignore_whitespace=hide_whitespace_changes, context=diff_context)
2123 2122 return vcs_diff
2124 2123
2125 2124 def _is_merge_enabled(self, pull_request):
2126 2125 return self._get_general_setting(
2127 2126 pull_request, 'rhodecode_pr_merge_enabled')
2128 2127
2129 2128 def _use_rebase_for_merging(self, pull_request):
2130 2129 repo_type = pull_request.target_repo.repo_type
2131 2130 if repo_type == 'hg':
2132 2131 return self._get_general_setting(
2133 2132 pull_request, 'rhodecode_hg_use_rebase_for_merging')
2134 2133 elif repo_type == 'git':
2135 2134 return self._get_general_setting(
2136 2135 pull_request, 'rhodecode_git_use_rebase_for_merging')
2137 2136
2138 2137 return False
2139 2138
2140 2139 def _user_name_for_merging(self, pull_request, user):
2141 2140 env_user_name_attr = os.environ.get('RC_MERGE_USER_NAME_ATTR', '')
2142 2141 if env_user_name_attr and hasattr(user, env_user_name_attr):
2143 2142 user_name_attr = env_user_name_attr
2144 2143 else:
2145 2144 user_name_attr = 'short_contact'
2146 2145
2147 2146 user_name = getattr(user, user_name_attr)
2148 2147 return user_name
2149 2148
2150 2149 def _close_branch_before_merging(self, pull_request):
2151 2150 repo_type = pull_request.target_repo.repo_type
2152 2151 if repo_type == 'hg':
2153 2152 return self._get_general_setting(
2154 2153 pull_request, 'rhodecode_hg_close_branch_before_merging')
2155 2154 elif repo_type == 'git':
2156 2155 return self._get_general_setting(
2157 2156 pull_request, 'rhodecode_git_close_branch_before_merging')
2158 2157
2159 2158 return False
2160 2159
2161 2160 def _get_general_setting(self, pull_request, settings_key, default=False):
2162 2161 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
2163 2162 settings = settings_model.get_general_settings()
2164 2163 return settings.get(settings_key, default)
2165 2164
2166 2165 def _log_audit_action(self, action, action_data, user, pull_request):
2167 2166 audit_logger.store(
2168 2167 action=action,
2169 2168 action_data=action_data,
2170 2169 user=user,
2171 2170 repo=pull_request.target_repo)
2172 2171
2173 2172 def get_reviewer_functions(self):
2174 2173 """
2175 2174 Fetches functions for validation and fetching default reviewers.
2176 2175 If available we use the EE package, else we fallback to CE
2177 2176 package functions
2178 2177 """
2179 2178 try:
2180 2179 from rc_reviewers.utils import get_default_reviewers_data
2181 2180 from rc_reviewers.utils import validate_default_reviewers
2182 2181 from rc_reviewers.utils import validate_observers
2183 2182 except ImportError:
2184 2183 from rhodecode.apps.repository.utils import get_default_reviewers_data
2185 2184 from rhodecode.apps.repository.utils import validate_default_reviewers
2186 2185 from rhodecode.apps.repository.utils import validate_observers
2187 2186
2188 2187 return get_default_reviewers_data, validate_default_reviewers, validate_observers
2189 2188
2190 2189
2191 2190 class MergeCheck(object):
2192 2191 """
2193 2192 Perform Merge Checks and returns a check object which stores information
2194 2193 about merge errors, and merge conditions
2195 2194 """
2196 2195 TODO_CHECK = 'todo'
2197 2196 PERM_CHECK = 'perm'
2198 2197 REVIEW_CHECK = 'review'
2199 2198 MERGE_CHECK = 'merge'
2200 2199 WIP_CHECK = 'wip'
2201 2200
2202 2201 def __init__(self):
2203 2202 self.review_status = None
2204 2203 self.merge_possible = None
2205 2204 self.merge_msg = ''
2206 2205 self.merge_response = None
2207 2206 self.failed = None
2208 2207 self.errors = []
2209 2208 self.error_details = OrderedDict()
2210 2209 self.source_commit = AttributeDict()
2211 2210 self.target_commit = AttributeDict()
2212 2211 self.reviewers_count = 0
2213 2212 self.observers_count = 0
2214 2213
2215 2214 def __repr__(self):
2216 2215 return '<MergeCheck(possible:{}, failed:{}, errors:{})>'.format(
2217 2216 self.merge_possible, self.failed, self.errors)
2218 2217
2219 2218 def push_error(self, error_type, message, error_key, details):
2220 2219 self.failed = True
2221 2220 self.errors.append([error_type, message])
2222 2221 self.error_details[error_key] = dict(
2223 2222 details=details,
2224 2223 error_type=error_type,
2225 2224 message=message
2226 2225 )
2227 2226
2228 2227 @classmethod
2229 2228 def validate(cls, pull_request, auth_user, translator, fail_early=False,
2230 2229 force_shadow_repo_refresh=False):
2231 2230 _ = translator
2232 2231 merge_check = cls()
2233 2232
2234 2233 # title has WIP:
2235 2234 if pull_request.work_in_progress:
2236 2235 log.debug("MergeCheck: cannot merge, title has wip: marker.")
2237 2236
2238 2237 msg = _('WIP marker in title prevents from accidental merge.')
2239 2238 merge_check.push_error('error', msg, cls.WIP_CHECK, pull_request.title)
2240 2239 if fail_early:
2241 2240 return merge_check
2242 2241
2243 2242 # permissions to merge
2244 2243 user_allowed_to_merge = PullRequestModel().check_user_merge(pull_request, auth_user)
2245 2244 if not user_allowed_to_merge:
2246 2245 log.debug("MergeCheck: cannot merge, approval is pending.")
2247 2246
2248 2247 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
2249 2248 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2250 2249 if fail_early:
2251 2250 return merge_check
2252 2251
2253 2252 # permission to merge into the target branch
2254 2253 target_commit_id = pull_request.target_ref_parts.commit_id
2255 2254 if pull_request.target_ref_parts.type == 'branch':
2256 2255 branch_name = pull_request.target_ref_parts.name
2257 2256 else:
2258 2257 # for mercurial we can always figure out the branch from the commit
2259 2258 # in case of bookmark
2260 2259 target_commit = pull_request.target_repo.get_commit(target_commit_id)
2261 2260 branch_name = target_commit.branch
2262 2261
2263 2262 rule, branch_perm = auth_user.get_rule_and_branch_permission(
2264 2263 pull_request.target_repo.repo_name, branch_name)
2265 2264 if branch_perm and branch_perm == 'branch.none':
2266 2265 msg = _('Target branch `{}` changes rejected by rule {}.').format(
2267 2266 branch_name, rule)
2268 2267 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2269 2268 if fail_early:
2270 2269 return merge_check
2271 2270
2272 2271 # review status, must be always present
2273 2272 review_status = pull_request.calculated_review_status()
2274 2273 merge_check.review_status = review_status
2275 2274 merge_check.reviewers_count = pull_request.reviewers_count
2276 2275 merge_check.observers_count = pull_request.observers_count
2277 2276
2278 2277 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
2279 2278 if not status_approved and merge_check.reviewers_count:
2280 2279 log.debug("MergeCheck: cannot merge, approval is pending.")
2281 2280 msg = _('Pull request reviewer approval is pending.')
2282 2281
2283 2282 merge_check.push_error('warning', msg, cls.REVIEW_CHECK, review_status)
2284 2283
2285 2284 if fail_early:
2286 2285 return merge_check
2287 2286
2288 2287 # left over TODOs
2289 2288 todos = CommentsModel().get_pull_request_unresolved_todos(pull_request)
2290 2289 if todos:
2291 2290 log.debug("MergeCheck: cannot merge, {} "
2292 2291 "unresolved TODOs left.".format(len(todos)))
2293 2292
2294 2293 if len(todos) == 1:
2295 2294 msg = _('Cannot merge, {} TODO still not resolved.').format(
2296 2295 len(todos))
2297 2296 else:
2298 2297 msg = _('Cannot merge, {} TODOs still not resolved.').format(
2299 2298 len(todos))
2300 2299
2301 2300 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
2302 2301
2303 2302 if fail_early:
2304 2303 return merge_check
2305 2304
2306 2305 # merge possible, here is the filesystem simulation + shadow repo
2307 2306 merge_response, merge_status, msg = PullRequestModel().merge_status(
2308 2307 pull_request, translator=translator,
2309 2308 force_shadow_repo_refresh=force_shadow_repo_refresh)
2310 2309
2311 2310 merge_check.merge_possible = merge_status
2312 2311 merge_check.merge_msg = msg
2313 2312 merge_check.merge_response = merge_response
2314 2313
2315 2314 source_ref_id = pull_request.source_ref_parts.commit_id
2316 2315 target_ref_id = pull_request.target_ref_parts.commit_id
2317 2316
2318 2317 try:
2319 2318 source_commit, target_commit = PullRequestModel().get_flow_commits(pull_request)
2320 2319 merge_check.source_commit.changed = source_ref_id != source_commit.raw_id
2321 2320 merge_check.source_commit.ref_spec = pull_request.source_ref_parts
2322 2321 merge_check.source_commit.current_raw_id = source_commit.raw_id
2323 2322 merge_check.source_commit.previous_raw_id = source_ref_id
2324 2323
2325 2324 merge_check.target_commit.changed = target_ref_id != target_commit.raw_id
2326 2325 merge_check.target_commit.ref_spec = pull_request.target_ref_parts
2327 2326 merge_check.target_commit.current_raw_id = target_commit.raw_id
2328 2327 merge_check.target_commit.previous_raw_id = target_ref_id
2329 2328 except (SourceRefMissing, TargetRefMissing):
2330 2329 pass
2331 2330
2332 2331 if not merge_status:
2333 2332 log.debug("MergeCheck: cannot merge, pull request merge not possible.")
2334 2333 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
2335 2334
2336 2335 if fail_early:
2337 2336 return merge_check
2338 2337
2339 2338 log.debug('MergeCheck: is failed: %s', merge_check.failed)
2340 2339 return merge_check
2341 2340
2342 2341 @classmethod
2343 2342 def get_merge_conditions(cls, pull_request, translator):
2344 2343 _ = translator
2345 2344 merge_details = {}
2346 2345
2347 2346 model = PullRequestModel()
2348 2347 use_rebase = model._use_rebase_for_merging(pull_request)
2349 2348
2350 2349 if use_rebase:
2351 2350 merge_details['merge_strategy'] = dict(
2352 2351 details={},
2353 2352 message=_('Merge strategy: rebase')
2354 2353 )
2355 2354 else:
2356 2355 merge_details['merge_strategy'] = dict(
2357 2356 details={},
2358 2357 message=_('Merge strategy: explicit merge commit')
2359 2358 )
2360 2359
2361 2360 close_branch = model._close_branch_before_merging(pull_request)
2362 2361 if close_branch:
2363 2362 repo_type = pull_request.target_repo.repo_type
2364 2363 close_msg = ''
2365 2364 if repo_type == 'hg':
2366 2365 close_msg = _('Source branch will be closed before the merge.')
2367 2366 elif repo_type == 'git':
2368 2367 close_msg = _('Source branch will be deleted after the merge.')
2369 2368
2370 2369 merge_details['close_branch'] = dict(
2371 2370 details={},
2372 2371 message=close_msg
2373 2372 )
2374 2373
2375 2374 return merge_details
2376 2375
2377 2376
2378 2377 @dataclasses.dataclass
2379 2378 class ChangeTuple:
2380 2379 added: list
2381 2380 common: list
2382 2381 removed: list
2383 2382 total: list
2384 2383
2385 2384
2386 2385 @dataclasses.dataclass
2387 2386 class FileChangeTuple:
2388 2387 added: list
2389 2388 modified: list
2390 2389 removed: list
@@ -1,487 +1,451 b''
1 1
2 2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 import mock
21 21 import pytest
22 22
23 23 from rhodecode.lib.str_utils import base64_to_str
24 24 from rhodecode.lib.utils2 import AttributeDict
25 25 from rhodecode.tests.utils import CustomTestApp
26 26
27 27 from rhodecode.lib.caching_query import FromCache
28 from rhodecode.lib.hooks_daemon import DummyHooksCallbackDaemon
29 28 from rhodecode.lib.middleware import simplevcs
30 29 from rhodecode.lib.middleware.https_fixup import HttpsFixup
31 30 from rhodecode.lib.middleware.utils import scm_app_http
32 31 from rhodecode.model.db import User, _hash_key
33 32 from rhodecode.model.meta import Session, cache as db_cache
34 33 from rhodecode.tests import (
35 34 HG_REPO, TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS)
36 35 from rhodecode.tests.lib.middleware import mock_scm_app
37 36
38 37
39 38 class StubVCSController(simplevcs.SimpleVCS):
40 39
41 40 SCM = 'hg'
42 41 stub_response_body = tuple()
43 42
44 43 def __init__(self, *args, **kwargs):
45 44 super(StubVCSController, self).__init__(*args, **kwargs)
46 45 self._action = 'pull'
47 46 self._is_shadow_repo_dir = True
48 47 self._name = HG_REPO
49 48 self.set_repo_names(None)
50 49
51 50 @property
52 51 def is_shadow_repo_dir(self):
53 52 return self._is_shadow_repo_dir
54 53
55 54 def _get_repository_name(self, environ):
56 55 return self._name
57 56
58 57 def _get_action(self, environ):
59 58 return self._action
60 59
61 60 def _create_wsgi_app(self, repo_path, repo_name, config):
62 61 def fake_app(environ, start_response):
63 62 headers = [
64 63 ('Http-Accept', 'application/mercurial')
65 64 ]
66 65 start_response('200 OK', headers)
67 66 return self.stub_response_body
68 67 return fake_app
69 68
70 69 def _create_config(self, extras, repo_name, scheme='http'):
71 70 return None
72 71
73 72
74 73 @pytest.fixture()
75 74 def vcscontroller(baseapp, config_stub, request_stub):
76 75 from rhodecode.config.middleware import ce_auth_resources
77 76
78 77 config_stub.testing_securitypolicy()
79 78 config_stub.include('rhodecode.authentication')
80 79
81 80 for resource in ce_auth_resources:
82 81 config_stub.include(resource)
83 82
84 83 controller = StubVCSController(
85 84 baseapp.config.get_settings(), request_stub.registry)
86 85 app = HttpsFixup(controller, baseapp.config.get_settings())
87 86 app = CustomTestApp(app)
88 87
89 88 _remove_default_user_from_query_cache()
90 89
91 90 # Sanity checks that things are set up correctly
92 91 app.get('/' + HG_REPO, status=200)
93 92
94 93 app.controller = controller
95 94 return app
96 95
97 96
98 97 def _remove_default_user_from_query_cache():
99 98 user = User.get_default_user(cache=True)
100 99 query = Session().query(User).filter(User.username == user.username)
101 100 query = query.options(
102 101 FromCache("sql_cache_short", f"get_user_{_hash_key(user.username)}"))
103 102
104 103 db_cache.invalidate(
105 104 query, {},
106 105 FromCache("sql_cache_short", f"get_user_{_hash_key(user.username)}"))
107 106
108 107 Session().expire(user)
109 108
110 109
111 110 def test_handles_exceptions_during_permissions_checks(
112 111 vcscontroller, disable_anonymous_user, enable_auth_plugins, test_user_factory):
113 112
114 113 test_password = 'qweqwe'
115 114 test_user = test_user_factory(password=test_password, extern_type='headers', extern_name='headers')
116 115 test_username = test_user.username
117 116
118 117 enable_auth_plugins.enable([
119 118 'egg:rhodecode-enterprise-ce#headers',
120 119 'egg:rhodecode-enterprise-ce#token',
121 120 'egg:rhodecode-enterprise-ce#rhodecode'],
122 121 override={
123 122 'egg:rhodecode-enterprise-ce#headers': {'auth_headers_header': 'REMOTE_USER'}
124 123 })
125 124
126 125 user_and_pass = f'{test_username}:{test_password}'
127 126 auth_password = base64_to_str(user_and_pass)
128 127
129 128 extra_environ = {
130 129 'AUTH_TYPE': 'Basic',
131 130 'HTTP_AUTHORIZATION': f'Basic {auth_password}',
132 131 'REMOTE_USER': test_username,
133 132 }
134 133
135 134 # Verify that things are hooked up correctly, we pass user with headers bound auth, and headers filled in
136 135 vcscontroller.get('/', status=200, extra_environ=extra_environ)
137 136
138 137 # Simulate trouble during permission checks
139 138 with mock.patch('rhodecode.model.db.User.get_by_username',
140 139 side_effect=Exception('permission_error_test')) as get_user:
141 140 # Verify that a correct 500 is returned and check that the expected
142 141 # code path was hit.
143 142 vcscontroller.get('/', status=500, extra_environ=extra_environ)
144 143 assert get_user.called
145 144
146 145
147 146 class StubFailVCSController(simplevcs.SimpleVCS):
148 147 def _handle_request(self, environ, start_response):
149 148 raise Exception("BOOM")
150 149
151 150
152 151 @pytest.fixture(scope='module')
153 152 def fail_controller(baseapp):
154 153 controller = StubFailVCSController(
155 154 baseapp.config.get_settings(), baseapp.config)
156 155 controller = HttpsFixup(controller, baseapp.config.get_settings())
157 156 controller = CustomTestApp(controller)
158 157 return controller
159 158
160 159
161 160 def test_handles_exceptions_as_internal_server_error(fail_controller):
162 161 fail_controller.get('/', status=500)
163 162
164 163
165 164 def test_provides_traceback_for_appenlight(fail_controller):
166 165 response = fail_controller.get(
167 166 '/', status=500, extra_environ={'appenlight.client': 'fake'})
168 167 assert 'appenlight.__traceback' in response.request.environ
169 168
170 169
171 170 def test_provides_utils_scm_app_as_scm_app_by_default(baseapp, request_stub):
172 171 controller = StubVCSController(baseapp.config.get_settings(), request_stub.registry)
173 172 assert controller.scm_app is scm_app_http
174 173
175 174
176 175 def test_allows_to_override_scm_app_via_config(baseapp, request_stub):
177 176 config = baseapp.config.get_settings().copy()
178 177 config['vcs.scm_app_implementation'] = (
179 178 'rhodecode.tests.lib.middleware.mock_scm_app')
180 179 controller = StubVCSController(config, request_stub.registry)
181 180 assert controller.scm_app is mock_scm_app
182 181
183 182
184 183 @pytest.mark.parametrize('query_string, expected', [
185 184 ('cmd=stub_command', True),
186 185 ('cmd=listkeys', False),
187 186 ])
188 187 def test_should_check_locking(query_string, expected):
189 188 result = simplevcs._should_check_locking(query_string)
190 189 assert result == expected
191 190
192 191
193 192 class TestShadowRepoRegularExpression(object):
194 193 pr_segment = 'pull-request'
195 194 shadow_segment = 'repository'
196 195
197 196 @pytest.mark.parametrize('url, expected', [
198 197 # repo with/without groups
199 198 ('My-Repo/{pr_segment}/1/{shadow_segment}', True),
200 199 ('Group/My-Repo/{pr_segment}/2/{shadow_segment}', True),
201 200 ('Group/Sub-Group/My-Repo/{pr_segment}/3/{shadow_segment}', True),
202 201 ('Group/Sub-Group1/Sub-Group2/My-Repo/{pr_segment}/3/{shadow_segment}', True),
203 202
204 203 # pull request ID
205 204 ('MyRepo/{pr_segment}/1/{shadow_segment}', True),
206 205 ('MyRepo/{pr_segment}/1234567890/{shadow_segment}', True),
207 206 ('MyRepo/{pr_segment}/-1/{shadow_segment}', False),
208 207 ('MyRepo/{pr_segment}/invalid/{shadow_segment}', False),
209 208
210 209 # unicode
211 210 (u'Sp€çîál-Repö/{pr_segment}/1/{shadow_segment}', True),
212 211 (u'Sp€çîál-Gröüp/Sp€çîál-Repö/{pr_segment}/1/{shadow_segment}', True),
213 212
214 213 # trailing/leading slash
215 214 ('/My-Repo/{pr_segment}/1/{shadow_segment}', False),
216 215 ('My-Repo/{pr_segment}/1/{shadow_segment}/', False),
217 216 ('/My-Repo/{pr_segment}/1/{shadow_segment}/', False),
218 217
219 218 # misc
220 219 ('My-Repo/{pr_segment}/1/{shadow_segment}/extra', False),
221 220 ('My-Repo/{pr_segment}/1/{shadow_segment}extra', False),
222 221 ])
223 222 def test_shadow_repo_regular_expression(self, url, expected):
224 223 from rhodecode.lib.middleware.simplevcs import SimpleVCS
225 224 url = url.format(
226 225 pr_segment=self.pr_segment,
227 226 shadow_segment=self.shadow_segment)
228 227 match_obj = SimpleVCS.shadow_repo_re.match(url)
229 228 assert (match_obj is not None) == expected
230 229
231 230
232 231 @pytest.mark.backends('git', 'hg')
233 232 class TestShadowRepoExposure(object):
234 233
235 234 def test_pull_on_shadow_repo_propagates_to_wsgi_app(
236 235 self, baseapp, request_stub):
237 236 """
238 237 Check that a pull action to a shadow repo is propagated to the
239 238 underlying wsgi app.
240 239 """
241 240 controller = StubVCSController(
242 241 baseapp.config.get_settings(), request_stub.registry)
243 242 controller._check_ssl = mock.Mock()
244 243 controller.is_shadow_repo = True
245 244 controller._action = 'pull'
246 245 controller._is_shadow_repo_dir = True
247 246 controller.stub_response_body = (b'dummy body value',)
248 247 controller._get_default_cache_ttl = mock.Mock(
249 248 return_value=(False, 0))
250 249
251 250 environ_stub = {
252 251 'HTTP_HOST': 'test.example.com',
253 252 'HTTP_ACCEPT': 'application/mercurial',
254 253 'REQUEST_METHOD': 'GET',
255 254 'wsgi.url_scheme': 'http',
256 255 }
257 256
258 257 response = controller(environ_stub, mock.Mock())
259 258 response_body = b''.join(response)
260 259
261 260 # Assert that we got the response from the wsgi app.
262 261 assert response_body == b''.join(controller.stub_response_body)
263 262
264 263 def test_pull_on_shadow_repo_that_is_missing(self, baseapp, request_stub):
265 264 """
266 265 Check that a pull action to a shadow repo is propagated to the
267 266 underlying wsgi app.
268 267 """
269 268 controller = StubVCSController(
270 269 baseapp.config.get_settings(), request_stub.registry)
271 270 controller._check_ssl = mock.Mock()
272 271 controller.is_shadow_repo = True
273 272 controller._action = 'pull'
274 273 controller._is_shadow_repo_dir = False
275 274 controller.stub_response_body = (b'dummy body value',)
276 275 environ_stub = {
277 276 'HTTP_HOST': 'test.example.com',
278 277 'HTTP_ACCEPT': 'application/mercurial',
279 278 'REQUEST_METHOD': 'GET',
280 279 'wsgi.url_scheme': 'http',
281 280 }
282 281
283 282 response = controller(environ_stub, mock.Mock())
284 283 response_body = b''.join(response)
285 284
286 285 # Assert that we got the response from the wsgi app.
287 286 assert b'404 Not Found' in response_body
288 287
289 288 def test_push_on_shadow_repo_raises(self, baseapp, request_stub):
290 289 """
291 290 Check that a push action to a shadow repo is aborted.
292 291 """
293 292 controller = StubVCSController(
294 293 baseapp.config.get_settings(), request_stub.registry)
295 294 controller._check_ssl = mock.Mock()
296 295 controller.is_shadow_repo = True
297 296 controller._action = 'push'
298 297 controller.stub_response_body = (b'dummy body value',)
299 298 environ_stub = {
300 299 'HTTP_HOST': 'test.example.com',
301 300 'HTTP_ACCEPT': 'application/mercurial',
302 301 'REQUEST_METHOD': 'GET',
303 302 'wsgi.url_scheme': 'http',
304 303 }
305 304
306 305 response = controller(environ_stub, mock.Mock())
307 306 response_body = b''.join(response)
308 307
309 308 assert response_body != controller.stub_response_body
310 309 # Assert that a 406 error is returned.
311 310 assert b'406 Not Acceptable' in response_body
312 311
313 312 def test_set_repo_names_no_shadow(self, baseapp, request_stub):
314 313 """
315 314 Check that the set_repo_names method sets all names to the one returned
316 315 by the _get_repository_name method on a request to a non shadow repo.
317 316 """
318 317 environ_stub = {}
319 318 controller = StubVCSController(
320 319 baseapp.config.get_settings(), request_stub.registry)
321 320 controller._name = 'RepoGroup/MyRepo'
322 321 controller.set_repo_names(environ_stub)
323 322 assert not controller.is_shadow_repo
324 323 assert (controller.url_repo_name ==
325 324 controller.acl_repo_name ==
326 325 controller.vcs_repo_name ==
327 326 controller._get_repository_name(environ_stub))
328 327
329 328 def test_set_repo_names_with_shadow(
330 329 self, baseapp, pr_util, config_stub, request_stub):
331 330 """
332 331 Check that the set_repo_names method sets correct names on a request
333 332 to a shadow repo.
334 333 """
335 334 from rhodecode.model.pull_request import PullRequestModel
336 335
337 336 pull_request = pr_util.create_pull_request()
338 337 shadow_url = '{target}/{pr_segment}/{pr_id}/{shadow_segment}'.format(
339 338 target=pull_request.target_repo.repo_name,
340 339 pr_id=pull_request.pull_request_id,
341 340 pr_segment=TestShadowRepoRegularExpression.pr_segment,
342 341 shadow_segment=TestShadowRepoRegularExpression.shadow_segment)
343 342 controller = StubVCSController(
344 343 baseapp.config.get_settings(), request_stub.registry)
345 344 controller._name = shadow_url
346 345 controller.set_repo_names({})
347 346
348 347 # Get file system path to shadow repo for assertions.
349 348 workspace_id = PullRequestModel()._workspace_id(pull_request)
350 349 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
351 350
352 351 assert controller.vcs_repo_name == vcs_repo_name
353 352 assert controller.url_repo_name == shadow_url
354 353 assert controller.acl_repo_name == pull_request.target_repo.repo_name
355 354 assert controller.is_shadow_repo
356 355
357 356 def test_set_repo_names_with_shadow_but_missing_pr(
358 357 self, baseapp, pr_util, config_stub, request_stub):
359 358 """
360 359 Checks that the set_repo_names method enforces matching target repos
361 360 and pull request IDs.
362 361 """
363 362 pull_request = pr_util.create_pull_request()
364 363 shadow_url = '{target}/{pr_segment}/{pr_id}/{shadow_segment}'.format(
365 364 target=pull_request.target_repo.repo_name,
366 365 pr_id=999999999,
367 366 pr_segment=TestShadowRepoRegularExpression.pr_segment,
368 367 shadow_segment=TestShadowRepoRegularExpression.shadow_segment)
369 368 controller = StubVCSController(
370 369 baseapp.config.get_settings(), request_stub.registry)
371 370 controller._name = shadow_url
372 371 controller.set_repo_names({})
373 372
374 373 assert not controller.is_shadow_repo
375 374 assert (controller.url_repo_name ==
376 375 controller.acl_repo_name ==
377 376 controller.vcs_repo_name)
378 377
379 378
380 379 @pytest.mark.usefixtures('baseapp')
381 380 class TestGenerateVcsResponse(object):
382 381
383 382 def test_ensures_that_start_response_is_called_early_enough(self):
384 383 self.call_controller_with_response_body(iter(['a', 'b']))
385 384 assert self.start_response.called
386 385
387 386 def test_invalidates_cache_after_body_is_consumed(self):
388 387 result = self.call_controller_with_response_body(iter(['a', 'b']))
389 388 assert not self.was_cache_invalidated()
390 389 # Consume the result
391 390 list(result)
392 391 assert self.was_cache_invalidated()
393 392
394 393 def test_raises_unknown_exceptions(self):
395 394 result = self.call_controller_with_response_body(
396 395 self.raise_result_iter(vcs_kind='unknown'))
397 396 with pytest.raises(Exception):
398 397 list(result)
399 398
400 def test_prepare_callback_daemon_is_called(self):
401 def side_effect(extras, environ, action, txn_id=None):
402 return DummyHooksCallbackDaemon(), extras
403
404 prepare_patcher = mock.patch.object(
405 StubVCSController, '_prepare_callback_daemon')
406 with prepare_patcher as prepare_mock:
407 prepare_mock.side_effect = side_effect
408 self.call_controller_with_response_body(iter(['a', 'b']))
409 assert prepare_mock.called
410 assert prepare_mock.call_count == 1
411
412 399 def call_controller_with_response_body(self, response_body):
413 400 settings = {
414 401 'base_path': 'fake_base_path',
415 402 'vcs.hooks.protocol': 'http',
416 403 'vcs.hooks.direct_calls': False,
417 404 }
418 405 registry = AttributeDict()
419 406 controller = StubVCSController(settings, registry)
420 407 controller._invalidate_cache = mock.Mock()
421 408 controller.stub_response_body = response_body
422 409 self.start_response = mock.Mock()
423 410 result = controller._generate_vcs_response(
424 411 environ={}, start_response=self.start_response,
425 412 repo_path='fake_repo_path',
426 413 extras={}, action='push')
427 414 self.controller = controller
428 415 return result
429 416
430 417 def raise_result_iter(self, vcs_kind='repo_locked'):
431 418 """
432 419 Simulates an exception due to a vcs raised exception if kind vcs_kind
433 420 """
434 421 raise self.vcs_exception(vcs_kind=vcs_kind)
435 422 yield "never_reached"
436 423
437 424 def vcs_exception(self, vcs_kind='repo_locked'):
438 425 locked_exception = Exception('TEST_MESSAGE')
439 426 locked_exception._vcs_kind = vcs_kind
440 427 return locked_exception
441 428
442 429 def was_cache_invalidated(self):
443 430 return self.controller._invalidate_cache.called
444 431
445 432
446 433 class TestInitializeGenerator(object):
447 434
448 435 def test_drains_first_element(self):
449 436 gen = self.factory(['__init__', 1, 2])
450 437 result = list(gen)
451 438 assert result == [1, 2]
452 439
453 440 @pytest.mark.parametrize('values', [
454 441 [],
455 442 [1, 2],
456 443 ])
457 444 def test_raises_value_error(self, values):
458 445 with pytest.raises(ValueError):
459 446 self.factory(values)
460 447
461 448 @simplevcs.initialize_generator
462 449 def factory(self, iterable):
463 450 for elem in iterable:
464 451 yield elem
465
466
467 class TestPrepareHooksDaemon(object):
468 def test_calls_imported_prepare_callback_daemon(self, app_settings, request_stub):
469 expected_extras = {'extra1': 'value1'}
470 daemon = DummyHooksCallbackDaemon()
471
472 controller = StubVCSController(app_settings, request_stub.registry)
473 prepare_patcher = mock.patch.object(
474 simplevcs, 'prepare_callback_daemon',
475 return_value=(daemon, expected_extras))
476 with prepare_patcher as prepare_mock:
477 callback_daemon, extras = controller._prepare_callback_daemon(
478 expected_extras.copy(), {}, 'push')
479 prepare_mock.assert_called_once_with(
480 expected_extras,
481 protocol=app_settings['vcs.hooks.protocol'],
482 host=app_settings['vcs.hooks.host'],
483 txn_id=None,
484 use_direct_calls=app_settings['vcs.hooks.direct_calls'])
485
486 assert callback_daemon == daemon
487 assert extras == extras
@@ -1,376 +1,355 b''
1 1
2 2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 import logging
21 21 import io
22 22
23 23 import mock
24 24 import msgpack
25 25 import pytest
26 import tempfile
26 27
27 28 from rhodecode.lib import hooks_daemon
28 29 from rhodecode.lib.str_utils import safe_bytes
29 30 from rhodecode.tests.utils import assert_message_in_log
30 31 from rhodecode.lib.ext_json import json
31 32
32 33 test_proto = hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO
33 34
34 35
35 class TestDummyHooksCallbackDaemon(object):
36 def test_hooks_module_path_set_properly(self):
37 daemon = hooks_daemon.DummyHooksCallbackDaemon()
38 assert daemon.hooks_module == 'rhodecode.lib.hooks_daemon'
39
40 def test_logs_entering_the_hook(self):
41 daemon = hooks_daemon.DummyHooksCallbackDaemon()
42 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
43 with daemon as return_value:
44 log_mock.assert_called_once_with(
45 'Running `%s` callback daemon', 'DummyHooksCallbackDaemon')
46 assert return_value == daemon
47
48 def test_logs_exiting_the_hook(self):
49 daemon = hooks_daemon.DummyHooksCallbackDaemon()
50 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
51 with daemon:
52 pass
53 log_mock.assert_called_with(
54 'Exiting `%s` callback daemon', 'DummyHooksCallbackDaemon')
55
56
57 36 class TestHooks(object):
58 37 def test_hooks_can_be_used_as_a_context_processor(self):
59 38 hooks = hooks_daemon.Hooks()
60 39 with hooks as return_value:
61 40 pass
62 41 assert hooks == return_value
63 42
64 43
65 44 class TestHooksHttpHandler(object):
66 45 def test_read_request_parses_method_name_and_arguments(self):
67 46 data = {
68 47 'method': 'test',
69 48 'extras': {
70 49 'param1': 1,
71 50 'param2': 'a'
72 51 }
73 52 }
74 53 request = self._generate_post_request(data)
75 54 hooks_patcher = mock.patch.object(
76 55 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
77 56
78 57 with hooks_patcher as hooks_mock:
79 58 handler = hooks_daemon.HooksHttpHandler
80 59 handler.DEFAULT_HOOKS_PROTO = test_proto
81 60 handler.wbufsize = 10240
82 61 MockServer(handler, request)
83 62
84 63 hooks_mock.assert_called_once_with(data['extras'])
85 64
86 65 def test_hooks_serialized_result_is_returned(self):
87 66 request = self._generate_post_request({})
88 67 rpc_method = 'test'
89 68 hook_result = {
90 69 'first': 'one',
91 70 'second': 2
92 71 }
93 72 extras = {}
94 73
95 74 # patching our _read to return test method and proto used
96 75 read_patcher = mock.patch.object(
97 76 hooks_daemon.HooksHttpHandler, '_read_request',
98 77 return_value=(test_proto, rpc_method, extras))
99 78
100 79 # patch Hooks instance to return hook_result data on 'test' call
101 80 hooks_patcher = mock.patch.object(
102 81 hooks_daemon.Hooks, rpc_method, create=True,
103 82 return_value=hook_result)
104 83
105 84 with read_patcher, hooks_patcher:
106 85 handler = hooks_daemon.HooksHttpHandler
107 86 handler.DEFAULT_HOOKS_PROTO = test_proto
108 87 handler.wbufsize = 10240
109 88 server = MockServer(handler, request)
110 89
111 90 expected_result = hooks_daemon.HooksHttpHandler.serialize_data(hook_result)
112 91
113 92 server.request.output_stream.seek(0)
114 93 assert server.request.output_stream.readlines()[-1] == expected_result
115 94
116 95 def test_exception_is_returned_in_response(self):
117 96 request = self._generate_post_request({})
118 97 rpc_method = 'test'
119 98
120 99 read_patcher = mock.patch.object(
121 100 hooks_daemon.HooksHttpHandler, '_read_request',
122 101 return_value=(test_proto, rpc_method, {}))
123 102
124 103 hooks_patcher = mock.patch.object(
125 104 hooks_daemon.Hooks, rpc_method, create=True,
126 105 side_effect=Exception('Test exception'))
127 106
128 107 with read_patcher, hooks_patcher:
129 108 handler = hooks_daemon.HooksHttpHandler
130 109 handler.DEFAULT_HOOKS_PROTO = test_proto
131 110 handler.wbufsize = 10240
132 111 server = MockServer(handler, request)
133 112
134 113 server.request.output_stream.seek(0)
135 114 data = server.request.output_stream.readlines()
136 115 msgpack_data = b''.join(data[5:])
137 116 org_exc = hooks_daemon.HooksHttpHandler.deserialize_data(msgpack_data)
138 117 expected_result = {
139 118 'exception': 'Exception',
140 119 'exception_traceback': org_exc['exception_traceback'],
141 120 'exception_args': ['Test exception']
142 121 }
143 122 assert org_exc == expected_result
144 123
145 124 def test_log_message_writes_to_debug_log(self, caplog):
146 125 ip_port = ('0.0.0.0', 8888)
147 126 handler = hooks_daemon.HooksHttpHandler(
148 127 MockRequest('POST /'), ip_port, mock.Mock())
149 128 fake_date = '1/Nov/2015 00:00:00'
150 129 date_patcher = mock.patch.object(
151 130 handler, 'log_date_time_string', return_value=fake_date)
152 131
153 132 with date_patcher, caplog.at_level(logging.DEBUG):
154 133 handler.log_message('Some message %d, %s', 123, 'string')
155 134
156 135 expected_message = f"HOOKS: client={ip_port} - - [{fake_date}] Some message 123, string"
157 136
158 137 assert_message_in_log(
159 138 caplog.records, expected_message,
160 139 levelno=logging.DEBUG, module='hooks_daemon')
161 140
162 141 def _generate_post_request(self, data, proto=test_proto):
163 142 if proto == hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO:
164 143 payload = msgpack.packb(data)
165 144 else:
166 145 payload = json.dumps(data)
167 146
168 147 return b'POST / HTTP/1.0\nContent-Length: %d\n\n%b' % (
169 148 len(payload), payload)
170 149
171 150
172 151 class ThreadedHookCallbackDaemon(object):
173 152 def test_constructor_calls_prepare(self):
174 153 prepare_daemon_patcher = mock.patch.object(
175 154 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
176 155 with prepare_daemon_patcher as prepare_daemon_mock:
177 156 hooks_daemon.ThreadedHookCallbackDaemon()
178 157 prepare_daemon_mock.assert_called_once_with()
179 158
180 159 def test_run_is_called_on_context_start(self):
181 160 patchers = mock.patch.multiple(
182 161 hooks_daemon.ThreadedHookCallbackDaemon,
183 162 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
184 163
185 164 with patchers as mocks:
186 165 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
187 166 with daemon as daemon_context:
188 167 pass
189 168 mocks['_run'].assert_called_once_with()
190 169 assert daemon_context == daemon
191 170
192 171 def test_stop_is_called_on_context_exit(self):
193 172 patchers = mock.patch.multiple(
194 173 hooks_daemon.ThreadedHookCallbackDaemon,
195 174 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
196 175
197 176 with patchers as mocks:
198 177 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
199 178 with daemon as daemon_context:
200 179 assert mocks['_stop'].call_count == 0
201 180
202 181 mocks['_stop'].assert_called_once_with()
203 182 assert daemon_context == daemon
204 183
205 184
206 185 class TestHttpHooksCallbackDaemon(object):
207 186 def test_hooks_callback_generates_new_port(self, caplog):
208 187 with caplog.at_level(logging.DEBUG):
209 188 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
210 189 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
211 190
212 191 with caplog.at_level(logging.DEBUG):
213 192 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
214 193 assert daemon._daemon.server_address[1] in range(0, 66000)
215 194 assert daemon._daemon.server_address[0] != '127.0.0.1'
216 195
217 196 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
218 197 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
219 198 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
220 199 assert daemon._daemon == tcp_server
221 200
222 201 _, port = tcp_server.server_address
223 202
224 203 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
225 204 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
226 205 assert_message_in_log(
227 206 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
228 207
229 208 def test_prepare_inits_hooks_uri_and_logs_it(
230 209 self, tcp_server, caplog):
231 210 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
232 211 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
233 212
234 213 _, port = tcp_server.server_address
235 214 expected_uri = '{}:{}'.format('127.0.0.1', port)
236 215 assert daemon.hooks_uri == expected_uri
237 216
238 217 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
239 218 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
240 219 assert_message_in_log(
241 220 caplog.records, msg,
242 221 levelno=logging.DEBUG, module='hooks_daemon')
243 222
244 223 def test_run_creates_a_thread(self, tcp_server):
245 224 thread = mock.Mock()
246 225
247 226 with self._tcp_patcher(tcp_server):
248 227 daemon = hooks_daemon.HttpHooksCallbackDaemon()
249 228
250 229 with self._thread_patcher(thread) as thread_mock:
251 230 daemon._run()
252 231
253 232 thread_mock.assert_called_once_with(
254 233 target=tcp_server.serve_forever,
255 234 kwargs={'poll_interval': daemon.POLL_INTERVAL})
256 235 assert thread.daemon is True
257 236 thread.start.assert_called_once_with()
258 237
259 238 def test_run_logs(self, tcp_server, caplog):
260 239
261 240 with self._tcp_patcher(tcp_server):
262 241 daemon = hooks_daemon.HttpHooksCallbackDaemon()
263 242
264 243 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
265 244 daemon._run()
266 245
267 246 assert_message_in_log(
268 247 caplog.records,
269 248 'Running thread-based loop of callback daemon in background',
270 249 levelno=logging.DEBUG, module='hooks_daemon')
271 250
272 251 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
273 252 thread = mock.Mock()
274 253
275 254 with self._tcp_patcher(tcp_server):
276 255 daemon = hooks_daemon.HttpHooksCallbackDaemon()
277 256
278 257 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
279 258 with daemon:
280 259 assert daemon._daemon == tcp_server
281 260 assert daemon._callback_thread == thread
282 261
283 262 assert daemon._daemon is None
284 263 assert daemon._callback_thread is None
285 264 tcp_server.shutdown.assert_called_with()
286 265 thread.join.assert_called_once_with()
287 266
288 267 assert_message_in_log(
289 268 caplog.records, 'Waiting for background thread to finish.',
290 269 levelno=logging.DEBUG, module='hooks_daemon')
291 270
292 271 def _tcp_patcher(self, tcp_server):
293 272 return mock.patch.object(
294 273 hooks_daemon, 'TCPServer', return_value=tcp_server)
295 274
296 275 def _thread_patcher(self, thread):
297 276 return mock.patch.object(
298 277 hooks_daemon.threading, 'Thread', return_value=thread)
299 278
300 279
301 280 class TestPrepareHooksDaemon(object):
302 @pytest.mark.parametrize('protocol', ('http',))
303 def test_returns_dummy_hooks_callback_daemon_when_using_direct_calls(
281 @pytest.mark.parametrize('protocol', ('celery',))
282 def test_returns_celery_hooks_callback_daemon_when_celery_protocol_specified(
304 283 self, protocol):
305 expected_extras = {'extra1': 'value1'}
284 with tempfile.NamedTemporaryFile(mode='w') as temp_file:
285 temp_file.write("[app:main]\ncelery.broker_url = redis://redis/0\n"
286 "celery.result_backend = redis://redis/0")
287 temp_file.flush()
288 expected_extras = {'config': temp_file.name}
306 289 callback, extras = hooks_daemon.prepare_callback_daemon(
307 expected_extras.copy(), protocol=protocol,
308 host='127.0.0.1', use_direct_calls=True)
309 assert isinstance(callback, hooks_daemon.DummyHooksCallbackDaemon)
310 expected_extras['hooks_module'] = 'rhodecode.lib.hooks_daemon'
311 expected_extras['time'] = extras['time']
312 assert 'extra1' in extras
290 expected_extras, protocol=protocol, host='')
291 assert isinstance(callback, hooks_daemon.CeleryHooksCallbackDaemon)
313 292
314 293 @pytest.mark.parametrize('protocol, expected_class', (
315 294 ('http', hooks_daemon.HttpHooksCallbackDaemon),
316 295 ))
317 296 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
318 297 self, protocol, expected_class):
319 298 expected_extras = {
320 299 'extra1': 'value1',
321 300 'txn_id': 'txnid2',
322 'hooks_protocol': protocol.lower()
301 'hooks_protocol': protocol.lower(),
302 'task_backend': '',
303 'task_queue': ''
323 304 }
324 305 callback, extras = hooks_daemon.prepare_callback_daemon(
325 306 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
326 use_direct_calls=False,
327 307 txn_id='txnid2')
328 308 assert isinstance(callback, expected_class)
329 309 extras.pop('hooks_uri')
330 310 expected_extras['time'] = extras['time']
331 311 assert extras == expected_extras
332 312
333 313 @pytest.mark.parametrize('protocol', (
334 314 'invalid',
335 315 'Http',
336 316 'HTTP',
337 317 ))
338 318 def test_raises_on_invalid_protocol(self, protocol):
339 319 expected_extras = {
340 320 'extra1': 'value1',
341 321 'hooks_protocol': protocol.lower()
342 322 }
343 323 with pytest.raises(Exception):
344 324 callback, extras = hooks_daemon.prepare_callback_daemon(
345 325 expected_extras.copy(),
346 protocol=protocol, host='127.0.0.1',
347 use_direct_calls=False)
326 protocol=protocol, host='127.0.0.1')
348 327
349 328
350 329 class MockRequest(object):
351 330
352 331 def __init__(self, request):
353 332 self.request = request
354 333 self.input_stream = io.BytesIO(safe_bytes(self.request))
355 334 self.output_stream = io.BytesIO() # make it un-closable for testing invesitagion
356 335 self.output_stream.close = lambda: None
357 336
358 337 def makefile(self, mode, *args, **kwargs):
359 338 return self.output_stream if mode == 'wb' else self.input_stream
360 339
361 340
362 341 class MockServer(object):
363 342
364 343 def __init__(self, handler_cls, request):
365 344 ip_port = ('0.0.0.0', 8888)
366 345 self.request = MockRequest(request)
367 346 self.server_address = ip_port
368 347 self.handler = handler_cls(self.request, ip_port, self)
369 348
370 349
371 350 @pytest.fixture()
372 351 def tcp_server():
373 352 server = mock.Mock()
374 353 server.server_address = ('127.0.0.1', 8881)
375 354 server.wbufsize = 1024
376 355 return server
General Comments 0
You need to be logged in to leave comments. Login now