##// END OF EJS Templates
feat(celery-hooks): added all needed changes to support new celery backend, removed DummyHooksCallbackDaemon, updated tests. Fixes: RCCE-55
ilin.s -
r5298:25044729 default
parent child Browse files
Show More
@@ -1,262 +1,263 b''
1 # Copyright (C) 2016-2023 RhodeCode GmbH
1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 import os
19 import os
20 import re
20 import re
21 import logging
21 import logging
22 import datetime
22 import datetime
23 import configparser
23 import configparser
24 from sqlalchemy import Table
24 from sqlalchemy import Table
25
25
26 from rhodecode.lib.utils2 import AttributeDict
26 from rhodecode.lib.utils2 import AttributeDict
27 from rhodecode.model.scm import ScmModel
27 from rhodecode.model.scm import ScmModel
28
28
29 from .hg import MercurialServer
29 from .hg import MercurialServer
30 from .git import GitServer
30 from .git import GitServer
31 from .svn import SubversionServer
31 from .svn import SubversionServer
32 log = logging.getLogger(__name__)
32 log = logging.getLogger(__name__)
33
33
34
34
35 class SshWrapper(object):
35 class SshWrapper(object):
36 hg_cmd_pat = re.compile(r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$')
36 hg_cmd_pat = re.compile(r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$')
37 git_cmd_pat = re.compile(r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$')
37 git_cmd_pat = re.compile(r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$')
38 svn_cmd_pat = re.compile(r'^svnserve -t')
38 svn_cmd_pat = re.compile(r'^svnserve -t')
39
39
40 def __init__(self, command, connection_info, mode,
40 def __init__(self, command, connection_info, mode,
41 user, user_id, key_id: int, shell, ini_path: str, env):
41 user, user_id, key_id: int, shell, ini_path: str, env):
42 self.command = command
42 self.command = command
43 self.connection_info = connection_info
43 self.connection_info = connection_info
44 self.mode = mode
44 self.mode = mode
45 self.username = user
45 self.username = user
46 self.user_id = user_id
46 self.user_id = user_id
47 self.key_id = key_id
47 self.key_id = key_id
48 self.shell = shell
48 self.shell = shell
49 self.ini_path = ini_path
49 self.ini_path = ini_path
50 self.env = env
50 self.env = env
51
51
52 self.config = self.parse_config(ini_path)
52 self.config = self.parse_config(ini_path)
53 self.server_impl = None
53 self.server_impl = None
54
54
55 def parse_config(self, config_path):
55 def parse_config(self, config_path):
56 parser = configparser.ConfigParser()
56 parser = configparser.ConfigParser()
57 parser.read(config_path)
57 parser.read(config_path)
58 return parser
58 return parser
59
59
60 def update_key_access_time(self, key_id):
60 def update_key_access_time(self, key_id):
61 from rhodecode.model.meta import raw_query_executor, Base
61 from rhodecode.model.meta import raw_query_executor, Base
62
62
63 table = Table('user_ssh_keys', Base.metadata, autoload=False)
63 table = Table('user_ssh_keys', Base.metadata, autoload=False)
64 atime = datetime.datetime.utcnow()
64 atime = datetime.datetime.utcnow()
65 stmt = (
65 stmt = (
66 table.update()
66 table.update()
67 .where(table.c.ssh_key_id == key_id)
67 .where(table.c.ssh_key_id == key_id)
68 .values(accessed_on=atime)
68 .values(accessed_on=atime)
69 # no MySQL Support for .returning :((
69 # no MySQL Support for .returning :((
70 #.returning(table.c.accessed_on, table.c.ssh_key_fingerprint)
70 #.returning(table.c.accessed_on, table.c.ssh_key_fingerprint)
71 )
71 )
72
72
73 res_count = None
73 res_count = None
74 with raw_query_executor() as session:
74 with raw_query_executor() as session:
75 result = session.execute(stmt)
75 result = session.execute(stmt)
76 if result.rowcount:
76 if result.rowcount:
77 res_count = result.rowcount
77 res_count = result.rowcount
78
78
79 if res_count:
79 if res_count:
80 log.debug('Update key id:`%s` access time', key_id)
80 log.debug('Update key id:`%s` access time', key_id)
81
81
82 def get_user(self, user_id):
82 def get_user(self, user_id):
83 user = AttributeDict()
83 user = AttributeDict()
84 # lazy load db imports
84 # lazy load db imports
85 from rhodecode.model.db import User
85 from rhodecode.model.db import User
86 dbuser = User.get(user_id)
86 dbuser = User.get(user_id)
87 if not dbuser:
87 if not dbuser:
88 return None
88 return None
89 user.user_id = dbuser.user_id
89 user.user_id = dbuser.user_id
90 user.username = dbuser.username
90 user.username = dbuser.username
91 user.auth_user = dbuser.AuthUser()
91 user.auth_user = dbuser.AuthUser()
92 return user
92 return user
93
93
94 def get_connection_info(self):
94 def get_connection_info(self):
95 """
95 """
96 connection_info
96 connection_info
97
97
98 Identifies the client and server ends of the connection.
98 Identifies the client and server ends of the connection.
99 The variable contains four space-separated values: client IP address,
99 The variable contains four space-separated values: client IP address,
100 client port number, server IP address, and server port number.
100 client port number, server IP address, and server port number.
101 """
101 """
102 conn = dict(
102 conn = dict(
103 client_ip=None,
103 client_ip=None,
104 client_port=None,
104 client_port=None,
105 server_ip=None,
105 server_ip=None,
106 server_port=None,
106 server_port=None,
107 )
107 )
108
108
109 info = self.connection_info.split(' ')
109 info = self.connection_info.split(' ')
110 if len(info) == 4:
110 if len(info) == 4:
111 conn['client_ip'] = info[0]
111 conn['client_ip'] = info[0]
112 conn['client_port'] = info[1]
112 conn['client_port'] = info[1]
113 conn['server_ip'] = info[2]
113 conn['server_ip'] = info[2]
114 conn['server_port'] = info[3]
114 conn['server_port'] = info[3]
115
115
116 return conn
116 return conn
117
117
118 def maybe_translate_repo_uid(self, repo_name):
118 def maybe_translate_repo_uid(self, repo_name):
119 _org_name = repo_name
119 _org_name = repo_name
120 if _org_name.startswith('_'):
120 if _org_name.startswith('_'):
121 # remove format of _ID/subrepo
121 # remove format of _ID/subrepo
122 _org_name = _org_name.split('/', 1)[0]
122 _org_name = _org_name.split('/', 1)[0]
123
123
124 if repo_name.startswith('_'):
124 if repo_name.startswith('_'):
125 from rhodecode.model.repo import RepoModel
125 from rhodecode.model.repo import RepoModel
126 org_repo_name = repo_name
126 org_repo_name = repo_name
127 log.debug('translating UID repo %s', org_repo_name)
127 log.debug('translating UID repo %s', org_repo_name)
128 by_id_match = RepoModel().get_repo_by_id(repo_name)
128 by_id_match = RepoModel().get_repo_by_id(repo_name)
129 if by_id_match:
129 if by_id_match:
130 repo_name = by_id_match.repo_name
130 repo_name = by_id_match.repo_name
131 log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
131 log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
132
132
133 return repo_name, _org_name
133 return repo_name, _org_name
134
134
135 def get_repo_details(self, mode):
135 def get_repo_details(self, mode):
136 vcs_type = mode if mode in ['svn', 'hg', 'git'] else None
136 vcs_type = mode if mode in ['svn', 'hg', 'git'] else None
137 repo_name = None
137 repo_name = None
138
138
139 hg_match = self.hg_cmd_pat.match(self.command)
139 hg_match = self.hg_cmd_pat.match(self.command)
140 if hg_match is not None:
140 if hg_match is not None:
141 vcs_type = 'hg'
141 vcs_type = 'hg'
142 repo_id = hg_match.group(1).strip('/')
142 repo_id = hg_match.group(1).strip('/')
143 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
143 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
144 return vcs_type, repo_name, mode
144 return vcs_type, repo_name, mode
145
145
146 git_match = self.git_cmd_pat.match(self.command)
146 git_match = self.git_cmd_pat.match(self.command)
147 if git_match is not None:
147 if git_match is not None:
148 mode = git_match.group(1)
148 mode = git_match.group(1)
149 vcs_type = 'git'
149 vcs_type = 'git'
150 repo_id = git_match.group(2).strip('/')
150 repo_id = git_match.group(2).strip('/')
151 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
151 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
152 return vcs_type, repo_name, mode
152 return vcs_type, repo_name, mode
153
153
154 svn_match = self.svn_cmd_pat.match(self.command)
154 svn_match = self.svn_cmd_pat.match(self.command)
155 if svn_match is not None:
155 if svn_match is not None:
156 vcs_type = 'svn'
156 vcs_type = 'svn'
157 # Repo name should be extracted from the input stream, we're unable to
157 # Repo name should be extracted from the input stream, we're unable to
158 # extract it at this point in execution
158 # extract it at this point in execution
159 return vcs_type, repo_name, mode
159 return vcs_type, repo_name, mode
160
160
161 return vcs_type, repo_name, mode
161 return vcs_type, repo_name, mode
162
162
163 def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
163 def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
164 store = ScmModel().repos_path
164 store = ScmModel().repos_path
165
165
166 check_branch_perms = False
166 check_branch_perms = False
167 detect_force_push = False
167 detect_force_push = False
168
168
169 if branch_permissions:
169 if branch_permissions:
170 check_branch_perms = True
170 check_branch_perms = True
171 detect_force_push = True
171 detect_force_push = True
172
172
173 log.debug(
173 log.debug(
174 'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
174 'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
175 vcs, mode, repo, check_branch_perms)
175 vcs, mode, repo, check_branch_perms)
176
176
177 # detect if we have to check branch permissions
177 # detect if we have to check branch permissions
178 extras = {
178 extras = {
179 'detect_force_push': detect_force_push,
179 'detect_force_push': detect_force_push,
180 'check_branch_perms': check_branch_perms,
180 'check_branch_perms': check_branch_perms,
181 'config': self.ini_path
181 }
182 }
182
183
183 if vcs == 'hg':
184 if vcs == 'hg':
184 server = MercurialServer(
185 server = MercurialServer(
185 store=store, ini_path=self.ini_path,
186 store=store, ini_path=self.ini_path,
186 repo_name=repo, user=user,
187 repo_name=repo, user=user,
187 user_permissions=permissions, config=self.config, env=self.env)
188 user_permissions=permissions, config=self.config, env=self.env)
188 self.server_impl = server
189 self.server_impl = server
189 return server.run(tunnel_extras=extras)
190 return server.run(tunnel_extras=extras)
190
191
191 elif vcs == 'git':
192 elif vcs == 'git':
192 server = GitServer(
193 server = GitServer(
193 store=store, ini_path=self.ini_path,
194 store=store, ini_path=self.ini_path,
194 repo_name=repo, repo_mode=mode, user=user,
195 repo_name=repo, repo_mode=mode, user=user,
195 user_permissions=permissions, config=self.config, env=self.env)
196 user_permissions=permissions, config=self.config, env=self.env)
196 self.server_impl = server
197 self.server_impl = server
197 return server.run(tunnel_extras=extras)
198 return server.run(tunnel_extras=extras)
198
199
199 elif vcs == 'svn':
200 elif vcs == 'svn':
200 server = SubversionServer(
201 server = SubversionServer(
201 store=store, ini_path=self.ini_path,
202 store=store, ini_path=self.ini_path,
202 repo_name=None, user=user,
203 repo_name=None, user=user,
203 user_permissions=permissions, config=self.config, env=self.env)
204 user_permissions=permissions, config=self.config, env=self.env)
204 self.server_impl = server
205 self.server_impl = server
205 return server.run(tunnel_extras=extras)
206 return server.run(tunnel_extras=extras)
206
207
207 else:
208 else:
208 raise Exception(f'Unrecognised VCS: {vcs}')
209 raise Exception(f'Unrecognised VCS: {vcs}')
209
210
210 def wrap(self):
211 def wrap(self):
211 mode = self.mode
212 mode = self.mode
212 username = self.username
213 username = self.username
213 user_id = self.user_id
214 user_id = self.user_id
214 key_id = self.key_id
215 key_id = self.key_id
215 shell = self.shell
216 shell = self.shell
216
217
217 scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
218 scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
218
219
219 log.debug(
220 log.debug(
220 'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
221 'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
221 'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
222 'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
222 mode, username, user_id, shell, self.command,
223 mode, username, user_id, shell, self.command,
223 scm_detected, scm_mode, scm_repo)
224 scm_detected, scm_mode, scm_repo)
224
225
225 log.debug('SSH Connection info %s', self.get_connection_info())
226 log.debug('SSH Connection info %s', self.get_connection_info())
226
227
227 # update last access time for this key
228 # update last access time for this key
228 if key_id:
229 if key_id:
229 self.update_key_access_time(key_id)
230 self.update_key_access_time(key_id)
230
231
231 if shell and self.command is None:
232 if shell and self.command is None:
232 log.info('Dropping to shell, no command given and shell is allowed')
233 log.info('Dropping to shell, no command given and shell is allowed')
233 os.execl('/bin/bash', '-l')
234 os.execl('/bin/bash', '-l')
234 exit_code = 1
235 exit_code = 1
235
236
236 elif scm_detected:
237 elif scm_detected:
237 user = self.get_user(user_id)
238 user = self.get_user(user_id)
238 if not user:
239 if not user:
239 log.warning('User with id %s not found', user_id)
240 log.warning('User with id %s not found', user_id)
240 exit_code = -1
241 exit_code = -1
241 return exit_code
242 return exit_code
242
243
243 auth_user = user.auth_user
244 auth_user = user.auth_user
244 permissions = auth_user.permissions['repositories']
245 permissions = auth_user.permissions['repositories']
245 repo_branch_permissions = auth_user.get_branch_permissions(scm_repo)
246 repo_branch_permissions = auth_user.get_branch_permissions(scm_repo)
246 try:
247 try:
247 exit_code, is_updated = self.serve(
248 exit_code, is_updated = self.serve(
248 scm_detected, scm_repo, scm_mode, user, permissions,
249 scm_detected, scm_repo, scm_mode, user, permissions,
249 repo_branch_permissions)
250 repo_branch_permissions)
250 except Exception:
251 except Exception:
251 log.exception('Error occurred during execution of SshWrapper')
252 log.exception('Error occurred during execution of SshWrapper')
252 exit_code = -1
253 exit_code = -1
253
254
254 elif self.command is None and shell is False:
255 elif self.command is None and shell is False:
255 log.error('No Command given.')
256 log.error('No Command given.')
256 exit_code = -1
257 exit_code = -1
257
258
258 else:
259 else:
259 log.error('Unhandled Command: "%s" Aborting.', self.command)
260 log.error('Unhandled Command: "%s" Aborting.', self.command)
260 exit_code = -1
261 exit_code = -1
261
262
262 return exit_code
263 return exit_code
@@ -1,161 +1,160 b''
1 # Copyright (C) 2016-2023 RhodeCode GmbH
1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 import os
19 import os
20 import sys
20 import sys
21 import logging
21 import logging
22
22
23 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
23 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
24 from rhodecode.lib.ext_json import sjson as json
24 from rhodecode.lib.ext_json import sjson as json
25 from rhodecode.lib.vcs.conf import settings as vcs_settings
25 from rhodecode.lib.vcs.conf import settings as vcs_settings
26 from rhodecode.model.scm import ScmModel
26 from rhodecode.model.scm import ScmModel
27
27
28 log = logging.getLogger(__name__)
28 log = logging.getLogger(__name__)
29
29
30
30
31 class VcsServer(object):
31 class VcsServer(object):
32 repo_user_agent = None # set in child classes
32 repo_user_agent = None # set in child classes
33 _path = None # set executable path for hg/git/svn binary
33 _path = None # set executable path for hg/git/svn binary
34 backend = None # set in child classes
34 backend = None # set in child classes
35 tunnel = None # subprocess handling tunnel
35 tunnel = None # subprocess handling tunnel
36 write_perms = ['repository.admin', 'repository.write']
36 write_perms = ['repository.admin', 'repository.write']
37 read_perms = ['repository.read', 'repository.admin', 'repository.write']
37 read_perms = ['repository.read', 'repository.admin', 'repository.write']
38
38
39 def __init__(self, user, user_permissions, config, env):
39 def __init__(self, user, user_permissions, config, env):
40 self.user = user
40 self.user = user
41 self.user_permissions = user_permissions
41 self.user_permissions = user_permissions
42 self.config = config
42 self.config = config
43 self.env = env
43 self.env = env
44 self.stdin = sys.stdin
44 self.stdin = sys.stdin
45
45
46 self.repo_name = None
46 self.repo_name = None
47 self.repo_mode = None
47 self.repo_mode = None
48 self.store = ''
48 self.store = ''
49 self.ini_path = ''
49 self.ini_path = ''
50
50
51 def _invalidate_cache(self, repo_name):
51 def _invalidate_cache(self, repo_name):
52 """
52 """
53 Set's cache for this repository for invalidation on next access
53 Set's cache for this repository for invalidation on next access
54
54
55 :param repo_name: full repo name, also a cache key
55 :param repo_name: full repo name, also a cache key
56 """
56 """
57 ScmModel().mark_for_invalidation(repo_name)
57 ScmModel().mark_for_invalidation(repo_name)
58
58
59 def has_write_perm(self):
59 def has_write_perm(self):
60 permission = self.user_permissions.get(self.repo_name)
60 permission = self.user_permissions.get(self.repo_name)
61 if permission in ['repository.write', 'repository.admin']:
61 if permission in ['repository.write', 'repository.admin']:
62 return True
62 return True
63
63
64 return False
64 return False
65
65
66 def _check_permissions(self, action):
66 def _check_permissions(self, action):
67 permission = self.user_permissions.get(self.repo_name)
67 permission = self.user_permissions.get(self.repo_name)
68 log.debug('permission for %s on %s are: %s',
68 log.debug('permission for %s on %s are: %s',
69 self.user, self.repo_name, permission)
69 self.user, self.repo_name, permission)
70
70
71 if not permission:
71 if not permission:
72 log.error('user `%s` permissions to repo:%s are empty. Forbidding access.',
72 log.error('user `%s` permissions to repo:%s are empty. Forbidding access.',
73 self.user, self.repo_name)
73 self.user, self.repo_name)
74 return -2
74 return -2
75
75
76 if action == 'pull':
76 if action == 'pull':
77 if permission in self.read_perms:
77 if permission in self.read_perms:
78 log.info(
78 log.info(
79 'READ Permissions for User "%s" detected to repo "%s"!',
79 'READ Permissions for User "%s" detected to repo "%s"!',
80 self.user, self.repo_name)
80 self.user, self.repo_name)
81 return 0
81 return 0
82 else:
82 else:
83 if permission in self.write_perms:
83 if permission in self.write_perms:
84 log.info(
84 log.info(
85 'WRITE, or Higher Permissions for User "%s" detected to repo "%s"!',
85 'WRITE, or Higher Permissions for User "%s" detected to repo "%s"!',
86 self.user, self.repo_name)
86 self.user, self.repo_name)
87 return 0
87 return 0
88
88
89 log.error('Cannot properly fetch or verify user `%s` permissions. '
89 log.error('Cannot properly fetch or verify user `%s` permissions. '
90 'Permissions: %s, vcs action: %s',
90 'Permissions: %s, vcs action: %s',
91 self.user, permission, action)
91 self.user, permission, action)
92 return -2
92 return -2
93
93
94 def update_environment(self, action, extras=None):
94 def update_environment(self, action, extras=None):
95
95
96 scm_data = {
96 scm_data = {
97 'ip': os.environ['SSH_CLIENT'].split()[0],
97 'ip': os.environ['SSH_CLIENT'].split()[0],
98 'username': self.user.username,
98 'username': self.user.username,
99 'user_id': self.user.user_id,
99 'user_id': self.user.user_id,
100 'action': action,
100 'action': action,
101 'repository': self.repo_name,
101 'repository': self.repo_name,
102 'scm': self.backend,
102 'scm': self.backend,
103 'config': self.ini_path,
103 'config': self.ini_path,
104 'repo_store': self.store,
104 'repo_store': self.store,
105 'make_lock': None,
105 'make_lock': None,
106 'locked_by': [None, None],
106 'locked_by': [None, None],
107 'server_url': None,
107 'server_url': None,
108 'user_agent': f'{self.repo_user_agent}/ssh-user-agent',
108 'user_agent': f'{self.repo_user_agent}/ssh-user-agent',
109 'hooks': ['push', 'pull'],
109 'hooks': ['push', 'pull'],
110 'hooks_module': 'rhodecode.lib.hooks_daemon',
110 'hooks_module': 'rhodecode.lib.hooks_daemon',
111 'is_shadow_repo': False,
111 'is_shadow_repo': False,
112 'detect_force_push': False,
112 'detect_force_push': False,
113 'check_branch_perms': False,
113 'check_branch_perms': False,
114
114
115 'SSH': True,
115 'SSH': True,
116 'SSH_PERMISSIONS': self.user_permissions.get(self.repo_name),
116 'SSH_PERMISSIONS': self.user_permissions.get(self.repo_name),
117 }
117 }
118 if extras:
118 if extras:
119 scm_data.update(extras)
119 scm_data.update(extras)
120 os.putenv("RC_SCM_DATA", json.dumps(scm_data))
120 os.putenv("RC_SCM_DATA", json.dumps(scm_data))
121
121
122 def get_root_store(self):
122 def get_root_store(self):
123 root_store = self.store
123 root_store = self.store
124 if not root_store.endswith('/'):
124 if not root_store.endswith('/'):
125 # always append trailing slash
125 # always append trailing slash
126 root_store = root_store + '/'
126 root_store = root_store + '/'
127 return root_store
127 return root_store
128
128
129 def _handle_tunnel(self, extras):
129 def _handle_tunnel(self, extras):
130 # pre-auth
130 # pre-auth
131 action = 'pull'
131 action = 'pull'
132 exit_code = self._check_permissions(action)
132 exit_code = self._check_permissions(action)
133 if exit_code:
133 if exit_code:
134 return exit_code, False
134 return exit_code, False
135
135
136 req = self.env['request']
136 req = self.env['request']
137 server_url = req.host_url + req.script_name
137 server_url = req.host_url + req.script_name
138 extras['server_url'] = server_url
138 extras['server_url'] = server_url
139
139
140 log.debug('Using %s binaries from path %s', self.backend, self._path)
140 log.debug('Using %s binaries from path %s', self.backend, self._path)
141 exit_code = self.tunnel.run(extras)
141 exit_code = self.tunnel.run(extras)
142
142
143 return exit_code, action == "push"
143 return exit_code, action == "push"
144
144
145 def run(self, tunnel_extras=None):
145 def run(self, tunnel_extras=None):
146 tunnel_extras = tunnel_extras or {}
146 tunnel_extras = tunnel_extras or {}
147 extras = {}
147 extras = {}
148 extras.update(tunnel_extras)
148 extras.update(tunnel_extras)
149
149
150 callback_daemon, extras = prepare_callback_daemon(
150 callback_daemon, extras = prepare_callback_daemon(
151 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
151 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
152 host=vcs_settings.HOOKS_HOST,
152 host=vcs_settings.HOOKS_HOST)
153 use_direct_calls=False)
154
153
155 with callback_daemon:
154 with callback_daemon:
156 try:
155 try:
157 return self._handle_tunnel(extras)
156 return self._handle_tunnel(extras)
158 finally:
157 finally:
159 log.debug('Running cleanup with cache invalidation')
158 log.debug('Running cleanup with cache invalidation')
160 if self.repo_name:
159 if self.repo_name:
161 self._invalidate_cache(self.repo_name)
160 self._invalidate_cache(self.repo_name)
@@ -1,118 +1,117 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 import os
19 import os
20 import platform
20 import platform
21
21
22 from rhodecode.model import init_model
22 from rhodecode.model import init_model
23
23
24
24
25 def configure_vcs(config):
25 def configure_vcs(config):
26 """
26 """
27 Patch VCS config with some RhodeCode specific stuff
27 Patch VCS config with some RhodeCode specific stuff
28 """
28 """
29 from rhodecode.lib.vcs import conf
29 from rhodecode.lib.vcs import conf
30 import rhodecode.lib.vcs.conf.settings
30 import rhodecode.lib.vcs.conf.settings
31
31
32 conf.settings.BACKENDS = {
32 conf.settings.BACKENDS = {
33 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
33 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
34 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
34 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
35 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
35 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
36 }
36 }
37
37
38 conf.settings.HOOKS_PROTOCOL = config['vcs.hooks.protocol']
38 conf.settings.HOOKS_PROTOCOL = config['vcs.hooks.protocol']
39 conf.settings.HOOKS_HOST = config['vcs.hooks.host']
39 conf.settings.HOOKS_HOST = config['vcs.hooks.host']
40 conf.settings.HOOKS_DIRECT_CALLS = config['vcs.hooks.direct_calls']
41 conf.settings.DEFAULT_ENCODINGS = config['default_encoding']
40 conf.settings.DEFAULT_ENCODINGS = config['default_encoding']
42 conf.settings.ALIASES[:] = config['vcs.backends']
41 conf.settings.ALIASES[:] = config['vcs.backends']
43 conf.settings.SVN_COMPATIBLE_VERSION = config['vcs.svn.compatible_version']
42 conf.settings.SVN_COMPATIBLE_VERSION = config['vcs.svn.compatible_version']
44
43
45
44
46 def initialize_database(config):
45 def initialize_database(config):
47 from rhodecode.lib.utils2 import engine_from_config, get_encryption_key
46 from rhodecode.lib.utils2 import engine_from_config, get_encryption_key
48 engine = engine_from_config(config, 'sqlalchemy.db1.')
47 engine = engine_from_config(config, 'sqlalchemy.db1.')
49 init_model(engine, encryption_key=get_encryption_key(config))
48 init_model(engine, encryption_key=get_encryption_key(config))
50
49
51
50
52 def initialize_test_environment(settings, test_env=None):
51 def initialize_test_environment(settings, test_env=None):
53 if test_env is None:
52 if test_env is None:
54 test_env = not int(os.environ.get('RC_NO_TMP_PATH', 0))
53 test_env = not int(os.environ.get('RC_NO_TMP_PATH', 0))
55
54
56 from rhodecode.lib.utils import (
55 from rhodecode.lib.utils import (
57 create_test_directory, create_test_database, create_test_repositories,
56 create_test_directory, create_test_database, create_test_repositories,
58 create_test_index)
57 create_test_index)
59 from rhodecode.tests import TESTS_TMP_PATH
58 from rhodecode.tests import TESTS_TMP_PATH
60 from rhodecode.lib.vcs.backends.hg import largefiles_store
59 from rhodecode.lib.vcs.backends.hg import largefiles_store
61 from rhodecode.lib.vcs.backends.git import lfs_store
60 from rhodecode.lib.vcs.backends.git import lfs_store
62
61
63 # test repos
62 # test repos
64 if test_env:
63 if test_env:
65 create_test_directory(TESTS_TMP_PATH)
64 create_test_directory(TESTS_TMP_PATH)
66 # large object stores
65 # large object stores
67 create_test_directory(largefiles_store(TESTS_TMP_PATH))
66 create_test_directory(largefiles_store(TESTS_TMP_PATH))
68 create_test_directory(lfs_store(TESTS_TMP_PATH))
67 create_test_directory(lfs_store(TESTS_TMP_PATH))
69
68
70 create_test_database(TESTS_TMP_PATH, settings)
69 create_test_database(TESTS_TMP_PATH, settings)
71 create_test_repositories(TESTS_TMP_PATH, settings)
70 create_test_repositories(TESTS_TMP_PATH, settings)
72 create_test_index(TESTS_TMP_PATH, settings)
71 create_test_index(TESTS_TMP_PATH, settings)
73
72
74
73
75 def get_vcs_server_protocol(config):
74 def get_vcs_server_protocol(config):
76 return config['vcs.server.protocol']
75 return config['vcs.server.protocol']
77
76
78
77
79 def set_instance_id(config):
78 def set_instance_id(config):
80 """
79 """
81 Sets a dynamic generated config['instance_id'] if missing or '*'
80 Sets a dynamic generated config['instance_id'] if missing or '*'
82 E.g instance_id = *cluster-1 or instance_id = *
81 E.g instance_id = *cluster-1 or instance_id = *
83 """
82 """
84
83
85 config['instance_id'] = config.get('instance_id') or ''
84 config['instance_id'] = config.get('instance_id') or ''
86 instance_id = config['instance_id']
85 instance_id = config['instance_id']
87 if instance_id.startswith('*') or not instance_id:
86 if instance_id.startswith('*') or not instance_id:
88 prefix = instance_id.lstrip('*')
87 prefix = instance_id.lstrip('*')
89 _platform_id = platform.uname()[1] or 'instance'
88 _platform_id = platform.uname()[1] or 'instance'
90 config['instance_id'] = '{prefix}uname:{platform}-pid:{pid}'.format(
89 config['instance_id'] = '{prefix}uname:{platform}-pid:{pid}'.format(
91 prefix=prefix,
90 prefix=prefix,
92 platform=_platform_id,
91 platform=_platform_id,
93 pid=os.getpid())
92 pid=os.getpid())
94
93
95
94
96 def get_default_user_id():
95 def get_default_user_id():
97 DEFAULT_USER = 'default'
96 DEFAULT_USER = 'default'
98 from sqlalchemy import text
97 from sqlalchemy import text
99 from rhodecode.model import meta
98 from rhodecode.model import meta
100
99
101 engine = meta.get_engine()
100 engine = meta.get_engine()
102 with meta.SA_Session(engine) as session:
101 with meta.SA_Session(engine) as session:
103 result = session.execute(text("SELECT user_id from users where username = :uname"), {'uname': DEFAULT_USER})
102 result = session.execute(text("SELECT user_id from users where username = :uname"), {'uname': DEFAULT_USER})
104 user_id = result.first()[0]
103 user_id = result.first()[0]
105
104
106 return user_id
105 return user_id
107
106
108
107
109 def get_default_base_path():
108 def get_default_base_path():
110 from sqlalchemy import text
109 from sqlalchemy import text
111 from rhodecode.model import meta
110 from rhodecode.model import meta
112
111
113 engine = meta.get_engine()
112 engine = meta.get_engine()
114 with meta.SA_Session(engine) as session:
113 with meta.SA_Session(engine) as session:
115 result = session.execute(text("SELECT ui_value from rhodecode_ui where ui_key = '/'"))
114 result = session.execute(text("SELECT ui_value from rhodecode_ui where ui_key = '/'"))
116 base_path = result.first()[0]
115 base_path = result.first()[0]
117
116
118 return base_path
117 return base_path
@@ -1,412 +1,448 b''
1 # Copyright (C) 2012-2023 RhodeCode GmbH
1 # Copyright (C) 2012-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 """
19 """
20 RhodeCode task modules, containing all task that suppose to be run
20 RhodeCode task modules, containing all task that suppose to be run
21 by celery daemon
21 by celery daemon
22 """
22 """
23
23
24 import os
24 import os
25 import time
25 import time
26
26
27 from pyramid_mailer.mailer import Mailer
27 from pyramid_mailer.mailer import Mailer
28 from pyramid_mailer.message import Message
28 from pyramid_mailer.message import Message
29 from email.utils import formatdate
29 from email.utils import formatdate
30
30
31 import rhodecode
31 import rhodecode
32 from rhodecode.lib import audit_logger
32 from rhodecode.lib import audit_logger
33 from rhodecode.lib.celerylib import get_logger, async_task, RequestContextTask, run_task
33 from rhodecode.lib.celerylib import get_logger, async_task, RequestContextTask, run_task
34 from rhodecode.lib import hooks_base
34 from rhodecode.lib import hooks_base
35 from rhodecode.lib.utils import adopt_for_celery
35 from rhodecode.lib.utils2 import safe_int, str2bool, aslist
36 from rhodecode.lib.utils2 import safe_int, str2bool, aslist
36 from rhodecode.lib.statsd_client import StatsdClient
37 from rhodecode.lib.statsd_client import StatsdClient
37 from rhodecode.model.db import (
38 from rhodecode.model.db import (
38 true, null, Session, IntegrityError, Repository, RepoGroup, User)
39 true, null, Session, IntegrityError, Repository, RepoGroup, User)
39 from rhodecode.model.permission import PermissionModel
40 from rhodecode.model.permission import PermissionModel
40
41
41
42
42 @async_task(ignore_result=True, base=RequestContextTask)
43 @async_task(ignore_result=True, base=RequestContextTask)
43 def send_email(recipients, subject, body='', html_body='', email_config=None,
44 def send_email(recipients, subject, body='', html_body='', email_config=None,
44 extra_headers=None):
45 extra_headers=None):
45 """
46 """
46 Sends an email with defined parameters from the .ini files.
47 Sends an email with defined parameters from the .ini files.
47
48
48 :param recipients: list of recipients, it this is empty the defined email
49 :param recipients: list of recipients, it this is empty the defined email
49 address from field 'email_to' is used instead
50 address from field 'email_to' is used instead
50 :param subject: subject of the mail
51 :param subject: subject of the mail
51 :param body: body of the mail
52 :param body: body of the mail
52 :param html_body: html version of body
53 :param html_body: html version of body
53 :param email_config: specify custom configuration for mailer
54 :param email_config: specify custom configuration for mailer
54 :param extra_headers: specify custom headers
55 :param extra_headers: specify custom headers
55 """
56 """
56 log = get_logger(send_email)
57 log = get_logger(send_email)
57
58
58 email_config = email_config or rhodecode.CONFIG
59 email_config = email_config or rhodecode.CONFIG
59
60
60 mail_server = email_config.get('smtp_server') or None
61 mail_server = email_config.get('smtp_server') or None
61 if mail_server is None:
62 if mail_server is None:
62 log.error("SMTP server information missing. Sending email failed. "
63 log.error("SMTP server information missing. Sending email failed. "
63 "Make sure that `smtp_server` variable is configured "
64 "Make sure that `smtp_server` variable is configured "
64 "inside the .ini file")
65 "inside the .ini file")
65 return False
66 return False
66
67
67 subject = "%s %s" % (email_config.get('email_prefix', ''), subject)
68 subject = "%s %s" % (email_config.get('email_prefix', ''), subject)
68
69
69 if recipients:
70 if recipients:
70 if isinstance(recipients, str):
71 if isinstance(recipients, str):
71 recipients = recipients.split(',')
72 recipients = recipients.split(',')
72 else:
73 else:
73 # if recipients are not defined we send to email_config + all admins
74 # if recipients are not defined we send to email_config + all admins
74 admins = []
75 admins = []
75 for u in User.query().filter(User.admin == true()).all():
76 for u in User.query().filter(User.admin == true()).all():
76 if u.email:
77 if u.email:
77 admins.append(u.email)
78 admins.append(u.email)
78 recipients = []
79 recipients = []
79 config_email = email_config.get('email_to')
80 config_email = email_config.get('email_to')
80 if config_email:
81 if config_email:
81 recipients += [config_email]
82 recipients += [config_email]
82 recipients += admins
83 recipients += admins
83
84
84 # translate our LEGACY config into the one that pyramid_mailer supports
85 # translate our LEGACY config into the one that pyramid_mailer supports
85 email_conf = dict(
86 email_conf = dict(
86 host=mail_server,
87 host=mail_server,
87 port=email_config.get('smtp_port', 25),
88 port=email_config.get('smtp_port', 25),
88 username=email_config.get('smtp_username'),
89 username=email_config.get('smtp_username'),
89 password=email_config.get('smtp_password'),
90 password=email_config.get('smtp_password'),
90
91
91 tls=str2bool(email_config.get('smtp_use_tls')),
92 tls=str2bool(email_config.get('smtp_use_tls')),
92 ssl=str2bool(email_config.get('smtp_use_ssl')),
93 ssl=str2bool(email_config.get('smtp_use_ssl')),
93
94
94 # SSL key file
95 # SSL key file
95 # keyfile='',
96 # keyfile='',
96
97
97 # SSL certificate file
98 # SSL certificate file
98 # certfile='',
99 # certfile='',
99
100
100 # Location of maildir
101 # Location of maildir
101 # queue_path='',
102 # queue_path='',
102
103
103 default_sender=email_config.get('app_email_from', 'RhodeCode-noreply@rhodecode.com'),
104 default_sender=email_config.get('app_email_from', 'RhodeCode-noreply@rhodecode.com'),
104
105
105 debug=str2bool(email_config.get('smtp_debug')),
106 debug=str2bool(email_config.get('smtp_debug')),
106 # /usr/sbin/sendmail Sendmail executable
107 # /usr/sbin/sendmail Sendmail executable
107 # sendmail_app='',
108 # sendmail_app='',
108
109
109 # {sendmail_app} -t -i -f {sender} Template for sendmail execution
110 # {sendmail_app} -t -i -f {sender} Template for sendmail execution
110 # sendmail_template='',
111 # sendmail_template='',
111 )
112 )
112
113
113 if extra_headers is None:
114 if extra_headers is None:
114 extra_headers = {}
115 extra_headers = {}
115
116
116 extra_headers.setdefault('Date', formatdate(time.time()))
117 extra_headers.setdefault('Date', formatdate(time.time()))
117
118
118 if 'thread_ids' in extra_headers:
119 if 'thread_ids' in extra_headers:
119 thread_ids = extra_headers.pop('thread_ids')
120 thread_ids = extra_headers.pop('thread_ids')
120 extra_headers['References'] = ' '.join('<{}>'.format(t) for t in thread_ids)
121 extra_headers['References'] = ' '.join('<{}>'.format(t) for t in thread_ids)
121
122
122 try:
123 try:
123 mailer = Mailer(**email_conf)
124 mailer = Mailer(**email_conf)
124
125
125 message = Message(subject=subject,
126 message = Message(subject=subject,
126 sender=email_conf['default_sender'],
127 sender=email_conf['default_sender'],
127 recipients=recipients,
128 recipients=recipients,
128 body=body, html=html_body,
129 body=body, html=html_body,
129 extra_headers=extra_headers)
130 extra_headers=extra_headers)
130 mailer.send_immediately(message)
131 mailer.send_immediately(message)
131 statsd = StatsdClient.statsd
132 statsd = StatsdClient.statsd
132 if statsd:
133 if statsd:
133 statsd.incr('rhodecode_email_sent_total')
134 statsd.incr('rhodecode_email_sent_total')
134
135
135 except Exception:
136 except Exception:
136 log.exception('Mail sending failed')
137 log.exception('Mail sending failed')
137 return False
138 return False
138 return True
139 return True
139
140
140
141
141 @async_task(ignore_result=True, base=RequestContextTask)
142 @async_task(ignore_result=True, base=RequestContextTask)
142 def create_repo(form_data, cur_user):
143 def create_repo(form_data, cur_user):
143 from rhodecode.model.repo import RepoModel
144 from rhodecode.model.repo import RepoModel
144 from rhodecode.model.user import UserModel
145 from rhodecode.model.user import UserModel
145 from rhodecode.model.scm import ScmModel
146 from rhodecode.model.scm import ScmModel
146 from rhodecode.model.settings import SettingsModel
147 from rhodecode.model.settings import SettingsModel
147
148
148 log = get_logger(create_repo)
149 log = get_logger(create_repo)
149
150
150 cur_user = UserModel()._get_user(cur_user)
151 cur_user = UserModel()._get_user(cur_user)
151 owner = cur_user
152 owner = cur_user
152
153
153 repo_name = form_data['repo_name']
154 repo_name = form_data['repo_name']
154 repo_name_full = form_data['repo_name_full']
155 repo_name_full = form_data['repo_name_full']
155 repo_type = form_data['repo_type']
156 repo_type = form_data['repo_type']
156 description = form_data['repo_description']
157 description = form_data['repo_description']
157 private = form_data['repo_private']
158 private = form_data['repo_private']
158 clone_uri = form_data.get('clone_uri')
159 clone_uri = form_data.get('clone_uri')
159 repo_group = safe_int(form_data['repo_group'])
160 repo_group = safe_int(form_data['repo_group'])
160 copy_fork_permissions = form_data.get('copy_permissions')
161 copy_fork_permissions = form_data.get('copy_permissions')
161 copy_group_permissions = form_data.get('repo_copy_permissions')
162 copy_group_permissions = form_data.get('repo_copy_permissions')
162 fork_of = form_data.get('fork_parent_id')
163 fork_of = form_data.get('fork_parent_id')
163 state = form_data.get('repo_state', Repository.STATE_PENDING)
164 state = form_data.get('repo_state', Repository.STATE_PENDING)
164
165
165 # repo creation defaults, private and repo_type are filled in form
166 # repo creation defaults, private and repo_type are filled in form
166 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
167 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
167 enable_statistics = form_data.get(
168 enable_statistics = form_data.get(
168 'enable_statistics', defs.get('repo_enable_statistics'))
169 'enable_statistics', defs.get('repo_enable_statistics'))
169 enable_locking = form_data.get(
170 enable_locking = form_data.get(
170 'enable_locking', defs.get('repo_enable_locking'))
171 'enable_locking', defs.get('repo_enable_locking'))
171 enable_downloads = form_data.get(
172 enable_downloads = form_data.get(
172 'enable_downloads', defs.get('repo_enable_downloads'))
173 'enable_downloads', defs.get('repo_enable_downloads'))
173
174
174 # set landing rev based on default branches for SCM
175 # set landing rev based on default branches for SCM
175 landing_ref, _label = ScmModel.backend_landing_ref(repo_type)
176 landing_ref, _label = ScmModel.backend_landing_ref(repo_type)
176
177
177 try:
178 try:
178 RepoModel()._create_repo(
179 RepoModel()._create_repo(
179 repo_name=repo_name_full,
180 repo_name=repo_name_full,
180 repo_type=repo_type,
181 repo_type=repo_type,
181 description=description,
182 description=description,
182 owner=owner,
183 owner=owner,
183 private=private,
184 private=private,
184 clone_uri=clone_uri,
185 clone_uri=clone_uri,
185 repo_group=repo_group,
186 repo_group=repo_group,
186 landing_rev=landing_ref,
187 landing_rev=landing_ref,
187 fork_of=fork_of,
188 fork_of=fork_of,
188 copy_fork_permissions=copy_fork_permissions,
189 copy_fork_permissions=copy_fork_permissions,
189 copy_group_permissions=copy_group_permissions,
190 copy_group_permissions=copy_group_permissions,
190 enable_statistics=enable_statistics,
191 enable_statistics=enable_statistics,
191 enable_locking=enable_locking,
192 enable_locking=enable_locking,
192 enable_downloads=enable_downloads,
193 enable_downloads=enable_downloads,
193 state=state
194 state=state
194 )
195 )
195 Session().commit()
196 Session().commit()
196
197
197 # now create this repo on Filesystem
198 # now create this repo on Filesystem
198 RepoModel()._create_filesystem_repo(
199 RepoModel()._create_filesystem_repo(
199 repo_name=repo_name,
200 repo_name=repo_name,
200 repo_type=repo_type,
201 repo_type=repo_type,
201 repo_group=RepoModel()._get_repo_group(repo_group),
202 repo_group=RepoModel()._get_repo_group(repo_group),
202 clone_uri=clone_uri,
203 clone_uri=clone_uri,
203 )
204 )
204 repo = Repository.get_by_repo_name(repo_name_full)
205 repo = Repository.get_by_repo_name(repo_name_full)
205 hooks_base.create_repository(created_by=owner.username, **repo.get_dict())
206 hooks_base.create_repository(created_by=owner.username, **repo.get_dict())
206
207
207 # update repo commit caches initially
208 # update repo commit caches initially
208 repo.update_commit_cache()
209 repo.update_commit_cache()
209
210
210 # set new created state
211 # set new created state
211 repo.set_state(Repository.STATE_CREATED)
212 repo.set_state(Repository.STATE_CREATED)
212 repo_id = repo.repo_id
213 repo_id = repo.repo_id
213 repo_data = repo.get_api_data()
214 repo_data = repo.get_api_data()
214
215
215 audit_logger.store(
216 audit_logger.store(
216 'repo.create', action_data={'data': repo_data},
217 'repo.create', action_data={'data': repo_data},
217 user=cur_user,
218 user=cur_user,
218 repo=audit_logger.RepoWrap(repo_name=repo_name, repo_id=repo_id))
219 repo=audit_logger.RepoWrap(repo_name=repo_name, repo_id=repo_id))
219
220
220 Session().commit()
221 Session().commit()
221
222
222 PermissionModel().trigger_permission_flush()
223 PermissionModel().trigger_permission_flush()
223
224
224 except Exception as e:
225 except Exception as e:
225 log.warning('Exception occurred when creating repository, '
226 log.warning('Exception occurred when creating repository, '
226 'doing cleanup...', exc_info=True)
227 'doing cleanup...', exc_info=True)
227 if isinstance(e, IntegrityError):
228 if isinstance(e, IntegrityError):
228 Session().rollback()
229 Session().rollback()
229
230
230 # rollback things manually !
231 # rollback things manually !
231 repo = Repository.get_by_repo_name(repo_name_full)
232 repo = Repository.get_by_repo_name(repo_name_full)
232 if repo:
233 if repo:
233 Repository.delete(repo.repo_id)
234 Repository.delete(repo.repo_id)
234 Session().commit()
235 Session().commit()
235 RepoModel()._delete_filesystem_repo(repo)
236 RepoModel()._delete_filesystem_repo(repo)
236 log.info('Cleanup of repo %s finished', repo_name_full)
237 log.info('Cleanup of repo %s finished', repo_name_full)
237 raise
238 raise
238
239
239 return True
240 return True
240
241
241
242
242 @async_task(ignore_result=True, base=RequestContextTask)
243 @async_task(ignore_result=True, base=RequestContextTask)
243 def create_repo_fork(form_data, cur_user):
244 def create_repo_fork(form_data, cur_user):
244 """
245 """
245 Creates a fork of repository using internal VCS methods
246 Creates a fork of repository using internal VCS methods
246 """
247 """
247 from rhodecode.model.repo import RepoModel
248 from rhodecode.model.repo import RepoModel
248 from rhodecode.model.user import UserModel
249 from rhodecode.model.user import UserModel
249
250
250 log = get_logger(create_repo_fork)
251 log = get_logger(create_repo_fork)
251
252
252 cur_user = UserModel()._get_user(cur_user)
253 cur_user = UserModel()._get_user(cur_user)
253 owner = cur_user
254 owner = cur_user
254
255
255 repo_name = form_data['repo_name'] # fork in this case
256 repo_name = form_data['repo_name'] # fork in this case
256 repo_name_full = form_data['repo_name_full']
257 repo_name_full = form_data['repo_name_full']
257 repo_type = form_data['repo_type']
258 repo_type = form_data['repo_type']
258 description = form_data['description']
259 description = form_data['description']
259 private = form_data['private']
260 private = form_data['private']
260 clone_uri = form_data.get('clone_uri')
261 clone_uri = form_data.get('clone_uri')
261 repo_group = safe_int(form_data['repo_group'])
262 repo_group = safe_int(form_data['repo_group'])
262 landing_ref = form_data['landing_rev']
263 landing_ref = form_data['landing_rev']
263 copy_fork_permissions = form_data.get('copy_permissions')
264 copy_fork_permissions = form_data.get('copy_permissions')
264 fork_id = safe_int(form_data.get('fork_parent_id'))
265 fork_id = safe_int(form_data.get('fork_parent_id'))
265
266
266 try:
267 try:
267 fork_of = RepoModel()._get_repo(fork_id)
268 fork_of = RepoModel()._get_repo(fork_id)
268 RepoModel()._create_repo(
269 RepoModel()._create_repo(
269 repo_name=repo_name_full,
270 repo_name=repo_name_full,
270 repo_type=repo_type,
271 repo_type=repo_type,
271 description=description,
272 description=description,
272 owner=owner,
273 owner=owner,
273 private=private,
274 private=private,
274 clone_uri=clone_uri,
275 clone_uri=clone_uri,
275 repo_group=repo_group,
276 repo_group=repo_group,
276 landing_rev=landing_ref,
277 landing_rev=landing_ref,
277 fork_of=fork_of,
278 fork_of=fork_of,
278 copy_fork_permissions=copy_fork_permissions
279 copy_fork_permissions=copy_fork_permissions
279 )
280 )
280
281
281 Session().commit()
282 Session().commit()
282
283
283 base_path = Repository.base_path()
284 base_path = Repository.base_path()
284 source_repo_path = os.path.join(base_path, fork_of.repo_name)
285 source_repo_path = os.path.join(base_path, fork_of.repo_name)
285
286
286 # now create this repo on Filesystem
287 # now create this repo on Filesystem
287 RepoModel()._create_filesystem_repo(
288 RepoModel()._create_filesystem_repo(
288 repo_name=repo_name,
289 repo_name=repo_name,
289 repo_type=repo_type,
290 repo_type=repo_type,
290 repo_group=RepoModel()._get_repo_group(repo_group),
291 repo_group=RepoModel()._get_repo_group(repo_group),
291 clone_uri=source_repo_path,
292 clone_uri=source_repo_path,
292 )
293 )
293 repo = Repository.get_by_repo_name(repo_name_full)
294 repo = Repository.get_by_repo_name(repo_name_full)
294 hooks_base.create_repository(created_by=owner.username, **repo.get_dict())
295 hooks_base.create_repository(created_by=owner.username, **repo.get_dict())
295
296
296 # update repo commit caches initially
297 # update repo commit caches initially
297 config = repo._config
298 config = repo._config
298 config.set('extensions', 'largefiles', '')
299 config.set('extensions', 'largefiles', '')
299 repo.update_commit_cache(config=config)
300 repo.update_commit_cache(config=config)
300
301
301 # set new created state
302 # set new created state
302 repo.set_state(Repository.STATE_CREATED)
303 repo.set_state(Repository.STATE_CREATED)
303
304
304 repo_id = repo.repo_id
305 repo_id = repo.repo_id
305 repo_data = repo.get_api_data()
306 repo_data = repo.get_api_data()
306 audit_logger.store(
307 audit_logger.store(
307 'repo.fork', action_data={'data': repo_data},
308 'repo.fork', action_data={'data': repo_data},
308 user=cur_user,
309 user=cur_user,
309 repo=audit_logger.RepoWrap(repo_name=repo_name, repo_id=repo_id))
310 repo=audit_logger.RepoWrap(repo_name=repo_name, repo_id=repo_id))
310
311
311 Session().commit()
312 Session().commit()
312 except Exception as e:
313 except Exception as e:
313 log.warning('Exception occurred when forking repository, '
314 log.warning('Exception occurred when forking repository, '
314 'doing cleanup...', exc_info=True)
315 'doing cleanup...', exc_info=True)
315 if isinstance(e, IntegrityError):
316 if isinstance(e, IntegrityError):
316 Session().rollback()
317 Session().rollback()
317
318
318 # rollback things manually !
319 # rollback things manually !
319 repo = Repository.get_by_repo_name(repo_name_full)
320 repo = Repository.get_by_repo_name(repo_name_full)
320 if repo:
321 if repo:
321 Repository.delete(repo.repo_id)
322 Repository.delete(repo.repo_id)
322 Session().commit()
323 Session().commit()
323 RepoModel()._delete_filesystem_repo(repo)
324 RepoModel()._delete_filesystem_repo(repo)
324 log.info('Cleanup of repo %s finished', repo_name_full)
325 log.info('Cleanup of repo %s finished', repo_name_full)
325 raise
326 raise
326
327
327 return True
328 return True
328
329
329
330
330 @async_task(ignore_result=True, base=RequestContextTask)
331 @async_task(ignore_result=True, base=RequestContextTask)
331 def repo_maintenance(repoid):
332 def repo_maintenance(repoid):
332 from rhodecode.lib import repo_maintenance as repo_maintenance_lib
333 from rhodecode.lib import repo_maintenance as repo_maintenance_lib
333 log = get_logger(repo_maintenance)
334 log = get_logger(repo_maintenance)
334 repo = Repository.get_by_id_or_repo_name(repoid)
335 repo = Repository.get_by_id_or_repo_name(repoid)
335 if repo:
336 if repo:
336 maintenance = repo_maintenance_lib.RepoMaintenance()
337 maintenance = repo_maintenance_lib.RepoMaintenance()
337 tasks = maintenance.get_tasks_for_repo(repo)
338 tasks = maintenance.get_tasks_for_repo(repo)
338 log.debug('Executing %s tasks on repo `%s`', tasks, repoid)
339 log.debug('Executing %s tasks on repo `%s`', tasks, repoid)
339 executed_types = maintenance.execute(repo)
340 executed_types = maintenance.execute(repo)
340 log.debug('Got execution results %s', executed_types)
341 log.debug('Got execution results %s', executed_types)
341 else:
342 else:
342 log.debug('Repo `%s` not found or without a clone_url', repoid)
343 log.debug('Repo `%s` not found or without a clone_url', repoid)
343
344
344
345
345 @async_task(ignore_result=True, base=RequestContextTask)
346 @async_task(ignore_result=True, base=RequestContextTask)
346 def check_for_update(send_email_notification=True, email_recipients=None):
347 def check_for_update(send_email_notification=True, email_recipients=None):
347 from rhodecode.model.update import UpdateModel
348 from rhodecode.model.update import UpdateModel
348 from rhodecode.model.notification import EmailNotificationModel
349 from rhodecode.model.notification import EmailNotificationModel
349
350
350 log = get_logger(check_for_update)
351 log = get_logger(check_for_update)
351 update_url = UpdateModel().get_update_url()
352 update_url = UpdateModel().get_update_url()
352 cur_ver = rhodecode.__version__
353 cur_ver = rhodecode.__version__
353
354
354 try:
355 try:
355 data = UpdateModel().get_update_data(update_url)
356 data = UpdateModel().get_update_data(update_url)
356
357
357 current_ver = UpdateModel().get_stored_version(fallback=cur_ver)
358 current_ver = UpdateModel().get_stored_version(fallback=cur_ver)
358 latest_ver = data['versions'][0]['version']
359 latest_ver = data['versions'][0]['version']
359 UpdateModel().store_version(latest_ver)
360 UpdateModel().store_version(latest_ver)
360
361
361 if send_email_notification:
362 if send_email_notification:
362 log.debug('Send email notification is enabled. '
363 log.debug('Send email notification is enabled. '
363 'Current RhodeCode version: %s, latest known: %s', current_ver, latest_ver)
364 'Current RhodeCode version: %s, latest known: %s', current_ver, latest_ver)
364 if UpdateModel().is_outdated(current_ver, latest_ver):
365 if UpdateModel().is_outdated(current_ver, latest_ver):
365
366
366 email_kwargs = {
367 email_kwargs = {
367 'current_ver': current_ver,
368 'current_ver': current_ver,
368 'latest_ver': latest_ver,
369 'latest_ver': latest_ver,
369 }
370 }
370
371
371 (subject, email_body, email_body_plaintext) = EmailNotificationModel().render_email(
372 (subject, email_body, email_body_plaintext) = EmailNotificationModel().render_email(
372 EmailNotificationModel.TYPE_UPDATE_AVAILABLE, **email_kwargs)
373 EmailNotificationModel.TYPE_UPDATE_AVAILABLE, **email_kwargs)
373
374
374 email_recipients = aslist(email_recipients, sep=',') or \
375 email_recipients = aslist(email_recipients, sep=',') or \
375 [user.email for user in User.get_all_super_admins()]
376 [user.email for user in User.get_all_super_admins()]
376 run_task(send_email, email_recipients, subject,
377 run_task(send_email, email_recipients, subject,
377 email_body_plaintext, email_body)
378 email_body_plaintext, email_body)
378
379
379 except Exception:
380 except Exception:
380 log.exception('Failed to check for update')
381 log.exception('Failed to check for update')
381 raise
382 raise
382
383
383
384
384 def sync_last_update_for_objects(*args, **kwargs):
385 def sync_last_update_for_objects(*args, **kwargs):
385 skip_repos = kwargs.get('skip_repos')
386 skip_repos = kwargs.get('skip_repos')
386 if not skip_repos:
387 if not skip_repos:
387 repos = Repository.query() \
388 repos = Repository.query() \
388 .order_by(Repository.group_id.asc())
389 .order_by(Repository.group_id.asc())
389
390
390 for repo in repos:
391 for repo in repos:
391 repo.update_commit_cache()
392 repo.update_commit_cache()
392
393
393 skip_groups = kwargs.get('skip_groups')
394 skip_groups = kwargs.get('skip_groups')
394 if not skip_groups:
395 if not skip_groups:
395 repo_groups = RepoGroup.query() \
396 repo_groups = RepoGroup.query() \
396 .filter(RepoGroup.group_parent_id == null())
397 .filter(RepoGroup.group_parent_id == null())
397
398
398 for root_gr in repo_groups:
399 for root_gr in repo_groups:
399 for repo_gr in reversed(root_gr.recursive_groups()):
400 for repo_gr in reversed(root_gr.recursive_groups()):
400 repo_gr.update_commit_cache()
401 repo_gr.update_commit_cache()
401
402
402
403
403 @async_task(ignore_result=True, base=RequestContextTask)
404 @async_task(ignore_result=True, base=RequestContextTask)
404 def sync_last_update(*args, **kwargs):
405 def sync_last_update(*args, **kwargs):
405 sync_last_update_for_objects(*args, **kwargs)
406 sync_last_update_for_objects(*args, **kwargs)
406
407
407
408
408 @async_task(ignore_result=False)
409 @async_task(ignore_result=False)
409 def beat_check(*args, **kwargs):
410 def beat_check(*args, **kwargs):
410 log = get_logger(beat_check)
411 log = get_logger(beat_check)
411 log.info('%r: Got args: %r and kwargs %r', beat_check, args, kwargs)
412 log.info('%r: Got args: %r and kwargs %r', beat_check, args, kwargs)
412 return time.time()
413 return time.time()
414
415
416 @async_task
417 @adopt_for_celery
418 def repo_size(extras):
419 from rhodecode.lib.hooks_base import repo_size
420 return repo_size(extras)
421
422
423 @async_task
424 @adopt_for_celery
425 def pre_pull(extras):
426 from rhodecode.lib.hooks_base import pre_pull
427 return pre_pull(extras)
428
429
430 @async_task
431 @adopt_for_celery
432 def post_pull(extras):
433 from rhodecode.lib.hooks_base import post_pull
434 return post_pull(extras)
435
436
437 @async_task
438 @adopt_for_celery
439 def pre_push(extras):
440 from rhodecode.lib.hooks_base import pre_push
441 return pre_push(extras)
442
443
444 @async_task
445 @adopt_for_celery
446 def post_push(extras):
447 from rhodecode.lib.hooks_base import post_push
448 return post_push(extras)
@@ -1,531 +1,534 b''
1 # Copyright (C) 2013-2023 RhodeCode GmbH
1 # Copyright (C) 2013-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19
19
20 """
20 """
21 Set of hooks run by RhodeCode Enterprise
21 Set of hooks run by RhodeCode Enterprise
22 """
22 """
23
23
24 import os
24 import os
25 import logging
25 import logging
26
26
27 import rhodecode
27 import rhodecode
28 from rhodecode import events
28 from rhodecode import events
29 from rhodecode.lib import helpers as h
29 from rhodecode.lib import helpers as h
30 from rhodecode.lib import audit_logger
30 from rhodecode.lib import audit_logger
31 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
31 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
32 from rhodecode.lib.exceptions import (
32 from rhodecode.lib.exceptions import (
33 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
33 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
34 from rhodecode.model.db import Repository, User
34 from rhodecode.model.db import Repository, User
35 from rhodecode.lib.statsd_client import StatsdClient
35 from rhodecode.lib.statsd_client import StatsdClient
36
36
37 log = logging.getLogger(__name__)
37 log = logging.getLogger(__name__)
38
38
39
39
40 class HookResponse(object):
40 class HookResponse(object):
41 def __init__(self, status, output):
41 def __init__(self, status, output):
42 self.status = status
42 self.status = status
43 self.output = output
43 self.output = output
44
44
45 def __add__(self, other):
45 def __add__(self, other):
46 other_status = getattr(other, 'status', 0)
46 other_status = getattr(other, 'status', 0)
47 new_status = max(self.status, other_status)
47 new_status = max(self.status, other_status)
48 other_output = getattr(other, 'output', '')
48 other_output = getattr(other, 'output', '')
49 new_output = self.output + other_output
49 new_output = self.output + other_output
50
50
51 return HookResponse(new_status, new_output)
51 return HookResponse(new_status, new_output)
52
52
53 def __bool__(self):
53 def __bool__(self):
54 return self.status == 0
54 return self.status == 0
55
55
56 def to_json(self):
57 return {'status': self.status, 'output': self.output}
58
56
59
57 def is_shadow_repo(extras):
60 def is_shadow_repo(extras):
58 """
61 """
59 Returns ``True`` if this is an action executed against a shadow repository.
62 Returns ``True`` if this is an action executed against a shadow repository.
60 """
63 """
61 return extras['is_shadow_repo']
64 return extras['is_shadow_repo']
62
65
63
66
64 def _get_scm_size(alias, root_path):
67 def _get_scm_size(alias, root_path):
65
68
66 if not alias.startswith('.'):
69 if not alias.startswith('.'):
67 alias += '.'
70 alias += '.'
68
71
69 size_scm, size_root = 0, 0
72 size_scm, size_root = 0, 0
70 for path, unused_dirs, files in os.walk(safe_str(root_path)):
73 for path, unused_dirs, files in os.walk(safe_str(root_path)):
71 if path.find(alias) != -1:
74 if path.find(alias) != -1:
72 for f in files:
75 for f in files:
73 try:
76 try:
74 size_scm += os.path.getsize(os.path.join(path, f))
77 size_scm += os.path.getsize(os.path.join(path, f))
75 except OSError:
78 except OSError:
76 pass
79 pass
77 else:
80 else:
78 for f in files:
81 for f in files:
79 try:
82 try:
80 size_root += os.path.getsize(os.path.join(path, f))
83 size_root += os.path.getsize(os.path.join(path, f))
81 except OSError:
84 except OSError:
82 pass
85 pass
83
86
84 size_scm_f = h.format_byte_size_binary(size_scm)
87 size_scm_f = h.format_byte_size_binary(size_scm)
85 size_root_f = h.format_byte_size_binary(size_root)
88 size_root_f = h.format_byte_size_binary(size_root)
86 size_total_f = h.format_byte_size_binary(size_root + size_scm)
89 size_total_f = h.format_byte_size_binary(size_root + size_scm)
87
90
88 return size_scm_f, size_root_f, size_total_f
91 return size_scm_f, size_root_f, size_total_f
89
92
90
93
91 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
94 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
92 def repo_size(extras):
95 def repo_size(extras):
93 """Present size of repository after push."""
96 """Present size of repository after push."""
94 repo = Repository.get_by_repo_name(extras.repository)
97 repo = Repository.get_by_repo_name(extras.repository)
95 vcs_part = f'.{repo.repo_type}'
98 vcs_part = f'.{repo.repo_type}'
96 size_vcs, size_root, size_total = _get_scm_size(vcs_part, repo.repo_full_path)
99 size_vcs, size_root, size_total = _get_scm_size(vcs_part, repo.repo_full_path)
97 msg = (f'RhodeCode: `{repo.repo_name}` size summary {vcs_part}:{size_vcs} repo:{size_root} total:{size_total}\n')
100 msg = (f'RhodeCode: `{repo.repo_name}` size summary {vcs_part}:{size_vcs} repo:{size_root} total:{size_total}\n')
98 return HookResponse(0, msg)
101 return HookResponse(0, msg)
99
102
100
103
101 def pre_push(extras):
104 def pre_push(extras):
102 """
105 """
103 Hook executed before pushing code.
106 Hook executed before pushing code.
104
107
105 It bans pushing when the repository is locked.
108 It bans pushing when the repository is locked.
106 """
109 """
107
110
108 user = User.get_by_username(extras.username)
111 user = User.get_by_username(extras.username)
109 output = ''
112 output = ''
110 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
113 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
111 locked_by = User.get(extras.locked_by[0]).username
114 locked_by = User.get(extras.locked_by[0]).username
112 reason = extras.locked_by[2]
115 reason = extras.locked_by[2]
113 # this exception is interpreted in git/hg middlewares and based
116 # this exception is interpreted in git/hg middlewares and based
114 # on that proper return code is server to client
117 # on that proper return code is server to client
115 _http_ret = HTTPLockedRC(
118 _http_ret = HTTPLockedRC(
116 _locked_by_explanation(extras.repository, locked_by, reason))
119 _locked_by_explanation(extras.repository, locked_by, reason))
117 if str(_http_ret.code).startswith('2'):
120 if str(_http_ret.code).startswith('2'):
118 # 2xx Codes don't raise exceptions
121 # 2xx Codes don't raise exceptions
119 output = _http_ret.title
122 output = _http_ret.title
120 else:
123 else:
121 raise _http_ret
124 raise _http_ret
122
125
123 hook_response = ''
126 hook_response = ''
124 if not is_shadow_repo(extras):
127 if not is_shadow_repo(extras):
125
128
126 if extras.commit_ids and extras.check_branch_perms:
129 if extras.commit_ids and extras.check_branch_perms:
127 auth_user = user.AuthUser()
130 auth_user = user.AuthUser()
128 repo = Repository.get_by_repo_name(extras.repository)
131 repo = Repository.get_by_repo_name(extras.repository)
129 affected_branches = []
132 affected_branches = []
130 if repo.repo_type == 'hg':
133 if repo.repo_type == 'hg':
131 for entry in extras.commit_ids:
134 for entry in extras.commit_ids:
132 if entry['type'] == 'branch':
135 if entry['type'] == 'branch':
133 is_forced = bool(entry['multiple_heads'])
136 is_forced = bool(entry['multiple_heads'])
134 affected_branches.append([entry['name'], is_forced])
137 affected_branches.append([entry['name'], is_forced])
135 elif repo.repo_type == 'git':
138 elif repo.repo_type == 'git':
136 for entry in extras.commit_ids:
139 for entry in extras.commit_ids:
137 if entry['type'] == 'heads':
140 if entry['type'] == 'heads':
138 is_forced = bool(entry['pruned_sha'])
141 is_forced = bool(entry['pruned_sha'])
139 affected_branches.append([entry['name'], is_forced])
142 affected_branches.append([entry['name'], is_forced])
140
143
141 for branch_name, is_forced in affected_branches:
144 for branch_name, is_forced in affected_branches:
142
145
143 rule, branch_perm = auth_user.get_rule_and_branch_permission(
146 rule, branch_perm = auth_user.get_rule_and_branch_permission(
144 extras.repository, branch_name)
147 extras.repository, branch_name)
145 if not branch_perm:
148 if not branch_perm:
146 # no branch permission found for this branch, just keep checking
149 # no branch permission found for this branch, just keep checking
147 continue
150 continue
148
151
149 if branch_perm == 'branch.push_force':
152 if branch_perm == 'branch.push_force':
150 continue
153 continue
151 elif branch_perm == 'branch.push' and is_forced is False:
154 elif branch_perm == 'branch.push' and is_forced is False:
152 continue
155 continue
153 elif branch_perm == 'branch.push' and is_forced is True:
156 elif branch_perm == 'branch.push' and is_forced is True:
154 halt_message = f'Branch `{branch_name}` changes rejected by rule {rule}. ' \
157 halt_message = f'Branch `{branch_name}` changes rejected by rule {rule}. ' \
155 f'FORCE PUSH FORBIDDEN.'
158 f'FORCE PUSH FORBIDDEN.'
156 else:
159 else:
157 halt_message = f'Branch `{branch_name}` changes rejected by rule {rule}.'
160 halt_message = f'Branch `{branch_name}` changes rejected by rule {rule}.'
158
161
159 if halt_message:
162 if halt_message:
160 _http_ret = HTTPBranchProtected(halt_message)
163 _http_ret = HTTPBranchProtected(halt_message)
161 raise _http_ret
164 raise _http_ret
162
165
163 # Propagate to external components. This is done after checking the
166 # Propagate to external components. This is done after checking the
164 # lock, for consistent behavior.
167 # lock, for consistent behavior.
165 hook_response = pre_push_extension(
168 hook_response = pre_push_extension(
166 repo_store_path=Repository.base_path(), **extras)
169 repo_store_path=Repository.base_path(), **extras)
167 events.trigger(events.RepoPrePushEvent(
170 events.trigger(events.RepoPrePushEvent(
168 repo_name=extras.repository, extras=extras))
171 repo_name=extras.repository, extras=extras))
169
172
170 return HookResponse(0, output) + hook_response
173 return HookResponse(0, output) + hook_response
171
174
172
175
173 def pre_pull(extras):
176 def pre_pull(extras):
174 """
177 """
175 Hook executed before pulling the code.
178 Hook executed before pulling the code.
176
179
177 It bans pulling when the repository is locked.
180 It bans pulling when the repository is locked.
178 """
181 """
179
182
180 output = ''
183 output = ''
181 if extras.locked_by[0]:
184 if extras.locked_by[0]:
182 locked_by = User.get(extras.locked_by[0]).username
185 locked_by = User.get(extras.locked_by[0]).username
183 reason = extras.locked_by[2]
186 reason = extras.locked_by[2]
184 # this exception is interpreted in git/hg middlewares and based
187 # this exception is interpreted in git/hg middlewares and based
185 # on that proper return code is server to client
188 # on that proper return code is server to client
186 _http_ret = HTTPLockedRC(
189 _http_ret = HTTPLockedRC(
187 _locked_by_explanation(extras.repository, locked_by, reason))
190 _locked_by_explanation(extras.repository, locked_by, reason))
188 if str(_http_ret.code).startswith('2'):
191 if str(_http_ret.code).startswith('2'):
189 # 2xx Codes don't raise exceptions
192 # 2xx Codes don't raise exceptions
190 output = _http_ret.title
193 output = _http_ret.title
191 else:
194 else:
192 raise _http_ret
195 raise _http_ret
193
196
194 # Propagate to external components. This is done after checking the
197 # Propagate to external components. This is done after checking the
195 # lock, for consistent behavior.
198 # lock, for consistent behavior.
196 hook_response = ''
199 hook_response = ''
197 if not is_shadow_repo(extras):
200 if not is_shadow_repo(extras):
198 extras.hook_type = extras.hook_type or 'pre_pull'
201 extras.hook_type = extras.hook_type or 'pre_pull'
199 hook_response = pre_pull_extension(
202 hook_response = pre_pull_extension(
200 repo_store_path=Repository.base_path(), **extras)
203 repo_store_path=Repository.base_path(), **extras)
201 events.trigger(events.RepoPrePullEvent(
204 events.trigger(events.RepoPrePullEvent(
202 repo_name=extras.repository, extras=extras))
205 repo_name=extras.repository, extras=extras))
203
206
204 return HookResponse(0, output) + hook_response
207 return HookResponse(0, output) + hook_response
205
208
206
209
207 def post_pull(extras):
210 def post_pull(extras):
208 """Hook executed after client pulls the code."""
211 """Hook executed after client pulls the code."""
209
212
210 audit_user = audit_logger.UserWrap(
213 audit_user = audit_logger.UserWrap(
211 username=extras.username,
214 username=extras.username,
212 ip_addr=extras.ip)
215 ip_addr=extras.ip)
213 repo = audit_logger.RepoWrap(repo_name=extras.repository)
216 repo = audit_logger.RepoWrap(repo_name=extras.repository)
214 audit_logger.store(
217 audit_logger.store(
215 'user.pull', action_data={'user_agent': extras.user_agent},
218 'user.pull', action_data={'user_agent': extras.user_agent},
216 user=audit_user, repo=repo, commit=True)
219 user=audit_user, repo=repo, commit=True)
217
220
218 statsd = StatsdClient.statsd
221 statsd = StatsdClient.statsd
219 if statsd:
222 if statsd:
220 statsd.incr('rhodecode_pull_total', tags=[
223 statsd.incr('rhodecode_pull_total', tags=[
221 f'user-agent:{user_agent_normalizer(extras.user_agent)}',
224 f'user-agent:{user_agent_normalizer(extras.user_agent)}',
222 ])
225 ])
223 output = ''
226 output = ''
224 # make lock is a tri state False, True, None. We only make lock on True
227 # make lock is a tri state False, True, None. We only make lock on True
225 if extras.make_lock is True and not is_shadow_repo(extras):
228 if extras.make_lock is True and not is_shadow_repo(extras):
226 user = User.get_by_username(extras.username)
229 user = User.get_by_username(extras.username)
227 Repository.lock(Repository.get_by_repo_name(extras.repository),
230 Repository.lock(Repository.get_by_repo_name(extras.repository),
228 user.user_id,
231 user.user_id,
229 lock_reason=Repository.LOCK_PULL)
232 lock_reason=Repository.LOCK_PULL)
230 msg = 'Made lock on repo `{}`'.format(extras.repository)
233 msg = 'Made lock on repo `{}`'.format(extras.repository)
231 output += msg
234 output += msg
232
235
233 if extras.locked_by[0]:
236 if extras.locked_by[0]:
234 locked_by = User.get(extras.locked_by[0]).username
237 locked_by = User.get(extras.locked_by[0]).username
235 reason = extras.locked_by[2]
238 reason = extras.locked_by[2]
236 _http_ret = HTTPLockedRC(
239 _http_ret = HTTPLockedRC(
237 _locked_by_explanation(extras.repository, locked_by, reason))
240 _locked_by_explanation(extras.repository, locked_by, reason))
238 if str(_http_ret.code).startswith('2'):
241 if str(_http_ret.code).startswith('2'):
239 # 2xx Codes don't raise exceptions
242 # 2xx Codes don't raise exceptions
240 output += _http_ret.title
243 output += _http_ret.title
241
244
242 # Propagate to external components.
245 # Propagate to external components.
243 hook_response = ''
246 hook_response = ''
244 if not is_shadow_repo(extras):
247 if not is_shadow_repo(extras):
245 extras.hook_type = extras.hook_type or 'post_pull'
248 extras.hook_type = extras.hook_type or 'post_pull'
246 hook_response = post_pull_extension(
249 hook_response = post_pull_extension(
247 repo_store_path=Repository.base_path(), **extras)
250 repo_store_path=Repository.base_path(), **extras)
248 events.trigger(events.RepoPullEvent(
251 events.trigger(events.RepoPullEvent(
249 repo_name=extras.repository, extras=extras))
252 repo_name=extras.repository, extras=extras))
250
253
251 return HookResponse(0, output) + hook_response
254 return HookResponse(0, output) + hook_response
252
255
253
256
254 def post_push(extras):
257 def post_push(extras):
255 """Hook executed after user pushes to the repository."""
258 """Hook executed after user pushes to the repository."""
256 commit_ids = extras.commit_ids
259 commit_ids = extras.commit_ids
257
260
258 # log the push call
261 # log the push call
259 audit_user = audit_logger.UserWrap(
262 audit_user = audit_logger.UserWrap(
260 username=extras.username, ip_addr=extras.ip)
263 username=extras.username, ip_addr=extras.ip)
261 repo = audit_logger.RepoWrap(repo_name=extras.repository)
264 repo = audit_logger.RepoWrap(repo_name=extras.repository)
262 audit_logger.store(
265 audit_logger.store(
263 'user.push', action_data={
266 'user.push', action_data={
264 'user_agent': extras.user_agent,
267 'user_agent': extras.user_agent,
265 'commit_ids': commit_ids[:400]},
268 'commit_ids': commit_ids[:400]},
266 user=audit_user, repo=repo, commit=True)
269 user=audit_user, repo=repo, commit=True)
267
270
268 statsd = StatsdClient.statsd
271 statsd = StatsdClient.statsd
269 if statsd:
272 if statsd:
270 statsd.incr('rhodecode_push_total', tags=[
273 statsd.incr('rhodecode_push_total', tags=[
271 f'user-agent:{user_agent_normalizer(extras.user_agent)}',
274 f'user-agent:{user_agent_normalizer(extras.user_agent)}',
272 ])
275 ])
273
276
274 # Propagate to external components.
277 # Propagate to external components.
275 output = ''
278 output = ''
276 # make lock is a tri state False, True, None. We only release lock on False
279 # make lock is a tri state False, True, None. We only release lock on False
277 if extras.make_lock is False and not is_shadow_repo(extras):
280 if extras.make_lock is False and not is_shadow_repo(extras):
278 Repository.unlock(Repository.get_by_repo_name(extras.repository))
281 Repository.unlock(Repository.get_by_repo_name(extras.repository))
279 msg = f'Released lock on repo `{extras.repository}`\n'
282 msg = f'Released lock on repo `{extras.repository}`\n'
280 output += msg
283 output += msg
281
284
282 if extras.locked_by[0]:
285 if extras.locked_by[0]:
283 locked_by = User.get(extras.locked_by[0]).username
286 locked_by = User.get(extras.locked_by[0]).username
284 reason = extras.locked_by[2]
287 reason = extras.locked_by[2]
285 _http_ret = HTTPLockedRC(
288 _http_ret = HTTPLockedRC(
286 _locked_by_explanation(extras.repository, locked_by, reason))
289 _locked_by_explanation(extras.repository, locked_by, reason))
287 # TODO: johbo: if not?
290 # TODO: johbo: if not?
288 if str(_http_ret.code).startswith('2'):
291 if str(_http_ret.code).startswith('2'):
289 # 2xx Codes don't raise exceptions
292 # 2xx Codes don't raise exceptions
290 output += _http_ret.title
293 output += _http_ret.title
291
294
292 if extras.new_refs:
295 if extras.new_refs:
293 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
296 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
294 safe_str(extras.server_url), safe_str(extras.repository))
297 safe_str(extras.server_url), safe_str(extras.repository))
295
298
296 for branch_name in extras.new_refs['branches']:
299 for branch_name in extras.new_refs['branches']:
297 pr_link = tmpl.format(ref_type='branch', ref_name=safe_str(branch_name))
300 pr_link = tmpl.format(ref_type='branch', ref_name=safe_str(branch_name))
298 output += f'RhodeCode: open pull request link: {pr_link}\n'
301 output += f'RhodeCode: open pull request link: {pr_link}\n'
299
302
300 for book_name in extras.new_refs['bookmarks']:
303 for book_name in extras.new_refs['bookmarks']:
301 pr_link = tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name))
304 pr_link = tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name))
302 output += f'RhodeCode: open pull request link: {pr_link}\n'
305 output += f'RhodeCode: open pull request link: {pr_link}\n'
303
306
304 hook_response = ''
307 hook_response = ''
305 if not is_shadow_repo(extras):
308 if not is_shadow_repo(extras):
306 hook_response = post_push_extension(
309 hook_response = post_push_extension(
307 repo_store_path=Repository.base_path(),
310 repo_store_path=Repository.base_path(),
308 **extras)
311 **extras)
309 events.trigger(events.RepoPushEvent(
312 events.trigger(events.RepoPushEvent(
310 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
313 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
311
314
312 output += 'RhodeCode: push completed\n'
315 output += 'RhodeCode: push completed\n'
313 return HookResponse(0, output) + hook_response
316 return HookResponse(0, output) + hook_response
314
317
315
318
316 def _locked_by_explanation(repo_name, user_name, reason):
319 def _locked_by_explanation(repo_name, user_name, reason):
317 message = f'Repository `{repo_name}` locked by user `{user_name}`. Reason:`{reason}`'
320 message = f'Repository `{repo_name}` locked by user `{user_name}`. Reason:`{reason}`'
318 return message
321 return message
319
322
320
323
321 def check_allowed_create_user(user_dict, created_by, **kwargs):
324 def check_allowed_create_user(user_dict, created_by, **kwargs):
322 # pre create hooks
325 # pre create hooks
323 if pre_create_user.is_active():
326 if pre_create_user.is_active():
324 hook_result = pre_create_user(created_by=created_by, **user_dict)
327 hook_result = pre_create_user(created_by=created_by, **user_dict)
325 allowed = hook_result.status == 0
328 allowed = hook_result.status == 0
326 if not allowed:
329 if not allowed:
327 reason = hook_result.output
330 reason = hook_result.output
328 raise UserCreationError(reason)
331 raise UserCreationError(reason)
329
332
330
333
331 class ExtensionCallback(object):
334 class ExtensionCallback(object):
332 """
335 """
333 Forwards a given call to rcextensions, sanitizes keyword arguments.
336 Forwards a given call to rcextensions, sanitizes keyword arguments.
334
337
335 Does check if there is an extension active for that hook. If it is
338 Does check if there is an extension active for that hook. If it is
336 there, it will forward all `kwargs_keys` keyword arguments to the
339 there, it will forward all `kwargs_keys` keyword arguments to the
337 extension callback.
340 extension callback.
338 """
341 """
339
342
340 def __init__(self, hook_name, kwargs_keys):
343 def __init__(self, hook_name, kwargs_keys):
341 self._hook_name = hook_name
344 self._hook_name = hook_name
342 self._kwargs_keys = set(kwargs_keys)
345 self._kwargs_keys = set(kwargs_keys)
343
346
344 def __call__(self, *args, **kwargs):
347 def __call__(self, *args, **kwargs):
345 log.debug('Calling extension callback for `%s`', self._hook_name)
348 log.debug('Calling extension callback for `%s`', self._hook_name)
346 callback = self._get_callback()
349 callback = self._get_callback()
347 if not callback:
350 if not callback:
348 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
351 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
349 return
352 return
350
353
351 kwargs_to_pass = {}
354 kwargs_to_pass = {}
352 for key in self._kwargs_keys:
355 for key in self._kwargs_keys:
353 try:
356 try:
354 kwargs_to_pass[key] = kwargs[key]
357 kwargs_to_pass[key] = kwargs[key]
355 except KeyError:
358 except KeyError:
356 log.error('Failed to fetch %s key from given kwargs. '
359 log.error('Failed to fetch %s key from given kwargs. '
357 'Expected keys: %s', key, self._kwargs_keys)
360 'Expected keys: %s', key, self._kwargs_keys)
358 raise
361 raise
359
362
360 # backward compat for removed api_key for old hooks. This was it works
363 # backward compat for removed api_key for old hooks. This was it works
361 # with older rcextensions that require api_key present
364 # with older rcextensions that require api_key present
362 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
365 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
363 kwargs_to_pass['api_key'] = '_DEPRECATED_'
366 kwargs_to_pass['api_key'] = '_DEPRECATED_'
364 return callback(**kwargs_to_pass)
367 return callback(**kwargs_to_pass)
365
368
366 def is_active(self):
369 def is_active(self):
367 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
370 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
368
371
369 def _get_callback(self):
372 def _get_callback(self):
370 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
373 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
371
374
372
375
373 pre_pull_extension = ExtensionCallback(
376 pre_pull_extension = ExtensionCallback(
374 hook_name='PRE_PULL_HOOK',
377 hook_name='PRE_PULL_HOOK',
375 kwargs_keys=(
378 kwargs_keys=(
376 'server_url', 'config', 'scm', 'username', 'ip', 'action',
379 'server_url', 'config', 'scm', 'username', 'ip', 'action',
377 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
380 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
378
381
379
382
380 post_pull_extension = ExtensionCallback(
383 post_pull_extension = ExtensionCallback(
381 hook_name='PULL_HOOK',
384 hook_name='PULL_HOOK',
382 kwargs_keys=(
385 kwargs_keys=(
383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
386 'server_url', 'config', 'scm', 'username', 'ip', 'action',
384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
387 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
385
388
386
389
387 pre_push_extension = ExtensionCallback(
390 pre_push_extension = ExtensionCallback(
388 hook_name='PRE_PUSH_HOOK',
391 hook_name='PRE_PUSH_HOOK',
389 kwargs_keys=(
392 kwargs_keys=(
390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
393 'server_url', 'config', 'scm', 'username', 'ip', 'action',
391 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
394 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
392
395
393
396
394 post_push_extension = ExtensionCallback(
397 post_push_extension = ExtensionCallback(
395 hook_name='PUSH_HOOK',
398 hook_name='PUSH_HOOK',
396 kwargs_keys=(
399 kwargs_keys=(
397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
400 'server_url', 'config', 'scm', 'username', 'ip', 'action',
398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
401 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
399
402
400
403
401 pre_create_user = ExtensionCallback(
404 pre_create_user = ExtensionCallback(
402 hook_name='PRE_CREATE_USER_HOOK',
405 hook_name='PRE_CREATE_USER_HOOK',
403 kwargs_keys=(
406 kwargs_keys=(
404 'username', 'password', 'email', 'firstname', 'lastname', 'active',
407 'username', 'password', 'email', 'firstname', 'lastname', 'active',
405 'admin', 'created_by'))
408 'admin', 'created_by'))
406
409
407
410
408 create_pull_request = ExtensionCallback(
411 create_pull_request = ExtensionCallback(
409 hook_name='CREATE_PULL_REQUEST',
412 hook_name='CREATE_PULL_REQUEST',
410 kwargs_keys=(
413 kwargs_keys=(
411 'server_url', 'config', 'scm', 'username', 'ip', 'action',
414 'server_url', 'config', 'scm', 'username', 'ip', 'action',
412 'repository', 'pull_request_id', 'url', 'title', 'description',
415 'repository', 'pull_request_id', 'url', 'title', 'description',
413 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
416 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
414 'mergeable', 'source', 'target', 'author', 'reviewers'))
417 'mergeable', 'source', 'target', 'author', 'reviewers'))
415
418
416
419
417 merge_pull_request = ExtensionCallback(
420 merge_pull_request = ExtensionCallback(
418 hook_name='MERGE_PULL_REQUEST',
421 hook_name='MERGE_PULL_REQUEST',
419 kwargs_keys=(
422 kwargs_keys=(
420 'server_url', 'config', 'scm', 'username', 'ip', 'action',
423 'server_url', 'config', 'scm', 'username', 'ip', 'action',
421 'repository', 'pull_request_id', 'url', 'title', 'description',
424 'repository', 'pull_request_id', 'url', 'title', 'description',
422 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
425 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
423 'mergeable', 'source', 'target', 'author', 'reviewers'))
426 'mergeable', 'source', 'target', 'author', 'reviewers'))
424
427
425
428
426 close_pull_request = ExtensionCallback(
429 close_pull_request = ExtensionCallback(
427 hook_name='CLOSE_PULL_REQUEST',
430 hook_name='CLOSE_PULL_REQUEST',
428 kwargs_keys=(
431 kwargs_keys=(
429 'server_url', 'config', 'scm', 'username', 'ip', 'action',
432 'server_url', 'config', 'scm', 'username', 'ip', 'action',
430 'repository', 'pull_request_id', 'url', 'title', 'description',
433 'repository', 'pull_request_id', 'url', 'title', 'description',
431 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
434 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
432 'mergeable', 'source', 'target', 'author', 'reviewers'))
435 'mergeable', 'source', 'target', 'author', 'reviewers'))
433
436
434
437
435 review_pull_request = ExtensionCallback(
438 review_pull_request = ExtensionCallback(
436 hook_name='REVIEW_PULL_REQUEST',
439 hook_name='REVIEW_PULL_REQUEST',
437 kwargs_keys=(
440 kwargs_keys=(
438 'server_url', 'config', 'scm', 'username', 'ip', 'action',
441 'server_url', 'config', 'scm', 'username', 'ip', 'action',
439 'repository', 'pull_request_id', 'url', 'title', 'description',
442 'repository', 'pull_request_id', 'url', 'title', 'description',
440 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
443 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
441 'mergeable', 'source', 'target', 'author', 'reviewers'))
444 'mergeable', 'source', 'target', 'author', 'reviewers'))
442
445
443
446
444 comment_pull_request = ExtensionCallback(
447 comment_pull_request = ExtensionCallback(
445 hook_name='COMMENT_PULL_REQUEST',
448 hook_name='COMMENT_PULL_REQUEST',
446 kwargs_keys=(
449 kwargs_keys=(
447 'server_url', 'config', 'scm', 'username', 'ip', 'action',
450 'server_url', 'config', 'scm', 'username', 'ip', 'action',
448 'repository', 'pull_request_id', 'url', 'title', 'description',
451 'repository', 'pull_request_id', 'url', 'title', 'description',
449 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
452 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
450 'mergeable', 'source', 'target', 'author', 'reviewers'))
453 'mergeable', 'source', 'target', 'author', 'reviewers'))
451
454
452
455
453 comment_edit_pull_request = ExtensionCallback(
456 comment_edit_pull_request = ExtensionCallback(
454 hook_name='COMMENT_EDIT_PULL_REQUEST',
457 hook_name='COMMENT_EDIT_PULL_REQUEST',
455 kwargs_keys=(
458 kwargs_keys=(
456 'server_url', 'config', 'scm', 'username', 'ip', 'action',
459 'server_url', 'config', 'scm', 'username', 'ip', 'action',
457 'repository', 'pull_request_id', 'url', 'title', 'description',
460 'repository', 'pull_request_id', 'url', 'title', 'description',
458 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
461 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
459 'mergeable', 'source', 'target', 'author', 'reviewers'))
462 'mergeable', 'source', 'target', 'author', 'reviewers'))
460
463
461
464
462 update_pull_request = ExtensionCallback(
465 update_pull_request = ExtensionCallback(
463 hook_name='UPDATE_PULL_REQUEST',
466 hook_name='UPDATE_PULL_REQUEST',
464 kwargs_keys=(
467 kwargs_keys=(
465 'server_url', 'config', 'scm', 'username', 'ip', 'action',
468 'server_url', 'config', 'scm', 'username', 'ip', 'action',
466 'repository', 'pull_request_id', 'url', 'title', 'description',
469 'repository', 'pull_request_id', 'url', 'title', 'description',
467 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
470 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
468 'mergeable', 'source', 'target', 'author', 'reviewers'))
471 'mergeable', 'source', 'target', 'author', 'reviewers'))
469
472
470
473
471 create_user = ExtensionCallback(
474 create_user = ExtensionCallback(
472 hook_name='CREATE_USER_HOOK',
475 hook_name='CREATE_USER_HOOK',
473 kwargs_keys=(
476 kwargs_keys=(
474 'username', 'full_name_or_username', 'full_contact', 'user_id',
477 'username', 'full_name_or_username', 'full_contact', 'user_id',
475 'name', 'firstname', 'short_contact', 'admin', 'lastname',
478 'name', 'firstname', 'short_contact', 'admin', 'lastname',
476 'ip_addresses', 'extern_type', 'extern_name',
479 'ip_addresses', 'extern_type', 'extern_name',
477 'email', 'api_keys', 'last_login',
480 'email', 'api_keys', 'last_login',
478 'full_name', 'active', 'password', 'emails',
481 'full_name', 'active', 'password', 'emails',
479 'inherit_default_permissions', 'created_by', 'created_on'))
482 'inherit_default_permissions', 'created_by', 'created_on'))
480
483
481
484
482 delete_user = ExtensionCallback(
485 delete_user = ExtensionCallback(
483 hook_name='DELETE_USER_HOOK',
486 hook_name='DELETE_USER_HOOK',
484 kwargs_keys=(
487 kwargs_keys=(
485 'username', 'full_name_or_username', 'full_contact', 'user_id',
488 'username', 'full_name_or_username', 'full_contact', 'user_id',
486 'name', 'firstname', 'short_contact', 'admin', 'lastname',
489 'name', 'firstname', 'short_contact', 'admin', 'lastname',
487 'ip_addresses',
490 'ip_addresses',
488 'email', 'last_login',
491 'email', 'last_login',
489 'full_name', 'active', 'password', 'emails',
492 'full_name', 'active', 'password', 'emails',
490 'inherit_default_permissions', 'deleted_by'))
493 'inherit_default_permissions', 'deleted_by'))
491
494
492
495
493 create_repository = ExtensionCallback(
496 create_repository = ExtensionCallback(
494 hook_name='CREATE_REPO_HOOK',
497 hook_name='CREATE_REPO_HOOK',
495 kwargs_keys=(
498 kwargs_keys=(
496 'repo_name', 'repo_type', 'description', 'private', 'created_on',
499 'repo_name', 'repo_type', 'description', 'private', 'created_on',
497 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
500 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
498 'clone_uri', 'fork_id', 'group_id', 'created_by'))
501 'clone_uri', 'fork_id', 'group_id', 'created_by'))
499
502
500
503
501 delete_repository = ExtensionCallback(
504 delete_repository = ExtensionCallback(
502 hook_name='DELETE_REPO_HOOK',
505 hook_name='DELETE_REPO_HOOK',
503 kwargs_keys=(
506 kwargs_keys=(
504 'repo_name', 'repo_type', 'description', 'private', 'created_on',
507 'repo_name', 'repo_type', 'description', 'private', 'created_on',
505 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
508 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
506 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
509 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
507
510
508
511
509 comment_commit_repository = ExtensionCallback(
512 comment_commit_repository = ExtensionCallback(
510 hook_name='COMMENT_COMMIT_REPO_HOOK',
513 hook_name='COMMENT_COMMIT_REPO_HOOK',
511 kwargs_keys=(
514 kwargs_keys=(
512 'repo_name', 'repo_type', 'description', 'private', 'created_on',
515 'repo_name', 'repo_type', 'description', 'private', 'created_on',
513 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
516 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
514 'clone_uri', 'fork_id', 'group_id',
517 'clone_uri', 'fork_id', 'group_id',
515 'repository', 'created_by', 'comment', 'commit'))
518 'repository', 'created_by', 'comment', 'commit'))
516
519
517 comment_edit_commit_repository = ExtensionCallback(
520 comment_edit_commit_repository = ExtensionCallback(
518 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
521 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
519 kwargs_keys=(
522 kwargs_keys=(
520 'repo_name', 'repo_type', 'description', 'private', 'created_on',
523 'repo_name', 'repo_type', 'description', 'private', 'created_on',
521 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
524 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
522 'clone_uri', 'fork_id', 'group_id',
525 'clone_uri', 'fork_id', 'group_id',
523 'repository', 'created_by', 'comment', 'commit'))
526 'repository', 'created_by', 'comment', 'commit'))
524
527
525
528
526 create_repository_group = ExtensionCallback(
529 create_repository_group = ExtensionCallback(
527 hook_name='CREATE_REPO_GROUP_HOOK',
530 hook_name='CREATE_REPO_GROUP_HOOK',
528 kwargs_keys=(
531 kwargs_keys=(
529 'group_name', 'group_parent_id', 'group_description',
532 'group_name', 'group_parent_id', 'group_description',
530 'group_id', 'user_id', 'created_by', 'created_on',
533 'group_id', 'user_id', 'created_by', 'created_on',
531 'enable_locking'))
534 'enable_locking'))
@@ -1,436 +1,451 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 import os
19 import os
20 import time
20 import time
21 import logging
21 import logging
22 import tempfile
22 import tempfile
23 import traceback
23 import traceback
24 import threading
24 import threading
25 import socket
25 import socket
26 import msgpack
26 import msgpack
27 import gevent
27 import gevent
28
28
29 from http.server import BaseHTTPRequestHandler
29 from http.server import BaseHTTPRequestHandler
30 from socketserver import TCPServer
30 from socketserver import TCPServer
31
31
32 import rhodecode
32 import rhodecode
33 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
33 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
34 from rhodecode.model import meta
34 from rhodecode.model import meta
35 from rhodecode.lib import hooks_base
35 from rhodecode.lib import hooks_base
36 from rhodecode.lib.utils2 import AttributeDict
36 from rhodecode.lib.utils2 import AttributeDict
37 from rhodecode.lib.pyramid_utils import get_config
37 from rhodecode.lib.ext_json import json
38 from rhodecode.lib.ext_json import json
38 from rhodecode.lib import rc_cache
39 from rhodecode.lib import rc_cache
39
40
40 log = logging.getLogger(__name__)
41 log = logging.getLogger(__name__)
41
42
42
43
43 class HooksHttpHandler(BaseHTTPRequestHandler):
44 class HooksHttpHandler(BaseHTTPRequestHandler):
44
45
45 JSON_HOOKS_PROTO = 'json.v1'
46 JSON_HOOKS_PROTO = 'json.v1'
46 MSGPACK_HOOKS_PROTO = 'msgpack.v1'
47 MSGPACK_HOOKS_PROTO = 'msgpack.v1'
47 # starting with RhodeCode 5.0.0 MsgPack is the default, prior it used json
48 # starting with RhodeCode 5.0.0 MsgPack is the default, prior it used json
48 DEFAULT_HOOKS_PROTO = MSGPACK_HOOKS_PROTO
49 DEFAULT_HOOKS_PROTO = MSGPACK_HOOKS_PROTO
49
50
50 @classmethod
51 @classmethod
51 def serialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
52 def serialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
52 if proto == cls.MSGPACK_HOOKS_PROTO:
53 if proto == cls.MSGPACK_HOOKS_PROTO:
53 return msgpack.packb(data)
54 return msgpack.packb(data)
54 return json.dumps(data)
55 return json.dumps(data)
55
56
56 @classmethod
57 @classmethod
57 def deserialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
58 def deserialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
58 if proto == cls.MSGPACK_HOOKS_PROTO:
59 if proto == cls.MSGPACK_HOOKS_PROTO:
59 return msgpack.unpackb(data)
60 return msgpack.unpackb(data)
60 return json.loads(data)
61 return json.loads(data)
61
62
62 def do_POST(self):
63 def do_POST(self):
63 hooks_proto, method, extras = self._read_request()
64 hooks_proto, method, extras = self._read_request()
64 log.debug('Handling HooksHttpHandler %s with %s proto', method, hooks_proto)
65 log.debug('Handling HooksHttpHandler %s with %s proto', method, hooks_proto)
65
66
66 txn_id = getattr(self.server, 'txn_id', None)
67 txn_id = getattr(self.server, 'txn_id', None)
67 if txn_id:
68 if txn_id:
68 log.debug('Computing TXN_ID based on `%s`:`%s`',
69 log.debug('Computing TXN_ID based on `%s`:`%s`',
69 extras['repository'], extras['txn_id'])
70 extras['repository'], extras['txn_id'])
70 computed_txn_id = rc_cache.utils.compute_key_from_params(
71 computed_txn_id = rc_cache.utils.compute_key_from_params(
71 extras['repository'], extras['txn_id'])
72 extras['repository'], extras['txn_id'])
72 if txn_id != computed_txn_id:
73 if txn_id != computed_txn_id:
73 raise Exception(
74 raise Exception(
74 'TXN ID fail: expected {} got {} instead'.format(
75 'TXN ID fail: expected {} got {} instead'.format(
75 txn_id, computed_txn_id))
76 txn_id, computed_txn_id))
76
77
77 request = getattr(self.server, 'request', None)
78 request = getattr(self.server, 'request', None)
78 try:
79 try:
79 hooks = Hooks(request=request, log_prefix='HOOKS: {} '.format(self.server.server_address))
80 hooks = Hooks(request=request, log_prefix='HOOKS: {} '.format(self.server.server_address))
80 result = self._call_hook_method(hooks, method, extras)
81 result = self._call_hook_method(hooks, method, extras)
81
82
82 except Exception as e:
83 except Exception as e:
83 exc_tb = traceback.format_exc()
84 exc_tb = traceback.format_exc()
84 result = {
85 result = {
85 'exception': e.__class__.__name__,
86 'exception': e.__class__.__name__,
86 'exception_traceback': exc_tb,
87 'exception_traceback': exc_tb,
87 'exception_args': e.args
88 'exception_args': e.args
88 }
89 }
89 self._write_response(hooks_proto, result)
90 self._write_response(hooks_proto, result)
90
91
91 def _read_request(self):
92 def _read_request(self):
92 length = int(self.headers['Content-Length'])
93 length = int(self.headers['Content-Length'])
93 # respect sent headers, fallback to OLD proto for compatability
94 # respect sent headers, fallback to OLD proto for compatability
94 hooks_proto = self.headers.get('rc-hooks-protocol') or self.JSON_HOOKS_PROTO
95 hooks_proto = self.headers.get('rc-hooks-protocol') or self.JSON_HOOKS_PROTO
95 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
96 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
96 # support for new vcsserver msgpack based protocol hooks
97 # support for new vcsserver msgpack based protocol hooks
97 body = self.rfile.read(length)
98 body = self.rfile.read(length)
98 data = self.deserialize_data(body)
99 data = self.deserialize_data(body)
99 else:
100 else:
100 body = self.rfile.read(length)
101 body = self.rfile.read(length)
101 data = self.deserialize_data(body)
102 data = self.deserialize_data(body)
102
103
103 return hooks_proto, data['method'], data['extras']
104 return hooks_proto, data['method'], data['extras']
104
105
105 def _write_response(self, hooks_proto, result):
106 def _write_response(self, hooks_proto, result):
106 self.send_response(200)
107 self.send_response(200)
107 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
108 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
108 self.send_header("Content-type", "application/msgpack")
109 self.send_header("Content-type", "application/msgpack")
109 self.end_headers()
110 self.end_headers()
110 data = self.serialize_data(result)
111 data = self.serialize_data(result)
111 self.wfile.write(data)
112 self.wfile.write(data)
112 else:
113 else:
113 self.send_header("Content-type", "text/json")
114 self.send_header("Content-type", "text/json")
114 self.end_headers()
115 self.end_headers()
115 data = self.serialize_data(result)
116 data = self.serialize_data(result)
116 self.wfile.write(data)
117 self.wfile.write(data)
117
118
118 def _call_hook_method(self, hooks, method, extras):
119 def _call_hook_method(self, hooks, method, extras):
119 try:
120 try:
120 result = getattr(hooks, method)(extras)
121 result = getattr(hooks, method)(extras)
121 finally:
122 finally:
122 meta.Session.remove()
123 meta.Session.remove()
123 return result
124 return result
124
125
125 def log_message(self, format, *args):
126 def log_message(self, format, *args):
126 """
127 """
127 This is an overridden method of BaseHTTPRequestHandler which logs using
128 This is an overridden method of BaseHTTPRequestHandler which logs using
128 logging library instead of writing directly to stderr.
129 logging library instead of writing directly to stderr.
129 """
130 """
130
131
131 message = format % args
132 message = format % args
132
133
133 log.debug(
134 log.debug(
134 "HOOKS: client=%s - - [%s] %s", self.client_address,
135 "HOOKS: client=%s - - [%s] %s", self.client_address,
135 self.log_date_time_string(), message)
136 self.log_date_time_string(), message)
136
137
137
138
138 class DummyHooksCallbackDaemon(object):
139 class BaseHooksCallbackDaemon:
139 hooks_uri = ''
140 """
140
141 Basic context manager for actions that don't require some extra
142 """
141 def __init__(self):
143 def __init__(self):
142 self.hooks_module = Hooks.__module__
144 self.hooks_module = Hooks.__module__
143
145
144 def __enter__(self):
146 def __enter__(self):
145 log.debug('Running `%s` callback daemon', self.__class__.__name__)
147 log.debug('Running `%s` callback daemon', self.__class__.__name__)
146 return self
148 return self
147
149
148 def __exit__(self, exc_type, exc_val, exc_tb):
150 def __exit__(self, exc_type, exc_val, exc_tb):
149 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
151 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
150
152
151
153
154 class CeleryHooksCallbackDaemon(BaseHooksCallbackDaemon):
155 """
156 Context manger for achieving a compatibility with celery backend
157 """
158
159 def __init__(self, config):
160 self.task_queue = config.get('app:main', 'celery.broker_url')
161 self.task_backend = config.get('app:main', 'celery.result_backend')
162
163
152 class ThreadedHookCallbackDaemon(object):
164 class ThreadedHookCallbackDaemon(object):
153
165
154 _callback_thread = None
166 _callback_thread = None
155 _daemon = None
167 _daemon = None
156 _done = False
168 _done = False
157 use_gevent = False
169 use_gevent = False
158
170
159 def __init__(self, txn_id=None, host=None, port=None):
171 def __init__(self, txn_id=None, host=None, port=None):
160 self._prepare(txn_id=txn_id, host=host, port=port)
172 self._prepare(txn_id=txn_id, host=host, port=port)
161 if self.use_gevent:
173 if self.use_gevent:
162 self._run_func = self._run_gevent
174 self._run_func = self._run_gevent
163 self._stop_func = self._stop_gevent
175 self._stop_func = self._stop_gevent
164 else:
176 else:
165 self._run_func = self._run
177 self._run_func = self._run
166 self._stop_func = self._stop
178 self._stop_func = self._stop
167
179
168 def __enter__(self):
180 def __enter__(self):
169 log.debug('Running `%s` callback daemon', self.__class__.__name__)
181 log.debug('Running `%s` callback daemon', self.__class__.__name__)
170 self._run_func()
182 self._run_func()
171 return self
183 return self
172
184
173 def __exit__(self, exc_type, exc_val, exc_tb):
185 def __exit__(self, exc_type, exc_val, exc_tb):
174 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
186 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
175 self._stop_func()
187 self._stop_func()
176
188
177 def _prepare(self, txn_id=None, host=None, port=None):
189 def _prepare(self, txn_id=None, host=None, port=None):
178 raise NotImplementedError()
190 raise NotImplementedError()
179
191
180 def _run(self):
192 def _run(self):
181 raise NotImplementedError()
193 raise NotImplementedError()
182
194
183 def _stop(self):
195 def _stop(self):
184 raise NotImplementedError()
196 raise NotImplementedError()
185
197
186 def _run_gevent(self):
198 def _run_gevent(self):
187 raise NotImplementedError()
199 raise NotImplementedError()
188
200
189 def _stop_gevent(self):
201 def _stop_gevent(self):
190 raise NotImplementedError()
202 raise NotImplementedError()
191
203
192
204
193 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
205 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
194 """
206 """
195 Context manager which will run a callback daemon in a background thread.
207 Context manager which will run a callback daemon in a background thread.
196 """
208 """
197
209
198 hooks_uri = None
210 hooks_uri = None
199
211
200 # From Python docs: Polling reduces our responsiveness to a shutdown
212 # From Python docs: Polling reduces our responsiveness to a shutdown
201 # request and wastes cpu at all other times.
213 # request and wastes cpu at all other times.
202 POLL_INTERVAL = 0.01
214 POLL_INTERVAL = 0.01
203
215
204 use_gevent = False
216 use_gevent = False
205
217
206 @property
218 @property
207 def _hook_prefix(self):
219 def _hook_prefix(self):
208 return 'HOOKS: {} '.format(self.hooks_uri)
220 return 'HOOKS: {} '.format(self.hooks_uri)
209
221
210 def get_hostname(self):
222 def get_hostname(self):
211 return socket.gethostname() or '127.0.0.1'
223 return socket.gethostname() or '127.0.0.1'
212
224
213 def get_available_port(self, min_port=20000, max_port=65535):
225 def get_available_port(self, min_port=20000, max_port=65535):
214 from rhodecode.lib.utils2 import get_available_port as _get_port
226 from rhodecode.lib.utils2 import get_available_port as _get_port
215 return _get_port(min_port, max_port)
227 return _get_port(min_port, max_port)
216
228
217 def _prepare(self, txn_id=None, host=None, port=None):
229 def _prepare(self, txn_id=None, host=None, port=None):
218 from pyramid.threadlocal import get_current_request
230 from pyramid.threadlocal import get_current_request
219
231
220 if not host or host == "*":
232 if not host or host == "*":
221 host = self.get_hostname()
233 host = self.get_hostname()
222 if not port:
234 if not port:
223 port = self.get_available_port()
235 port = self.get_available_port()
224
236
225 server_address = (host, port)
237 server_address = (host, port)
226 self.hooks_uri = '{}:{}'.format(host, port)
238 self.hooks_uri = '{}:{}'.format(host, port)
227 self.txn_id = txn_id
239 self.txn_id = txn_id
228 self._done = False
240 self._done = False
229
241
230 log.debug(
242 log.debug(
231 "%s Preparing HTTP callback daemon registering hook object: %s",
243 "%s Preparing HTTP callback daemon registering hook object: %s",
232 self._hook_prefix, HooksHttpHandler)
244 self._hook_prefix, HooksHttpHandler)
233
245
234 self._daemon = TCPServer(server_address, HooksHttpHandler)
246 self._daemon = TCPServer(server_address, HooksHttpHandler)
235 # inject transaction_id for later verification
247 # inject transaction_id for later verification
236 self._daemon.txn_id = self.txn_id
248 self._daemon.txn_id = self.txn_id
237
249
238 # pass the WEB app request into daemon
250 # pass the WEB app request into daemon
239 self._daemon.request = get_current_request()
251 self._daemon.request = get_current_request()
240
252
241 def _run(self):
253 def _run(self):
242 log.debug("Running thread-based loop of callback daemon in background")
254 log.debug("Running thread-based loop of callback daemon in background")
243 callback_thread = threading.Thread(
255 callback_thread = threading.Thread(
244 target=self._daemon.serve_forever,
256 target=self._daemon.serve_forever,
245 kwargs={'poll_interval': self.POLL_INTERVAL})
257 kwargs={'poll_interval': self.POLL_INTERVAL})
246 callback_thread.daemon = True
258 callback_thread.daemon = True
247 callback_thread.start()
259 callback_thread.start()
248 self._callback_thread = callback_thread
260 self._callback_thread = callback_thread
249
261
250 def _run_gevent(self):
262 def _run_gevent(self):
251 log.debug("Running gevent-based loop of callback daemon in background")
263 log.debug("Running gevent-based loop of callback daemon in background")
252 # create a new greenlet for the daemon's serve_forever method
264 # create a new greenlet for the daemon's serve_forever method
253 callback_greenlet = gevent.spawn(
265 callback_greenlet = gevent.spawn(
254 self._daemon.serve_forever,
266 self._daemon.serve_forever,
255 poll_interval=self.POLL_INTERVAL)
267 poll_interval=self.POLL_INTERVAL)
256
268
257 # store reference to greenlet
269 # store reference to greenlet
258 self._callback_greenlet = callback_greenlet
270 self._callback_greenlet = callback_greenlet
259
271
260 # switch to this greenlet
272 # switch to this greenlet
261 gevent.sleep(0.01)
273 gevent.sleep(0.01)
262
274
263 def _stop(self):
275 def _stop(self):
264 log.debug("Waiting for background thread to finish.")
276 log.debug("Waiting for background thread to finish.")
265 self._daemon.shutdown()
277 self._daemon.shutdown()
266 self._callback_thread.join()
278 self._callback_thread.join()
267 self._daemon = None
279 self._daemon = None
268 self._callback_thread = None
280 self._callback_thread = None
269 if self.txn_id:
281 if self.txn_id:
270 txn_id_file = get_txn_id_data_path(self.txn_id)
282 txn_id_file = get_txn_id_data_path(self.txn_id)
271 log.debug('Cleaning up TXN ID %s', txn_id_file)
283 log.debug('Cleaning up TXN ID %s', txn_id_file)
272 if os.path.isfile(txn_id_file):
284 if os.path.isfile(txn_id_file):
273 os.remove(txn_id_file)
285 os.remove(txn_id_file)
274
286
275 log.debug("Background thread done.")
287 log.debug("Background thread done.")
276
288
277 def _stop_gevent(self):
289 def _stop_gevent(self):
278 log.debug("Waiting for background greenlet to finish.")
290 log.debug("Waiting for background greenlet to finish.")
279
291
280 # if greenlet exists and is running
292 # if greenlet exists and is running
281 if self._callback_greenlet and not self._callback_greenlet.dead:
293 if self._callback_greenlet and not self._callback_greenlet.dead:
282 # shutdown daemon if it exists
294 # shutdown daemon if it exists
283 if self._daemon:
295 if self._daemon:
284 self._daemon.shutdown()
296 self._daemon.shutdown()
285
297
286 # kill the greenlet
298 # kill the greenlet
287 self._callback_greenlet.kill()
299 self._callback_greenlet.kill()
288
300
289 self._daemon = None
301 self._daemon = None
290 self._callback_greenlet = None
302 self._callback_greenlet = None
291
303
292 if self.txn_id:
304 if self.txn_id:
293 txn_id_file = get_txn_id_data_path(self.txn_id)
305 txn_id_file = get_txn_id_data_path(self.txn_id)
294 log.debug('Cleaning up TXN ID %s', txn_id_file)
306 log.debug('Cleaning up TXN ID %s', txn_id_file)
295 if os.path.isfile(txn_id_file):
307 if os.path.isfile(txn_id_file):
296 os.remove(txn_id_file)
308 os.remove(txn_id_file)
297
309
298 log.debug("Background greenlet done.")
310 log.debug("Background greenlet done.")
299
311
300
312
301 def get_txn_id_data_path(txn_id):
313 def get_txn_id_data_path(txn_id):
302 import rhodecode
314 import rhodecode
303
315
304 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
316 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
305 final_dir = os.path.join(root, 'svn_txn_id')
317 final_dir = os.path.join(root, 'svn_txn_id')
306
318
307 if not os.path.isdir(final_dir):
319 if not os.path.isdir(final_dir):
308 os.makedirs(final_dir)
320 os.makedirs(final_dir)
309 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
321 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
310
322
311
323
312 def store_txn_id_data(txn_id, data_dict):
324 def store_txn_id_data(txn_id, data_dict):
313 if not txn_id:
325 if not txn_id:
314 log.warning('Cannot store txn_id because it is empty')
326 log.warning('Cannot store txn_id because it is empty')
315 return
327 return
316
328
317 path = get_txn_id_data_path(txn_id)
329 path = get_txn_id_data_path(txn_id)
318 try:
330 try:
319 with open(path, 'wb') as f:
331 with open(path, 'wb') as f:
320 f.write(json.dumps(data_dict))
332 f.write(json.dumps(data_dict))
321 except Exception:
333 except Exception:
322 log.exception('Failed to write txn_id metadata')
334 log.exception('Failed to write txn_id metadata')
323
335
324
336
325 def get_txn_id_from_store(txn_id):
337 def get_txn_id_from_store(txn_id):
326 """
338 """
327 Reads txn_id from store and if present returns the data for callback manager
339 Reads txn_id from store and if present returns the data for callback manager
328 """
340 """
329 path = get_txn_id_data_path(txn_id)
341 path = get_txn_id_data_path(txn_id)
330 try:
342 try:
331 with open(path, 'rb') as f:
343 with open(path, 'rb') as f:
332 return json.loads(f.read())
344 return json.loads(f.read())
333 except Exception:
345 except Exception:
334 return {}
346 return {}
335
347
336
348
337 def prepare_callback_daemon(extras, protocol, host, use_direct_calls, txn_id=None):
349 def prepare_callback_daemon(extras, protocol, host, txn_id=None):
338 txn_details = get_txn_id_from_store(txn_id)
350 txn_details = get_txn_id_from_store(txn_id)
339 port = txn_details.get('port', 0)
351 port = txn_details.get('port', 0)
340 if use_direct_calls:
352 match protocol:
341 callback_daemon = DummyHooksCallbackDaemon()
353 case 'http':
342 extras['hooks_module'] = callback_daemon.hooks_module
343 else:
344 if protocol == 'http':
345 callback_daemon = HttpHooksCallbackDaemon(
354 callback_daemon = HttpHooksCallbackDaemon(
346 txn_id=txn_id, host=host, port=port)
355 txn_id=txn_id, host=host, port=port)
347 else:
356 case 'celery':
357 callback_daemon = CeleryHooksCallbackDaemon(get_config(extras['config']))
358 case 'local':
359 callback_daemon = BaseHooksCallbackDaemon()
360 case _:
348 log.error('Unsupported callback daemon protocol "%s"', protocol)
361 log.error('Unsupported callback daemon protocol "%s"', protocol)
349 raise Exception('Unsupported callback daemon protocol.')
362 raise Exception('Unsupported callback daemon protocol.')
350
363
351 extras['hooks_uri'] = callback_daemon.hooks_uri
364 extras['hooks_uri'] = getattr(callback_daemon, 'hooks_uri', '')
365 extras['task_queue'] = getattr(callback_daemon, 'task_queue', '')
366 extras['task_backend'] = getattr(callback_daemon, 'task_backend', '')
352 extras['hooks_protocol'] = protocol
367 extras['hooks_protocol'] = protocol
353 extras['time'] = time.time()
368 extras['time'] = time.time()
354
369
355 # register txn_id
370 # register txn_id
356 extras['txn_id'] = txn_id
371 extras['txn_id'] = txn_id
357 log.debug('Prepared a callback daemon: %s at url `%s`',
372 log.debug('Prepared a callback daemon: %s',
358 callback_daemon.__class__.__name__, callback_daemon.hooks_uri)
373 callback_daemon.__class__.__name__)
359 return callback_daemon, extras
374 return callback_daemon, extras
360
375
361
376
362 class Hooks(object):
377 class Hooks(object):
363 """
378 """
364 Exposes the hooks for remote call backs
379 Exposes the hooks for remote call backs
365 """
380 """
366 def __init__(self, request=None, log_prefix=''):
381 def __init__(self, request=None, log_prefix=''):
367 self.log_prefix = log_prefix
382 self.log_prefix = log_prefix
368 self.request = request
383 self.request = request
369
384
370 def repo_size(self, extras):
385 def repo_size(self, extras):
371 log.debug("%sCalled repo_size of %s object", self.log_prefix, self)
386 log.debug("%sCalled repo_size of %s object", self.log_prefix, self)
372 return self._call_hook(hooks_base.repo_size, extras)
387 return self._call_hook(hooks_base.repo_size, extras)
373
388
374 def pre_pull(self, extras):
389 def pre_pull(self, extras):
375 log.debug("%sCalled pre_pull of %s object", self.log_prefix, self)
390 log.debug("%sCalled pre_pull of %s object", self.log_prefix, self)
376 return self._call_hook(hooks_base.pre_pull, extras)
391 return self._call_hook(hooks_base.pre_pull, extras)
377
392
378 def post_pull(self, extras):
393 def post_pull(self, extras):
379 log.debug("%sCalled post_pull of %s object", self.log_prefix, self)
394 log.debug("%sCalled post_pull of %s object", self.log_prefix, self)
380 return self._call_hook(hooks_base.post_pull, extras)
395 return self._call_hook(hooks_base.post_pull, extras)
381
396
382 def pre_push(self, extras):
397 def pre_push(self, extras):
383 log.debug("%sCalled pre_push of %s object", self.log_prefix, self)
398 log.debug("%sCalled pre_push of %s object", self.log_prefix, self)
384 return self._call_hook(hooks_base.pre_push, extras)
399 return self._call_hook(hooks_base.pre_push, extras)
385
400
386 def post_push(self, extras):
401 def post_push(self, extras):
387 log.debug("%sCalled post_push of %s object", self.log_prefix, self)
402 log.debug("%sCalled post_push of %s object", self.log_prefix, self)
388 return self._call_hook(hooks_base.post_push, extras)
403 return self._call_hook(hooks_base.post_push, extras)
389
404
390 def _call_hook(self, hook, extras):
405 def _call_hook(self, hook, extras):
391 extras = AttributeDict(extras)
406 extras = AttributeDict(extras)
392 server_url = extras['server_url']
407 server_url = extras['server_url']
393
408
394 extras.request = self.request
409 extras.request = self.request
395
410
396 try:
411 try:
397 result = hook(extras)
412 result = hook(extras)
398 if result is None:
413 if result is None:
399 raise Exception(
414 raise Exception(
400 'Failed to obtain hook result from func: {}'.format(hook))
415 'Failed to obtain hook result from func: {}'.format(hook))
401 except HTTPBranchProtected as handled_error:
416 except HTTPBranchProtected as handled_error:
402 # Those special cases doesn't need error reporting. It's a case of
417 # Those special cases doesn't need error reporting. It's a case of
403 # locked repo or protected branch
418 # locked repo or protected branch
404 result = AttributeDict({
419 result = AttributeDict({
405 'status': handled_error.code,
420 'status': handled_error.code,
406 'output': handled_error.explanation
421 'output': handled_error.explanation
407 })
422 })
408 except (HTTPLockedRC, Exception) as error:
423 except (HTTPLockedRC, Exception) as error:
409 # locked needs different handling since we need to also
424 # locked needs different handling since we need to also
410 # handle PULL operations
425 # handle PULL operations
411 exc_tb = ''
426 exc_tb = ''
412 if not isinstance(error, HTTPLockedRC):
427 if not isinstance(error, HTTPLockedRC):
413 exc_tb = traceback.format_exc()
428 exc_tb = traceback.format_exc()
414 log.exception('%sException when handling hook %s', self.log_prefix, hook)
429 log.exception('%sException when handling hook %s', self.log_prefix, hook)
415 error_args = error.args
430 error_args = error.args
416 return {
431 return {
417 'status': 128,
432 'status': 128,
418 'output': '',
433 'output': '',
419 'exception': type(error).__name__,
434 'exception': type(error).__name__,
420 'exception_traceback': exc_tb,
435 'exception_traceback': exc_tb,
421 'exception_args': error_args,
436 'exception_args': error_args,
422 }
437 }
423 finally:
438 finally:
424 meta.Session.remove()
439 meta.Session.remove()
425
440
426 log.debug('%sGot hook call response %s', self.log_prefix, result)
441 log.debug('%sGot hook call response %s', self.log_prefix, result)
427 return {
442 return {
428 'status': result.status,
443 'status': result.status,
429 'output': result.output,
444 'output': result.output,
430 }
445 }
431
446
432 def __enter__(self):
447 def __enter__(self):
433 return self
448 return self
434
449
435 def __exit__(self, exc_type, exc_val, exc_tb):
450 def __exit__(self, exc_type, exc_val, exc_tb):
436 pass
451 pass
@@ -1,701 +1,701 b''
1
1
2
2
3 # Copyright (C) 2014-2023 RhodeCode GmbH
3 # Copyright (C) 2014-2023 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleVCS middleware for handling protocol request (push/clone etc.)
22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 It's implemented with basic auth function
23 It's implemented with basic auth function
24 """
24 """
25
25
26 import os
26 import os
27 import re
27 import re
28 import io
28 import io
29 import logging
29 import logging
30 import importlib
30 import importlib
31 from functools import wraps
31 from functools import wraps
32 from lxml import etree
32 from lxml import etree
33
33
34 import time
34 import time
35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36
36
37 from pyramid.httpexceptions import (
37 from pyramid.httpexceptions import (
38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 from zope.cachedescriptors.property import Lazy as LazyProperty
39 from zope.cachedescriptors.property import Lazy as LazyProperty
40
40
41 import rhodecode
41 import rhodecode
42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 from rhodecode.lib import rc_cache
43 from rhodecode.lib import rc_cache
44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 from rhodecode.lib.base import (
45 from rhodecode.lib.base import (
46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
49 from rhodecode.lib.middleware import appenlight
49 from rhodecode.lib.middleware import appenlight
50 from rhodecode.lib.middleware.utils import scm_app_http
50 from rhodecode.lib.middleware.utils import scm_app_http
51 from rhodecode.lib.str_utils import safe_bytes
51 from rhodecode.lib.str_utils import safe_bytes
52 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
52 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
53 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool
53 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool
54 from rhodecode.lib.vcs.conf import settings as vcs_settings
54 from rhodecode.lib.vcs.conf import settings as vcs_settings
55 from rhodecode.lib.vcs.backends import base
55 from rhodecode.lib.vcs.backends import base
56
56
57 from rhodecode.model import meta
57 from rhodecode.model import meta
58 from rhodecode.model.db import User, Repository, PullRequest
58 from rhodecode.model.db import User, Repository, PullRequest
59 from rhodecode.model.scm import ScmModel
59 from rhodecode.model.scm import ScmModel
60 from rhodecode.model.pull_request import PullRequestModel
60 from rhodecode.model.pull_request import PullRequestModel
61 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
61 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
62
62
63 log = logging.getLogger(__name__)
63 log = logging.getLogger(__name__)
64
64
65
65
66 def extract_svn_txn_id(acl_repo_name, data: bytes):
66 def extract_svn_txn_id(acl_repo_name, data: bytes):
67 """
67 """
68 Helper method for extraction of svn txn_id from submitted XML data during
68 Helper method for extraction of svn txn_id from submitted XML data during
69 POST operations
69 POST operations
70 """
70 """
71
71
72 try:
72 try:
73 root = etree.fromstring(data)
73 root = etree.fromstring(data)
74 pat = re.compile(r'/txn/(?P<txn_id>.*)')
74 pat = re.compile(r'/txn/(?P<txn_id>.*)')
75 for el in root:
75 for el in root:
76 if el.tag == '{DAV:}source':
76 if el.tag == '{DAV:}source':
77 for sub_el in el:
77 for sub_el in el:
78 if sub_el.tag == '{DAV:}href':
78 if sub_el.tag == '{DAV:}href':
79 match = pat.search(sub_el.text)
79 match = pat.search(sub_el.text)
80 if match:
80 if match:
81 svn_tx_id = match.groupdict()['txn_id']
81 svn_tx_id = match.groupdict()['txn_id']
82 txn_id = rc_cache.utils.compute_key_from_params(
82 txn_id = rc_cache.utils.compute_key_from_params(
83 acl_repo_name, svn_tx_id)
83 acl_repo_name, svn_tx_id)
84 return txn_id
84 return txn_id
85 except Exception:
85 except Exception:
86 log.exception('Failed to extract txn_id')
86 log.exception('Failed to extract txn_id')
87
87
88
88
89 def initialize_generator(factory):
89 def initialize_generator(factory):
90 """
90 """
91 Initializes the returned generator by draining its first element.
91 Initializes the returned generator by draining its first element.
92
92
93 This can be used to give a generator an initializer, which is the code
93 This can be used to give a generator an initializer, which is the code
94 up to the first yield statement. This decorator enforces that the first
94 up to the first yield statement. This decorator enforces that the first
95 produced element has the value ``"__init__"`` to make its special
95 produced element has the value ``"__init__"`` to make its special
96 purpose very explicit in the using code.
96 purpose very explicit in the using code.
97 """
97 """
98
98
99 @wraps(factory)
99 @wraps(factory)
100 def wrapper(*args, **kwargs):
100 def wrapper(*args, **kwargs):
101 gen = factory(*args, **kwargs)
101 gen = factory(*args, **kwargs)
102 try:
102 try:
103 init = next(gen)
103 init = next(gen)
104 except StopIteration:
104 except StopIteration:
105 raise ValueError('Generator must yield at least one element.')
105 raise ValueError('Generator must yield at least one element.')
106 if init != "__init__":
106 if init != "__init__":
107 raise ValueError('First yielded element must be "__init__".')
107 raise ValueError('First yielded element must be "__init__".')
108 return gen
108 return gen
109 return wrapper
109 return wrapper
110
110
111
111
112 class SimpleVCS(object):
112 class SimpleVCS(object):
113 """Common functionality for SCM HTTP handlers."""
113 """Common functionality for SCM HTTP handlers."""
114
114
115 SCM = 'unknown'
115 SCM = 'unknown'
116
116
117 acl_repo_name = None
117 acl_repo_name = None
118 url_repo_name = None
118 url_repo_name = None
119 vcs_repo_name = None
119 vcs_repo_name = None
120 rc_extras = {}
120 rc_extras = {}
121
121
122 # We have to handle requests to shadow repositories different than requests
122 # We have to handle requests to shadow repositories different than requests
123 # to normal repositories. Therefore we have to distinguish them. To do this
123 # to normal repositories. Therefore we have to distinguish them. To do this
124 # we use this regex which will match only on URLs pointing to shadow
124 # we use this regex which will match only on URLs pointing to shadow
125 # repositories.
125 # repositories.
126 shadow_repo_re = re.compile(
126 shadow_repo_re = re.compile(
127 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
127 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
128 '(?P<target>{slug_pat})/' # target repo
128 '(?P<target>{slug_pat})/' # target repo
129 'pull-request/(?P<pr_id>\\d+)/' # pull request
129 'pull-request/(?P<pr_id>\\d+)/' # pull request
130 'repository$' # shadow repo
130 'repository$' # shadow repo
131 .format(slug_pat=SLUG_RE.pattern))
131 .format(slug_pat=SLUG_RE.pattern))
132
132
133 def __init__(self, config, registry):
133 def __init__(self, config, registry):
134 self.registry = registry
134 self.registry = registry
135 self.config = config
135 self.config = config
136 # re-populated by specialized middleware
136 # re-populated by specialized middleware
137 self.repo_vcs_config = base.Config()
137 self.repo_vcs_config = base.Config()
138
138
139 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
139 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
140 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
140 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
141
141
142 # authenticate this VCS request using authfunc
142 # authenticate this VCS request using authfunc
143 auth_ret_code_detection = \
143 auth_ret_code_detection = \
144 str2bool(self.config.get('auth_ret_code_detection', False))
144 str2bool(self.config.get('auth_ret_code_detection', False))
145 self.authenticate = BasicAuth(
145 self.authenticate = BasicAuth(
146 '', authenticate, registry, config.get('auth_ret_code'),
146 '', authenticate, registry, config.get('auth_ret_code'),
147 auth_ret_code_detection, rc_realm=realm)
147 auth_ret_code_detection, rc_realm=realm)
148 self.ip_addr = '0.0.0.0'
148 self.ip_addr = '0.0.0.0'
149
149
150 @LazyProperty
150 @LazyProperty
151 def global_vcs_config(self):
151 def global_vcs_config(self):
152 try:
152 try:
153 return VcsSettingsModel().get_ui_settings_as_config_obj()
153 return VcsSettingsModel().get_ui_settings_as_config_obj()
154 except Exception:
154 except Exception:
155 return base.Config()
155 return base.Config()
156
156
157 @property
157 @property
158 def base_path(self):
158 def base_path(self):
159 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
159 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
160
160
161 if not settings_path:
161 if not settings_path:
162 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
162 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
163
163
164 if not settings_path:
164 if not settings_path:
165 # try, maybe we passed in explicitly as config option
165 # try, maybe we passed in explicitly as config option
166 settings_path = self.config.get('base_path')
166 settings_path = self.config.get('base_path')
167
167
168 if not settings_path:
168 if not settings_path:
169 raise ValueError('FATAL: base_path is empty')
169 raise ValueError('FATAL: base_path is empty')
170 return settings_path
170 return settings_path
171
171
172 def set_repo_names(self, environ):
172 def set_repo_names(self, environ):
173 """
173 """
174 This will populate the attributes acl_repo_name, url_repo_name,
174 This will populate the attributes acl_repo_name, url_repo_name,
175 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
175 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
176 shadow) repositories all names are equal. In case of requests to a
176 shadow) repositories all names are equal. In case of requests to a
177 shadow repository the acl-name points to the target repo of the pull
177 shadow repository the acl-name points to the target repo of the pull
178 request and the vcs-name points to the shadow repo file system path.
178 request and the vcs-name points to the shadow repo file system path.
179 The url-name is always the URL used by the vcs client program.
179 The url-name is always the URL used by the vcs client program.
180
180
181 Example in case of a shadow repo:
181 Example in case of a shadow repo:
182 acl_repo_name = RepoGroup/MyRepo
182 acl_repo_name = RepoGroup/MyRepo
183 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
183 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
184 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
184 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
185 """
185 """
186 # First we set the repo name from URL for all attributes. This is the
186 # First we set the repo name from URL for all attributes. This is the
187 # default if handling normal (non shadow) repo requests.
187 # default if handling normal (non shadow) repo requests.
188 self.url_repo_name = self._get_repository_name(environ)
188 self.url_repo_name = self._get_repository_name(environ)
189 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
189 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
190 self.is_shadow_repo = False
190 self.is_shadow_repo = False
191
191
192 # Check if this is a request to a shadow repository.
192 # Check if this is a request to a shadow repository.
193 match = self.shadow_repo_re.match(self.url_repo_name)
193 match = self.shadow_repo_re.match(self.url_repo_name)
194 if match:
194 if match:
195 match_dict = match.groupdict()
195 match_dict = match.groupdict()
196
196
197 # Build acl repo name from regex match.
197 # Build acl repo name from regex match.
198 acl_repo_name = safe_str('{groups}{target}'.format(
198 acl_repo_name = safe_str('{groups}{target}'.format(
199 groups=match_dict['groups'] or '',
199 groups=match_dict['groups'] or '',
200 target=match_dict['target']))
200 target=match_dict['target']))
201
201
202 # Retrieve pull request instance by ID from regex match.
202 # Retrieve pull request instance by ID from regex match.
203 pull_request = PullRequest.get(match_dict['pr_id'])
203 pull_request = PullRequest.get(match_dict['pr_id'])
204
204
205 # Only proceed if we got a pull request and if acl repo name from
205 # Only proceed if we got a pull request and if acl repo name from
206 # URL equals the target repo name of the pull request.
206 # URL equals the target repo name of the pull request.
207 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
207 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
208
208
209 # Get file system path to shadow repository.
209 # Get file system path to shadow repository.
210 workspace_id = PullRequestModel()._workspace_id(pull_request)
210 workspace_id = PullRequestModel()._workspace_id(pull_request)
211 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
211 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
212
212
213 # Store names for later usage.
213 # Store names for later usage.
214 self.vcs_repo_name = vcs_repo_name
214 self.vcs_repo_name = vcs_repo_name
215 self.acl_repo_name = acl_repo_name
215 self.acl_repo_name = acl_repo_name
216 self.is_shadow_repo = True
216 self.is_shadow_repo = True
217
217
218 log.debug('Setting all VCS repository names: %s', {
218 log.debug('Setting all VCS repository names: %s', {
219 'acl_repo_name': self.acl_repo_name,
219 'acl_repo_name': self.acl_repo_name,
220 'url_repo_name': self.url_repo_name,
220 'url_repo_name': self.url_repo_name,
221 'vcs_repo_name': self.vcs_repo_name,
221 'vcs_repo_name': self.vcs_repo_name,
222 })
222 })
223
223
224 @property
224 @property
225 def scm_app(self):
225 def scm_app(self):
226 custom_implementation = self.config['vcs.scm_app_implementation']
226 custom_implementation = self.config['vcs.scm_app_implementation']
227 if custom_implementation == 'http':
227 if custom_implementation == 'http':
228 log.debug('Using HTTP implementation of scm app.')
228 log.debug('Using HTTP implementation of scm app.')
229 scm_app_impl = scm_app_http
229 scm_app_impl = scm_app_http
230 else:
230 else:
231 log.debug('Using custom implementation of scm_app: "{}"'.format(
231 log.debug('Using custom implementation of scm_app: "{}"'.format(
232 custom_implementation))
232 custom_implementation))
233 scm_app_impl = importlib.import_module(custom_implementation)
233 scm_app_impl = importlib.import_module(custom_implementation)
234 return scm_app_impl
234 return scm_app_impl
235
235
236 def _get_by_id(self, repo_name):
236 def _get_by_id(self, repo_name):
237 """
237 """
238 Gets a special pattern _<ID> from clone url and tries to replace it
238 Gets a special pattern _<ID> from clone url and tries to replace it
239 with a repository_name for support of _<ID> non changeable urls
239 with a repository_name for support of _<ID> non changeable urls
240 """
240 """
241
241
242 data = repo_name.split('/')
242 data = repo_name.split('/')
243 if len(data) >= 2:
243 if len(data) >= 2:
244 from rhodecode.model.repo import RepoModel
244 from rhodecode.model.repo import RepoModel
245 by_id_match = RepoModel().get_repo_by_id(repo_name)
245 by_id_match = RepoModel().get_repo_by_id(repo_name)
246 if by_id_match:
246 if by_id_match:
247 data[1] = by_id_match.repo_name
247 data[1] = by_id_match.repo_name
248
248
249 # Because PEP-3333-WSGI uses bytes-tunneled-in-latin-1 as PATH_INFO
249 # Because PEP-3333-WSGI uses bytes-tunneled-in-latin-1 as PATH_INFO
250 # and we use this data
250 # and we use this data
251 maybe_new_path = '/'.join(data)
251 maybe_new_path = '/'.join(data)
252 return safe_bytes(maybe_new_path).decode('latin1')
252 return safe_bytes(maybe_new_path).decode('latin1')
253
253
254 def _invalidate_cache(self, repo_name):
254 def _invalidate_cache(self, repo_name):
255 """
255 """
256 Set's cache for this repository for invalidation on next access
256 Set's cache for this repository for invalidation on next access
257
257
258 :param repo_name: full repo name, also a cache key
258 :param repo_name: full repo name, also a cache key
259 """
259 """
260 ScmModel().mark_for_invalidation(repo_name)
260 ScmModel().mark_for_invalidation(repo_name)
261
261
262 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
262 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
263 db_repo = Repository.get_by_repo_name(repo_name)
263 db_repo = Repository.get_by_repo_name(repo_name)
264 if not db_repo:
264 if not db_repo:
265 log.debug('Repository `%s` not found inside the database.',
265 log.debug('Repository `%s` not found inside the database.',
266 repo_name)
266 repo_name)
267 return False
267 return False
268
268
269 if db_repo.repo_type != scm_type:
269 if db_repo.repo_type != scm_type:
270 log.warning(
270 log.warning(
271 'Repository `%s` have incorrect scm_type, expected %s got %s',
271 'Repository `%s` have incorrect scm_type, expected %s got %s',
272 repo_name, db_repo.repo_type, scm_type)
272 repo_name, db_repo.repo_type, scm_type)
273 return False
273 return False
274
274
275 config = db_repo._config
275 config = db_repo._config
276 config.set('extensions', 'largefiles', '')
276 config.set('extensions', 'largefiles', '')
277 return is_valid_repo(
277 return is_valid_repo(
278 repo_name, base_path,
278 repo_name, base_path,
279 explicit_scm=scm_type, expect_scm=scm_type, config=config)
279 explicit_scm=scm_type, expect_scm=scm_type, config=config)
280
280
281 def valid_and_active_user(self, user):
281 def valid_and_active_user(self, user):
282 """
282 """
283 Checks if that user is not empty, and if it's actually object it checks
283 Checks if that user is not empty, and if it's actually object it checks
284 if he's active.
284 if he's active.
285
285
286 :param user: user object or None
286 :param user: user object or None
287 :return: boolean
287 :return: boolean
288 """
288 """
289 if user is None:
289 if user is None:
290 return False
290 return False
291
291
292 elif user.active:
292 elif user.active:
293 return True
293 return True
294
294
295 return False
295 return False
296
296
297 @property
297 @property
298 def is_shadow_repo_dir(self):
298 def is_shadow_repo_dir(self):
299 return os.path.isdir(self.vcs_repo_name)
299 return os.path.isdir(self.vcs_repo_name)
300
300
301 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
301 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
302 plugin_id='', plugin_cache_active=False, cache_ttl=0):
302 plugin_id='', plugin_cache_active=False, cache_ttl=0):
303 """
303 """
304 Checks permissions using action (push/pull) user and repository
304 Checks permissions using action (push/pull) user and repository
305 name. If plugin_cache and ttl is set it will use the plugin which
305 name. If plugin_cache and ttl is set it will use the plugin which
306 authenticated the user to store the cached permissions result for N
306 authenticated the user to store the cached permissions result for N
307 amount of seconds as in cache_ttl
307 amount of seconds as in cache_ttl
308
308
309 :param action: push or pull action
309 :param action: push or pull action
310 :param user: user instance
310 :param user: user instance
311 :param repo_name: repository name
311 :param repo_name: repository name
312 """
312 """
313
313
314 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
314 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
315 plugin_id, plugin_cache_active, cache_ttl)
315 plugin_id, plugin_cache_active, cache_ttl)
316
316
317 user_id = user.user_id
317 user_id = user.user_id
318 cache_namespace_uid = f'cache_user_auth.{rc_cache.PERMISSIONS_CACHE_VER}.{user_id}'
318 cache_namespace_uid = f'cache_user_auth.{rc_cache.PERMISSIONS_CACHE_VER}.{user_id}'
319 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
319 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
320
320
321 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
321 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
322 expiration_time=cache_ttl,
322 expiration_time=cache_ttl,
323 condition=plugin_cache_active)
323 condition=plugin_cache_active)
324 def compute_perm_vcs(
324 def compute_perm_vcs(
325 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
325 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
326
326
327 log.debug('auth: calculating permission access now...')
327 log.debug('auth: calculating permission access now...')
328 # check IP
328 # check IP
329 inherit = user.inherit_default_permissions
329 inherit = user.inherit_default_permissions
330 ip_allowed = AuthUser.check_ip_allowed(
330 ip_allowed = AuthUser.check_ip_allowed(
331 user_id, ip_addr, inherit_from_default=inherit)
331 user_id, ip_addr, inherit_from_default=inherit)
332 if ip_allowed:
332 if ip_allowed:
333 log.info('Access for IP:%s allowed', ip_addr)
333 log.info('Access for IP:%s allowed', ip_addr)
334 else:
334 else:
335 return False
335 return False
336
336
337 if action == 'push':
337 if action == 'push':
338 perms = ('repository.write', 'repository.admin')
338 perms = ('repository.write', 'repository.admin')
339 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
339 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
340 return False
340 return False
341
341
342 else:
342 else:
343 # any other action need at least read permission
343 # any other action need at least read permission
344 perms = (
344 perms = (
345 'repository.read', 'repository.write', 'repository.admin')
345 'repository.read', 'repository.write', 'repository.admin')
346 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
346 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
347 return False
347 return False
348
348
349 return True
349 return True
350
350
351 start = time.time()
351 start = time.time()
352 log.debug('Running plugin `%s` permissions check', plugin_id)
352 log.debug('Running plugin `%s` permissions check', plugin_id)
353
353
354 # for environ based auth, password can be empty, but then the validation is
354 # for environ based auth, password can be empty, but then the validation is
355 # on the server that fills in the env data needed for authentication
355 # on the server that fills in the env data needed for authentication
356 perm_result = compute_perm_vcs(
356 perm_result = compute_perm_vcs(
357 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
357 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
358
358
359 auth_time = time.time() - start
359 auth_time = time.time() - start
360 log.debug('Permissions for plugin `%s` completed in %.4fs, '
360 log.debug('Permissions for plugin `%s` completed in %.4fs, '
361 'expiration time of fetched cache %.1fs.',
361 'expiration time of fetched cache %.1fs.',
362 plugin_id, auth_time, cache_ttl)
362 plugin_id, auth_time, cache_ttl)
363
363
364 return perm_result
364 return perm_result
365
365
366 def _get_http_scheme(self, environ):
366 def _get_http_scheme(self, environ):
367 try:
367 try:
368 return environ['wsgi.url_scheme']
368 return environ['wsgi.url_scheme']
369 except Exception:
369 except Exception:
370 log.exception('Failed to read http scheme')
370 log.exception('Failed to read http scheme')
371 return 'http'
371 return 'http'
372
372
373 def _check_ssl(self, environ, start_response):
373 def _check_ssl(self, environ, start_response):
374 """
374 """
375 Checks the SSL check flag and returns False if SSL is not present
375 Checks the SSL check flag and returns False if SSL is not present
376 and required True otherwise
376 and required True otherwise
377 """
377 """
378 org_proto = environ['wsgi._org_proto']
378 org_proto = environ['wsgi._org_proto']
379 # check if we have SSL required ! if not it's a bad request !
379 # check if we have SSL required ! if not it's a bad request !
380 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
380 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
381 if require_ssl and org_proto == 'http':
381 if require_ssl and org_proto == 'http':
382 log.debug(
382 log.debug(
383 'Bad request: detected protocol is `%s` and '
383 'Bad request: detected protocol is `%s` and '
384 'SSL/HTTPS is required.', org_proto)
384 'SSL/HTTPS is required.', org_proto)
385 return False
385 return False
386 return True
386 return True
387
387
388 def _get_default_cache_ttl(self):
388 def _get_default_cache_ttl(self):
389 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
389 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
390 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
390 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
391 plugin_settings = plugin.get_settings()
391 plugin_settings = plugin.get_settings()
392 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
392 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
393 plugin_settings) or (False, 0)
393 plugin_settings) or (False, 0)
394 return plugin_cache_active, cache_ttl
394 return plugin_cache_active, cache_ttl
395
395
396 def __call__(self, environ, start_response):
396 def __call__(self, environ, start_response):
397 try:
397 try:
398 return self._handle_request(environ, start_response)
398 return self._handle_request(environ, start_response)
399 except Exception:
399 except Exception:
400 log.exception("Exception while handling request")
400 log.exception("Exception while handling request")
401 appenlight.track_exception(environ)
401 appenlight.track_exception(environ)
402 return HTTPInternalServerError()(environ, start_response)
402 return HTTPInternalServerError()(environ, start_response)
403 finally:
403 finally:
404 meta.Session.remove()
404 meta.Session.remove()
405
405
406 def _handle_request(self, environ, start_response):
406 def _handle_request(self, environ, start_response):
407 if not self._check_ssl(environ, start_response):
407 if not self._check_ssl(environ, start_response):
408 reason = ('SSL required, while RhodeCode was unable '
408 reason = ('SSL required, while RhodeCode was unable '
409 'to detect this as SSL request')
409 'to detect this as SSL request')
410 log.debug('User not allowed to proceed, %s', reason)
410 log.debug('User not allowed to proceed, %s', reason)
411 return HTTPNotAcceptable(reason)(environ, start_response)
411 return HTTPNotAcceptable(reason)(environ, start_response)
412
412
413 if not self.url_repo_name:
413 if not self.url_repo_name:
414 log.warning('Repository name is empty: %s', self.url_repo_name)
414 log.warning('Repository name is empty: %s', self.url_repo_name)
415 # failed to get repo name, we fail now
415 # failed to get repo name, we fail now
416 return HTTPNotFound()(environ, start_response)
416 return HTTPNotFound()(environ, start_response)
417 log.debug('Extracted repo name is %s', self.url_repo_name)
417 log.debug('Extracted repo name is %s', self.url_repo_name)
418
418
419 ip_addr = get_ip_addr(environ)
419 ip_addr = get_ip_addr(environ)
420 user_agent = get_user_agent(environ)
420 user_agent = get_user_agent(environ)
421 username = None
421 username = None
422
422
423 # skip passing error to error controller
423 # skip passing error to error controller
424 environ['pylons.status_code_redirect'] = True
424 environ['pylons.status_code_redirect'] = True
425
425
426 # ======================================================================
426 # ======================================================================
427 # GET ACTION PULL or PUSH
427 # GET ACTION PULL or PUSH
428 # ======================================================================
428 # ======================================================================
429 action = self._get_action(environ)
429 action = self._get_action(environ)
430
430
431 # ======================================================================
431 # ======================================================================
432 # Check if this is a request to a shadow repository of a pull request.
432 # Check if this is a request to a shadow repository of a pull request.
433 # In this case only pull action is allowed.
433 # In this case only pull action is allowed.
434 # ======================================================================
434 # ======================================================================
435 if self.is_shadow_repo and action != 'pull':
435 if self.is_shadow_repo and action != 'pull':
436 reason = 'Only pull action is allowed for shadow repositories.'
436 reason = 'Only pull action is allowed for shadow repositories.'
437 log.debug('User not allowed to proceed, %s', reason)
437 log.debug('User not allowed to proceed, %s', reason)
438 return HTTPNotAcceptable(reason)(environ, start_response)
438 return HTTPNotAcceptable(reason)(environ, start_response)
439
439
440 # Check if the shadow repo actually exists, in case someone refers
440 # Check if the shadow repo actually exists, in case someone refers
441 # to it, and it has been deleted because of successful merge.
441 # to it, and it has been deleted because of successful merge.
442 if self.is_shadow_repo and not self.is_shadow_repo_dir:
442 if self.is_shadow_repo and not self.is_shadow_repo_dir:
443 log.debug(
443 log.debug(
444 'Shadow repo detected, and shadow repo dir `%s` is missing',
444 'Shadow repo detected, and shadow repo dir `%s` is missing',
445 self.is_shadow_repo_dir)
445 self.is_shadow_repo_dir)
446 return HTTPNotFound()(environ, start_response)
446 return HTTPNotFound()(environ, start_response)
447
447
448 # ======================================================================
448 # ======================================================================
449 # CHECK ANONYMOUS PERMISSION
449 # CHECK ANONYMOUS PERMISSION
450 # ======================================================================
450 # ======================================================================
451 detect_force_push = False
451 detect_force_push = False
452 check_branch_perms = False
452 check_branch_perms = False
453 if action in ['pull', 'push']:
453 if action in ['pull', 'push']:
454 user_obj = anonymous_user = User.get_default_user()
454 user_obj = anonymous_user = User.get_default_user()
455 auth_user = user_obj.AuthUser()
455 auth_user = user_obj.AuthUser()
456 username = anonymous_user.username
456 username = anonymous_user.username
457 if anonymous_user.active:
457 if anonymous_user.active:
458 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
458 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
459 # ONLY check permissions if the user is activated
459 # ONLY check permissions if the user is activated
460 anonymous_perm = self._check_permission(
460 anonymous_perm = self._check_permission(
461 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
461 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
462 plugin_id='anonymous_access',
462 plugin_id='anonymous_access',
463 plugin_cache_active=plugin_cache_active,
463 plugin_cache_active=plugin_cache_active,
464 cache_ttl=cache_ttl,
464 cache_ttl=cache_ttl,
465 )
465 )
466 else:
466 else:
467 anonymous_perm = False
467 anonymous_perm = False
468
468
469 if not anonymous_user.active or not anonymous_perm:
469 if not anonymous_user.active or not anonymous_perm:
470 if not anonymous_user.active:
470 if not anonymous_user.active:
471 log.debug('Anonymous access is disabled, running '
471 log.debug('Anonymous access is disabled, running '
472 'authentication')
472 'authentication')
473
473
474 if not anonymous_perm:
474 if not anonymous_perm:
475 log.debug('Not enough credentials to access repo: `%s` '
475 log.debug('Not enough credentials to access repo: `%s` '
476 'repository as anonymous user', self.acl_repo_name)
476 'repository as anonymous user', self.acl_repo_name)
477
477
478
478
479 username = None
479 username = None
480 # ==============================================================
480 # ==============================================================
481 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
481 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
482 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
482 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
483 # ==============================================================
483 # ==============================================================
484
484
485 # try to auth based on environ, container auth methods
485 # try to auth based on environ, container auth methods
486 log.debug('Running PRE-AUTH for container|headers based authentication')
486 log.debug('Running PRE-AUTH for container|headers based authentication')
487
487
488 # headers auth, by just reading special headers and bypass the auth with user/passwd
488 # headers auth, by just reading special headers and bypass the auth with user/passwd
489 pre_auth = authenticate(
489 pre_auth = authenticate(
490 '', '', environ, VCS_TYPE, registry=self.registry,
490 '', '', environ, VCS_TYPE, registry=self.registry,
491 acl_repo_name=self.acl_repo_name)
491 acl_repo_name=self.acl_repo_name)
492
492
493 if pre_auth and pre_auth.get('username'):
493 if pre_auth and pre_auth.get('username'):
494 username = pre_auth['username']
494 username = pre_auth['username']
495 log.debug('PRE-AUTH got `%s` as username', username)
495 log.debug('PRE-AUTH got `%s` as username', username)
496 if pre_auth:
496 if pre_auth:
497 log.debug('PRE-AUTH successful from %s',
497 log.debug('PRE-AUTH successful from %s',
498 pre_auth.get('auth_data', {}).get('_plugin'))
498 pre_auth.get('auth_data', {}).get('_plugin'))
499
499
500 # If not authenticated by the container, running basic auth
500 # If not authenticated by the container, running basic auth
501 # before inject the calling repo_name for special scope checks
501 # before inject the calling repo_name for special scope checks
502 self.authenticate.acl_repo_name = self.acl_repo_name
502 self.authenticate.acl_repo_name = self.acl_repo_name
503
503
504 plugin_cache_active, cache_ttl = False, 0
504 plugin_cache_active, cache_ttl = False, 0
505 plugin = None
505 plugin = None
506
506
507 # regular auth chain
507 # regular auth chain
508 if not username:
508 if not username:
509 self.authenticate.realm = self.authenticate.get_rc_realm()
509 self.authenticate.realm = self.authenticate.get_rc_realm()
510
510
511 try:
511 try:
512 auth_result = self.authenticate(environ)
512 auth_result = self.authenticate(environ)
513 except (UserCreationError, NotAllowedToCreateUserError) as e:
513 except (UserCreationError, NotAllowedToCreateUserError) as e:
514 log.error(e)
514 log.error(e)
515 reason = safe_str(e)
515 reason = safe_str(e)
516 return HTTPNotAcceptable(reason)(environ, start_response)
516 return HTTPNotAcceptable(reason)(environ, start_response)
517
517
518 if isinstance(auth_result, dict):
518 if isinstance(auth_result, dict):
519 AUTH_TYPE.update(environ, 'basic')
519 AUTH_TYPE.update(environ, 'basic')
520 REMOTE_USER.update(environ, auth_result['username'])
520 REMOTE_USER.update(environ, auth_result['username'])
521 username = auth_result['username']
521 username = auth_result['username']
522 plugin = auth_result.get('auth_data', {}).get('_plugin')
522 plugin = auth_result.get('auth_data', {}).get('_plugin')
523 log.info(
523 log.info(
524 'MAIN-AUTH successful for user `%s` from %s plugin',
524 'MAIN-AUTH successful for user `%s` from %s plugin',
525 username, plugin)
525 username, plugin)
526
526
527 plugin_cache_active, cache_ttl = auth_result.get(
527 plugin_cache_active, cache_ttl = auth_result.get(
528 'auth_data', {}).get('_ttl_cache') or (False, 0)
528 'auth_data', {}).get('_ttl_cache') or (False, 0)
529 else:
529 else:
530 return auth_result.wsgi_application(environ, start_response)
530 return auth_result.wsgi_application(environ, start_response)
531
531
532 # ==============================================================
532 # ==============================================================
533 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
533 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
534 # ==============================================================
534 # ==============================================================
535 user = User.get_by_username(username)
535 user = User.get_by_username(username)
536 if not self.valid_and_active_user(user):
536 if not self.valid_and_active_user(user):
537 return HTTPForbidden()(environ, start_response)
537 return HTTPForbidden()(environ, start_response)
538 username = user.username
538 username = user.username
539 user_id = user.user_id
539 user_id = user.user_id
540
540
541 # check user attributes for password change flag
541 # check user attributes for password change flag
542 user_obj = user
542 user_obj = user
543 auth_user = user_obj.AuthUser()
543 auth_user = user_obj.AuthUser()
544 if user_obj and user_obj.username != User.DEFAULT_USER and \
544 if user_obj and user_obj.username != User.DEFAULT_USER and \
545 user_obj.user_data.get('force_password_change'):
545 user_obj.user_data.get('force_password_change'):
546 reason = 'password change required'
546 reason = 'password change required'
547 log.debug('User not allowed to authenticate, %s', reason)
547 log.debug('User not allowed to authenticate, %s', reason)
548 return HTTPNotAcceptable(reason)(environ, start_response)
548 return HTTPNotAcceptable(reason)(environ, start_response)
549
549
550 # check permissions for this repository
550 # check permissions for this repository
551 perm = self._check_permission(
551 perm = self._check_permission(
552 action, user, auth_user, self.acl_repo_name, ip_addr,
552 action, user, auth_user, self.acl_repo_name, ip_addr,
553 plugin, plugin_cache_active, cache_ttl)
553 plugin, plugin_cache_active, cache_ttl)
554 if not perm:
554 if not perm:
555 return HTTPForbidden()(environ, start_response)
555 return HTTPForbidden()(environ, start_response)
556 environ['rc_auth_user_id'] = str(user_id)
556 environ['rc_auth_user_id'] = str(user_id)
557
557
558 if action == 'push':
558 if action == 'push':
559 perms = auth_user.get_branch_permissions(self.acl_repo_name)
559 perms = auth_user.get_branch_permissions(self.acl_repo_name)
560 if perms:
560 if perms:
561 check_branch_perms = True
561 check_branch_perms = True
562 detect_force_push = True
562 detect_force_push = True
563
563
564 # extras are injected into UI object and later available
564 # extras are injected into UI object and later available
565 # in hooks executed by RhodeCode
565 # in hooks executed by RhodeCode
566 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
566 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
567
567
568 extras = vcs_operation_context(
568 extras = vcs_operation_context(
569 environ, repo_name=self.acl_repo_name, username=username,
569 environ, repo_name=self.acl_repo_name, username=username,
570 action=action, scm=self.SCM, check_locking=check_locking,
570 action=action, scm=self.SCM, check_locking=check_locking,
571 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
571 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
572 detect_force_push=detect_force_push
572 detect_force_push=detect_force_push
573 )
573 )
574
574
575 # ======================================================================
575 # ======================================================================
576 # REQUEST HANDLING
576 # REQUEST HANDLING
577 # ======================================================================
577 # ======================================================================
578 repo_path = os.path.join(
578 repo_path = os.path.join(
579 safe_str(self.base_path), safe_str(self.vcs_repo_name))
579 safe_str(self.base_path), safe_str(self.vcs_repo_name))
580 log.debug('Repository path is %s', repo_path)
580 log.debug('Repository path is %s', repo_path)
581
581
582 fix_PATH()
582 fix_PATH()
583
583
584 log.info(
584 log.info(
585 '%s action on %s repo "%s" by "%s" from %s %s',
585 '%s action on %s repo "%s" by "%s" from %s %s',
586 action, self.SCM, safe_str(self.url_repo_name),
586 action, self.SCM, safe_str(self.url_repo_name),
587 safe_str(username), ip_addr, user_agent)
587 safe_str(username), ip_addr, user_agent)
588
588
589 return self._generate_vcs_response(
589 return self._generate_vcs_response(
590 environ, start_response, repo_path, extras, action)
590 environ, start_response, repo_path, extras, action)
591
591
592 @initialize_generator
592 @initialize_generator
593 def _generate_vcs_response(
593 def _generate_vcs_response(
594 self, environ, start_response, repo_path, extras, action):
594 self, environ, start_response, repo_path, extras, action):
595 """
595 """
596 Returns a generator for the response content.
596 Returns a generator for the response content.
597
597
598 This method is implemented as a generator, so that it can trigger
598 This method is implemented as a generator, so that it can trigger
599 the cache validation after all content sent back to the client. It
599 the cache validation after all content sent back to the client. It
600 also handles the locking exceptions which will be triggered when
600 also handles the locking exceptions which will be triggered when
601 the first chunk is produced by the underlying WSGI application.
601 the first chunk is produced by the underlying WSGI application.
602 """
602 """
603
603
604 txn_id = ''
604 txn_id = ''
605 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
605 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
606 # case for SVN, we want to re-use the callback daemon port
606 # case for SVN, we want to re-use the callback daemon port
607 # so we use the txn_id, for this we peek the body, and still save
607 # so we use the txn_id, for this we peek the body, and still save
608 # it as wsgi.input
608 # it as wsgi.input
609
609
610 stream = environ['wsgi.input']
610 stream = environ['wsgi.input']
611
611
612 if isinstance(stream, io.BytesIO):
612 if isinstance(stream, io.BytesIO):
613 data: bytes = stream.getvalue()
613 data: bytes = stream.getvalue()
614 elif hasattr(stream, 'buf'): # most likely gunicorn.http.body.Body
614 elif hasattr(stream, 'buf'): # most likely gunicorn.http.body.Body
615 data: bytes = stream.buf.getvalue()
615 data: bytes = stream.buf.getvalue()
616 else:
616 else:
617 # fallback to the crudest way, copy the iterator
617 # fallback to the crudest way, copy the iterator
618 data = safe_bytes(stream.read())
618 data = safe_bytes(stream.read())
619 environ['wsgi.input'] = io.BytesIO(data)
619 environ['wsgi.input'] = io.BytesIO(data)
620
620
621 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
621 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
622
622
623 callback_daemon, extras = self._prepare_callback_daemon(
623 callback_daemon, extras = self._prepare_callback_daemon(
624 extras, environ, action, txn_id=txn_id)
624 extras, environ, action, txn_id=txn_id)
625 log.debug('HOOKS extras is %s', extras)
625 log.debug('HOOKS extras is %s', extras)
626
626
627 http_scheme = self._get_http_scheme(environ)
627 http_scheme = self._get_http_scheme(environ)
628
628
629 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
629 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
630 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
630 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
631 with callback_daemon:
631 with callback_daemon:
632 app.rc_extras = extras
632 app.rc_extras = extras
633
633
634 try:
634 try:
635 response = app(environ, start_response)
635 response = app(environ, start_response)
636 finally:
636 finally:
637 # This statement works together with the decorator
637 # This statement works together with the decorator
638 # "initialize_generator" above. The decorator ensures that
638 # "initialize_generator" above. The decorator ensures that
639 # we hit the first yield statement before the generator is
639 # we hit the first yield statement before the generator is
640 # returned back to the WSGI server. This is needed to
640 # returned back to the WSGI server. This is needed to
641 # ensure that the call to "app" above triggers the
641 # ensure that the call to "app" above triggers the
642 # needed callback to "start_response" before the
642 # needed callback to "start_response" before the
643 # generator is actually used.
643 # generator is actually used.
644 yield "__init__"
644 yield "__init__"
645
645
646 # iter content
646 # iter content
647 for chunk in response:
647 for chunk in response:
648 yield chunk
648 yield chunk
649
649
650 try:
650 try:
651 # invalidate cache on push
651 # invalidate cache on push
652 if action == 'push':
652 if action == 'push':
653 self._invalidate_cache(self.url_repo_name)
653 self._invalidate_cache(self.url_repo_name)
654 finally:
654 finally:
655 meta.Session.remove()
655 meta.Session.remove()
656
656
657 def _get_repository_name(self, environ):
657 def _get_repository_name(self, environ):
658 """Get repository name out of the environmnent
658 """Get repository name out of the environmnent
659
659
660 :param environ: WSGI environment
660 :param environ: WSGI environment
661 """
661 """
662 raise NotImplementedError()
662 raise NotImplementedError()
663
663
664 def _get_action(self, environ):
664 def _get_action(self, environ):
665 """Map request commands into a pull or push command.
665 """Map request commands into a pull or push command.
666
666
667 :param environ: WSGI environment
667 :param environ: WSGI environment
668 """
668 """
669 raise NotImplementedError()
669 raise NotImplementedError()
670
670
671 def _create_wsgi_app(self, repo_path, repo_name, config):
671 def _create_wsgi_app(self, repo_path, repo_name, config):
672 """Return the WSGI app that will finally handle the request."""
672 """Return the WSGI app that will finally handle the request."""
673 raise NotImplementedError()
673 raise NotImplementedError()
674
674
675 def _create_config(self, extras, repo_name, scheme='http'):
675 def _create_config(self, extras, repo_name, scheme='http'):
676 """Create a safe config representation."""
676 """Create a safe config representation."""
677 raise NotImplementedError()
677 raise NotImplementedError()
678
678
679 def _should_use_callback_daemon(self, extras, environ, action):
679 def _should_use_callback_daemon(self, extras, environ, action):
680 if extras.get('is_shadow_repo'):
680 if extras.get('is_shadow_repo'):
681 # we don't want to execute hooks, and callback daemon for shadow repos
681 # we don't want to execute hooks, and callback daemon for shadow repos
682 return False
682 return False
683 return True
683 return True
684
684
685 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
685 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
686 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
686 protocol = vcs_settings.HOOKS_PROTOCOL
687 if not self._should_use_callback_daemon(extras, environ, action):
687 if not self._should_use_callback_daemon(extras, environ, action):
688 # disable callback daemon for actions that don't require it
688 # disable callback daemon for actions that don't require it
689 direct_calls = True
689 protocol = 'local'
690
690
691 return prepare_callback_daemon(
691 return prepare_callback_daemon(
692 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
692 extras, protocol=protocol,
693 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
693 host=vcs_settings.HOOKS_HOST, txn_id=txn_id)
694
694
695
695
696 def _should_check_locking(query_string):
696 def _should_check_locking(query_string):
697 # this is kind of hacky, but due to how mercurial handles client-server
697 # this is kind of hacky, but due to how mercurial handles client-server
698 # server see all operation on commit; bookmarks, phases and
698 # server see all operation on commit; bookmarks, phases and
699 # obsolescence marker in different transaction, we don't want to check
699 # obsolescence marker in different transaction, we don't want to check
700 # locking on those
700 # locking on those
701 return query_string not in ['cmd=listkeys']
701 return query_string not in ['cmd=listkeys']
@@ -1,808 +1,823 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 """
19 """
20 Utilities library for RhodeCode
20 Utilities library for RhodeCode
21 """
21 """
22
22
23 import datetime
23 import datetime
24 import decorator
24 import decorator
25 import logging
25 import logging
26 import os
26 import os
27 import re
27 import re
28 import sys
28 import sys
29 import shutil
29 import shutil
30 import socket
30 import socket
31 import tempfile
31 import tempfile
32 import traceback
32 import traceback
33 import tarfile
33 import tarfile
34 import warnings
34 import warnings
35 from functools import wraps
35 from os.path import join as jn
36 from os.path import join as jn
36
37
37 import paste
38 import paste
38 import pkg_resources
39 import pkg_resources
39 from webhelpers2.text import collapse, strip_tags, convert_accented_entities, convert_misc_entities
40 from webhelpers2.text import collapse, strip_tags, convert_accented_entities, convert_misc_entities
40
41
41 from mako import exceptions
42 from mako import exceptions
42
43
43 from rhodecode.lib.hash_utils import sha256_safe, md5, sha1
44 from rhodecode.lib.hash_utils import sha256_safe, md5, sha1
45 from rhodecode.lib.type_utils import AttributeDict
44 from rhodecode.lib.str_utils import safe_bytes, safe_str
46 from rhodecode.lib.str_utils import safe_bytes, safe_str
45 from rhodecode.lib.vcs.backends.base import Config
47 from rhodecode.lib.vcs.backends.base import Config
46 from rhodecode.lib.vcs.exceptions import VCSError
48 from rhodecode.lib.vcs.exceptions import VCSError
47 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
49 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
48 from rhodecode.lib.ext_json import sjson as json
50 from rhodecode.lib.ext_json import sjson as json
49 from rhodecode.model import meta
51 from rhodecode.model import meta
50 from rhodecode.model.db import (
52 from rhodecode.model.db import (
51 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
53 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
52 from rhodecode.model.meta import Session
54 from rhodecode.model.meta import Session
53
55
54
56
55 log = logging.getLogger(__name__)
57 log = logging.getLogger(__name__)
56
58
57 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
59 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
58
60
59 # String which contains characters that are not allowed in slug names for
61 # String which contains characters that are not allowed in slug names for
60 # repositories or repository groups. It is properly escaped to use it in
62 # repositories or repository groups. It is properly escaped to use it in
61 # regular expressions.
63 # regular expressions.
62 SLUG_BAD_CHARS = re.escape(r'`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
64 SLUG_BAD_CHARS = re.escape(r'`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
63
65
64 # Regex that matches forbidden characters in repo/group slugs.
66 # Regex that matches forbidden characters in repo/group slugs.
65 SLUG_BAD_CHAR_RE = re.compile(r'[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
67 SLUG_BAD_CHAR_RE = re.compile(r'[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
66
68
67 # Regex that matches allowed characters in repo/group slugs.
69 # Regex that matches allowed characters in repo/group slugs.
68 SLUG_GOOD_CHAR_RE = re.compile(r'[^{}]'.format(SLUG_BAD_CHARS))
70 SLUG_GOOD_CHAR_RE = re.compile(r'[^{}]'.format(SLUG_BAD_CHARS))
69
71
70 # Regex that matches whole repo/group slugs.
72 # Regex that matches whole repo/group slugs.
71 SLUG_RE = re.compile(r'[^{}]+'.format(SLUG_BAD_CHARS))
73 SLUG_RE = re.compile(r'[^{}]+'.format(SLUG_BAD_CHARS))
72
74
73 _license_cache = None
75 _license_cache = None
74
76
75
77
78 def adopt_for_celery(func):
79 """
80 Decorator designed to adopt hooks (from rhodecode.lib.hooks_base)
81 for further usage as a celery tasks.
82 """
83 @wraps(func)
84 def wrapper(extras):
85 extras = AttributeDict(extras)
86 # HooksResponse implements to_json method which must be used there.
87 return func(extras).to_json()
88 return wrapper
89
90
76 def repo_name_slug(value):
91 def repo_name_slug(value):
77 """
92 """
78 Return slug of name of repository
93 Return slug of name of repository
79 This function is called on each creation/modification
94 This function is called on each creation/modification
80 of repository to prevent bad names in repo
95 of repository to prevent bad names in repo
81 """
96 """
82
97
83 replacement_char = '-'
98 replacement_char = '-'
84
99
85 slug = strip_tags(value)
100 slug = strip_tags(value)
86 slug = convert_accented_entities(slug)
101 slug = convert_accented_entities(slug)
87 slug = convert_misc_entities(slug)
102 slug = convert_misc_entities(slug)
88
103
89 slug = SLUG_BAD_CHAR_RE.sub('', slug)
104 slug = SLUG_BAD_CHAR_RE.sub('', slug)
90 slug = re.sub(r'[\s]+', '-', slug)
105 slug = re.sub(r'[\s]+', '-', slug)
91 slug = collapse(slug, replacement_char)
106 slug = collapse(slug, replacement_char)
92
107
93 return slug
108 return slug
94
109
95
110
96 #==============================================================================
111 #==============================================================================
97 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
112 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
98 #==============================================================================
113 #==============================================================================
99 def get_repo_slug(request):
114 def get_repo_slug(request):
100 _repo = ''
115 _repo = ''
101
116
102 if hasattr(request, 'db_repo_name'):
117 if hasattr(request, 'db_repo_name'):
103 # if our requests has set db reference use it for name, this
118 # if our requests has set db reference use it for name, this
104 # translates the example.com/_<id> into proper repo names
119 # translates the example.com/_<id> into proper repo names
105 _repo = request.db_repo_name
120 _repo = request.db_repo_name
106 elif getattr(request, 'matchdict', None):
121 elif getattr(request, 'matchdict', None):
107 # pyramid
122 # pyramid
108 _repo = request.matchdict.get('repo_name')
123 _repo = request.matchdict.get('repo_name')
109
124
110 if _repo:
125 if _repo:
111 _repo = _repo.rstrip('/')
126 _repo = _repo.rstrip('/')
112 return _repo
127 return _repo
113
128
114
129
115 def get_repo_group_slug(request):
130 def get_repo_group_slug(request):
116 _group = ''
131 _group = ''
117 if hasattr(request, 'db_repo_group'):
132 if hasattr(request, 'db_repo_group'):
118 # if our requests has set db reference use it for name, this
133 # if our requests has set db reference use it for name, this
119 # translates the example.com/_<id> into proper repo group names
134 # translates the example.com/_<id> into proper repo group names
120 _group = request.db_repo_group.group_name
135 _group = request.db_repo_group.group_name
121 elif getattr(request, 'matchdict', None):
136 elif getattr(request, 'matchdict', None):
122 # pyramid
137 # pyramid
123 _group = request.matchdict.get('repo_group_name')
138 _group = request.matchdict.get('repo_group_name')
124
139
125 if _group:
140 if _group:
126 _group = _group.rstrip('/')
141 _group = _group.rstrip('/')
127 return _group
142 return _group
128
143
129
144
130 def get_user_group_slug(request):
145 def get_user_group_slug(request):
131 _user_group = ''
146 _user_group = ''
132
147
133 if hasattr(request, 'db_user_group'):
148 if hasattr(request, 'db_user_group'):
134 _user_group = request.db_user_group.users_group_name
149 _user_group = request.db_user_group.users_group_name
135 elif getattr(request, 'matchdict', None):
150 elif getattr(request, 'matchdict', None):
136 # pyramid
151 # pyramid
137 _user_group = request.matchdict.get('user_group_id')
152 _user_group = request.matchdict.get('user_group_id')
138 _user_group_name = request.matchdict.get('user_group_name')
153 _user_group_name = request.matchdict.get('user_group_name')
139 try:
154 try:
140 if _user_group:
155 if _user_group:
141 _user_group = UserGroup.get(_user_group)
156 _user_group = UserGroup.get(_user_group)
142 elif _user_group_name:
157 elif _user_group_name:
143 _user_group = UserGroup.get_by_group_name(_user_group_name)
158 _user_group = UserGroup.get_by_group_name(_user_group_name)
144
159
145 if _user_group:
160 if _user_group:
146 _user_group = _user_group.users_group_name
161 _user_group = _user_group.users_group_name
147 except Exception:
162 except Exception:
148 log.exception('Failed to get user group by id and name')
163 log.exception('Failed to get user group by id and name')
149 # catch all failures here
164 # catch all failures here
150 return None
165 return None
151
166
152 return _user_group
167 return _user_group
153
168
154
169
155 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
170 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
156 """
171 """
157 Scans given path for repos and return (name,(type,path)) tuple
172 Scans given path for repos and return (name,(type,path)) tuple
158
173
159 :param path: path to scan for repositories
174 :param path: path to scan for repositories
160 :param recursive: recursive search and return names with subdirs in front
175 :param recursive: recursive search and return names with subdirs in front
161 """
176 """
162
177
163 # remove ending slash for better results
178 # remove ending slash for better results
164 path = path.rstrip(os.sep)
179 path = path.rstrip(os.sep)
165 log.debug('now scanning in %s location recursive:%s...', path, recursive)
180 log.debug('now scanning in %s location recursive:%s...', path, recursive)
166
181
167 def _get_repos(p):
182 def _get_repos(p):
168 dirpaths = get_dirpaths(p)
183 dirpaths = get_dirpaths(p)
169 if not _is_dir_writable(p):
184 if not _is_dir_writable(p):
170 log.warning('repo path without write access: %s', p)
185 log.warning('repo path without write access: %s', p)
171
186
172 for dirpath in dirpaths:
187 for dirpath in dirpaths:
173 if os.path.isfile(os.path.join(p, dirpath)):
188 if os.path.isfile(os.path.join(p, dirpath)):
174 continue
189 continue
175 cur_path = os.path.join(p, dirpath)
190 cur_path = os.path.join(p, dirpath)
176
191
177 # skip removed repos
192 # skip removed repos
178 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
193 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
179 continue
194 continue
180
195
181 #skip .<somethin> dirs
196 #skip .<somethin> dirs
182 if dirpath.startswith('.'):
197 if dirpath.startswith('.'):
183 continue
198 continue
184
199
185 try:
200 try:
186 scm_info = get_scm(cur_path)
201 scm_info = get_scm(cur_path)
187 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
202 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
188 except VCSError:
203 except VCSError:
189 if not recursive:
204 if not recursive:
190 continue
205 continue
191 #check if this dir containts other repos for recursive scan
206 #check if this dir containts other repos for recursive scan
192 rec_path = os.path.join(p, dirpath)
207 rec_path = os.path.join(p, dirpath)
193 if os.path.isdir(rec_path):
208 if os.path.isdir(rec_path):
194 yield from _get_repos(rec_path)
209 yield from _get_repos(rec_path)
195
210
196 return _get_repos(path)
211 return _get_repos(path)
197
212
198
213
199 def get_dirpaths(p: str) -> list:
214 def get_dirpaths(p: str) -> list:
200 try:
215 try:
201 # OS-independable way of checking if we have at least read-only
216 # OS-independable way of checking if we have at least read-only
202 # access or not.
217 # access or not.
203 dirpaths = os.listdir(p)
218 dirpaths = os.listdir(p)
204 except OSError:
219 except OSError:
205 log.warning('ignoring repo path without read access: %s', p)
220 log.warning('ignoring repo path without read access: %s', p)
206 return []
221 return []
207
222
208 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
223 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
209 # decode paths and suddenly returns unicode objects itself. The items it
224 # decode paths and suddenly returns unicode objects itself. The items it
210 # cannot decode are returned as strings and cause issues.
225 # cannot decode are returned as strings and cause issues.
211 #
226 #
212 # Those paths are ignored here until a solid solution for path handling has
227 # Those paths are ignored here until a solid solution for path handling has
213 # been built.
228 # been built.
214 expected_type = type(p)
229 expected_type = type(p)
215
230
216 def _has_correct_type(item):
231 def _has_correct_type(item):
217 if type(item) is not expected_type:
232 if type(item) is not expected_type:
218 log.error(
233 log.error(
219 "Ignoring path %s since it cannot be decoded into str.",
234 "Ignoring path %s since it cannot be decoded into str.",
220 # Using "repr" to make sure that we see the byte value in case
235 # Using "repr" to make sure that we see the byte value in case
221 # of support.
236 # of support.
222 repr(item))
237 repr(item))
223 return False
238 return False
224 return True
239 return True
225
240
226 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
241 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
227
242
228 return dirpaths
243 return dirpaths
229
244
230
245
231 def _is_dir_writable(path):
246 def _is_dir_writable(path):
232 """
247 """
233 Probe if `path` is writable.
248 Probe if `path` is writable.
234
249
235 Due to trouble on Cygwin / Windows, this is actually probing if it is
250 Due to trouble on Cygwin / Windows, this is actually probing if it is
236 possible to create a file inside of `path`, stat does not produce reliable
251 possible to create a file inside of `path`, stat does not produce reliable
237 results in this case.
252 results in this case.
238 """
253 """
239 try:
254 try:
240 with tempfile.TemporaryFile(dir=path):
255 with tempfile.TemporaryFile(dir=path):
241 pass
256 pass
242 except OSError:
257 except OSError:
243 return False
258 return False
244 return True
259 return True
245
260
246
261
247 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
262 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
248 """
263 """
249 Returns True if given path is a valid repository False otherwise.
264 Returns True if given path is a valid repository False otherwise.
250 If expect_scm param is given also, compare if given scm is the same
265 If expect_scm param is given also, compare if given scm is the same
251 as expected from scm parameter. If explicit_scm is given don't try to
266 as expected from scm parameter. If explicit_scm is given don't try to
252 detect the scm, just use the given one to check if repo is valid
267 detect the scm, just use the given one to check if repo is valid
253
268
254 :param repo_name:
269 :param repo_name:
255 :param base_path:
270 :param base_path:
256 :param expect_scm:
271 :param expect_scm:
257 :param explicit_scm:
272 :param explicit_scm:
258 :param config:
273 :param config:
259
274
260 :return True: if given path is a valid repository
275 :return True: if given path is a valid repository
261 """
276 """
262 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
277 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
263 log.debug('Checking if `%s` is a valid path for repository. '
278 log.debug('Checking if `%s` is a valid path for repository. '
264 'Explicit type: %s', repo_name, explicit_scm)
279 'Explicit type: %s', repo_name, explicit_scm)
265
280
266 try:
281 try:
267 if explicit_scm:
282 if explicit_scm:
268 detected_scms = [get_scm_backend(explicit_scm)(
283 detected_scms = [get_scm_backend(explicit_scm)(
269 full_path, config=config).alias]
284 full_path, config=config).alias]
270 else:
285 else:
271 detected_scms = get_scm(full_path)
286 detected_scms = get_scm(full_path)
272
287
273 if expect_scm:
288 if expect_scm:
274 return detected_scms[0] == expect_scm
289 return detected_scms[0] == expect_scm
275 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
290 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
276 return True
291 return True
277 except VCSError:
292 except VCSError:
278 log.debug('path: %s is not a valid repo !', full_path)
293 log.debug('path: %s is not a valid repo !', full_path)
279 return False
294 return False
280
295
281
296
282 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
297 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
283 """
298 """
284 Returns True if a given path is a repository group, False otherwise
299 Returns True if a given path is a repository group, False otherwise
285
300
286 :param repo_group_name:
301 :param repo_group_name:
287 :param base_path:
302 :param base_path:
288 """
303 """
289 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
304 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
290 log.debug('Checking if `%s` is a valid path for repository group',
305 log.debug('Checking if `%s` is a valid path for repository group',
291 repo_group_name)
306 repo_group_name)
292
307
293 # check if it's not a repo
308 # check if it's not a repo
294 if is_valid_repo(repo_group_name, base_path):
309 if is_valid_repo(repo_group_name, base_path):
295 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
310 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
296 return False
311 return False
297
312
298 try:
313 try:
299 # we need to check bare git repos at higher level
314 # we need to check bare git repos at higher level
300 # since we might match branches/hooks/info/objects or possible
315 # since we might match branches/hooks/info/objects or possible
301 # other things inside bare git repo
316 # other things inside bare git repo
302 maybe_repo = os.path.dirname(full_path)
317 maybe_repo = os.path.dirname(full_path)
303 if maybe_repo == base_path:
318 if maybe_repo == base_path:
304 # skip root level repo check; we know root location CANNOT BE a repo group
319 # skip root level repo check; we know root location CANNOT BE a repo group
305 return False
320 return False
306
321
307 scm_ = get_scm(maybe_repo)
322 scm_ = get_scm(maybe_repo)
308 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
323 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
309 return False
324 return False
310 except VCSError:
325 except VCSError:
311 pass
326 pass
312
327
313 # check if it's a valid path
328 # check if it's a valid path
314 if skip_path_check or os.path.isdir(full_path):
329 if skip_path_check or os.path.isdir(full_path):
315 log.debug('path: %s is a valid repo group !', full_path)
330 log.debug('path: %s is a valid repo group !', full_path)
316 return True
331 return True
317
332
318 log.debug('path: %s is not a valid repo group !', full_path)
333 log.debug('path: %s is not a valid repo group !', full_path)
319 return False
334 return False
320
335
321
336
322 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
337 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
323 while True:
338 while True:
324 ok = input(prompt)
339 ok = input(prompt)
325 if ok.lower() in ('y', 'ye', 'yes'):
340 if ok.lower() in ('y', 'ye', 'yes'):
326 return True
341 return True
327 if ok.lower() in ('n', 'no', 'nop', 'nope'):
342 if ok.lower() in ('n', 'no', 'nop', 'nope'):
328 return False
343 return False
329 retries = retries - 1
344 retries = retries - 1
330 if retries < 0:
345 if retries < 0:
331 raise OSError
346 raise OSError
332 print(complaint)
347 print(complaint)
333
348
334 # propagated from mercurial documentation
349 # propagated from mercurial documentation
335 ui_sections = [
350 ui_sections = [
336 'alias', 'auth',
351 'alias', 'auth',
337 'decode/encode', 'defaults',
352 'decode/encode', 'defaults',
338 'diff', 'email',
353 'diff', 'email',
339 'extensions', 'format',
354 'extensions', 'format',
340 'merge-patterns', 'merge-tools',
355 'merge-patterns', 'merge-tools',
341 'hooks', 'http_proxy',
356 'hooks', 'http_proxy',
342 'smtp', 'patch',
357 'smtp', 'patch',
343 'paths', 'profiling',
358 'paths', 'profiling',
344 'server', 'trusted',
359 'server', 'trusted',
345 'ui', 'web', ]
360 'ui', 'web', ]
346
361
347
362
348 def config_data_from_db(clear_session=True, repo=None):
363 def config_data_from_db(clear_session=True, repo=None):
349 """
364 """
350 Read the configuration data from the database and return configuration
365 Read the configuration data from the database and return configuration
351 tuples.
366 tuples.
352 """
367 """
353 from rhodecode.model.settings import VcsSettingsModel
368 from rhodecode.model.settings import VcsSettingsModel
354
369
355 config = []
370 config = []
356
371
357 sa = meta.Session()
372 sa = meta.Session()
358 settings_model = VcsSettingsModel(repo=repo, sa=sa)
373 settings_model = VcsSettingsModel(repo=repo, sa=sa)
359
374
360 ui_settings = settings_model.get_ui_settings()
375 ui_settings = settings_model.get_ui_settings()
361
376
362 ui_data = []
377 ui_data = []
363 for setting in ui_settings:
378 for setting in ui_settings:
364 if setting.active:
379 if setting.active:
365 ui_data.append((setting.section, setting.key, setting.value))
380 ui_data.append((setting.section, setting.key, setting.value))
366 config.append((
381 config.append((
367 safe_str(setting.section), safe_str(setting.key),
382 safe_str(setting.section), safe_str(setting.key),
368 safe_str(setting.value)))
383 safe_str(setting.value)))
369 if setting.key == 'push_ssl':
384 if setting.key == 'push_ssl':
370 # force set push_ssl requirement to False, rhodecode
385 # force set push_ssl requirement to False, rhodecode
371 # handles that
386 # handles that
372 config.append((
387 config.append((
373 safe_str(setting.section), safe_str(setting.key), False))
388 safe_str(setting.section), safe_str(setting.key), False))
374 log.debug(
389 log.debug(
375 'settings ui from db@repo[%s]: %s',
390 'settings ui from db@repo[%s]: %s',
376 repo,
391 repo,
377 ','.join(['[{}] {}={}'.format(*s) for s in ui_data]))
392 ','.join(['[{}] {}={}'.format(*s) for s in ui_data]))
378 if clear_session:
393 if clear_session:
379 meta.Session.remove()
394 meta.Session.remove()
380
395
381 # TODO: mikhail: probably it makes no sense to re-read hooks information.
396 # TODO: mikhail: probably it makes no sense to re-read hooks information.
382 # It's already there and activated/deactivated
397 # It's already there and activated/deactivated
383 skip_entries = []
398 skip_entries = []
384 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
399 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
385 if 'pull' not in enabled_hook_classes:
400 if 'pull' not in enabled_hook_classes:
386 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
401 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
387 if 'push' not in enabled_hook_classes:
402 if 'push' not in enabled_hook_classes:
388 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
403 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
389 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
404 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
390 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
405 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
391
406
392 config = [entry for entry in config if entry[:2] not in skip_entries]
407 config = [entry for entry in config if entry[:2] not in skip_entries]
393
408
394 return config
409 return config
395
410
396
411
397 def make_db_config(clear_session=True, repo=None):
412 def make_db_config(clear_session=True, repo=None):
398 """
413 """
399 Create a :class:`Config` instance based on the values in the database.
414 Create a :class:`Config` instance based on the values in the database.
400 """
415 """
401 config = Config()
416 config = Config()
402 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
417 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
403 for section, option, value in config_data:
418 for section, option, value in config_data:
404 config.set(section, option, value)
419 config.set(section, option, value)
405 return config
420 return config
406
421
407
422
408 def get_enabled_hook_classes(ui_settings):
423 def get_enabled_hook_classes(ui_settings):
409 """
424 """
410 Return the enabled hook classes.
425 Return the enabled hook classes.
411
426
412 :param ui_settings: List of ui_settings as returned
427 :param ui_settings: List of ui_settings as returned
413 by :meth:`VcsSettingsModel.get_ui_settings`
428 by :meth:`VcsSettingsModel.get_ui_settings`
414
429
415 :return: a list with the enabled hook classes. The order is not guaranteed.
430 :return: a list with the enabled hook classes. The order is not guaranteed.
416 :rtype: list
431 :rtype: list
417 """
432 """
418 enabled_hooks = []
433 enabled_hooks = []
419 active_hook_keys = [
434 active_hook_keys = [
420 key for section, key, value, active in ui_settings
435 key for section, key, value, active in ui_settings
421 if section == 'hooks' and active]
436 if section == 'hooks' and active]
422
437
423 hook_names = {
438 hook_names = {
424 RhodeCodeUi.HOOK_PUSH: 'push',
439 RhodeCodeUi.HOOK_PUSH: 'push',
425 RhodeCodeUi.HOOK_PULL: 'pull',
440 RhodeCodeUi.HOOK_PULL: 'pull',
426 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
441 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
427 }
442 }
428
443
429 for key in active_hook_keys:
444 for key in active_hook_keys:
430 hook = hook_names.get(key)
445 hook = hook_names.get(key)
431 if hook:
446 if hook:
432 enabled_hooks.append(hook)
447 enabled_hooks.append(hook)
433
448
434 return enabled_hooks
449 return enabled_hooks
435
450
436
451
437 def set_rhodecode_config(config):
452 def set_rhodecode_config(config):
438 """
453 """
439 Updates pyramid config with new settings from database
454 Updates pyramid config with new settings from database
440
455
441 :param config:
456 :param config:
442 """
457 """
443 from rhodecode.model.settings import SettingsModel
458 from rhodecode.model.settings import SettingsModel
444 app_settings = SettingsModel().get_all_settings()
459 app_settings = SettingsModel().get_all_settings()
445
460
446 for k, v in list(app_settings.items()):
461 for k, v in list(app_settings.items()):
447 config[k] = v
462 config[k] = v
448
463
449
464
450 def get_rhodecode_realm():
465 def get_rhodecode_realm():
451 """
466 """
452 Return the rhodecode realm from database.
467 Return the rhodecode realm from database.
453 """
468 """
454 from rhodecode.model.settings import SettingsModel
469 from rhodecode.model.settings import SettingsModel
455 realm = SettingsModel().get_setting_by_name('realm')
470 realm = SettingsModel().get_setting_by_name('realm')
456 return safe_str(realm.app_settings_value)
471 return safe_str(realm.app_settings_value)
457
472
458
473
459 def get_rhodecode_base_path():
474 def get_rhodecode_base_path():
460 """
475 """
461 Returns the base path. The base path is the filesystem path which points
476 Returns the base path. The base path is the filesystem path which points
462 to the repository store.
477 to the repository store.
463 """
478 """
464
479
465 import rhodecode
480 import rhodecode
466 return rhodecode.CONFIG['default_base_path']
481 return rhodecode.CONFIG['default_base_path']
467
482
468
483
469 def map_groups(path):
484 def map_groups(path):
470 """
485 """
471 Given a full path to a repository, create all nested groups that this
486 Given a full path to a repository, create all nested groups that this
472 repo is inside. This function creates parent-child relationships between
487 repo is inside. This function creates parent-child relationships between
473 groups and creates default perms for all new groups.
488 groups and creates default perms for all new groups.
474
489
475 :param paths: full path to repository
490 :param paths: full path to repository
476 """
491 """
477 from rhodecode.model.repo_group import RepoGroupModel
492 from rhodecode.model.repo_group import RepoGroupModel
478 sa = meta.Session()
493 sa = meta.Session()
479 groups = path.split(Repository.NAME_SEP)
494 groups = path.split(Repository.NAME_SEP)
480 parent = None
495 parent = None
481 group = None
496 group = None
482
497
483 # last element is repo in nested groups structure
498 # last element is repo in nested groups structure
484 groups = groups[:-1]
499 groups = groups[:-1]
485 rgm = RepoGroupModel(sa)
500 rgm = RepoGroupModel(sa)
486 owner = User.get_first_super_admin()
501 owner = User.get_first_super_admin()
487 for lvl, group_name in enumerate(groups):
502 for lvl, group_name in enumerate(groups):
488 group_name = '/'.join(groups[:lvl] + [group_name])
503 group_name = '/'.join(groups[:lvl] + [group_name])
489 group = RepoGroup.get_by_group_name(group_name)
504 group = RepoGroup.get_by_group_name(group_name)
490 desc = '%s group' % group_name
505 desc = '%s group' % group_name
491
506
492 # skip folders that are now removed repos
507 # skip folders that are now removed repos
493 if REMOVED_REPO_PAT.match(group_name):
508 if REMOVED_REPO_PAT.match(group_name):
494 break
509 break
495
510
496 if group is None:
511 if group is None:
497 log.debug('creating group level: %s group_name: %s',
512 log.debug('creating group level: %s group_name: %s',
498 lvl, group_name)
513 lvl, group_name)
499 group = RepoGroup(group_name, parent)
514 group = RepoGroup(group_name, parent)
500 group.group_description = desc
515 group.group_description = desc
501 group.user = owner
516 group.user = owner
502 sa.add(group)
517 sa.add(group)
503 perm_obj = rgm._create_default_perms(group)
518 perm_obj = rgm._create_default_perms(group)
504 sa.add(perm_obj)
519 sa.add(perm_obj)
505 sa.flush()
520 sa.flush()
506
521
507 parent = group
522 parent = group
508 return group
523 return group
509
524
510
525
511 def repo2db_mapper(initial_repo_list, remove_obsolete=False, force_hooks_rebuild=False):
526 def repo2db_mapper(initial_repo_list, remove_obsolete=False, force_hooks_rebuild=False):
512 """
527 """
513 maps all repos given in initial_repo_list, non existing repositories
528 maps all repos given in initial_repo_list, non existing repositories
514 are created, if remove_obsolete is True it also checks for db entries
529 are created, if remove_obsolete is True it also checks for db entries
515 that are not in initial_repo_list and removes them.
530 that are not in initial_repo_list and removes them.
516
531
517 :param initial_repo_list: list of repositories found by scanning methods
532 :param initial_repo_list: list of repositories found by scanning methods
518 :param remove_obsolete: check for obsolete entries in database
533 :param remove_obsolete: check for obsolete entries in database
519 """
534 """
520 from rhodecode.model.repo import RepoModel
535 from rhodecode.model.repo import RepoModel
521 from rhodecode.model.repo_group import RepoGroupModel
536 from rhodecode.model.repo_group import RepoGroupModel
522 from rhodecode.model.settings import SettingsModel
537 from rhodecode.model.settings import SettingsModel
523
538
524 sa = meta.Session()
539 sa = meta.Session()
525 repo_model = RepoModel()
540 repo_model = RepoModel()
526 user = User.get_first_super_admin()
541 user = User.get_first_super_admin()
527 added = []
542 added = []
528
543
529 # creation defaults
544 # creation defaults
530 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
545 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
531 enable_statistics = defs.get('repo_enable_statistics')
546 enable_statistics = defs.get('repo_enable_statistics')
532 enable_locking = defs.get('repo_enable_locking')
547 enable_locking = defs.get('repo_enable_locking')
533 enable_downloads = defs.get('repo_enable_downloads')
548 enable_downloads = defs.get('repo_enable_downloads')
534 private = defs.get('repo_private')
549 private = defs.get('repo_private')
535
550
536 for name, repo in list(initial_repo_list.items()):
551 for name, repo in list(initial_repo_list.items()):
537 group = map_groups(name)
552 group = map_groups(name)
538 str_name = safe_str(name)
553 str_name = safe_str(name)
539 db_repo = repo_model.get_by_repo_name(str_name)
554 db_repo = repo_model.get_by_repo_name(str_name)
540
555
541 # found repo that is on filesystem not in RhodeCode database
556 # found repo that is on filesystem not in RhodeCode database
542 if not db_repo:
557 if not db_repo:
543 log.info('repository `%s` not found in the database, creating now', name)
558 log.info('repository `%s` not found in the database, creating now', name)
544 added.append(name)
559 added.append(name)
545 desc = (repo.description
560 desc = (repo.description
546 if repo.description != 'unknown'
561 if repo.description != 'unknown'
547 else '%s repository' % name)
562 else '%s repository' % name)
548
563
549 db_repo = repo_model._create_repo(
564 db_repo = repo_model._create_repo(
550 repo_name=name,
565 repo_name=name,
551 repo_type=repo.alias,
566 repo_type=repo.alias,
552 description=desc,
567 description=desc,
553 repo_group=getattr(group, 'group_id', None),
568 repo_group=getattr(group, 'group_id', None),
554 owner=user,
569 owner=user,
555 enable_locking=enable_locking,
570 enable_locking=enable_locking,
556 enable_downloads=enable_downloads,
571 enable_downloads=enable_downloads,
557 enable_statistics=enable_statistics,
572 enable_statistics=enable_statistics,
558 private=private,
573 private=private,
559 state=Repository.STATE_CREATED
574 state=Repository.STATE_CREATED
560 )
575 )
561 sa.commit()
576 sa.commit()
562 # we added that repo just now, and make sure we updated server info
577 # we added that repo just now, and make sure we updated server info
563 if db_repo.repo_type == 'git':
578 if db_repo.repo_type == 'git':
564 git_repo = db_repo.scm_instance()
579 git_repo = db_repo.scm_instance()
565 # update repository server-info
580 # update repository server-info
566 log.debug('Running update server info')
581 log.debug('Running update server info')
567 git_repo._update_server_info(force=True)
582 git_repo._update_server_info(force=True)
568
583
569 db_repo.update_commit_cache()
584 db_repo.update_commit_cache()
570
585
571 config = db_repo._config
586 config = db_repo._config
572 config.set('extensions', 'largefiles', '')
587 config.set('extensions', 'largefiles', '')
573 repo = db_repo.scm_instance(config=config)
588 repo = db_repo.scm_instance(config=config)
574 repo.install_hooks(force=force_hooks_rebuild)
589 repo.install_hooks(force=force_hooks_rebuild)
575
590
576 removed = []
591 removed = []
577 if remove_obsolete:
592 if remove_obsolete:
578 # remove from database those repositories that are not in the filesystem
593 # remove from database those repositories that are not in the filesystem
579 for repo in sa.query(Repository).all():
594 for repo in sa.query(Repository).all():
580 if repo.repo_name not in list(initial_repo_list.keys()):
595 if repo.repo_name not in list(initial_repo_list.keys()):
581 log.debug("Removing non-existing repository found in db `%s`",
596 log.debug("Removing non-existing repository found in db `%s`",
582 repo.repo_name)
597 repo.repo_name)
583 try:
598 try:
584 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
599 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
585 sa.commit()
600 sa.commit()
586 removed.append(repo.repo_name)
601 removed.append(repo.repo_name)
587 except Exception:
602 except Exception:
588 # don't hold further removals on error
603 # don't hold further removals on error
589 log.error(traceback.format_exc())
604 log.error(traceback.format_exc())
590 sa.rollback()
605 sa.rollback()
591
606
592 def splitter(full_repo_name):
607 def splitter(full_repo_name):
593 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
608 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
594 gr_name = None
609 gr_name = None
595 if len(_parts) == 2:
610 if len(_parts) == 2:
596 gr_name = _parts[0]
611 gr_name = _parts[0]
597 return gr_name
612 return gr_name
598
613
599 initial_repo_group_list = [splitter(x) for x in
614 initial_repo_group_list = [splitter(x) for x in
600 list(initial_repo_list.keys()) if splitter(x)]
615 list(initial_repo_list.keys()) if splitter(x)]
601
616
602 # remove from database those repository groups that are not in the
617 # remove from database those repository groups that are not in the
603 # filesystem due to parent child relationships we need to delete them
618 # filesystem due to parent child relationships we need to delete them
604 # in a specific order of most nested first
619 # in a specific order of most nested first
605 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
620 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
606 def nested_sort(gr):
621 def nested_sort(gr):
607 return len(gr.split('/'))
622 return len(gr.split('/'))
608 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
623 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
609 if group_name not in initial_repo_group_list:
624 if group_name not in initial_repo_group_list:
610 repo_group = RepoGroup.get_by_group_name(group_name)
625 repo_group = RepoGroup.get_by_group_name(group_name)
611 if (repo_group.children.all() or
626 if (repo_group.children.all() or
612 not RepoGroupModel().check_exist_filesystem(
627 not RepoGroupModel().check_exist_filesystem(
613 group_name=group_name, exc_on_failure=False)):
628 group_name=group_name, exc_on_failure=False)):
614 continue
629 continue
615
630
616 log.info(
631 log.info(
617 'Removing non-existing repository group found in db `%s`',
632 'Removing non-existing repository group found in db `%s`',
618 group_name)
633 group_name)
619 try:
634 try:
620 RepoGroupModel(sa).delete(group_name, fs_remove=False)
635 RepoGroupModel(sa).delete(group_name, fs_remove=False)
621 sa.commit()
636 sa.commit()
622 removed.append(group_name)
637 removed.append(group_name)
623 except Exception:
638 except Exception:
624 # don't hold further removals on error
639 # don't hold further removals on error
625 log.exception(
640 log.exception(
626 'Unable to remove repository group `%s`',
641 'Unable to remove repository group `%s`',
627 group_name)
642 group_name)
628 sa.rollback()
643 sa.rollback()
629 raise
644 raise
630
645
631 return added, removed
646 return added, removed
632
647
633
648
634 def load_rcextensions(root_path):
649 def load_rcextensions(root_path):
635 import rhodecode
650 import rhodecode
636 from rhodecode.config import conf
651 from rhodecode.config import conf
637
652
638 path = os.path.join(root_path)
653 path = os.path.join(root_path)
639 sys.path.append(path)
654 sys.path.append(path)
640
655
641 try:
656 try:
642 rcextensions = __import__('rcextensions')
657 rcextensions = __import__('rcextensions')
643 except ImportError:
658 except ImportError:
644 if os.path.isdir(os.path.join(path, 'rcextensions')):
659 if os.path.isdir(os.path.join(path, 'rcextensions')):
645 log.warning('Unable to load rcextensions from %s', path)
660 log.warning('Unable to load rcextensions from %s', path)
646 rcextensions = None
661 rcextensions = None
647
662
648 if rcextensions:
663 if rcextensions:
649 log.info('Loaded rcextensions from %s...', rcextensions)
664 log.info('Loaded rcextensions from %s...', rcextensions)
650 rhodecode.EXTENSIONS = rcextensions
665 rhodecode.EXTENSIONS = rcextensions
651
666
652 # Additional mappings that are not present in the pygments lexers
667 # Additional mappings that are not present in the pygments lexers
653 conf.LANGUAGES_EXTENSIONS_MAP.update(
668 conf.LANGUAGES_EXTENSIONS_MAP.update(
654 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
669 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
655
670
656
671
657 def get_custom_lexer(extension):
672 def get_custom_lexer(extension):
658 """
673 """
659 returns a custom lexer if it is defined in rcextensions module, or None
674 returns a custom lexer if it is defined in rcextensions module, or None
660 if there's no custom lexer defined
675 if there's no custom lexer defined
661 """
676 """
662 import rhodecode
677 import rhodecode
663 from pygments import lexers
678 from pygments import lexers
664
679
665 # custom override made by RhodeCode
680 # custom override made by RhodeCode
666 if extension in ['mako']:
681 if extension in ['mako']:
667 return lexers.get_lexer_by_name('html+mako')
682 return lexers.get_lexer_by_name('html+mako')
668
683
669 # check if we didn't define this extension as other lexer
684 # check if we didn't define this extension as other lexer
670 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
685 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
671 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
686 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
672 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
687 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
673 return lexers.get_lexer_by_name(_lexer_name)
688 return lexers.get_lexer_by_name(_lexer_name)
674
689
675
690
676 #==============================================================================
691 #==============================================================================
677 # TEST FUNCTIONS AND CREATORS
692 # TEST FUNCTIONS AND CREATORS
678 #==============================================================================
693 #==============================================================================
679 def create_test_index(repo_location, config):
694 def create_test_index(repo_location, config):
680 """
695 """
681 Makes default test index.
696 Makes default test index.
682 """
697 """
683 try:
698 try:
684 import rc_testdata
699 import rc_testdata
685 except ImportError:
700 except ImportError:
686 raise ImportError('Failed to import rc_testdata, '
701 raise ImportError('Failed to import rc_testdata, '
687 'please make sure this package is installed from requirements_test.txt')
702 'please make sure this package is installed from requirements_test.txt')
688 rc_testdata.extract_search_index(
703 rc_testdata.extract_search_index(
689 'vcs_search_index', os.path.dirname(config['search.location']))
704 'vcs_search_index', os.path.dirname(config['search.location']))
690
705
691
706
692 def create_test_directory(test_path):
707 def create_test_directory(test_path):
693 """
708 """
694 Create test directory if it doesn't exist.
709 Create test directory if it doesn't exist.
695 """
710 """
696 if not os.path.isdir(test_path):
711 if not os.path.isdir(test_path):
697 log.debug('Creating testdir %s', test_path)
712 log.debug('Creating testdir %s', test_path)
698 os.makedirs(test_path)
713 os.makedirs(test_path)
699
714
700
715
701 def create_test_database(test_path, config):
716 def create_test_database(test_path, config):
702 """
717 """
703 Makes a fresh database.
718 Makes a fresh database.
704 """
719 """
705 from rhodecode.lib.db_manage import DbManage
720 from rhodecode.lib.db_manage import DbManage
706 from rhodecode.lib.utils2 import get_encryption_key
721 from rhodecode.lib.utils2 import get_encryption_key
707
722
708 # PART ONE create db
723 # PART ONE create db
709 dbconf = config['sqlalchemy.db1.url']
724 dbconf = config['sqlalchemy.db1.url']
710 enc_key = get_encryption_key(config)
725 enc_key = get_encryption_key(config)
711
726
712 log.debug('making test db %s', dbconf)
727 log.debug('making test db %s', dbconf)
713
728
714 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
729 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
715 tests=True, cli_args={'force_ask': True}, enc_key=enc_key)
730 tests=True, cli_args={'force_ask': True}, enc_key=enc_key)
716 dbmanage.create_tables(override=True)
731 dbmanage.create_tables(override=True)
717 dbmanage.set_db_version()
732 dbmanage.set_db_version()
718 # for tests dynamically set new root paths based on generated content
733 # for tests dynamically set new root paths based on generated content
719 dbmanage.create_settings(dbmanage.config_prompt(test_path))
734 dbmanage.create_settings(dbmanage.config_prompt(test_path))
720 dbmanage.create_default_user()
735 dbmanage.create_default_user()
721 dbmanage.create_test_admin_and_users()
736 dbmanage.create_test_admin_and_users()
722 dbmanage.create_permissions()
737 dbmanage.create_permissions()
723 dbmanage.populate_default_permissions()
738 dbmanage.populate_default_permissions()
724 Session().commit()
739 Session().commit()
725
740
726
741
727 def create_test_repositories(test_path, config):
742 def create_test_repositories(test_path, config):
728 """
743 """
729 Creates test repositories in the temporary directory. Repositories are
744 Creates test repositories in the temporary directory. Repositories are
730 extracted from archives within the rc_testdata package.
745 extracted from archives within the rc_testdata package.
731 """
746 """
732 import rc_testdata
747 import rc_testdata
733 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
748 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
734
749
735 log.debug('making test vcs repositories')
750 log.debug('making test vcs repositories')
736
751
737 idx_path = config['search.location']
752 idx_path = config['search.location']
738 data_path = config['cache_dir']
753 data_path = config['cache_dir']
739
754
740 # clean index and data
755 # clean index and data
741 if idx_path and os.path.exists(idx_path):
756 if idx_path and os.path.exists(idx_path):
742 log.debug('remove %s', idx_path)
757 log.debug('remove %s', idx_path)
743 shutil.rmtree(idx_path)
758 shutil.rmtree(idx_path)
744
759
745 if data_path and os.path.exists(data_path):
760 if data_path and os.path.exists(data_path):
746 log.debug('remove %s', data_path)
761 log.debug('remove %s', data_path)
747 shutil.rmtree(data_path)
762 shutil.rmtree(data_path)
748
763
749 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
764 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
750 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
765 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
751
766
752 # Note: Subversion is in the process of being integrated with the system,
767 # Note: Subversion is in the process of being integrated with the system,
753 # until we have a properly packed version of the test svn repository, this
768 # until we have a properly packed version of the test svn repository, this
754 # tries to copy over the repo from a package "rc_testdata"
769 # tries to copy over the repo from a package "rc_testdata"
755 svn_repo_path = rc_testdata.get_svn_repo_archive()
770 svn_repo_path = rc_testdata.get_svn_repo_archive()
756 with tarfile.open(svn_repo_path) as tar:
771 with tarfile.open(svn_repo_path) as tar:
757 tar.extractall(jn(test_path, SVN_REPO))
772 tar.extractall(jn(test_path, SVN_REPO))
758
773
759
774
760 def password_changed(auth_user, session):
775 def password_changed(auth_user, session):
761 # Never report password change in case of default user or anonymous user.
776 # Never report password change in case of default user or anonymous user.
762 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
777 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
763 return False
778 return False
764
779
765 password_hash = md5(safe_bytes(auth_user.password)) if auth_user.password else None
780 password_hash = md5(safe_bytes(auth_user.password)) if auth_user.password else None
766 rhodecode_user = session.get('rhodecode_user', {})
781 rhodecode_user = session.get('rhodecode_user', {})
767 session_password_hash = rhodecode_user.get('password', '')
782 session_password_hash = rhodecode_user.get('password', '')
768 return password_hash != session_password_hash
783 return password_hash != session_password_hash
769
784
770
785
771 def read_opensource_licenses():
786 def read_opensource_licenses():
772 global _license_cache
787 global _license_cache
773
788
774 if not _license_cache:
789 if not _license_cache:
775 licenses = pkg_resources.resource_string(
790 licenses = pkg_resources.resource_string(
776 'rhodecode', 'config/licenses.json')
791 'rhodecode', 'config/licenses.json')
777 _license_cache = json.loads(licenses)
792 _license_cache = json.loads(licenses)
778
793
779 return _license_cache
794 return _license_cache
780
795
781
796
782 def generate_platform_uuid():
797 def generate_platform_uuid():
783 """
798 """
784 Generates platform UUID based on it's name
799 Generates platform UUID based on it's name
785 """
800 """
786 import platform
801 import platform
787
802
788 try:
803 try:
789 uuid_list = [platform.platform()]
804 uuid_list = [platform.platform()]
790 return sha256_safe(':'.join(uuid_list))
805 return sha256_safe(':'.join(uuid_list))
791 except Exception as e:
806 except Exception as e:
792 log.error('Failed to generate host uuid: %s', e)
807 log.error('Failed to generate host uuid: %s', e)
793 return 'UNDEFINED'
808 return 'UNDEFINED'
794
809
795
810
796 def send_test_email(recipients, email_body='TEST EMAIL'):
811 def send_test_email(recipients, email_body='TEST EMAIL'):
797 """
812 """
798 Simple code for generating test emails.
813 Simple code for generating test emails.
799 Usage::
814 Usage::
800
815
801 from rhodecode.lib import utils
816 from rhodecode.lib import utils
802 utils.send_test_email()
817 utils.send_test_email()
803 """
818 """
804 from rhodecode.lib.celerylib import tasks, run_task
819 from rhodecode.lib.celerylib import tasks, run_task
805
820
806 email_body = email_body_plaintext = email_body
821 email_body = email_body_plaintext = email_body
807 subject = f'SUBJECT FROM: {socket.gethostname()}'
822 subject = f'SUBJECT FROM: {socket.gethostname()}'
808 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
823 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
@@ -1,74 +1,73 b''
1 # Copyright (C) 2014-2023 RhodeCode GmbH
1 # Copyright (C) 2014-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 """
19 """
20 Internal settings for vcs-lib
20 Internal settings for vcs-lib
21 """
21 """
22
22
23 # list of default encoding used in safe_str methods
23 # list of default encoding used in safe_str methods
24 DEFAULT_ENCODINGS = ['utf8']
24 DEFAULT_ENCODINGS = ['utf8']
25
25
26
26
27 # Compatibility version when creating SVN repositories. None means newest.
27 # Compatibility version when creating SVN repositories. None means newest.
28 # Other available options are: pre-1.4-compatible, pre-1.5-compatible,
28 # Other available options are: pre-1.4-compatible, pre-1.5-compatible,
29 # pre-1.6-compatible, pre-1.8-compatible
29 # pre-1.6-compatible, pre-1.8-compatible
30 SVN_COMPATIBLE_VERSION = None
30 SVN_COMPATIBLE_VERSION = None
31
31
32 ALIASES = ['hg', 'git', 'svn']
32 ALIASES = ['hg', 'git', 'svn']
33
33
34 BACKENDS = {
34 BACKENDS = {
35 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
35 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
36 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
36 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
37 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
37 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
38 }
38 }
39
39
40
40
41 ARCHIVE_SPECS = [
41 ARCHIVE_SPECS = [
42 ('tbz2', 'application/x-bzip2', '.tbz2'),
42 ('tbz2', 'application/x-bzip2', '.tbz2'),
43 ('tbz2', 'application/x-bzip2', '.tar.bz2'),
43 ('tbz2', 'application/x-bzip2', '.tar.bz2'),
44
44
45 ('tgz', 'application/x-gzip', '.tgz'),
45 ('tgz', 'application/x-gzip', '.tgz'),
46 ('tgz', 'application/x-gzip', '.tar.gz'),
46 ('tgz', 'application/x-gzip', '.tar.gz'),
47
47
48 ('zip', 'application/zip', '.zip'),
48 ('zip', 'application/zip', '.zip'),
49 ]
49 ]
50
50
51 HOOKS_PROTOCOL = None
51 HOOKS_PROTOCOL = None
52 HOOKS_DIRECT_CALLS = False
53 HOOKS_HOST = '127.0.0.1'
52 HOOKS_HOST = '127.0.0.1'
54
53
55
54
56 MERGE_MESSAGE_TMPL = (
55 MERGE_MESSAGE_TMPL = (
57 'Merge pull request !{pr_id} from {source_repo} {source_ref_name}\n\n '
56 'Merge pull request !{pr_id} from {source_repo} {source_ref_name}\n\n '
58 '{pr_title}')
57 '{pr_title}')
59 MERGE_DRY_RUN_MESSAGE = 'dry_run_merge_message_from_rhodecode'
58 MERGE_DRY_RUN_MESSAGE = 'dry_run_merge_message_from_rhodecode'
60 MERGE_DRY_RUN_USER = 'Dry-Run User'
59 MERGE_DRY_RUN_USER = 'Dry-Run User'
61 MERGE_DRY_RUN_EMAIL = 'dry-run-merge@rhodecode.com'
60 MERGE_DRY_RUN_EMAIL = 'dry-run-merge@rhodecode.com'
62
61
63
62
64 def available_aliases():
63 def available_aliases():
65 """
64 """
66 Mercurial is required for the system to work, so in case vcs.backends does
65 Mercurial is required for the system to work, so in case vcs.backends does
67 not include it, we make sure it will be available internally
66 not include it, we make sure it will be available internally
68 TODO: anderson: refactor vcs.backends so it won't be necessary, VCS server
67 TODO: anderson: refactor vcs.backends so it won't be necessary, VCS server
69 should be responsible to dictate available backends.
68 should be responsible to dictate available backends.
70 """
69 """
71 aliases = ALIASES[:]
70 aliases = ALIASES[:]
72 if 'hg' not in aliases:
71 if 'hg' not in aliases:
73 aliases += ['hg']
72 aliases += ['hg']
74 return aliases
73 return aliases
@@ -1,2390 +1,2389 b''
1 # Copyright (C) 2012-2023 RhodeCode GmbH
1 # Copyright (C) 2012-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19
19
20 """
20 """
21 pull request model for RhodeCode
21 pull request model for RhodeCode
22 """
22 """
23
23
24 import logging
24 import logging
25 import os
25 import os
26
26
27 import datetime
27 import datetime
28 import urllib.request
28 import urllib.request
29 import urllib.parse
29 import urllib.parse
30 import urllib.error
30 import urllib.error
31 import collections
31 import collections
32
32
33 import dataclasses as dataclasses
33 import dataclasses as dataclasses
34 from pyramid.threadlocal import get_current_request
34 from pyramid.threadlocal import get_current_request
35
35
36 from rhodecode.lib.vcs.nodes import FileNode
36 from rhodecode.lib.vcs.nodes import FileNode
37 from rhodecode.translation import lazy_ugettext
37 from rhodecode.translation import lazy_ugettext
38 from rhodecode.lib import helpers as h, hooks_utils, diffs
38 from rhodecode.lib import helpers as h, hooks_utils, diffs
39 from rhodecode.lib import audit_logger
39 from rhodecode.lib import audit_logger
40 from collections import OrderedDict
40 from collections import OrderedDict
41 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
41 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
42 from rhodecode.lib.ext_json import sjson as json
42 from rhodecode.lib.ext_json import sjson as json
43 from rhodecode.lib.markup_renderer import (
43 from rhodecode.lib.markup_renderer import (
44 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
44 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
45 from rhodecode.lib.hash_utils import md5_safe
45 from rhodecode.lib.hash_utils import md5_safe
46 from rhodecode.lib.str_utils import safe_str
46 from rhodecode.lib.str_utils import safe_str
47 from rhodecode.lib.utils2 import AttributeDict, get_current_rhodecode_user
47 from rhodecode.lib.utils2 import AttributeDict, get_current_rhodecode_user
48 from rhodecode.lib.vcs.backends.base import (
48 from rhodecode.lib.vcs.backends.base import (
49 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason,
49 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason,
50 TargetRefMissing, SourceRefMissing)
50 TargetRefMissing, SourceRefMissing)
51 from rhodecode.lib.vcs.conf import settings as vcs_settings
51 from rhodecode.lib.vcs.conf import settings as vcs_settings
52 from rhodecode.lib.vcs.exceptions import (
52 from rhodecode.lib.vcs.exceptions import (
53 CommitDoesNotExistError, EmptyRepositoryError)
53 CommitDoesNotExistError, EmptyRepositoryError)
54 from rhodecode.model import BaseModel
54 from rhodecode.model import BaseModel
55 from rhodecode.model.changeset_status import ChangesetStatusModel
55 from rhodecode.model.changeset_status import ChangesetStatusModel
56 from rhodecode.model.comment import CommentsModel
56 from rhodecode.model.comment import CommentsModel
57 from rhodecode.model.db import (
57 from rhodecode.model.db import (
58 aliased, null, lazyload, and_, or_, select, func, String, cast, PullRequest, PullRequestReviewers, ChangesetStatus,
58 aliased, null, lazyload, and_, or_, select, func, String, cast, PullRequest, PullRequestReviewers, ChangesetStatus,
59 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule, User)
59 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule, User)
60 from rhodecode.model.meta import Session
60 from rhodecode.model.meta import Session
61 from rhodecode.model.notification import NotificationModel, \
61 from rhodecode.model.notification import NotificationModel, \
62 EmailNotificationModel
62 EmailNotificationModel
63 from rhodecode.model.scm import ScmModel
63 from rhodecode.model.scm import ScmModel
64 from rhodecode.model.settings import VcsSettingsModel
64 from rhodecode.model.settings import VcsSettingsModel
65
65
66
66
67 log = logging.getLogger(__name__)
67 log = logging.getLogger(__name__)
68
68
69
69
70 # Data structure to hold the response data when updating commits during a pull
70 # Data structure to hold the response data when updating commits during a pull
71 # request update.
71 # request update.
72 class UpdateResponse(object):
72 class UpdateResponse(object):
73
73
74 def __init__(self, executed, reason, new, old, common_ancestor_id,
74 def __init__(self, executed, reason, new, old, common_ancestor_id,
75 commit_changes, source_changed, target_changed):
75 commit_changes, source_changed, target_changed):
76
76
77 self.executed = executed
77 self.executed = executed
78 self.reason = reason
78 self.reason = reason
79 self.new = new
79 self.new = new
80 self.old = old
80 self.old = old
81 self.common_ancestor_id = common_ancestor_id
81 self.common_ancestor_id = common_ancestor_id
82 self.changes = commit_changes
82 self.changes = commit_changes
83 self.source_changed = source_changed
83 self.source_changed = source_changed
84 self.target_changed = target_changed
84 self.target_changed = target_changed
85
85
86
86
87 def get_diff_info(
87 def get_diff_info(
88 source_repo, source_ref, target_repo, target_ref, get_authors=False,
88 source_repo, source_ref, target_repo, target_ref, get_authors=False,
89 get_commit_authors=True):
89 get_commit_authors=True):
90 """
90 """
91 Calculates detailed diff information for usage in preview of creation of a pull-request.
91 Calculates detailed diff information for usage in preview of creation of a pull-request.
92 This is also used for default reviewers logic
92 This is also used for default reviewers logic
93 """
93 """
94
94
95 source_scm = source_repo.scm_instance()
95 source_scm = source_repo.scm_instance()
96 target_scm = target_repo.scm_instance()
96 target_scm = target_repo.scm_instance()
97
97
98 ancestor_id = target_scm.get_common_ancestor(target_ref, source_ref, source_scm)
98 ancestor_id = target_scm.get_common_ancestor(target_ref, source_ref, source_scm)
99 if not ancestor_id:
99 if not ancestor_id:
100 raise ValueError(
100 raise ValueError(
101 'cannot calculate diff info without a common ancestor. '
101 'cannot calculate diff info without a common ancestor. '
102 'Make sure both repositories are related, and have a common forking commit.')
102 'Make sure both repositories are related, and have a common forking commit.')
103
103
104 # case here is that want a simple diff without incoming commits,
104 # case here is that want a simple diff without incoming commits,
105 # previewing what will be merged based only on commits in the source.
105 # previewing what will be merged based only on commits in the source.
106 log.debug('Using ancestor %s as source_ref instead of %s',
106 log.debug('Using ancestor %s as source_ref instead of %s',
107 ancestor_id, source_ref)
107 ancestor_id, source_ref)
108
108
109 # source of changes now is the common ancestor
109 # source of changes now is the common ancestor
110 source_commit = source_scm.get_commit(commit_id=ancestor_id)
110 source_commit = source_scm.get_commit(commit_id=ancestor_id)
111 # target commit becomes the source ref as it is the last commit
111 # target commit becomes the source ref as it is the last commit
112 # for diff generation this logic gives proper diff
112 # for diff generation this logic gives proper diff
113 target_commit = source_scm.get_commit(commit_id=source_ref)
113 target_commit = source_scm.get_commit(commit_id=source_ref)
114
114
115 vcs_diff = \
115 vcs_diff = \
116 source_scm.get_diff(commit1=source_commit, commit2=target_commit,
116 source_scm.get_diff(commit1=source_commit, commit2=target_commit,
117 ignore_whitespace=False, context=3)
117 ignore_whitespace=False, context=3)
118
118
119 diff_processor = diffs.DiffProcessor(vcs_diff, diff_format='newdiff',
119 diff_processor = diffs.DiffProcessor(vcs_diff, diff_format='newdiff',
120 diff_limit=0, file_limit=0, show_full_diff=True)
120 diff_limit=0, file_limit=0, show_full_diff=True)
121
121
122 _parsed = diff_processor.prepare()
122 _parsed = diff_processor.prepare()
123
123
124 all_files = []
124 all_files = []
125 all_files_changes = []
125 all_files_changes = []
126 changed_lines = {}
126 changed_lines = {}
127 stats = [0, 0]
127 stats = [0, 0]
128 for f in _parsed:
128 for f in _parsed:
129 all_files.append(f['filename'])
129 all_files.append(f['filename'])
130 all_files_changes.append({
130 all_files_changes.append({
131 'filename': f['filename'],
131 'filename': f['filename'],
132 'stats': f['stats']
132 'stats': f['stats']
133 })
133 })
134 stats[0] += f['stats']['added']
134 stats[0] += f['stats']['added']
135 stats[1] += f['stats']['deleted']
135 stats[1] += f['stats']['deleted']
136
136
137 changed_lines[f['filename']] = []
137 changed_lines[f['filename']] = []
138 if len(f['chunks']) < 2:
138 if len(f['chunks']) < 2:
139 continue
139 continue
140 # first line is "context" information
140 # first line is "context" information
141 for chunks in f['chunks'][1:]:
141 for chunks in f['chunks'][1:]:
142 for chunk in chunks['lines']:
142 for chunk in chunks['lines']:
143 if chunk['action'] not in ('del', 'mod'):
143 if chunk['action'] not in ('del', 'mod'):
144 continue
144 continue
145 changed_lines[f['filename']].append(chunk['old_lineno'])
145 changed_lines[f['filename']].append(chunk['old_lineno'])
146
146
147 commit_authors = []
147 commit_authors = []
148 user_counts = {}
148 user_counts = {}
149 email_counts = {}
149 email_counts = {}
150 author_counts = {}
150 author_counts = {}
151 _commit_cache = {}
151 _commit_cache = {}
152
152
153 commits = []
153 commits = []
154 if get_commit_authors:
154 if get_commit_authors:
155 log.debug('Obtaining commit authors from set of commits')
155 log.debug('Obtaining commit authors from set of commits')
156 _compare_data = target_scm.compare(
156 _compare_data = target_scm.compare(
157 target_ref, source_ref, source_scm, merge=True,
157 target_ref, source_ref, source_scm, merge=True,
158 pre_load=["author", "date", "message"]
158 pre_load=["author", "date", "message"]
159 )
159 )
160
160
161 for commit in _compare_data:
161 for commit in _compare_data:
162 # NOTE(marcink): we serialize here, so we don't produce more vcsserver calls on data returned
162 # NOTE(marcink): we serialize here, so we don't produce more vcsserver calls on data returned
163 # at this function which is later called via JSON serialization
163 # at this function which is later called via JSON serialization
164 serialized_commit = dict(
164 serialized_commit = dict(
165 author=commit.author,
165 author=commit.author,
166 date=commit.date,
166 date=commit.date,
167 message=commit.message,
167 message=commit.message,
168 commit_id=commit.raw_id,
168 commit_id=commit.raw_id,
169 raw_id=commit.raw_id
169 raw_id=commit.raw_id
170 )
170 )
171 commits.append(serialized_commit)
171 commits.append(serialized_commit)
172 user = User.get_from_cs_author(serialized_commit['author'])
172 user = User.get_from_cs_author(serialized_commit['author'])
173 if user and user not in commit_authors:
173 if user and user not in commit_authors:
174 commit_authors.append(user)
174 commit_authors.append(user)
175
175
176 # lines
176 # lines
177 if get_authors:
177 if get_authors:
178 log.debug('Calculating authors of changed files')
178 log.debug('Calculating authors of changed files')
179 target_commit = source_repo.get_commit(ancestor_id)
179 target_commit = source_repo.get_commit(ancestor_id)
180
180
181 for fname, lines in changed_lines.items():
181 for fname, lines in changed_lines.items():
182
182
183 try:
183 try:
184 node = target_commit.get_node(fname, pre_load=["is_binary"])
184 node = target_commit.get_node(fname, pre_load=["is_binary"])
185 except Exception:
185 except Exception:
186 log.exception("Failed to load node with path %s", fname)
186 log.exception("Failed to load node with path %s", fname)
187 continue
187 continue
188
188
189 if not isinstance(node, FileNode):
189 if not isinstance(node, FileNode):
190 continue
190 continue
191
191
192 # NOTE(marcink): for binary node we don't do annotation, just use last author
192 # NOTE(marcink): for binary node we don't do annotation, just use last author
193 if node.is_binary:
193 if node.is_binary:
194 author = node.last_commit.author
194 author = node.last_commit.author
195 email = node.last_commit.author_email
195 email = node.last_commit.author_email
196
196
197 user = User.get_from_cs_author(author)
197 user = User.get_from_cs_author(author)
198 if user:
198 if user:
199 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
199 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
200 author_counts[author] = author_counts.get(author, 0) + 1
200 author_counts[author] = author_counts.get(author, 0) + 1
201 email_counts[email] = email_counts.get(email, 0) + 1
201 email_counts[email] = email_counts.get(email, 0) + 1
202
202
203 continue
203 continue
204
204
205 for annotation in node.annotate:
205 for annotation in node.annotate:
206 line_no, commit_id, get_commit_func, line_text = annotation
206 line_no, commit_id, get_commit_func, line_text = annotation
207 if line_no in lines:
207 if line_no in lines:
208 if commit_id not in _commit_cache:
208 if commit_id not in _commit_cache:
209 _commit_cache[commit_id] = get_commit_func()
209 _commit_cache[commit_id] = get_commit_func()
210 commit = _commit_cache[commit_id]
210 commit = _commit_cache[commit_id]
211 author = commit.author
211 author = commit.author
212 email = commit.author_email
212 email = commit.author_email
213 user = User.get_from_cs_author(author)
213 user = User.get_from_cs_author(author)
214 if user:
214 if user:
215 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
215 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
216 author_counts[author] = author_counts.get(author, 0) + 1
216 author_counts[author] = author_counts.get(author, 0) + 1
217 email_counts[email] = email_counts.get(email, 0) + 1
217 email_counts[email] = email_counts.get(email, 0) + 1
218
218
219 log.debug('Default reviewers processing finished')
219 log.debug('Default reviewers processing finished')
220
220
221 return {
221 return {
222 'commits': commits,
222 'commits': commits,
223 'files': all_files_changes,
223 'files': all_files_changes,
224 'stats': stats,
224 'stats': stats,
225 'ancestor': ancestor_id,
225 'ancestor': ancestor_id,
226 # original authors of modified files
226 # original authors of modified files
227 'original_authors': {
227 'original_authors': {
228 'users': user_counts,
228 'users': user_counts,
229 'authors': author_counts,
229 'authors': author_counts,
230 'emails': email_counts,
230 'emails': email_counts,
231 },
231 },
232 'commit_authors': commit_authors
232 'commit_authors': commit_authors
233 }
233 }
234
234
235
235
236 class PullRequestModel(BaseModel):
236 class PullRequestModel(BaseModel):
237
237
238 cls = PullRequest
238 cls = PullRequest
239
239
240 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
240 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
241
241
242 UPDATE_STATUS_MESSAGES = {
242 UPDATE_STATUS_MESSAGES = {
243 UpdateFailureReason.NONE: lazy_ugettext(
243 UpdateFailureReason.NONE: lazy_ugettext(
244 'Pull request update successful.'),
244 'Pull request update successful.'),
245 UpdateFailureReason.UNKNOWN: lazy_ugettext(
245 UpdateFailureReason.UNKNOWN: lazy_ugettext(
246 'Pull request update failed because of an unknown error.'),
246 'Pull request update failed because of an unknown error.'),
247 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
247 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
248 'No update needed because the source and target have not changed.'),
248 'No update needed because the source and target have not changed.'),
249 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
249 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
250 'Pull request cannot be updated because the reference type is '
250 'Pull request cannot be updated because the reference type is '
251 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
251 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
252 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
252 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
253 'This pull request cannot be updated because the target '
253 'This pull request cannot be updated because the target '
254 'reference is missing.'),
254 'reference is missing.'),
255 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
255 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
256 'This pull request cannot be updated because the source '
256 'This pull request cannot be updated because the source '
257 'reference is missing.'),
257 'reference is missing.'),
258 }
258 }
259 REF_TYPES = ['bookmark', 'book', 'tag', 'branch']
259 REF_TYPES = ['bookmark', 'book', 'tag', 'branch']
260 UPDATABLE_REF_TYPES = ['bookmark', 'book', 'branch']
260 UPDATABLE_REF_TYPES = ['bookmark', 'book', 'branch']
261
261
262 def __get_pull_request(self, pull_request):
262 def __get_pull_request(self, pull_request):
263 return self._get_instance((
263 return self._get_instance((
264 PullRequest, PullRequestVersion), pull_request)
264 PullRequest, PullRequestVersion), pull_request)
265
265
266 def _check_perms(self, perms, pull_request, user, api=False):
266 def _check_perms(self, perms, pull_request, user, api=False):
267 if not api:
267 if not api:
268 return h.HasRepoPermissionAny(*perms)(
268 return h.HasRepoPermissionAny(*perms)(
269 user=user, repo_name=pull_request.target_repo.repo_name)
269 user=user, repo_name=pull_request.target_repo.repo_name)
270 else:
270 else:
271 return h.HasRepoPermissionAnyApi(*perms)(
271 return h.HasRepoPermissionAnyApi(*perms)(
272 user=user, repo_name=pull_request.target_repo.repo_name)
272 user=user, repo_name=pull_request.target_repo.repo_name)
273
273
274 def check_user_read(self, pull_request, user, api=False):
274 def check_user_read(self, pull_request, user, api=False):
275 _perms = ('repository.admin', 'repository.write', 'repository.read',)
275 _perms = ('repository.admin', 'repository.write', 'repository.read',)
276 return self._check_perms(_perms, pull_request, user, api)
276 return self._check_perms(_perms, pull_request, user, api)
277
277
278 def check_user_merge(self, pull_request, user, api=False):
278 def check_user_merge(self, pull_request, user, api=False):
279 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
279 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
280 return self._check_perms(_perms, pull_request, user, api)
280 return self._check_perms(_perms, pull_request, user, api)
281
281
282 def check_user_update(self, pull_request, user, api=False):
282 def check_user_update(self, pull_request, user, api=False):
283 owner = user.user_id == pull_request.user_id
283 owner = user.user_id == pull_request.user_id
284 return self.check_user_merge(pull_request, user, api) or owner
284 return self.check_user_merge(pull_request, user, api) or owner
285
285
286 def check_user_delete(self, pull_request, user):
286 def check_user_delete(self, pull_request, user):
287 owner = user.user_id == pull_request.user_id
287 owner = user.user_id == pull_request.user_id
288 _perms = ('repository.admin',)
288 _perms = ('repository.admin',)
289 return self._check_perms(_perms, pull_request, user) or owner
289 return self._check_perms(_perms, pull_request, user) or owner
290
290
291 def is_user_reviewer(self, pull_request, user):
291 def is_user_reviewer(self, pull_request, user):
292 return user.user_id in [
292 return user.user_id in [
293 x.user_id for x in
293 x.user_id for x in
294 pull_request.get_pull_request_reviewers(PullRequestReviewers.ROLE_REVIEWER)
294 pull_request.get_pull_request_reviewers(PullRequestReviewers.ROLE_REVIEWER)
295 if x.user
295 if x.user
296 ]
296 ]
297
297
298 def check_user_change_status(self, pull_request, user, api=False):
298 def check_user_change_status(self, pull_request, user, api=False):
299 return self.check_user_update(pull_request, user, api) \
299 return self.check_user_update(pull_request, user, api) \
300 or self.is_user_reviewer(pull_request, user)
300 or self.is_user_reviewer(pull_request, user)
301
301
302 def check_user_comment(self, pull_request, user):
302 def check_user_comment(self, pull_request, user):
303 owner = user.user_id == pull_request.user_id
303 owner = user.user_id == pull_request.user_id
304 return self.check_user_read(pull_request, user) or owner
304 return self.check_user_read(pull_request, user) or owner
305
305
306 def get(self, pull_request):
306 def get(self, pull_request):
307 return self.__get_pull_request(pull_request)
307 return self.__get_pull_request(pull_request)
308
308
309 def _prepare_get_all_query(self, repo_name, search_q=None, source=False,
309 def _prepare_get_all_query(self, repo_name, search_q=None, source=False,
310 statuses=None, opened_by=None, order_by=None,
310 statuses=None, opened_by=None, order_by=None,
311 order_dir='desc', only_created=False):
311 order_dir='desc', only_created=False):
312 repo = None
312 repo = None
313 if repo_name:
313 if repo_name:
314 repo = self._get_repo(repo_name)
314 repo = self._get_repo(repo_name)
315
315
316 q = PullRequest.query()
316 q = PullRequest.query()
317
317
318 if search_q:
318 if search_q:
319 like_expression = u'%{}%'.format(safe_str(search_q))
319 like_expression = u'%{}%'.format(safe_str(search_q))
320 q = q.join(User, User.user_id == PullRequest.user_id)
320 q = q.join(User, User.user_id == PullRequest.user_id)
321 q = q.filter(or_(
321 q = q.filter(or_(
322 cast(PullRequest.pull_request_id, String).ilike(like_expression),
322 cast(PullRequest.pull_request_id, String).ilike(like_expression),
323 User.username.ilike(like_expression),
323 User.username.ilike(like_expression),
324 PullRequest.title.ilike(like_expression),
324 PullRequest.title.ilike(like_expression),
325 PullRequest.description.ilike(like_expression),
325 PullRequest.description.ilike(like_expression),
326 ))
326 ))
327
327
328 # source or target
328 # source or target
329 if repo and source:
329 if repo and source:
330 q = q.filter(PullRequest.source_repo == repo)
330 q = q.filter(PullRequest.source_repo == repo)
331 elif repo:
331 elif repo:
332 q = q.filter(PullRequest.target_repo == repo)
332 q = q.filter(PullRequest.target_repo == repo)
333
333
334 # closed,opened
334 # closed,opened
335 if statuses:
335 if statuses:
336 q = q.filter(PullRequest.status.in_(statuses))
336 q = q.filter(PullRequest.status.in_(statuses))
337
337
338 # opened by filter
338 # opened by filter
339 if opened_by:
339 if opened_by:
340 q = q.filter(PullRequest.user_id.in_(opened_by))
340 q = q.filter(PullRequest.user_id.in_(opened_by))
341
341
342 # only get those that are in "created" state
342 # only get those that are in "created" state
343 if only_created:
343 if only_created:
344 q = q.filter(PullRequest.pull_request_state == PullRequest.STATE_CREATED)
344 q = q.filter(PullRequest.pull_request_state == PullRequest.STATE_CREATED)
345
345
346 order_map = {
346 order_map = {
347 'name_raw': PullRequest.pull_request_id,
347 'name_raw': PullRequest.pull_request_id,
348 'id': PullRequest.pull_request_id,
348 'id': PullRequest.pull_request_id,
349 'title': PullRequest.title,
349 'title': PullRequest.title,
350 'updated_on_raw': PullRequest.updated_on,
350 'updated_on_raw': PullRequest.updated_on,
351 'target_repo': PullRequest.target_repo_id
351 'target_repo': PullRequest.target_repo_id
352 }
352 }
353 if order_by and order_by in order_map:
353 if order_by and order_by in order_map:
354 if order_dir == 'asc':
354 if order_dir == 'asc':
355 q = q.order_by(order_map[order_by].asc())
355 q = q.order_by(order_map[order_by].asc())
356 else:
356 else:
357 q = q.order_by(order_map[order_by].desc())
357 q = q.order_by(order_map[order_by].desc())
358
358
359 return q
359 return q
360
360
361 def count_all(self, repo_name, search_q=None, source=False, statuses=None,
361 def count_all(self, repo_name, search_q=None, source=False, statuses=None,
362 opened_by=None):
362 opened_by=None):
363 """
363 """
364 Count the number of pull requests for a specific repository.
364 Count the number of pull requests for a specific repository.
365
365
366 :param repo_name: target or source repo
366 :param repo_name: target or source repo
367 :param search_q: filter by text
367 :param search_q: filter by text
368 :param source: boolean flag to specify if repo_name refers to source
368 :param source: boolean flag to specify if repo_name refers to source
369 :param statuses: list of pull request statuses
369 :param statuses: list of pull request statuses
370 :param opened_by: author user of the pull request
370 :param opened_by: author user of the pull request
371 :returns: int number of pull requests
371 :returns: int number of pull requests
372 """
372 """
373 q = self._prepare_get_all_query(
373 q = self._prepare_get_all_query(
374 repo_name, search_q=search_q, source=source, statuses=statuses,
374 repo_name, search_q=search_q, source=source, statuses=statuses,
375 opened_by=opened_by)
375 opened_by=opened_by)
376
376
377 return q.count()
377 return q.count()
378
378
379 def get_all(self, repo_name, search_q=None, source=False, statuses=None,
379 def get_all(self, repo_name, search_q=None, source=False, statuses=None,
380 opened_by=None, offset=0, length=None, order_by=None, order_dir='desc'):
380 opened_by=None, offset=0, length=None, order_by=None, order_dir='desc'):
381 """
381 """
382 Get all pull requests for a specific repository.
382 Get all pull requests for a specific repository.
383
383
384 :param repo_name: target or source repo
384 :param repo_name: target or source repo
385 :param search_q: filter by text
385 :param search_q: filter by text
386 :param source: boolean flag to specify if repo_name refers to source
386 :param source: boolean flag to specify if repo_name refers to source
387 :param statuses: list of pull request statuses
387 :param statuses: list of pull request statuses
388 :param opened_by: author user of the pull request
388 :param opened_by: author user of the pull request
389 :param offset: pagination offset
389 :param offset: pagination offset
390 :param length: length of returned list
390 :param length: length of returned list
391 :param order_by: order of the returned list
391 :param order_by: order of the returned list
392 :param order_dir: 'asc' or 'desc' ordering direction
392 :param order_dir: 'asc' or 'desc' ordering direction
393 :returns: list of pull requests
393 :returns: list of pull requests
394 """
394 """
395 q = self._prepare_get_all_query(
395 q = self._prepare_get_all_query(
396 repo_name, search_q=search_q, source=source, statuses=statuses,
396 repo_name, search_q=search_q, source=source, statuses=statuses,
397 opened_by=opened_by, order_by=order_by, order_dir=order_dir)
397 opened_by=opened_by, order_by=order_by, order_dir=order_dir)
398
398
399 if length:
399 if length:
400 pull_requests = q.limit(length).offset(offset).all()
400 pull_requests = q.limit(length).offset(offset).all()
401 else:
401 else:
402 pull_requests = q.all()
402 pull_requests = q.all()
403
403
404 return pull_requests
404 return pull_requests
405
405
406 def count_awaiting_review(self, repo_name, search_q=None, statuses=None):
406 def count_awaiting_review(self, repo_name, search_q=None, statuses=None):
407 """
407 """
408 Count the number of pull requests for a specific repository that are
408 Count the number of pull requests for a specific repository that are
409 awaiting review.
409 awaiting review.
410
410
411 :param repo_name: target or source repo
411 :param repo_name: target or source repo
412 :param search_q: filter by text
412 :param search_q: filter by text
413 :param statuses: list of pull request statuses
413 :param statuses: list of pull request statuses
414 :returns: int number of pull requests
414 :returns: int number of pull requests
415 """
415 """
416 pull_requests = self.get_awaiting_review(
416 pull_requests = self.get_awaiting_review(
417 repo_name, search_q=search_q, statuses=statuses)
417 repo_name, search_q=search_q, statuses=statuses)
418
418
419 return len(pull_requests)
419 return len(pull_requests)
420
420
421 def get_awaiting_review(self, repo_name, search_q=None, statuses=None,
421 def get_awaiting_review(self, repo_name, search_q=None, statuses=None,
422 offset=0, length=None, order_by=None, order_dir='desc'):
422 offset=0, length=None, order_by=None, order_dir='desc'):
423 """
423 """
424 Get all pull requests for a specific repository that are awaiting
424 Get all pull requests for a specific repository that are awaiting
425 review.
425 review.
426
426
427 :param repo_name: target or source repo
427 :param repo_name: target or source repo
428 :param search_q: filter by text
428 :param search_q: filter by text
429 :param statuses: list of pull request statuses
429 :param statuses: list of pull request statuses
430 :param offset: pagination offset
430 :param offset: pagination offset
431 :param length: length of returned list
431 :param length: length of returned list
432 :param order_by: order of the returned list
432 :param order_by: order of the returned list
433 :param order_dir: 'asc' or 'desc' ordering direction
433 :param order_dir: 'asc' or 'desc' ordering direction
434 :returns: list of pull requests
434 :returns: list of pull requests
435 """
435 """
436 pull_requests = self.get_all(
436 pull_requests = self.get_all(
437 repo_name, search_q=search_q, statuses=statuses,
437 repo_name, search_q=search_q, statuses=statuses,
438 order_by=order_by, order_dir=order_dir)
438 order_by=order_by, order_dir=order_dir)
439
439
440 _filtered_pull_requests = []
440 _filtered_pull_requests = []
441 for pr in pull_requests:
441 for pr in pull_requests:
442 status = pr.calculated_review_status()
442 status = pr.calculated_review_status()
443 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
443 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
444 ChangesetStatus.STATUS_UNDER_REVIEW]:
444 ChangesetStatus.STATUS_UNDER_REVIEW]:
445 _filtered_pull_requests.append(pr)
445 _filtered_pull_requests.append(pr)
446 if length:
446 if length:
447 return _filtered_pull_requests[offset:offset+length]
447 return _filtered_pull_requests[offset:offset+length]
448 else:
448 else:
449 return _filtered_pull_requests
449 return _filtered_pull_requests
450
450
451 def _prepare_awaiting_my_review_review_query(
451 def _prepare_awaiting_my_review_review_query(
452 self, repo_name, user_id, search_q=None, statuses=None,
452 self, repo_name, user_id, search_q=None, statuses=None,
453 order_by=None, order_dir='desc'):
453 order_by=None, order_dir='desc'):
454
454
455 for_review_statuses = [
455 for_review_statuses = [
456 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
456 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
457 ]
457 ]
458
458
459 pull_request_alias = aliased(PullRequest)
459 pull_request_alias = aliased(PullRequest)
460 status_alias = aliased(ChangesetStatus)
460 status_alias = aliased(ChangesetStatus)
461 reviewers_alias = aliased(PullRequestReviewers)
461 reviewers_alias = aliased(PullRequestReviewers)
462 repo_alias = aliased(Repository)
462 repo_alias = aliased(Repository)
463
463
464 last_ver_subq = Session()\
464 last_ver_subq = Session()\
465 .query(func.min(ChangesetStatus.version)) \
465 .query(func.min(ChangesetStatus.version)) \
466 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
466 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
467 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
467 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
468 .subquery()
468 .subquery()
469
469
470 q = Session().query(pull_request_alias) \
470 q = Session().query(pull_request_alias) \
471 .options(lazyload(pull_request_alias.author)) \
471 .options(lazyload(pull_request_alias.author)) \
472 .join(reviewers_alias,
472 .join(reviewers_alias,
473 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
473 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
474 .join(repo_alias,
474 .join(repo_alias,
475 repo_alias.repo_id == pull_request_alias.target_repo_id) \
475 repo_alias.repo_id == pull_request_alias.target_repo_id) \
476 .outerjoin(status_alias,
476 .outerjoin(status_alias,
477 and_(status_alias.user_id == reviewers_alias.user_id,
477 and_(status_alias.user_id == reviewers_alias.user_id,
478 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
478 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
479 .filter(or_(status_alias.version == null(),
479 .filter(or_(status_alias.version == null(),
480 status_alias.version == last_ver_subq)) \
480 status_alias.version == last_ver_subq)) \
481 .filter(reviewers_alias.user_id == user_id) \
481 .filter(reviewers_alias.user_id == user_id) \
482 .filter(repo_alias.repo_name == repo_name) \
482 .filter(repo_alias.repo_name == repo_name) \
483 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
483 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
484 .group_by(pull_request_alias)
484 .group_by(pull_request_alias)
485
485
486 # closed,opened
486 # closed,opened
487 if statuses:
487 if statuses:
488 q = q.filter(pull_request_alias.status.in_(statuses))
488 q = q.filter(pull_request_alias.status.in_(statuses))
489
489
490 if search_q:
490 if search_q:
491 like_expression = u'%{}%'.format(safe_str(search_q))
491 like_expression = u'%{}%'.format(safe_str(search_q))
492 q = q.join(User, User.user_id == pull_request_alias.user_id)
492 q = q.join(User, User.user_id == pull_request_alias.user_id)
493 q = q.filter(or_(
493 q = q.filter(or_(
494 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
494 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
495 User.username.ilike(like_expression),
495 User.username.ilike(like_expression),
496 pull_request_alias.title.ilike(like_expression),
496 pull_request_alias.title.ilike(like_expression),
497 pull_request_alias.description.ilike(like_expression),
497 pull_request_alias.description.ilike(like_expression),
498 ))
498 ))
499
499
500 order_map = {
500 order_map = {
501 'name_raw': pull_request_alias.pull_request_id,
501 'name_raw': pull_request_alias.pull_request_id,
502 'title': pull_request_alias.title,
502 'title': pull_request_alias.title,
503 'updated_on_raw': pull_request_alias.updated_on,
503 'updated_on_raw': pull_request_alias.updated_on,
504 'target_repo': pull_request_alias.target_repo_id
504 'target_repo': pull_request_alias.target_repo_id
505 }
505 }
506 if order_by and order_by in order_map:
506 if order_by and order_by in order_map:
507 if order_dir == 'asc':
507 if order_dir == 'asc':
508 q = q.order_by(order_map[order_by].asc())
508 q = q.order_by(order_map[order_by].asc())
509 else:
509 else:
510 q = q.order_by(order_map[order_by].desc())
510 q = q.order_by(order_map[order_by].desc())
511
511
512 return q
512 return q
513
513
514 def count_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None):
514 def count_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None):
515 """
515 """
516 Count the number of pull requests for a specific repository that are
516 Count the number of pull requests for a specific repository that are
517 awaiting review from a specific user.
517 awaiting review from a specific user.
518
518
519 :param repo_name: target or source repo
519 :param repo_name: target or source repo
520 :param user_id: reviewer user of the pull request
520 :param user_id: reviewer user of the pull request
521 :param search_q: filter by text
521 :param search_q: filter by text
522 :param statuses: list of pull request statuses
522 :param statuses: list of pull request statuses
523 :returns: int number of pull requests
523 :returns: int number of pull requests
524 """
524 """
525 q = self._prepare_awaiting_my_review_review_query(
525 q = self._prepare_awaiting_my_review_review_query(
526 repo_name, user_id, search_q=search_q, statuses=statuses)
526 repo_name, user_id, search_q=search_q, statuses=statuses)
527 return q.count()
527 return q.count()
528
528
529 def get_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None,
529 def get_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None,
530 offset=0, length=None, order_by=None, order_dir='desc'):
530 offset=0, length=None, order_by=None, order_dir='desc'):
531 """
531 """
532 Get all pull requests for a specific repository that are awaiting
532 Get all pull requests for a specific repository that are awaiting
533 review from a specific user.
533 review from a specific user.
534
534
535 :param repo_name: target or source repo
535 :param repo_name: target or source repo
536 :param user_id: reviewer user of the pull request
536 :param user_id: reviewer user of the pull request
537 :param search_q: filter by text
537 :param search_q: filter by text
538 :param statuses: list of pull request statuses
538 :param statuses: list of pull request statuses
539 :param offset: pagination offset
539 :param offset: pagination offset
540 :param length: length of returned list
540 :param length: length of returned list
541 :param order_by: order of the returned list
541 :param order_by: order of the returned list
542 :param order_dir: 'asc' or 'desc' ordering direction
542 :param order_dir: 'asc' or 'desc' ordering direction
543 :returns: list of pull requests
543 :returns: list of pull requests
544 """
544 """
545
545
546 q = self._prepare_awaiting_my_review_review_query(
546 q = self._prepare_awaiting_my_review_review_query(
547 repo_name, user_id, search_q=search_q, statuses=statuses,
547 repo_name, user_id, search_q=search_q, statuses=statuses,
548 order_by=order_by, order_dir=order_dir)
548 order_by=order_by, order_dir=order_dir)
549
549
550 if length:
550 if length:
551 pull_requests = q.limit(length).offset(offset).all()
551 pull_requests = q.limit(length).offset(offset).all()
552 else:
552 else:
553 pull_requests = q.all()
553 pull_requests = q.all()
554
554
555 return pull_requests
555 return pull_requests
556
556
557 def _prepare_im_participating_query(self, user_id=None, statuses=None, query='',
557 def _prepare_im_participating_query(self, user_id=None, statuses=None, query='',
558 order_by=None, order_dir='desc'):
558 order_by=None, order_dir='desc'):
559 """
559 """
560 return a query of pull-requests user is an creator, or he's added as a reviewer
560 return a query of pull-requests user is an creator, or he's added as a reviewer
561 """
561 """
562 q = PullRequest.query()
562 q = PullRequest.query()
563 if user_id:
563 if user_id:
564
564
565 base_query = select(PullRequestReviewers)\
565 base_query = select(PullRequestReviewers)\
566 .where(PullRequestReviewers.user_id == user_id)\
566 .where(PullRequestReviewers.user_id == user_id)\
567 .with_only_columns(PullRequestReviewers.pull_request_id)
567 .with_only_columns(PullRequestReviewers.pull_request_id)
568
568
569 user_filter = or_(
569 user_filter = or_(
570 PullRequest.user_id == user_id,
570 PullRequest.user_id == user_id,
571 PullRequest.pull_request_id.in_(base_query)
571 PullRequest.pull_request_id.in_(base_query)
572 )
572 )
573 q = PullRequest.query().filter(user_filter)
573 q = PullRequest.query().filter(user_filter)
574
574
575 # closed,opened
575 # closed,opened
576 if statuses:
576 if statuses:
577 q = q.filter(PullRequest.status.in_(statuses))
577 q = q.filter(PullRequest.status.in_(statuses))
578
578
579 if query:
579 if query:
580 like_expression = u'%{}%'.format(safe_str(query))
580 like_expression = u'%{}%'.format(safe_str(query))
581 q = q.join(User, User.user_id == PullRequest.user_id)
581 q = q.join(User, User.user_id == PullRequest.user_id)
582 q = q.filter(or_(
582 q = q.filter(or_(
583 cast(PullRequest.pull_request_id, String).ilike(like_expression),
583 cast(PullRequest.pull_request_id, String).ilike(like_expression),
584 User.username.ilike(like_expression),
584 User.username.ilike(like_expression),
585 PullRequest.title.ilike(like_expression),
585 PullRequest.title.ilike(like_expression),
586 PullRequest.description.ilike(like_expression),
586 PullRequest.description.ilike(like_expression),
587 ))
587 ))
588
588
589 order_map = {
589 order_map = {
590 'name_raw': PullRequest.pull_request_id,
590 'name_raw': PullRequest.pull_request_id,
591 'title': PullRequest.title,
591 'title': PullRequest.title,
592 'updated_on_raw': PullRequest.updated_on,
592 'updated_on_raw': PullRequest.updated_on,
593 'target_repo': PullRequest.target_repo_id
593 'target_repo': PullRequest.target_repo_id
594 }
594 }
595 if order_by and order_by in order_map:
595 if order_by and order_by in order_map:
596 if order_dir == 'asc':
596 if order_dir == 'asc':
597 q = q.order_by(order_map[order_by].asc())
597 q = q.order_by(order_map[order_by].asc())
598 else:
598 else:
599 q = q.order_by(order_map[order_by].desc())
599 q = q.order_by(order_map[order_by].desc())
600
600
601 return q
601 return q
602
602
603 def count_im_participating_in(self, user_id=None, statuses=None, query=''):
603 def count_im_participating_in(self, user_id=None, statuses=None, query=''):
604 q = self._prepare_im_participating_query(user_id, statuses=statuses, query=query)
604 q = self._prepare_im_participating_query(user_id, statuses=statuses, query=query)
605 return q.count()
605 return q.count()
606
606
607 def get_im_participating_in(
607 def get_im_participating_in(
608 self, user_id=None, statuses=None, query='', offset=0,
608 self, user_id=None, statuses=None, query='', offset=0,
609 length=None, order_by=None, order_dir='desc'):
609 length=None, order_by=None, order_dir='desc'):
610 """
610 """
611 Get all Pull requests that i'm participating in as a reviewer, or i have opened
611 Get all Pull requests that i'm participating in as a reviewer, or i have opened
612 """
612 """
613
613
614 q = self._prepare_im_participating_query(
614 q = self._prepare_im_participating_query(
615 user_id, statuses=statuses, query=query, order_by=order_by,
615 user_id, statuses=statuses, query=query, order_by=order_by,
616 order_dir=order_dir)
616 order_dir=order_dir)
617
617
618 if length:
618 if length:
619 pull_requests = q.limit(length).offset(offset).all()
619 pull_requests = q.limit(length).offset(offset).all()
620 else:
620 else:
621 pull_requests = q.all()
621 pull_requests = q.all()
622
622
623 return pull_requests
623 return pull_requests
624
624
625 def _prepare_participating_in_for_review_query(
625 def _prepare_participating_in_for_review_query(
626 self, user_id, statuses=None, query='', order_by=None, order_dir='desc'):
626 self, user_id, statuses=None, query='', order_by=None, order_dir='desc'):
627
627
628 for_review_statuses = [
628 for_review_statuses = [
629 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
629 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
630 ]
630 ]
631
631
632 pull_request_alias = aliased(PullRequest)
632 pull_request_alias = aliased(PullRequest)
633 status_alias = aliased(ChangesetStatus)
633 status_alias = aliased(ChangesetStatus)
634 reviewers_alias = aliased(PullRequestReviewers)
634 reviewers_alias = aliased(PullRequestReviewers)
635
635
636 last_ver_subq = Session()\
636 last_ver_subq = Session()\
637 .query(func.min(ChangesetStatus.version)) \
637 .query(func.min(ChangesetStatus.version)) \
638 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
638 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
639 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
639 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
640 .subquery()
640 .subquery()
641
641
642 q = Session().query(pull_request_alias) \
642 q = Session().query(pull_request_alias) \
643 .options(lazyload(pull_request_alias.author)) \
643 .options(lazyload(pull_request_alias.author)) \
644 .join(reviewers_alias,
644 .join(reviewers_alias,
645 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
645 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
646 .outerjoin(status_alias,
646 .outerjoin(status_alias,
647 and_(status_alias.user_id == reviewers_alias.user_id,
647 and_(status_alias.user_id == reviewers_alias.user_id,
648 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
648 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
649 .filter(or_(status_alias.version == null(),
649 .filter(or_(status_alias.version == null(),
650 status_alias.version == last_ver_subq)) \
650 status_alias.version == last_ver_subq)) \
651 .filter(reviewers_alias.user_id == user_id) \
651 .filter(reviewers_alias.user_id == user_id) \
652 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
652 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
653 .group_by(pull_request_alias)
653 .group_by(pull_request_alias)
654
654
655 # closed,opened
655 # closed,opened
656 if statuses:
656 if statuses:
657 q = q.filter(pull_request_alias.status.in_(statuses))
657 q = q.filter(pull_request_alias.status.in_(statuses))
658
658
659 if query:
659 if query:
660 like_expression = u'%{}%'.format(safe_str(query))
660 like_expression = u'%{}%'.format(safe_str(query))
661 q = q.join(User, User.user_id == pull_request_alias.user_id)
661 q = q.join(User, User.user_id == pull_request_alias.user_id)
662 q = q.filter(or_(
662 q = q.filter(or_(
663 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
663 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
664 User.username.ilike(like_expression),
664 User.username.ilike(like_expression),
665 pull_request_alias.title.ilike(like_expression),
665 pull_request_alias.title.ilike(like_expression),
666 pull_request_alias.description.ilike(like_expression),
666 pull_request_alias.description.ilike(like_expression),
667 ))
667 ))
668
668
669 order_map = {
669 order_map = {
670 'name_raw': pull_request_alias.pull_request_id,
670 'name_raw': pull_request_alias.pull_request_id,
671 'title': pull_request_alias.title,
671 'title': pull_request_alias.title,
672 'updated_on_raw': pull_request_alias.updated_on,
672 'updated_on_raw': pull_request_alias.updated_on,
673 'target_repo': pull_request_alias.target_repo_id
673 'target_repo': pull_request_alias.target_repo_id
674 }
674 }
675 if order_by and order_by in order_map:
675 if order_by and order_by in order_map:
676 if order_dir == 'asc':
676 if order_dir == 'asc':
677 q = q.order_by(order_map[order_by].asc())
677 q = q.order_by(order_map[order_by].asc())
678 else:
678 else:
679 q = q.order_by(order_map[order_by].desc())
679 q = q.order_by(order_map[order_by].desc())
680
680
681 return q
681 return q
682
682
683 def count_im_participating_in_for_review(self, user_id, statuses=None, query=''):
683 def count_im_participating_in_for_review(self, user_id, statuses=None, query=''):
684 q = self._prepare_participating_in_for_review_query(user_id, statuses=statuses, query=query)
684 q = self._prepare_participating_in_for_review_query(user_id, statuses=statuses, query=query)
685 return q.count()
685 return q.count()
686
686
687 def get_im_participating_in_for_review(
687 def get_im_participating_in_for_review(
688 self, user_id, statuses=None, query='', offset=0,
688 self, user_id, statuses=None, query='', offset=0,
689 length=None, order_by=None, order_dir='desc'):
689 length=None, order_by=None, order_dir='desc'):
690 """
690 """
691 Get all Pull requests that needs user approval or rejection
691 Get all Pull requests that needs user approval or rejection
692 """
692 """
693
693
694 q = self._prepare_participating_in_for_review_query(
694 q = self._prepare_participating_in_for_review_query(
695 user_id, statuses=statuses, query=query, order_by=order_by,
695 user_id, statuses=statuses, query=query, order_by=order_by,
696 order_dir=order_dir)
696 order_dir=order_dir)
697
697
698 if length:
698 if length:
699 pull_requests = q.limit(length).offset(offset).all()
699 pull_requests = q.limit(length).offset(offset).all()
700 else:
700 else:
701 pull_requests = q.all()
701 pull_requests = q.all()
702
702
703 return pull_requests
703 return pull_requests
704
704
705 def get_versions(self, pull_request):
705 def get_versions(self, pull_request):
706 """
706 """
707 returns version of pull request sorted by ID descending
707 returns version of pull request sorted by ID descending
708 """
708 """
709 return PullRequestVersion.query()\
709 return PullRequestVersion.query()\
710 .filter(PullRequestVersion.pull_request == pull_request)\
710 .filter(PullRequestVersion.pull_request == pull_request)\
711 .order_by(PullRequestVersion.pull_request_version_id.asc())\
711 .order_by(PullRequestVersion.pull_request_version_id.asc())\
712 .all()
712 .all()
713
713
714 def get_pr_version(self, pull_request_id, version=None):
714 def get_pr_version(self, pull_request_id, version=None):
715 at_version = None
715 at_version = None
716
716
717 if version and version == 'latest':
717 if version and version == 'latest':
718 pull_request_ver = PullRequest.get(pull_request_id)
718 pull_request_ver = PullRequest.get(pull_request_id)
719 pull_request_obj = pull_request_ver
719 pull_request_obj = pull_request_ver
720 _org_pull_request_obj = pull_request_obj
720 _org_pull_request_obj = pull_request_obj
721 at_version = 'latest'
721 at_version = 'latest'
722 elif version:
722 elif version:
723 pull_request_ver = PullRequestVersion.get_or_404(version)
723 pull_request_ver = PullRequestVersion.get_or_404(version)
724 pull_request_obj = pull_request_ver
724 pull_request_obj = pull_request_ver
725 _org_pull_request_obj = pull_request_ver.pull_request
725 _org_pull_request_obj = pull_request_ver.pull_request
726 at_version = pull_request_ver.pull_request_version_id
726 at_version = pull_request_ver.pull_request_version_id
727 else:
727 else:
728 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
728 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
729 pull_request_id)
729 pull_request_id)
730
730
731 pull_request_display_obj = PullRequest.get_pr_display_object(
731 pull_request_display_obj = PullRequest.get_pr_display_object(
732 pull_request_obj, _org_pull_request_obj)
732 pull_request_obj, _org_pull_request_obj)
733
733
734 return _org_pull_request_obj, pull_request_obj, \
734 return _org_pull_request_obj, pull_request_obj, \
735 pull_request_display_obj, at_version
735 pull_request_display_obj, at_version
736
736
737 def pr_commits_versions(self, versions):
737 def pr_commits_versions(self, versions):
738 """
738 """
739 Maps the pull-request commits into all known PR versions. This way we can obtain
739 Maps the pull-request commits into all known PR versions. This way we can obtain
740 each pr version the commit was introduced in.
740 each pr version the commit was introduced in.
741 """
741 """
742 commit_versions = collections.defaultdict(list)
742 commit_versions = collections.defaultdict(list)
743 num_versions = [x.pull_request_version_id for x in versions]
743 num_versions = [x.pull_request_version_id for x in versions]
744 for ver in versions:
744 for ver in versions:
745 for commit_id in ver.revisions:
745 for commit_id in ver.revisions:
746 ver_idx = ChangesetComment.get_index_from_version(
746 ver_idx = ChangesetComment.get_index_from_version(
747 ver.pull_request_version_id, num_versions=num_versions)
747 ver.pull_request_version_id, num_versions=num_versions)
748 commit_versions[commit_id].append(ver_idx)
748 commit_versions[commit_id].append(ver_idx)
749 return commit_versions
749 return commit_versions
750
750
751 def create(self, created_by, source_repo, source_ref, target_repo,
751 def create(self, created_by, source_repo, source_ref, target_repo,
752 target_ref, revisions, reviewers, observers, title, description=None,
752 target_ref, revisions, reviewers, observers, title, description=None,
753 common_ancestor_id=None,
753 common_ancestor_id=None,
754 description_renderer=None,
754 description_renderer=None,
755 reviewer_data=None, translator=None, auth_user=None):
755 reviewer_data=None, translator=None, auth_user=None):
756 translator = translator or get_current_request().translate
756 translator = translator or get_current_request().translate
757
757
758 created_by_user = self._get_user(created_by)
758 created_by_user = self._get_user(created_by)
759 auth_user = auth_user or created_by_user.AuthUser()
759 auth_user = auth_user or created_by_user.AuthUser()
760 source_repo = self._get_repo(source_repo)
760 source_repo = self._get_repo(source_repo)
761 target_repo = self._get_repo(target_repo)
761 target_repo = self._get_repo(target_repo)
762
762
763 pull_request = PullRequest()
763 pull_request = PullRequest()
764 pull_request.source_repo = source_repo
764 pull_request.source_repo = source_repo
765 pull_request.source_ref = source_ref
765 pull_request.source_ref = source_ref
766 pull_request.target_repo = target_repo
766 pull_request.target_repo = target_repo
767 pull_request.target_ref = target_ref
767 pull_request.target_ref = target_ref
768 pull_request.revisions = revisions
768 pull_request.revisions = revisions
769 pull_request.title = title
769 pull_request.title = title
770 pull_request.description = description
770 pull_request.description = description
771 pull_request.description_renderer = description_renderer
771 pull_request.description_renderer = description_renderer
772 pull_request.author = created_by_user
772 pull_request.author = created_by_user
773 pull_request.reviewer_data = reviewer_data
773 pull_request.reviewer_data = reviewer_data
774 pull_request.pull_request_state = pull_request.STATE_CREATING
774 pull_request.pull_request_state = pull_request.STATE_CREATING
775 pull_request.common_ancestor_id = common_ancestor_id
775 pull_request.common_ancestor_id = common_ancestor_id
776
776
777 Session().add(pull_request)
777 Session().add(pull_request)
778 Session().flush()
778 Session().flush()
779
779
780 reviewer_ids = set()
780 reviewer_ids = set()
781 # members / reviewers
781 # members / reviewers
782 for reviewer_object in reviewers:
782 for reviewer_object in reviewers:
783 user_id, reasons, mandatory, role, rules = reviewer_object
783 user_id, reasons, mandatory, role, rules = reviewer_object
784 user = self._get_user(user_id)
784 user = self._get_user(user_id)
785
785
786 # skip duplicates
786 # skip duplicates
787 if user.user_id in reviewer_ids:
787 if user.user_id in reviewer_ids:
788 continue
788 continue
789
789
790 reviewer_ids.add(user.user_id)
790 reviewer_ids.add(user.user_id)
791
791
792 reviewer = PullRequestReviewers()
792 reviewer = PullRequestReviewers()
793 reviewer.user = user
793 reviewer.user = user
794 reviewer.pull_request = pull_request
794 reviewer.pull_request = pull_request
795 reviewer.reasons = reasons
795 reviewer.reasons = reasons
796 reviewer.mandatory = mandatory
796 reviewer.mandatory = mandatory
797 reviewer.role = role
797 reviewer.role = role
798
798
799 # NOTE(marcink): pick only first rule for now
799 # NOTE(marcink): pick only first rule for now
800 rule_id = list(rules)[0] if rules else None
800 rule_id = list(rules)[0] if rules else None
801 rule = RepoReviewRule.get(rule_id) if rule_id else None
801 rule = RepoReviewRule.get(rule_id) if rule_id else None
802 if rule:
802 if rule:
803 review_group = rule.user_group_vote_rule(user_id)
803 review_group = rule.user_group_vote_rule(user_id)
804 # we check if this particular reviewer is member of a voting group
804 # we check if this particular reviewer is member of a voting group
805 if review_group:
805 if review_group:
806 # NOTE(marcink):
806 # NOTE(marcink):
807 # can be that user is member of more but we pick the first same,
807 # can be that user is member of more but we pick the first same,
808 # same as default reviewers algo
808 # same as default reviewers algo
809 review_group = review_group[0]
809 review_group = review_group[0]
810
810
811 rule_data = {
811 rule_data = {
812 'rule_name':
812 'rule_name':
813 rule.review_rule_name,
813 rule.review_rule_name,
814 'rule_user_group_entry_id':
814 'rule_user_group_entry_id':
815 review_group.repo_review_rule_users_group_id,
815 review_group.repo_review_rule_users_group_id,
816 'rule_user_group_name':
816 'rule_user_group_name':
817 review_group.users_group.users_group_name,
817 review_group.users_group.users_group_name,
818 'rule_user_group_members':
818 'rule_user_group_members':
819 [x.user.username for x in review_group.users_group.members],
819 [x.user.username for x in review_group.users_group.members],
820 'rule_user_group_members_id':
820 'rule_user_group_members_id':
821 [x.user.user_id for x in review_group.users_group.members],
821 [x.user.user_id for x in review_group.users_group.members],
822 }
822 }
823 # e.g {'vote_rule': -1, 'mandatory': True}
823 # e.g {'vote_rule': -1, 'mandatory': True}
824 rule_data.update(review_group.rule_data())
824 rule_data.update(review_group.rule_data())
825
825
826 reviewer.rule_data = rule_data
826 reviewer.rule_data = rule_data
827
827
828 Session().add(reviewer)
828 Session().add(reviewer)
829 Session().flush()
829 Session().flush()
830
830
831 for observer_object in observers:
831 for observer_object in observers:
832 user_id, reasons, mandatory, role, rules = observer_object
832 user_id, reasons, mandatory, role, rules = observer_object
833 user = self._get_user(user_id)
833 user = self._get_user(user_id)
834
834
835 # skip duplicates from reviewers
835 # skip duplicates from reviewers
836 if user.user_id in reviewer_ids:
836 if user.user_id in reviewer_ids:
837 continue
837 continue
838
838
839 #reviewer_ids.add(user.user_id)
839 #reviewer_ids.add(user.user_id)
840
840
841 observer = PullRequestReviewers()
841 observer = PullRequestReviewers()
842 observer.user = user
842 observer.user = user
843 observer.pull_request = pull_request
843 observer.pull_request = pull_request
844 observer.reasons = reasons
844 observer.reasons = reasons
845 observer.mandatory = mandatory
845 observer.mandatory = mandatory
846 observer.role = role
846 observer.role = role
847
847
848 # NOTE(marcink): pick only first rule for now
848 # NOTE(marcink): pick only first rule for now
849 rule_id = list(rules)[0] if rules else None
849 rule_id = list(rules)[0] if rules else None
850 rule = RepoReviewRule.get(rule_id) if rule_id else None
850 rule = RepoReviewRule.get(rule_id) if rule_id else None
851 if rule:
851 if rule:
852 # TODO(marcink): do we need this for observers ??
852 # TODO(marcink): do we need this for observers ??
853 pass
853 pass
854
854
855 Session().add(observer)
855 Session().add(observer)
856 Session().flush()
856 Session().flush()
857
857
858 # Set approval status to "Under Review" for all commits which are
858 # Set approval status to "Under Review" for all commits which are
859 # part of this pull request.
859 # part of this pull request.
860 ChangesetStatusModel().set_status(
860 ChangesetStatusModel().set_status(
861 repo=target_repo,
861 repo=target_repo,
862 status=ChangesetStatus.STATUS_UNDER_REVIEW,
862 status=ChangesetStatus.STATUS_UNDER_REVIEW,
863 user=created_by_user,
863 user=created_by_user,
864 pull_request=pull_request
864 pull_request=pull_request
865 )
865 )
866 # we commit early at this point. This has to do with a fact
866 # we commit early at this point. This has to do with a fact
867 # that before queries do some row-locking. And because of that
867 # that before queries do some row-locking. And because of that
868 # we need to commit and finish transaction before below validate call
868 # we need to commit and finish transaction before below validate call
869 # that for large repos could be long resulting in long row locks
869 # that for large repos could be long resulting in long row locks
870 Session().commit()
870 Session().commit()
871
871
872 # prepare workspace, and run initial merge simulation. Set state during that
872 # prepare workspace, and run initial merge simulation. Set state during that
873 # operation
873 # operation
874 pull_request = PullRequest.get(pull_request.pull_request_id)
874 pull_request = PullRequest.get(pull_request.pull_request_id)
875
875
876 # set as merging, for merge simulation, and if finished to created so we mark
876 # set as merging, for merge simulation, and if finished to created so we mark
877 # simulation is working fine
877 # simulation is working fine
878 with pull_request.set_state(PullRequest.STATE_MERGING,
878 with pull_request.set_state(PullRequest.STATE_MERGING,
879 final_state=PullRequest.STATE_CREATED) as state_obj:
879 final_state=PullRequest.STATE_CREATED) as state_obj:
880 MergeCheck.validate(
880 MergeCheck.validate(
881 pull_request, auth_user=auth_user, translator=translator)
881 pull_request, auth_user=auth_user, translator=translator)
882
882
883 self.notify_reviewers(pull_request, reviewer_ids, created_by_user)
883 self.notify_reviewers(pull_request, reviewer_ids, created_by_user)
884 self.trigger_pull_request_hook(pull_request, created_by_user, 'create')
884 self.trigger_pull_request_hook(pull_request, created_by_user, 'create')
885
885
886 creation_data = pull_request.get_api_data(with_merge_state=False)
886 creation_data = pull_request.get_api_data(with_merge_state=False)
887 self._log_audit_action(
887 self._log_audit_action(
888 'repo.pull_request.create', {'data': creation_data},
888 'repo.pull_request.create', {'data': creation_data},
889 auth_user, pull_request)
889 auth_user, pull_request)
890
890
891 return pull_request
891 return pull_request
892
892
893 def trigger_pull_request_hook(self, pull_request, user, action, data=None):
893 def trigger_pull_request_hook(self, pull_request, user, action, data=None):
894 pull_request = self.__get_pull_request(pull_request)
894 pull_request = self.__get_pull_request(pull_request)
895 target_scm = pull_request.target_repo.scm_instance()
895 target_scm = pull_request.target_repo.scm_instance()
896 if action == 'create':
896 if action == 'create':
897 trigger_hook = hooks_utils.trigger_create_pull_request_hook
897 trigger_hook = hooks_utils.trigger_create_pull_request_hook
898 elif action == 'merge':
898 elif action == 'merge':
899 trigger_hook = hooks_utils.trigger_merge_pull_request_hook
899 trigger_hook = hooks_utils.trigger_merge_pull_request_hook
900 elif action == 'close':
900 elif action == 'close':
901 trigger_hook = hooks_utils.trigger_close_pull_request_hook
901 trigger_hook = hooks_utils.trigger_close_pull_request_hook
902 elif action == 'review_status_change':
902 elif action == 'review_status_change':
903 trigger_hook = hooks_utils.trigger_review_pull_request_hook
903 trigger_hook = hooks_utils.trigger_review_pull_request_hook
904 elif action == 'update':
904 elif action == 'update':
905 trigger_hook = hooks_utils.trigger_update_pull_request_hook
905 trigger_hook = hooks_utils.trigger_update_pull_request_hook
906 elif action == 'comment':
906 elif action == 'comment':
907 trigger_hook = hooks_utils.trigger_comment_pull_request_hook
907 trigger_hook = hooks_utils.trigger_comment_pull_request_hook
908 elif action == 'comment_edit':
908 elif action == 'comment_edit':
909 trigger_hook = hooks_utils.trigger_comment_pull_request_edit_hook
909 trigger_hook = hooks_utils.trigger_comment_pull_request_edit_hook
910 else:
910 else:
911 return
911 return
912
912
913 log.debug('Handling pull_request %s trigger_pull_request_hook with action %s and hook: %s',
913 log.debug('Handling pull_request %s trigger_pull_request_hook with action %s and hook: %s',
914 pull_request, action, trigger_hook)
914 pull_request, action, trigger_hook)
915 trigger_hook(
915 trigger_hook(
916 username=user.username,
916 username=user.username,
917 repo_name=pull_request.target_repo.repo_name,
917 repo_name=pull_request.target_repo.repo_name,
918 repo_type=target_scm.alias,
918 repo_type=target_scm.alias,
919 pull_request=pull_request,
919 pull_request=pull_request,
920 data=data)
920 data=data)
921
921
922 def _get_commit_ids(self, pull_request):
922 def _get_commit_ids(self, pull_request):
923 """
923 """
924 Return the commit ids of the merged pull request.
924 Return the commit ids of the merged pull request.
925
925
926 This method is not dealing correctly yet with the lack of autoupdates
926 This method is not dealing correctly yet with the lack of autoupdates
927 nor with the implicit target updates.
927 nor with the implicit target updates.
928 For example: if a commit in the source repo is already in the target it
928 For example: if a commit in the source repo is already in the target it
929 will be reported anyways.
929 will be reported anyways.
930 """
930 """
931 merge_rev = pull_request.merge_rev
931 merge_rev = pull_request.merge_rev
932 if merge_rev is None:
932 if merge_rev is None:
933 raise ValueError('This pull request was not merged yet')
933 raise ValueError('This pull request was not merged yet')
934
934
935 commit_ids = list(pull_request.revisions)
935 commit_ids = list(pull_request.revisions)
936 if merge_rev not in commit_ids:
936 if merge_rev not in commit_ids:
937 commit_ids.append(merge_rev)
937 commit_ids.append(merge_rev)
938
938
939 return commit_ids
939 return commit_ids
940
940
941 def merge_repo(self, pull_request, user, extras):
941 def merge_repo(self, pull_request, user, extras):
942 repo_type = pull_request.source_repo.repo_type
942 repo_type = pull_request.source_repo.repo_type
943 log.debug("Merging pull request %s", pull_request)
943 log.debug("Merging pull request %s", pull_request)
944
944
945 extras['user_agent'] = '{}/internal-merge'.format(repo_type)
945 extras['user_agent'] = '{}/internal-merge'.format(repo_type)
946 merge_state = self._merge_pull_request(pull_request, user, extras)
946 merge_state = self._merge_pull_request(pull_request, user, extras)
947 if merge_state.executed:
947 if merge_state.executed:
948 log.debug("Merge was successful, updating the pull request comments.")
948 log.debug("Merge was successful, updating the pull request comments.")
949 self._comment_and_close_pr(pull_request, user, merge_state)
949 self._comment_and_close_pr(pull_request, user, merge_state)
950
950
951 self._log_audit_action(
951 self._log_audit_action(
952 'repo.pull_request.merge',
952 'repo.pull_request.merge',
953 {'merge_state': merge_state.__dict__},
953 {'merge_state': merge_state.__dict__},
954 user, pull_request)
954 user, pull_request)
955
955
956 else:
956 else:
957 log.warning("Merge failed, not updating the pull request.")
957 log.warning("Merge failed, not updating the pull request.")
958 return merge_state
958 return merge_state
959
959
960 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
960 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
961 target_vcs = pull_request.target_repo.scm_instance()
961 target_vcs = pull_request.target_repo.scm_instance()
962 source_vcs = pull_request.source_repo.scm_instance()
962 source_vcs = pull_request.source_repo.scm_instance()
963
963
964 message = safe_str(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
964 message = safe_str(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
965 pr_id=pull_request.pull_request_id,
965 pr_id=pull_request.pull_request_id,
966 pr_title=pull_request.title,
966 pr_title=pull_request.title,
967 pr_desc=pull_request.description,
967 pr_desc=pull_request.description,
968 source_repo=source_vcs.name,
968 source_repo=source_vcs.name,
969 source_ref_name=pull_request.source_ref_parts.name,
969 source_ref_name=pull_request.source_ref_parts.name,
970 target_repo=target_vcs.name,
970 target_repo=target_vcs.name,
971 target_ref_name=pull_request.target_ref_parts.name,
971 target_ref_name=pull_request.target_ref_parts.name,
972 )
972 )
973
973
974 workspace_id = self._workspace_id(pull_request)
974 workspace_id = self._workspace_id(pull_request)
975 repo_id = pull_request.target_repo.repo_id
975 repo_id = pull_request.target_repo.repo_id
976 use_rebase = self._use_rebase_for_merging(pull_request)
976 use_rebase = self._use_rebase_for_merging(pull_request)
977 close_branch = self._close_branch_before_merging(pull_request)
977 close_branch = self._close_branch_before_merging(pull_request)
978 user_name = self._user_name_for_merging(pull_request, user)
978 user_name = self._user_name_for_merging(pull_request, user)
979
979
980 target_ref = self._refresh_reference(
980 target_ref = self._refresh_reference(
981 pull_request.target_ref_parts, target_vcs)
981 pull_request.target_ref_parts, target_vcs)
982
982
983 callback_daemon, extras = prepare_callback_daemon(
983 callback_daemon, extras = prepare_callback_daemon(
984 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
984 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
985 host=vcs_settings.HOOKS_HOST,
985 host=vcs_settings.HOOKS_HOST)
986 use_direct_calls=vcs_settings.HOOKS_DIRECT_CALLS)
987
986
988 with callback_daemon:
987 with callback_daemon:
989 # TODO: johbo: Implement a clean way to run a config_override
988 # TODO: johbo: Implement a clean way to run a config_override
990 # for a single call.
989 # for a single call.
991 target_vcs.config.set(
990 target_vcs.config.set(
992 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
991 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
993
992
994 merge_state = target_vcs.merge(
993 merge_state = target_vcs.merge(
995 repo_id, workspace_id, target_ref, source_vcs,
994 repo_id, workspace_id, target_ref, source_vcs,
996 pull_request.source_ref_parts,
995 pull_request.source_ref_parts,
997 user_name=user_name, user_email=user.email,
996 user_name=user_name, user_email=user.email,
998 message=message, use_rebase=use_rebase,
997 message=message, use_rebase=use_rebase,
999 close_branch=close_branch)
998 close_branch=close_branch)
1000
999
1001 return merge_state
1000 return merge_state
1002
1001
1003 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
1002 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
1004 pull_request.merge_rev = merge_state.merge_ref.commit_id
1003 pull_request.merge_rev = merge_state.merge_ref.commit_id
1005 pull_request.updated_on = datetime.datetime.now()
1004 pull_request.updated_on = datetime.datetime.now()
1006 close_msg = close_msg or 'Pull request merged and closed'
1005 close_msg = close_msg or 'Pull request merged and closed'
1007
1006
1008 CommentsModel().create(
1007 CommentsModel().create(
1009 text=safe_str(close_msg),
1008 text=safe_str(close_msg),
1010 repo=pull_request.target_repo.repo_id,
1009 repo=pull_request.target_repo.repo_id,
1011 user=user.user_id,
1010 user=user.user_id,
1012 pull_request=pull_request.pull_request_id,
1011 pull_request=pull_request.pull_request_id,
1013 f_path=None,
1012 f_path=None,
1014 line_no=None,
1013 line_no=None,
1015 closing_pr=True
1014 closing_pr=True
1016 )
1015 )
1017
1016
1018 Session().add(pull_request)
1017 Session().add(pull_request)
1019 Session().flush()
1018 Session().flush()
1020 # TODO: paris: replace invalidation with less radical solution
1019 # TODO: paris: replace invalidation with less radical solution
1021 ScmModel().mark_for_invalidation(
1020 ScmModel().mark_for_invalidation(
1022 pull_request.target_repo.repo_name)
1021 pull_request.target_repo.repo_name)
1023 self.trigger_pull_request_hook(pull_request, user, 'merge')
1022 self.trigger_pull_request_hook(pull_request, user, 'merge')
1024
1023
1025 def has_valid_update_type(self, pull_request):
1024 def has_valid_update_type(self, pull_request):
1026 source_ref_type = pull_request.source_ref_parts.type
1025 source_ref_type = pull_request.source_ref_parts.type
1027 return source_ref_type in self.REF_TYPES
1026 return source_ref_type in self.REF_TYPES
1028
1027
1029 def get_flow_commits(self, pull_request):
1028 def get_flow_commits(self, pull_request):
1030
1029
1031 # source repo
1030 # source repo
1032 source_ref_name = pull_request.source_ref_parts.name
1031 source_ref_name = pull_request.source_ref_parts.name
1033 source_ref_type = pull_request.source_ref_parts.type
1032 source_ref_type = pull_request.source_ref_parts.type
1034 source_ref_id = pull_request.source_ref_parts.commit_id
1033 source_ref_id = pull_request.source_ref_parts.commit_id
1035 source_repo = pull_request.source_repo.scm_instance()
1034 source_repo = pull_request.source_repo.scm_instance()
1036
1035
1037 try:
1036 try:
1038 if source_ref_type in self.REF_TYPES:
1037 if source_ref_type in self.REF_TYPES:
1039 source_commit = source_repo.get_commit(
1038 source_commit = source_repo.get_commit(
1040 source_ref_name, reference_obj=pull_request.source_ref_parts)
1039 source_ref_name, reference_obj=pull_request.source_ref_parts)
1041 else:
1040 else:
1042 source_commit = source_repo.get_commit(source_ref_id)
1041 source_commit = source_repo.get_commit(source_ref_id)
1043 except CommitDoesNotExistError:
1042 except CommitDoesNotExistError:
1044 raise SourceRefMissing()
1043 raise SourceRefMissing()
1045
1044
1046 # target repo
1045 # target repo
1047 target_ref_name = pull_request.target_ref_parts.name
1046 target_ref_name = pull_request.target_ref_parts.name
1048 target_ref_type = pull_request.target_ref_parts.type
1047 target_ref_type = pull_request.target_ref_parts.type
1049 target_ref_id = pull_request.target_ref_parts.commit_id
1048 target_ref_id = pull_request.target_ref_parts.commit_id
1050 target_repo = pull_request.target_repo.scm_instance()
1049 target_repo = pull_request.target_repo.scm_instance()
1051
1050
1052 try:
1051 try:
1053 if target_ref_type in self.REF_TYPES:
1052 if target_ref_type in self.REF_TYPES:
1054 target_commit = target_repo.get_commit(
1053 target_commit = target_repo.get_commit(
1055 target_ref_name, reference_obj=pull_request.target_ref_parts)
1054 target_ref_name, reference_obj=pull_request.target_ref_parts)
1056 else:
1055 else:
1057 target_commit = target_repo.get_commit(target_ref_id)
1056 target_commit = target_repo.get_commit(target_ref_id)
1058 except CommitDoesNotExistError:
1057 except CommitDoesNotExistError:
1059 raise TargetRefMissing()
1058 raise TargetRefMissing()
1060
1059
1061 return source_commit, target_commit
1060 return source_commit, target_commit
1062
1061
1063 def update_commits(self, pull_request, updating_user):
1062 def update_commits(self, pull_request, updating_user):
1064 """
1063 """
1065 Get the updated list of commits for the pull request
1064 Get the updated list of commits for the pull request
1066 and return the new pull request version and the list
1065 and return the new pull request version and the list
1067 of commits processed by this update action
1066 of commits processed by this update action
1068
1067
1069 updating_user is the user_object who triggered the update
1068 updating_user is the user_object who triggered the update
1070 """
1069 """
1071 pull_request = self.__get_pull_request(pull_request)
1070 pull_request = self.__get_pull_request(pull_request)
1072 source_ref_type = pull_request.source_ref_parts.type
1071 source_ref_type = pull_request.source_ref_parts.type
1073 source_ref_name = pull_request.source_ref_parts.name
1072 source_ref_name = pull_request.source_ref_parts.name
1074 source_ref_id = pull_request.source_ref_parts.commit_id
1073 source_ref_id = pull_request.source_ref_parts.commit_id
1075
1074
1076 target_ref_type = pull_request.target_ref_parts.type
1075 target_ref_type = pull_request.target_ref_parts.type
1077 target_ref_name = pull_request.target_ref_parts.name
1076 target_ref_name = pull_request.target_ref_parts.name
1078 target_ref_id = pull_request.target_ref_parts.commit_id
1077 target_ref_id = pull_request.target_ref_parts.commit_id
1079
1078
1080 if not self.has_valid_update_type(pull_request):
1079 if not self.has_valid_update_type(pull_request):
1081 log.debug("Skipping update of pull request %s due to ref type: %s",
1080 log.debug("Skipping update of pull request %s due to ref type: %s",
1082 pull_request, source_ref_type)
1081 pull_request, source_ref_type)
1083 return UpdateResponse(
1082 return UpdateResponse(
1084 executed=False,
1083 executed=False,
1085 reason=UpdateFailureReason.WRONG_REF_TYPE,
1084 reason=UpdateFailureReason.WRONG_REF_TYPE,
1086 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1085 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1087 source_changed=False, target_changed=False)
1086 source_changed=False, target_changed=False)
1088
1087
1089 try:
1088 try:
1090 source_commit, target_commit = self.get_flow_commits(pull_request)
1089 source_commit, target_commit = self.get_flow_commits(pull_request)
1091 except SourceRefMissing:
1090 except SourceRefMissing:
1092 return UpdateResponse(
1091 return UpdateResponse(
1093 executed=False,
1092 executed=False,
1094 reason=UpdateFailureReason.MISSING_SOURCE_REF,
1093 reason=UpdateFailureReason.MISSING_SOURCE_REF,
1095 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1094 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1096 source_changed=False, target_changed=False)
1095 source_changed=False, target_changed=False)
1097 except TargetRefMissing:
1096 except TargetRefMissing:
1098 return UpdateResponse(
1097 return UpdateResponse(
1099 executed=False,
1098 executed=False,
1100 reason=UpdateFailureReason.MISSING_TARGET_REF,
1099 reason=UpdateFailureReason.MISSING_TARGET_REF,
1101 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1100 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1102 source_changed=False, target_changed=False)
1101 source_changed=False, target_changed=False)
1103
1102
1104 source_changed = source_ref_id != source_commit.raw_id
1103 source_changed = source_ref_id != source_commit.raw_id
1105 target_changed = target_ref_id != target_commit.raw_id
1104 target_changed = target_ref_id != target_commit.raw_id
1106
1105
1107 if not (source_changed or target_changed):
1106 if not (source_changed or target_changed):
1108 log.debug("Nothing changed in pull request %s", pull_request)
1107 log.debug("Nothing changed in pull request %s", pull_request)
1109 return UpdateResponse(
1108 return UpdateResponse(
1110 executed=False,
1109 executed=False,
1111 reason=UpdateFailureReason.NO_CHANGE,
1110 reason=UpdateFailureReason.NO_CHANGE,
1112 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1111 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1113 source_changed=target_changed, target_changed=source_changed)
1112 source_changed=target_changed, target_changed=source_changed)
1114
1113
1115 change_in_found = 'target repo' if target_changed else 'source repo'
1114 change_in_found = 'target repo' if target_changed else 'source repo'
1116 log.debug('Updating pull request because of change in %s detected',
1115 log.debug('Updating pull request because of change in %s detected',
1117 change_in_found)
1116 change_in_found)
1118
1117
1119 # Finally there is a need for an update, in case of source change
1118 # Finally there is a need for an update, in case of source change
1120 # we create a new version, else just an update
1119 # we create a new version, else just an update
1121 if source_changed:
1120 if source_changed:
1122 pull_request_version = self._create_version_from_snapshot(pull_request)
1121 pull_request_version = self._create_version_from_snapshot(pull_request)
1123 self._link_comments_to_version(pull_request_version)
1122 self._link_comments_to_version(pull_request_version)
1124 else:
1123 else:
1125 try:
1124 try:
1126 ver = pull_request.versions[-1]
1125 ver = pull_request.versions[-1]
1127 except IndexError:
1126 except IndexError:
1128 ver = None
1127 ver = None
1129
1128
1130 pull_request.pull_request_version_id = \
1129 pull_request.pull_request_version_id = \
1131 ver.pull_request_version_id if ver else None
1130 ver.pull_request_version_id if ver else None
1132 pull_request_version = pull_request
1131 pull_request_version = pull_request
1133
1132
1134 source_repo = pull_request.source_repo.scm_instance()
1133 source_repo = pull_request.source_repo.scm_instance()
1135 target_repo = pull_request.target_repo.scm_instance()
1134 target_repo = pull_request.target_repo.scm_instance()
1136
1135
1137 # re-compute commit ids
1136 # re-compute commit ids
1138 old_commit_ids = pull_request.revisions
1137 old_commit_ids = pull_request.revisions
1139 pre_load = ["author", "date", "message", "branch"]
1138 pre_load = ["author", "date", "message", "branch"]
1140 commit_ranges = target_repo.compare(
1139 commit_ranges = target_repo.compare(
1141 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
1140 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
1142 pre_load=pre_load)
1141 pre_load=pre_load)
1143
1142
1144 target_ref = target_commit.raw_id
1143 target_ref = target_commit.raw_id
1145 source_ref = source_commit.raw_id
1144 source_ref = source_commit.raw_id
1146 ancestor_commit_id = target_repo.get_common_ancestor(
1145 ancestor_commit_id = target_repo.get_common_ancestor(
1147 target_ref, source_ref, source_repo)
1146 target_ref, source_ref, source_repo)
1148
1147
1149 if not ancestor_commit_id:
1148 if not ancestor_commit_id:
1150 raise ValueError(
1149 raise ValueError(
1151 'cannot calculate diff info without a common ancestor. '
1150 'cannot calculate diff info without a common ancestor. '
1152 'Make sure both repositories are related, and have a common forking commit.')
1151 'Make sure both repositories are related, and have a common forking commit.')
1153
1152
1154 pull_request.common_ancestor_id = ancestor_commit_id
1153 pull_request.common_ancestor_id = ancestor_commit_id
1155
1154
1156 pull_request.source_ref = f'{source_ref_type}:{source_ref_name}:{source_commit.raw_id}'
1155 pull_request.source_ref = f'{source_ref_type}:{source_ref_name}:{source_commit.raw_id}'
1157 pull_request.target_ref = f'{target_ref_type}:{target_ref_name}:{ancestor_commit_id}'
1156 pull_request.target_ref = f'{target_ref_type}:{target_ref_name}:{ancestor_commit_id}'
1158
1157
1159 pull_request.revisions = [
1158 pull_request.revisions = [
1160 commit.raw_id for commit in reversed(commit_ranges)]
1159 commit.raw_id for commit in reversed(commit_ranges)]
1161 pull_request.updated_on = datetime.datetime.now()
1160 pull_request.updated_on = datetime.datetime.now()
1162 Session().add(pull_request)
1161 Session().add(pull_request)
1163 new_commit_ids = pull_request.revisions
1162 new_commit_ids = pull_request.revisions
1164
1163
1165 old_diff_data, new_diff_data = self._generate_update_diffs(
1164 old_diff_data, new_diff_data = self._generate_update_diffs(
1166 pull_request, pull_request_version)
1165 pull_request, pull_request_version)
1167
1166
1168 # calculate commit and file changes
1167 # calculate commit and file changes
1169 commit_changes = self._calculate_commit_id_changes(
1168 commit_changes = self._calculate_commit_id_changes(
1170 old_commit_ids, new_commit_ids)
1169 old_commit_ids, new_commit_ids)
1171 file_changes = self._calculate_file_changes(
1170 file_changes = self._calculate_file_changes(
1172 old_diff_data, new_diff_data)
1171 old_diff_data, new_diff_data)
1173
1172
1174 # set comments as outdated if DIFFS changed
1173 # set comments as outdated if DIFFS changed
1175 CommentsModel().outdate_comments(
1174 CommentsModel().outdate_comments(
1176 pull_request, old_diff_data=old_diff_data,
1175 pull_request, old_diff_data=old_diff_data,
1177 new_diff_data=new_diff_data)
1176 new_diff_data=new_diff_data)
1178
1177
1179 valid_commit_changes = (commit_changes.added or commit_changes.removed)
1178 valid_commit_changes = (commit_changes.added or commit_changes.removed)
1180 file_node_changes = (
1179 file_node_changes = (
1181 file_changes.added or file_changes.modified or file_changes.removed)
1180 file_changes.added or file_changes.modified or file_changes.removed)
1182 pr_has_changes = valid_commit_changes or file_node_changes
1181 pr_has_changes = valid_commit_changes or file_node_changes
1183
1182
1184 # Add an automatic comment to the pull request, in case
1183 # Add an automatic comment to the pull request, in case
1185 # anything has changed
1184 # anything has changed
1186 if pr_has_changes:
1185 if pr_has_changes:
1187 update_comment = CommentsModel().create(
1186 update_comment = CommentsModel().create(
1188 text=self._render_update_message(ancestor_commit_id, commit_changes, file_changes),
1187 text=self._render_update_message(ancestor_commit_id, commit_changes, file_changes),
1189 repo=pull_request.target_repo,
1188 repo=pull_request.target_repo,
1190 user=pull_request.author,
1189 user=pull_request.author,
1191 pull_request=pull_request,
1190 pull_request=pull_request,
1192 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
1191 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
1193
1192
1194 # Update status to "Under Review" for added commits
1193 # Update status to "Under Review" for added commits
1195 for commit_id in commit_changes.added:
1194 for commit_id in commit_changes.added:
1196 ChangesetStatusModel().set_status(
1195 ChangesetStatusModel().set_status(
1197 repo=pull_request.source_repo,
1196 repo=pull_request.source_repo,
1198 status=ChangesetStatus.STATUS_UNDER_REVIEW,
1197 status=ChangesetStatus.STATUS_UNDER_REVIEW,
1199 comment=update_comment,
1198 comment=update_comment,
1200 user=pull_request.author,
1199 user=pull_request.author,
1201 pull_request=pull_request,
1200 pull_request=pull_request,
1202 revision=commit_id)
1201 revision=commit_id)
1203
1202
1204 # initial commit
1203 # initial commit
1205 Session().commit()
1204 Session().commit()
1206
1205
1207 if pr_has_changes:
1206 if pr_has_changes:
1208 # send update email to users
1207 # send update email to users
1209 try:
1208 try:
1210 self.notify_users(pull_request=pull_request, updating_user=updating_user,
1209 self.notify_users(pull_request=pull_request, updating_user=updating_user,
1211 ancestor_commit_id=ancestor_commit_id,
1210 ancestor_commit_id=ancestor_commit_id,
1212 commit_changes=commit_changes,
1211 commit_changes=commit_changes,
1213 file_changes=file_changes)
1212 file_changes=file_changes)
1214 Session().commit()
1213 Session().commit()
1215 except Exception:
1214 except Exception:
1216 log.exception('Failed to send email notification to users')
1215 log.exception('Failed to send email notification to users')
1217 Session().rollback()
1216 Session().rollback()
1218
1217
1219 log.debug(
1218 log.debug(
1220 'Updated pull request %s, added_ids: %s, common_ids: %s, '
1219 'Updated pull request %s, added_ids: %s, common_ids: %s, '
1221 'removed_ids: %s', pull_request.pull_request_id,
1220 'removed_ids: %s', pull_request.pull_request_id,
1222 commit_changes.added, commit_changes.common, commit_changes.removed)
1221 commit_changes.added, commit_changes.common, commit_changes.removed)
1223 log.debug(
1222 log.debug(
1224 'Updated pull request with the following file changes: %s',
1223 'Updated pull request with the following file changes: %s',
1225 file_changes)
1224 file_changes)
1226
1225
1227 log.info(
1226 log.info(
1228 "Updated pull request %s from commit %s to commit %s, "
1227 "Updated pull request %s from commit %s to commit %s, "
1229 "stored new version %s of this pull request.",
1228 "stored new version %s of this pull request.",
1230 pull_request.pull_request_id, source_ref_id,
1229 pull_request.pull_request_id, source_ref_id,
1231 pull_request.source_ref_parts.commit_id,
1230 pull_request.source_ref_parts.commit_id,
1232 pull_request_version.pull_request_version_id)
1231 pull_request_version.pull_request_version_id)
1233
1232
1234 self.trigger_pull_request_hook(pull_request, pull_request.author, 'update')
1233 self.trigger_pull_request_hook(pull_request, pull_request.author, 'update')
1235
1234
1236 return UpdateResponse(
1235 return UpdateResponse(
1237 executed=True, reason=UpdateFailureReason.NONE,
1236 executed=True, reason=UpdateFailureReason.NONE,
1238 old=pull_request, new=pull_request_version,
1237 old=pull_request, new=pull_request_version,
1239 common_ancestor_id=ancestor_commit_id, commit_changes=commit_changes,
1238 common_ancestor_id=ancestor_commit_id, commit_changes=commit_changes,
1240 source_changed=source_changed, target_changed=target_changed)
1239 source_changed=source_changed, target_changed=target_changed)
1241
1240
1242 def _create_version_from_snapshot(self, pull_request):
1241 def _create_version_from_snapshot(self, pull_request):
1243 version = PullRequestVersion()
1242 version = PullRequestVersion()
1244 version.title = pull_request.title
1243 version.title = pull_request.title
1245 version.description = pull_request.description
1244 version.description = pull_request.description
1246 version.status = pull_request.status
1245 version.status = pull_request.status
1247 version.pull_request_state = pull_request.pull_request_state
1246 version.pull_request_state = pull_request.pull_request_state
1248 version.created_on = datetime.datetime.now()
1247 version.created_on = datetime.datetime.now()
1249 version.updated_on = pull_request.updated_on
1248 version.updated_on = pull_request.updated_on
1250 version.user_id = pull_request.user_id
1249 version.user_id = pull_request.user_id
1251 version.source_repo = pull_request.source_repo
1250 version.source_repo = pull_request.source_repo
1252 version.source_ref = pull_request.source_ref
1251 version.source_ref = pull_request.source_ref
1253 version.target_repo = pull_request.target_repo
1252 version.target_repo = pull_request.target_repo
1254 version.target_ref = pull_request.target_ref
1253 version.target_ref = pull_request.target_ref
1255
1254
1256 version._last_merge_source_rev = pull_request._last_merge_source_rev
1255 version._last_merge_source_rev = pull_request._last_merge_source_rev
1257 version._last_merge_target_rev = pull_request._last_merge_target_rev
1256 version._last_merge_target_rev = pull_request._last_merge_target_rev
1258 version.last_merge_status = pull_request.last_merge_status
1257 version.last_merge_status = pull_request.last_merge_status
1259 version.last_merge_metadata = pull_request.last_merge_metadata
1258 version.last_merge_metadata = pull_request.last_merge_metadata
1260 version.shadow_merge_ref = pull_request.shadow_merge_ref
1259 version.shadow_merge_ref = pull_request.shadow_merge_ref
1261 version.merge_rev = pull_request.merge_rev
1260 version.merge_rev = pull_request.merge_rev
1262 version.reviewer_data = pull_request.reviewer_data
1261 version.reviewer_data = pull_request.reviewer_data
1263
1262
1264 version.revisions = pull_request.revisions
1263 version.revisions = pull_request.revisions
1265 version.common_ancestor_id = pull_request.common_ancestor_id
1264 version.common_ancestor_id = pull_request.common_ancestor_id
1266 version.pull_request = pull_request
1265 version.pull_request = pull_request
1267 Session().add(version)
1266 Session().add(version)
1268 Session().flush()
1267 Session().flush()
1269
1268
1270 return version
1269 return version
1271
1270
1272 def _generate_update_diffs(self, pull_request, pull_request_version):
1271 def _generate_update_diffs(self, pull_request, pull_request_version):
1273
1272
1274 diff_context = (
1273 diff_context = (
1275 self.DIFF_CONTEXT +
1274 self.DIFF_CONTEXT +
1276 CommentsModel.needed_extra_diff_context())
1275 CommentsModel.needed_extra_diff_context())
1277 hide_whitespace_changes = False
1276 hide_whitespace_changes = False
1278 source_repo = pull_request_version.source_repo
1277 source_repo = pull_request_version.source_repo
1279 source_ref_id = pull_request_version.source_ref_parts.commit_id
1278 source_ref_id = pull_request_version.source_ref_parts.commit_id
1280 target_ref_id = pull_request_version.target_ref_parts.commit_id
1279 target_ref_id = pull_request_version.target_ref_parts.commit_id
1281 old_diff = self._get_diff_from_pr_or_version(
1280 old_diff = self._get_diff_from_pr_or_version(
1282 source_repo, source_ref_id, target_ref_id,
1281 source_repo, source_ref_id, target_ref_id,
1283 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1282 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1284
1283
1285 source_repo = pull_request.source_repo
1284 source_repo = pull_request.source_repo
1286 source_ref_id = pull_request.source_ref_parts.commit_id
1285 source_ref_id = pull_request.source_ref_parts.commit_id
1287 target_ref_id = pull_request.target_ref_parts.commit_id
1286 target_ref_id = pull_request.target_ref_parts.commit_id
1288
1287
1289 new_diff = self._get_diff_from_pr_or_version(
1288 new_diff = self._get_diff_from_pr_or_version(
1290 source_repo, source_ref_id, target_ref_id,
1289 source_repo, source_ref_id, target_ref_id,
1291 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1290 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1292
1291
1293 # NOTE: this was using diff_format='gitdiff'
1292 # NOTE: this was using diff_format='gitdiff'
1294 old_diff_data = diffs.DiffProcessor(old_diff, diff_format='newdiff')
1293 old_diff_data = diffs.DiffProcessor(old_diff, diff_format='newdiff')
1295 old_diff_data.prepare()
1294 old_diff_data.prepare()
1296 new_diff_data = diffs.DiffProcessor(new_diff, diff_format='newdiff')
1295 new_diff_data = diffs.DiffProcessor(new_diff, diff_format='newdiff')
1297 new_diff_data.prepare()
1296 new_diff_data.prepare()
1298
1297
1299 return old_diff_data, new_diff_data
1298 return old_diff_data, new_diff_data
1300
1299
1301 def _link_comments_to_version(self, pull_request_version):
1300 def _link_comments_to_version(self, pull_request_version):
1302 """
1301 """
1303 Link all unlinked comments of this pull request to the given version.
1302 Link all unlinked comments of this pull request to the given version.
1304
1303
1305 :param pull_request_version: The `PullRequestVersion` to which
1304 :param pull_request_version: The `PullRequestVersion` to which
1306 the comments shall be linked.
1305 the comments shall be linked.
1307
1306
1308 """
1307 """
1309 pull_request = pull_request_version.pull_request
1308 pull_request = pull_request_version.pull_request
1310 comments = ChangesetComment.query()\
1309 comments = ChangesetComment.query()\
1311 .filter(
1310 .filter(
1312 # TODO: johbo: Should we query for the repo at all here?
1311 # TODO: johbo: Should we query for the repo at all here?
1313 # Pending decision on how comments of PRs are to be related
1312 # Pending decision on how comments of PRs are to be related
1314 # to either the source repo, the target repo or no repo at all.
1313 # to either the source repo, the target repo or no repo at all.
1315 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
1314 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
1316 ChangesetComment.pull_request == pull_request,
1315 ChangesetComment.pull_request == pull_request,
1317 ChangesetComment.pull_request_version == null())\
1316 ChangesetComment.pull_request_version == null())\
1318 .order_by(ChangesetComment.comment_id.asc())
1317 .order_by(ChangesetComment.comment_id.asc())
1319
1318
1320 # TODO: johbo: Find out why this breaks if it is done in a bulk
1319 # TODO: johbo: Find out why this breaks if it is done in a bulk
1321 # operation.
1320 # operation.
1322 for comment in comments:
1321 for comment in comments:
1323 comment.pull_request_version_id = (
1322 comment.pull_request_version_id = (
1324 pull_request_version.pull_request_version_id)
1323 pull_request_version.pull_request_version_id)
1325 Session().add(comment)
1324 Session().add(comment)
1326
1325
1327 def _calculate_commit_id_changes(self, old_ids, new_ids):
1326 def _calculate_commit_id_changes(self, old_ids, new_ids):
1328 added = [x for x in new_ids if x not in old_ids]
1327 added = [x for x in new_ids if x not in old_ids]
1329 common = [x for x in new_ids if x in old_ids]
1328 common = [x for x in new_ids if x in old_ids]
1330 removed = [x for x in old_ids if x not in new_ids]
1329 removed = [x for x in old_ids if x not in new_ids]
1331 total = new_ids
1330 total = new_ids
1332 return ChangeTuple(added, common, removed, total)
1331 return ChangeTuple(added, common, removed, total)
1333
1332
1334 def _calculate_file_changes(self, old_diff_data, new_diff_data):
1333 def _calculate_file_changes(self, old_diff_data, new_diff_data):
1335
1334
1336 old_files = OrderedDict()
1335 old_files = OrderedDict()
1337 for diff_data in old_diff_data.parsed_diff:
1336 for diff_data in old_diff_data.parsed_diff:
1338 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
1337 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
1339
1338
1340 added_files = []
1339 added_files = []
1341 modified_files = []
1340 modified_files = []
1342 removed_files = []
1341 removed_files = []
1343 for diff_data in new_diff_data.parsed_diff:
1342 for diff_data in new_diff_data.parsed_diff:
1344 new_filename = diff_data['filename']
1343 new_filename = diff_data['filename']
1345 new_hash = md5_safe(diff_data['raw_diff'])
1344 new_hash = md5_safe(diff_data['raw_diff'])
1346
1345
1347 old_hash = old_files.get(new_filename)
1346 old_hash = old_files.get(new_filename)
1348 if not old_hash:
1347 if not old_hash:
1349 # file is not present in old diff, we have to figure out from parsed diff
1348 # file is not present in old diff, we have to figure out from parsed diff
1350 # operation ADD/REMOVE
1349 # operation ADD/REMOVE
1351 operations_dict = diff_data['stats']['ops']
1350 operations_dict = diff_data['stats']['ops']
1352 if diffs.DEL_FILENODE in operations_dict:
1351 if diffs.DEL_FILENODE in operations_dict:
1353 removed_files.append(new_filename)
1352 removed_files.append(new_filename)
1354 else:
1353 else:
1355 added_files.append(new_filename)
1354 added_files.append(new_filename)
1356 else:
1355 else:
1357 if new_hash != old_hash:
1356 if new_hash != old_hash:
1358 modified_files.append(new_filename)
1357 modified_files.append(new_filename)
1359 # now remove a file from old, since we have seen it already
1358 # now remove a file from old, since we have seen it already
1360 del old_files[new_filename]
1359 del old_files[new_filename]
1361
1360
1362 # removed files is when there are present in old, but not in NEW,
1361 # removed files is when there are present in old, but not in NEW,
1363 # since we remove old files that are present in new diff, left-overs
1362 # since we remove old files that are present in new diff, left-overs
1364 # if any should be the removed files
1363 # if any should be the removed files
1365 removed_files.extend(old_files.keys())
1364 removed_files.extend(old_files.keys())
1366
1365
1367 return FileChangeTuple(added_files, modified_files, removed_files)
1366 return FileChangeTuple(added_files, modified_files, removed_files)
1368
1367
1369 def _render_update_message(self, ancestor_commit_id, changes, file_changes):
1368 def _render_update_message(self, ancestor_commit_id, changes, file_changes):
1370 """
1369 """
1371 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
1370 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
1372 so it's always looking the same disregarding on which default
1371 so it's always looking the same disregarding on which default
1373 renderer system is using.
1372 renderer system is using.
1374
1373
1375 :param ancestor_commit_id: ancestor raw_id
1374 :param ancestor_commit_id: ancestor raw_id
1376 :param changes: changes named tuple
1375 :param changes: changes named tuple
1377 :param file_changes: file changes named tuple
1376 :param file_changes: file changes named tuple
1378
1377
1379 """
1378 """
1380 new_status = ChangesetStatus.get_status_lbl(
1379 new_status = ChangesetStatus.get_status_lbl(
1381 ChangesetStatus.STATUS_UNDER_REVIEW)
1380 ChangesetStatus.STATUS_UNDER_REVIEW)
1382
1381
1383 changed_files = (
1382 changed_files = (
1384 file_changes.added + file_changes.modified + file_changes.removed)
1383 file_changes.added + file_changes.modified + file_changes.removed)
1385
1384
1386 params = {
1385 params = {
1387 'under_review_label': new_status,
1386 'under_review_label': new_status,
1388 'added_commits': changes.added,
1387 'added_commits': changes.added,
1389 'removed_commits': changes.removed,
1388 'removed_commits': changes.removed,
1390 'changed_files': changed_files,
1389 'changed_files': changed_files,
1391 'added_files': file_changes.added,
1390 'added_files': file_changes.added,
1392 'modified_files': file_changes.modified,
1391 'modified_files': file_changes.modified,
1393 'removed_files': file_changes.removed,
1392 'removed_files': file_changes.removed,
1394 'ancestor_commit_id': ancestor_commit_id
1393 'ancestor_commit_id': ancestor_commit_id
1395 }
1394 }
1396 renderer = RstTemplateRenderer()
1395 renderer = RstTemplateRenderer()
1397 return renderer.render('pull_request_update.mako', **params)
1396 return renderer.render('pull_request_update.mako', **params)
1398
1397
1399 def edit(self, pull_request, title, description, description_renderer, user):
1398 def edit(self, pull_request, title, description, description_renderer, user):
1400 pull_request = self.__get_pull_request(pull_request)
1399 pull_request = self.__get_pull_request(pull_request)
1401 old_data = pull_request.get_api_data(with_merge_state=False)
1400 old_data = pull_request.get_api_data(with_merge_state=False)
1402 if pull_request.is_closed():
1401 if pull_request.is_closed():
1403 raise ValueError('This pull request is closed')
1402 raise ValueError('This pull request is closed')
1404 if title:
1403 if title:
1405 pull_request.title = title
1404 pull_request.title = title
1406 pull_request.description = description
1405 pull_request.description = description
1407 pull_request.updated_on = datetime.datetime.now()
1406 pull_request.updated_on = datetime.datetime.now()
1408 pull_request.description_renderer = description_renderer
1407 pull_request.description_renderer = description_renderer
1409 Session().add(pull_request)
1408 Session().add(pull_request)
1410 self._log_audit_action(
1409 self._log_audit_action(
1411 'repo.pull_request.edit', {'old_data': old_data},
1410 'repo.pull_request.edit', {'old_data': old_data},
1412 user, pull_request)
1411 user, pull_request)
1413
1412
1414 def update_reviewers(self, pull_request, reviewer_data, user):
1413 def update_reviewers(self, pull_request, reviewer_data, user):
1415 """
1414 """
1416 Update the reviewers in the pull request
1415 Update the reviewers in the pull request
1417
1416
1418 :param pull_request: the pr to update
1417 :param pull_request: the pr to update
1419 :param reviewer_data: list of tuples
1418 :param reviewer_data: list of tuples
1420 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1419 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1421 :param user: current use who triggers this action
1420 :param user: current use who triggers this action
1422 """
1421 """
1423
1422
1424 pull_request = self.__get_pull_request(pull_request)
1423 pull_request = self.__get_pull_request(pull_request)
1425 if pull_request.is_closed():
1424 if pull_request.is_closed():
1426 raise ValueError('This pull request is closed')
1425 raise ValueError('This pull request is closed')
1427
1426
1428 reviewers = {}
1427 reviewers = {}
1429 for user_id, reasons, mandatory, role, rules in reviewer_data:
1428 for user_id, reasons, mandatory, role, rules in reviewer_data:
1430 if isinstance(user_id, (int, str)):
1429 if isinstance(user_id, (int, str)):
1431 user_id = self._get_user(user_id).user_id
1430 user_id = self._get_user(user_id).user_id
1432 reviewers[user_id] = {
1431 reviewers[user_id] = {
1433 'reasons': reasons, 'mandatory': mandatory, 'role': role}
1432 'reasons': reasons, 'mandatory': mandatory, 'role': role}
1434
1433
1435 reviewers_ids = set(reviewers.keys())
1434 reviewers_ids = set(reviewers.keys())
1436 current_reviewers = PullRequestReviewers.get_pull_request_reviewers(
1435 current_reviewers = PullRequestReviewers.get_pull_request_reviewers(
1437 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_REVIEWER)
1436 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_REVIEWER)
1438
1437
1439 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1438 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1440
1439
1441 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1440 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1442 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1441 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1443
1442
1444 log.debug("Adding %s reviewers", ids_to_add)
1443 log.debug("Adding %s reviewers", ids_to_add)
1445 log.debug("Removing %s reviewers", ids_to_remove)
1444 log.debug("Removing %s reviewers", ids_to_remove)
1446 changed = False
1445 changed = False
1447 added_audit_reviewers = []
1446 added_audit_reviewers = []
1448 removed_audit_reviewers = []
1447 removed_audit_reviewers = []
1449
1448
1450 for uid in ids_to_add:
1449 for uid in ids_to_add:
1451 changed = True
1450 changed = True
1452 _usr = self._get_user(uid)
1451 _usr = self._get_user(uid)
1453 reviewer = PullRequestReviewers()
1452 reviewer = PullRequestReviewers()
1454 reviewer.user = _usr
1453 reviewer.user = _usr
1455 reviewer.pull_request = pull_request
1454 reviewer.pull_request = pull_request
1456 reviewer.reasons = reviewers[uid]['reasons']
1455 reviewer.reasons = reviewers[uid]['reasons']
1457 # NOTE(marcink): mandatory shouldn't be changed now
1456 # NOTE(marcink): mandatory shouldn't be changed now
1458 # reviewer.mandatory = reviewers[uid]['reasons']
1457 # reviewer.mandatory = reviewers[uid]['reasons']
1459 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1458 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1460 reviewer.role = PullRequestReviewers.ROLE_REVIEWER
1459 reviewer.role = PullRequestReviewers.ROLE_REVIEWER
1461 Session().add(reviewer)
1460 Session().add(reviewer)
1462 added_audit_reviewers.append(reviewer.get_dict())
1461 added_audit_reviewers.append(reviewer.get_dict())
1463
1462
1464 for uid in ids_to_remove:
1463 for uid in ids_to_remove:
1465 changed = True
1464 changed = True
1466 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1465 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1467 # This is an edge case that handles previous state of having the same reviewer twice.
1466 # This is an edge case that handles previous state of having the same reviewer twice.
1468 # this CAN happen due to the lack of DB checks
1467 # this CAN happen due to the lack of DB checks
1469 reviewers = PullRequestReviewers.query()\
1468 reviewers = PullRequestReviewers.query()\
1470 .filter(PullRequestReviewers.user_id == uid,
1469 .filter(PullRequestReviewers.user_id == uid,
1471 PullRequestReviewers.role == PullRequestReviewers.ROLE_REVIEWER,
1470 PullRequestReviewers.role == PullRequestReviewers.ROLE_REVIEWER,
1472 PullRequestReviewers.pull_request == pull_request)\
1471 PullRequestReviewers.pull_request == pull_request)\
1473 .all()
1472 .all()
1474
1473
1475 for obj in reviewers:
1474 for obj in reviewers:
1476 added_audit_reviewers.append(obj.get_dict())
1475 added_audit_reviewers.append(obj.get_dict())
1477 Session().delete(obj)
1476 Session().delete(obj)
1478
1477
1479 if changed:
1478 if changed:
1480 Session().expire_all()
1479 Session().expire_all()
1481 pull_request.updated_on = datetime.datetime.now()
1480 pull_request.updated_on = datetime.datetime.now()
1482 Session().add(pull_request)
1481 Session().add(pull_request)
1483
1482
1484 # finally store audit logs
1483 # finally store audit logs
1485 for user_data in added_audit_reviewers:
1484 for user_data in added_audit_reviewers:
1486 self._log_audit_action(
1485 self._log_audit_action(
1487 'repo.pull_request.reviewer.add', {'data': user_data},
1486 'repo.pull_request.reviewer.add', {'data': user_data},
1488 user, pull_request)
1487 user, pull_request)
1489 for user_data in removed_audit_reviewers:
1488 for user_data in removed_audit_reviewers:
1490 self._log_audit_action(
1489 self._log_audit_action(
1491 'repo.pull_request.reviewer.delete', {'old_data': user_data},
1490 'repo.pull_request.reviewer.delete', {'old_data': user_data},
1492 user, pull_request)
1491 user, pull_request)
1493
1492
1494 self.notify_reviewers(pull_request, ids_to_add, user)
1493 self.notify_reviewers(pull_request, ids_to_add, user)
1495 return ids_to_add, ids_to_remove
1494 return ids_to_add, ids_to_remove
1496
1495
1497 def update_observers(self, pull_request, observer_data, user):
1496 def update_observers(self, pull_request, observer_data, user):
1498 """
1497 """
1499 Update the observers in the pull request
1498 Update the observers in the pull request
1500
1499
1501 :param pull_request: the pr to update
1500 :param pull_request: the pr to update
1502 :param observer_data: list of tuples
1501 :param observer_data: list of tuples
1503 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1502 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1504 :param user: current use who triggers this action
1503 :param user: current use who triggers this action
1505 """
1504 """
1506 pull_request = self.__get_pull_request(pull_request)
1505 pull_request = self.__get_pull_request(pull_request)
1507 if pull_request.is_closed():
1506 if pull_request.is_closed():
1508 raise ValueError('This pull request is closed')
1507 raise ValueError('This pull request is closed')
1509
1508
1510 observers = {}
1509 observers = {}
1511 for user_id, reasons, mandatory, role, rules in observer_data:
1510 for user_id, reasons, mandatory, role, rules in observer_data:
1512 if isinstance(user_id, (int, str)):
1511 if isinstance(user_id, (int, str)):
1513 user_id = self._get_user(user_id).user_id
1512 user_id = self._get_user(user_id).user_id
1514 observers[user_id] = {
1513 observers[user_id] = {
1515 'reasons': reasons, 'observers': mandatory, 'role': role}
1514 'reasons': reasons, 'observers': mandatory, 'role': role}
1516
1515
1517 observers_ids = set(observers.keys())
1516 observers_ids = set(observers.keys())
1518 current_observers = PullRequestReviewers.get_pull_request_reviewers(
1517 current_observers = PullRequestReviewers.get_pull_request_reviewers(
1519 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_OBSERVER)
1518 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_OBSERVER)
1520
1519
1521 current_observers_ids = set([x.user.user_id for x in current_observers])
1520 current_observers_ids = set([x.user.user_id for x in current_observers])
1522
1521
1523 ids_to_add = observers_ids.difference(current_observers_ids)
1522 ids_to_add = observers_ids.difference(current_observers_ids)
1524 ids_to_remove = current_observers_ids.difference(observers_ids)
1523 ids_to_remove = current_observers_ids.difference(observers_ids)
1525
1524
1526 log.debug("Adding %s observer", ids_to_add)
1525 log.debug("Adding %s observer", ids_to_add)
1527 log.debug("Removing %s observer", ids_to_remove)
1526 log.debug("Removing %s observer", ids_to_remove)
1528 changed = False
1527 changed = False
1529 added_audit_observers = []
1528 added_audit_observers = []
1530 removed_audit_observers = []
1529 removed_audit_observers = []
1531
1530
1532 for uid in ids_to_add:
1531 for uid in ids_to_add:
1533 changed = True
1532 changed = True
1534 _usr = self._get_user(uid)
1533 _usr = self._get_user(uid)
1535 observer = PullRequestReviewers()
1534 observer = PullRequestReviewers()
1536 observer.user = _usr
1535 observer.user = _usr
1537 observer.pull_request = pull_request
1536 observer.pull_request = pull_request
1538 observer.reasons = observers[uid]['reasons']
1537 observer.reasons = observers[uid]['reasons']
1539 # NOTE(marcink): mandatory shouldn't be changed now
1538 # NOTE(marcink): mandatory shouldn't be changed now
1540 # observer.mandatory = observer[uid]['reasons']
1539 # observer.mandatory = observer[uid]['reasons']
1541
1540
1542 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1541 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1543 observer.role = PullRequestReviewers.ROLE_OBSERVER
1542 observer.role = PullRequestReviewers.ROLE_OBSERVER
1544 Session().add(observer)
1543 Session().add(observer)
1545 added_audit_observers.append(observer.get_dict())
1544 added_audit_observers.append(observer.get_dict())
1546
1545
1547 for uid in ids_to_remove:
1546 for uid in ids_to_remove:
1548 changed = True
1547 changed = True
1549 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1548 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1550 # This is an edge case that handles previous state of having the same reviewer twice.
1549 # This is an edge case that handles previous state of having the same reviewer twice.
1551 # this CAN happen due to the lack of DB checks
1550 # this CAN happen due to the lack of DB checks
1552 observers = PullRequestReviewers.query()\
1551 observers = PullRequestReviewers.query()\
1553 .filter(PullRequestReviewers.user_id == uid,
1552 .filter(PullRequestReviewers.user_id == uid,
1554 PullRequestReviewers.role == PullRequestReviewers.ROLE_OBSERVER,
1553 PullRequestReviewers.role == PullRequestReviewers.ROLE_OBSERVER,
1555 PullRequestReviewers.pull_request == pull_request)\
1554 PullRequestReviewers.pull_request == pull_request)\
1556 .all()
1555 .all()
1557
1556
1558 for obj in observers:
1557 for obj in observers:
1559 added_audit_observers.append(obj.get_dict())
1558 added_audit_observers.append(obj.get_dict())
1560 Session().delete(obj)
1559 Session().delete(obj)
1561
1560
1562 if changed:
1561 if changed:
1563 Session().expire_all()
1562 Session().expire_all()
1564 pull_request.updated_on = datetime.datetime.now()
1563 pull_request.updated_on = datetime.datetime.now()
1565 Session().add(pull_request)
1564 Session().add(pull_request)
1566
1565
1567 # finally store audit logs
1566 # finally store audit logs
1568 for user_data in added_audit_observers:
1567 for user_data in added_audit_observers:
1569 self._log_audit_action(
1568 self._log_audit_action(
1570 'repo.pull_request.observer.add', {'data': user_data},
1569 'repo.pull_request.observer.add', {'data': user_data},
1571 user, pull_request)
1570 user, pull_request)
1572 for user_data in removed_audit_observers:
1571 for user_data in removed_audit_observers:
1573 self._log_audit_action(
1572 self._log_audit_action(
1574 'repo.pull_request.observer.delete', {'old_data': user_data},
1573 'repo.pull_request.observer.delete', {'old_data': user_data},
1575 user, pull_request)
1574 user, pull_request)
1576
1575
1577 self.notify_observers(pull_request, ids_to_add, user)
1576 self.notify_observers(pull_request, ids_to_add, user)
1578 return ids_to_add, ids_to_remove
1577 return ids_to_add, ids_to_remove
1579
1578
1580 def get_url(self, pull_request, request=None, permalink=False):
1579 def get_url(self, pull_request, request=None, permalink=False):
1581 if not request:
1580 if not request:
1582 request = get_current_request()
1581 request = get_current_request()
1583
1582
1584 if permalink:
1583 if permalink:
1585 return request.route_url(
1584 return request.route_url(
1586 'pull_requests_global',
1585 'pull_requests_global',
1587 pull_request_id=pull_request.pull_request_id,)
1586 pull_request_id=pull_request.pull_request_id,)
1588 else:
1587 else:
1589 return request.route_url('pullrequest_show',
1588 return request.route_url('pullrequest_show',
1590 repo_name=safe_str(pull_request.target_repo.repo_name),
1589 repo_name=safe_str(pull_request.target_repo.repo_name),
1591 pull_request_id=pull_request.pull_request_id,)
1590 pull_request_id=pull_request.pull_request_id,)
1592
1591
1593 def get_shadow_clone_url(self, pull_request, request=None):
1592 def get_shadow_clone_url(self, pull_request, request=None):
1594 """
1593 """
1595 Returns qualified url pointing to the shadow repository. If this pull
1594 Returns qualified url pointing to the shadow repository. If this pull
1596 request is closed there is no shadow repository and ``None`` will be
1595 request is closed there is no shadow repository and ``None`` will be
1597 returned.
1596 returned.
1598 """
1597 """
1599 if pull_request.is_closed():
1598 if pull_request.is_closed():
1600 return None
1599 return None
1601 else:
1600 else:
1602 pr_url = urllib.parse.unquote(self.get_url(pull_request, request=request))
1601 pr_url = urllib.parse.unquote(self.get_url(pull_request, request=request))
1603 return safe_str('{pr_url}/repository'.format(pr_url=pr_url))
1602 return safe_str('{pr_url}/repository'.format(pr_url=pr_url))
1604
1603
1605 def _notify_reviewers(self, pull_request, user_ids, role, user):
1604 def _notify_reviewers(self, pull_request, user_ids, role, user):
1606 # notification to reviewers/observers
1605 # notification to reviewers/observers
1607 if not user_ids:
1606 if not user_ids:
1608 return
1607 return
1609
1608
1610 log.debug('Notify following %s users about pull-request %s', role, user_ids)
1609 log.debug('Notify following %s users about pull-request %s', role, user_ids)
1611
1610
1612 pull_request_obj = pull_request
1611 pull_request_obj = pull_request
1613 # get the current participants of this pull request
1612 # get the current participants of this pull request
1614 recipients = user_ids
1613 recipients = user_ids
1615 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1614 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1616
1615
1617 pr_source_repo = pull_request_obj.source_repo
1616 pr_source_repo = pull_request_obj.source_repo
1618 pr_target_repo = pull_request_obj.target_repo
1617 pr_target_repo = pull_request_obj.target_repo
1619
1618
1620 pr_url = h.route_url('pullrequest_show',
1619 pr_url = h.route_url('pullrequest_show',
1621 repo_name=pr_target_repo.repo_name,
1620 repo_name=pr_target_repo.repo_name,
1622 pull_request_id=pull_request_obj.pull_request_id,)
1621 pull_request_id=pull_request_obj.pull_request_id,)
1623
1622
1624 # set some variables for email notification
1623 # set some variables for email notification
1625 pr_target_repo_url = h.route_url(
1624 pr_target_repo_url = h.route_url(
1626 'repo_summary', repo_name=pr_target_repo.repo_name)
1625 'repo_summary', repo_name=pr_target_repo.repo_name)
1627
1626
1628 pr_source_repo_url = h.route_url(
1627 pr_source_repo_url = h.route_url(
1629 'repo_summary', repo_name=pr_source_repo.repo_name)
1628 'repo_summary', repo_name=pr_source_repo.repo_name)
1630
1629
1631 # pull request specifics
1630 # pull request specifics
1632 pull_request_commits = [
1631 pull_request_commits = [
1633 (x.raw_id, x.message)
1632 (x.raw_id, x.message)
1634 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1633 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1635
1634
1636 current_rhodecode_user = user
1635 current_rhodecode_user = user
1637 kwargs = {
1636 kwargs = {
1638 'user': current_rhodecode_user,
1637 'user': current_rhodecode_user,
1639 'pull_request_author': pull_request.author,
1638 'pull_request_author': pull_request.author,
1640 'pull_request': pull_request_obj,
1639 'pull_request': pull_request_obj,
1641 'pull_request_commits': pull_request_commits,
1640 'pull_request_commits': pull_request_commits,
1642
1641
1643 'pull_request_target_repo': pr_target_repo,
1642 'pull_request_target_repo': pr_target_repo,
1644 'pull_request_target_repo_url': pr_target_repo_url,
1643 'pull_request_target_repo_url': pr_target_repo_url,
1645
1644
1646 'pull_request_source_repo': pr_source_repo,
1645 'pull_request_source_repo': pr_source_repo,
1647 'pull_request_source_repo_url': pr_source_repo_url,
1646 'pull_request_source_repo_url': pr_source_repo_url,
1648
1647
1649 'pull_request_url': pr_url,
1648 'pull_request_url': pr_url,
1650 'thread_ids': [pr_url],
1649 'thread_ids': [pr_url],
1651 'user_role': role
1650 'user_role': role
1652 }
1651 }
1653
1652
1654 # create notification objects, and emails
1653 # create notification objects, and emails
1655 NotificationModel().create(
1654 NotificationModel().create(
1656 created_by=current_rhodecode_user,
1655 created_by=current_rhodecode_user,
1657 notification_subject='', # Filled in based on the notification_type
1656 notification_subject='', # Filled in based on the notification_type
1658 notification_body='', # Filled in based on the notification_type
1657 notification_body='', # Filled in based on the notification_type
1659 notification_type=notification_type,
1658 notification_type=notification_type,
1660 recipients=recipients,
1659 recipients=recipients,
1661 email_kwargs=kwargs,
1660 email_kwargs=kwargs,
1662 )
1661 )
1663
1662
1664 def notify_reviewers(self, pull_request, reviewers_ids, user):
1663 def notify_reviewers(self, pull_request, reviewers_ids, user):
1665 return self._notify_reviewers(pull_request, reviewers_ids,
1664 return self._notify_reviewers(pull_request, reviewers_ids,
1666 PullRequestReviewers.ROLE_REVIEWER, user)
1665 PullRequestReviewers.ROLE_REVIEWER, user)
1667
1666
1668 def notify_observers(self, pull_request, observers_ids, user):
1667 def notify_observers(self, pull_request, observers_ids, user):
1669 return self._notify_reviewers(pull_request, observers_ids,
1668 return self._notify_reviewers(pull_request, observers_ids,
1670 PullRequestReviewers.ROLE_OBSERVER, user)
1669 PullRequestReviewers.ROLE_OBSERVER, user)
1671
1670
1672 def notify_users(self, pull_request, updating_user, ancestor_commit_id,
1671 def notify_users(self, pull_request, updating_user, ancestor_commit_id,
1673 commit_changes, file_changes):
1672 commit_changes, file_changes):
1674
1673
1675 updating_user_id = updating_user.user_id
1674 updating_user_id = updating_user.user_id
1676 reviewers = set([x.user.user_id for x in pull_request.get_pull_request_reviewers()])
1675 reviewers = set([x.user.user_id for x in pull_request.get_pull_request_reviewers()])
1677 # NOTE(marcink): send notification to all other users except to
1676 # NOTE(marcink): send notification to all other users except to
1678 # person who updated the PR
1677 # person who updated the PR
1679 recipients = reviewers.difference(set([updating_user_id]))
1678 recipients = reviewers.difference(set([updating_user_id]))
1680
1679
1681 log.debug('Notify following recipients about pull-request update %s', recipients)
1680 log.debug('Notify following recipients about pull-request update %s', recipients)
1682
1681
1683 pull_request_obj = pull_request
1682 pull_request_obj = pull_request
1684
1683
1685 # send email about the update
1684 # send email about the update
1686 changed_files = (
1685 changed_files = (
1687 file_changes.added + file_changes.modified + file_changes.removed)
1686 file_changes.added + file_changes.modified + file_changes.removed)
1688
1687
1689 pr_source_repo = pull_request_obj.source_repo
1688 pr_source_repo = pull_request_obj.source_repo
1690 pr_target_repo = pull_request_obj.target_repo
1689 pr_target_repo = pull_request_obj.target_repo
1691
1690
1692 pr_url = h.route_url('pullrequest_show',
1691 pr_url = h.route_url('pullrequest_show',
1693 repo_name=pr_target_repo.repo_name,
1692 repo_name=pr_target_repo.repo_name,
1694 pull_request_id=pull_request_obj.pull_request_id,)
1693 pull_request_id=pull_request_obj.pull_request_id,)
1695
1694
1696 # set some variables for email notification
1695 # set some variables for email notification
1697 pr_target_repo_url = h.route_url(
1696 pr_target_repo_url = h.route_url(
1698 'repo_summary', repo_name=pr_target_repo.repo_name)
1697 'repo_summary', repo_name=pr_target_repo.repo_name)
1699
1698
1700 pr_source_repo_url = h.route_url(
1699 pr_source_repo_url = h.route_url(
1701 'repo_summary', repo_name=pr_source_repo.repo_name)
1700 'repo_summary', repo_name=pr_source_repo.repo_name)
1702
1701
1703 email_kwargs = {
1702 email_kwargs = {
1704 'date': datetime.datetime.now(),
1703 'date': datetime.datetime.now(),
1705 'updating_user': updating_user,
1704 'updating_user': updating_user,
1706
1705
1707 'pull_request': pull_request_obj,
1706 'pull_request': pull_request_obj,
1708
1707
1709 'pull_request_target_repo': pr_target_repo,
1708 'pull_request_target_repo': pr_target_repo,
1710 'pull_request_target_repo_url': pr_target_repo_url,
1709 'pull_request_target_repo_url': pr_target_repo_url,
1711
1710
1712 'pull_request_source_repo': pr_source_repo,
1711 'pull_request_source_repo': pr_source_repo,
1713 'pull_request_source_repo_url': pr_source_repo_url,
1712 'pull_request_source_repo_url': pr_source_repo_url,
1714
1713
1715 'pull_request_url': pr_url,
1714 'pull_request_url': pr_url,
1716
1715
1717 'ancestor_commit_id': ancestor_commit_id,
1716 'ancestor_commit_id': ancestor_commit_id,
1718 'added_commits': commit_changes.added,
1717 'added_commits': commit_changes.added,
1719 'removed_commits': commit_changes.removed,
1718 'removed_commits': commit_changes.removed,
1720 'changed_files': changed_files,
1719 'changed_files': changed_files,
1721 'added_files': file_changes.added,
1720 'added_files': file_changes.added,
1722 'modified_files': file_changes.modified,
1721 'modified_files': file_changes.modified,
1723 'removed_files': file_changes.removed,
1722 'removed_files': file_changes.removed,
1724 'thread_ids': [pr_url],
1723 'thread_ids': [pr_url],
1725 }
1724 }
1726
1725
1727 # create notification objects, and emails
1726 # create notification objects, and emails
1728 NotificationModel().create(
1727 NotificationModel().create(
1729 created_by=updating_user,
1728 created_by=updating_user,
1730 notification_subject='', # Filled in based on the notification_type
1729 notification_subject='', # Filled in based on the notification_type
1731 notification_body='', # Filled in based on the notification_type
1730 notification_body='', # Filled in based on the notification_type
1732 notification_type=EmailNotificationModel.TYPE_PULL_REQUEST_UPDATE,
1731 notification_type=EmailNotificationModel.TYPE_PULL_REQUEST_UPDATE,
1733 recipients=recipients,
1732 recipients=recipients,
1734 email_kwargs=email_kwargs,
1733 email_kwargs=email_kwargs,
1735 )
1734 )
1736
1735
1737 def delete(self, pull_request, user=None):
1736 def delete(self, pull_request, user=None):
1738 if not user:
1737 if not user:
1739 user = getattr(get_current_rhodecode_user(), 'username', None)
1738 user = getattr(get_current_rhodecode_user(), 'username', None)
1740
1739
1741 pull_request = self.__get_pull_request(pull_request)
1740 pull_request = self.__get_pull_request(pull_request)
1742 old_data = pull_request.get_api_data(with_merge_state=False)
1741 old_data = pull_request.get_api_data(with_merge_state=False)
1743 self._cleanup_merge_workspace(pull_request)
1742 self._cleanup_merge_workspace(pull_request)
1744 self._log_audit_action(
1743 self._log_audit_action(
1745 'repo.pull_request.delete', {'old_data': old_data},
1744 'repo.pull_request.delete', {'old_data': old_data},
1746 user, pull_request)
1745 user, pull_request)
1747 Session().delete(pull_request)
1746 Session().delete(pull_request)
1748
1747
1749 def close_pull_request(self, pull_request, user):
1748 def close_pull_request(self, pull_request, user):
1750 pull_request = self.__get_pull_request(pull_request)
1749 pull_request = self.__get_pull_request(pull_request)
1751 self._cleanup_merge_workspace(pull_request)
1750 self._cleanup_merge_workspace(pull_request)
1752 pull_request.status = PullRequest.STATUS_CLOSED
1751 pull_request.status = PullRequest.STATUS_CLOSED
1753 pull_request.updated_on = datetime.datetime.now()
1752 pull_request.updated_on = datetime.datetime.now()
1754 Session().add(pull_request)
1753 Session().add(pull_request)
1755 self.trigger_pull_request_hook(pull_request, pull_request.author, 'close')
1754 self.trigger_pull_request_hook(pull_request, pull_request.author, 'close')
1756
1755
1757 pr_data = pull_request.get_api_data(with_merge_state=False)
1756 pr_data = pull_request.get_api_data(with_merge_state=False)
1758 self._log_audit_action(
1757 self._log_audit_action(
1759 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1758 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1760
1759
1761 def close_pull_request_with_comment(
1760 def close_pull_request_with_comment(
1762 self, pull_request, user, repo, message=None, auth_user=None):
1761 self, pull_request, user, repo, message=None, auth_user=None):
1763
1762
1764 pull_request_review_status = pull_request.calculated_review_status()
1763 pull_request_review_status = pull_request.calculated_review_status()
1765
1764
1766 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1765 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1767 # approved only if we have voting consent
1766 # approved only if we have voting consent
1768 status = ChangesetStatus.STATUS_APPROVED
1767 status = ChangesetStatus.STATUS_APPROVED
1769 else:
1768 else:
1770 status = ChangesetStatus.STATUS_REJECTED
1769 status = ChangesetStatus.STATUS_REJECTED
1771 status_lbl = ChangesetStatus.get_status_lbl(status)
1770 status_lbl = ChangesetStatus.get_status_lbl(status)
1772
1771
1773 default_message = (
1772 default_message = (
1774 'Closing with status change {transition_icon} {status}.'
1773 'Closing with status change {transition_icon} {status}.'
1775 ).format(transition_icon='>', status=status_lbl)
1774 ).format(transition_icon='>', status=status_lbl)
1776 text = message or default_message
1775 text = message or default_message
1777
1776
1778 # create a comment, and link it to new status
1777 # create a comment, and link it to new status
1779 comment = CommentsModel().create(
1778 comment = CommentsModel().create(
1780 text=text,
1779 text=text,
1781 repo=repo.repo_id,
1780 repo=repo.repo_id,
1782 user=user.user_id,
1781 user=user.user_id,
1783 pull_request=pull_request.pull_request_id,
1782 pull_request=pull_request.pull_request_id,
1784 status_change=status_lbl,
1783 status_change=status_lbl,
1785 status_change_type=status,
1784 status_change_type=status,
1786 closing_pr=True,
1785 closing_pr=True,
1787 auth_user=auth_user,
1786 auth_user=auth_user,
1788 )
1787 )
1789
1788
1790 # calculate old status before we change it
1789 # calculate old status before we change it
1791 old_calculated_status = pull_request.calculated_review_status()
1790 old_calculated_status = pull_request.calculated_review_status()
1792 ChangesetStatusModel().set_status(
1791 ChangesetStatusModel().set_status(
1793 repo.repo_id,
1792 repo.repo_id,
1794 status,
1793 status,
1795 user.user_id,
1794 user.user_id,
1796 comment=comment,
1795 comment=comment,
1797 pull_request=pull_request.pull_request_id
1796 pull_request=pull_request.pull_request_id
1798 )
1797 )
1799
1798
1800 Session().flush()
1799 Session().flush()
1801
1800
1802 self.trigger_pull_request_hook(pull_request, user, 'comment',
1801 self.trigger_pull_request_hook(pull_request, user, 'comment',
1803 data={'comment': comment})
1802 data={'comment': comment})
1804
1803
1805 # we now calculate the status of pull request again, and based on that
1804 # we now calculate the status of pull request again, and based on that
1806 # calculation trigger status change. This might happen in cases
1805 # calculation trigger status change. This might happen in cases
1807 # that non-reviewer admin closes a pr, which means his vote doesn't
1806 # that non-reviewer admin closes a pr, which means his vote doesn't
1808 # change the status, while if he's a reviewer this might change it.
1807 # change the status, while if he's a reviewer this might change it.
1809 calculated_status = pull_request.calculated_review_status()
1808 calculated_status = pull_request.calculated_review_status()
1810 if old_calculated_status != calculated_status:
1809 if old_calculated_status != calculated_status:
1811 self.trigger_pull_request_hook(pull_request, user, 'review_status_change',
1810 self.trigger_pull_request_hook(pull_request, user, 'review_status_change',
1812 data={'status': calculated_status})
1811 data={'status': calculated_status})
1813
1812
1814 # finally close the PR
1813 # finally close the PR
1815 PullRequestModel().close_pull_request(pull_request.pull_request_id, user)
1814 PullRequestModel().close_pull_request(pull_request.pull_request_id, user)
1816
1815
1817 return comment, status
1816 return comment, status
1818
1817
1819 def merge_status(self, pull_request, translator=None, force_shadow_repo_refresh=False):
1818 def merge_status(self, pull_request, translator=None, force_shadow_repo_refresh=False):
1820 _ = translator or get_current_request().translate
1819 _ = translator or get_current_request().translate
1821
1820
1822 if not self._is_merge_enabled(pull_request):
1821 if not self._is_merge_enabled(pull_request):
1823 return None, False, _('Server-side pull request merging is disabled.')
1822 return None, False, _('Server-side pull request merging is disabled.')
1824
1823
1825 if pull_request.is_closed():
1824 if pull_request.is_closed():
1826 return None, False, _('This pull request is closed.')
1825 return None, False, _('This pull request is closed.')
1827
1826
1828 merge_possible, msg = self._check_repo_requirements(
1827 merge_possible, msg = self._check_repo_requirements(
1829 target=pull_request.target_repo, source=pull_request.source_repo,
1828 target=pull_request.target_repo, source=pull_request.source_repo,
1830 translator=_)
1829 translator=_)
1831 if not merge_possible:
1830 if not merge_possible:
1832 return None, merge_possible, msg
1831 return None, merge_possible, msg
1833
1832
1834 try:
1833 try:
1835 merge_response = self._try_merge(
1834 merge_response = self._try_merge(
1836 pull_request, force_shadow_repo_refresh=force_shadow_repo_refresh)
1835 pull_request, force_shadow_repo_refresh=force_shadow_repo_refresh)
1837 log.debug("Merge response: %s", merge_response)
1836 log.debug("Merge response: %s", merge_response)
1838 return merge_response, merge_response.possible, merge_response.merge_status_message
1837 return merge_response, merge_response.possible, merge_response.merge_status_message
1839 except NotImplementedError:
1838 except NotImplementedError:
1840 return None, False, _('Pull request merging is not supported.')
1839 return None, False, _('Pull request merging is not supported.')
1841
1840
1842 def _check_repo_requirements(self, target, source, translator):
1841 def _check_repo_requirements(self, target, source, translator):
1843 """
1842 """
1844 Check if `target` and `source` have compatible requirements.
1843 Check if `target` and `source` have compatible requirements.
1845
1844
1846 Currently this is just checking for largefiles.
1845 Currently this is just checking for largefiles.
1847 """
1846 """
1848 _ = translator
1847 _ = translator
1849 target_has_largefiles = self._has_largefiles(target)
1848 target_has_largefiles = self._has_largefiles(target)
1850 source_has_largefiles = self._has_largefiles(source)
1849 source_has_largefiles = self._has_largefiles(source)
1851 merge_possible = True
1850 merge_possible = True
1852 message = u''
1851 message = u''
1853
1852
1854 if target_has_largefiles != source_has_largefiles:
1853 if target_has_largefiles != source_has_largefiles:
1855 merge_possible = False
1854 merge_possible = False
1856 if source_has_largefiles:
1855 if source_has_largefiles:
1857 message = _(
1856 message = _(
1858 'Target repository large files support is disabled.')
1857 'Target repository large files support is disabled.')
1859 else:
1858 else:
1860 message = _(
1859 message = _(
1861 'Source repository large files support is disabled.')
1860 'Source repository large files support is disabled.')
1862
1861
1863 return merge_possible, message
1862 return merge_possible, message
1864
1863
1865 def _has_largefiles(self, repo):
1864 def _has_largefiles(self, repo):
1866 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1865 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1867 'extensions', 'largefiles')
1866 'extensions', 'largefiles')
1868 return largefiles_ui and largefiles_ui[0].active
1867 return largefiles_ui and largefiles_ui[0].active
1869
1868
1870 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1869 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1871 """
1870 """
1872 Try to merge the pull request and return the merge status.
1871 Try to merge the pull request and return the merge status.
1873 """
1872 """
1874 log.debug(
1873 log.debug(
1875 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1874 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1876 pull_request.pull_request_id, force_shadow_repo_refresh)
1875 pull_request.pull_request_id, force_shadow_repo_refresh)
1877 target_vcs = pull_request.target_repo.scm_instance()
1876 target_vcs = pull_request.target_repo.scm_instance()
1878 # Refresh the target reference.
1877 # Refresh the target reference.
1879 try:
1878 try:
1880 target_ref = self._refresh_reference(
1879 target_ref = self._refresh_reference(
1881 pull_request.target_ref_parts, target_vcs)
1880 pull_request.target_ref_parts, target_vcs)
1882 except CommitDoesNotExistError:
1881 except CommitDoesNotExistError:
1883 merge_state = MergeResponse(
1882 merge_state = MergeResponse(
1884 False, False, None, MergeFailureReason.MISSING_TARGET_REF,
1883 False, False, None, MergeFailureReason.MISSING_TARGET_REF,
1885 metadata={'target_ref': pull_request.target_ref_parts})
1884 metadata={'target_ref': pull_request.target_ref_parts})
1886 return merge_state
1885 return merge_state
1887
1886
1888 target_locked = pull_request.target_repo.locked
1887 target_locked = pull_request.target_repo.locked
1889 if target_locked and target_locked[0]:
1888 if target_locked and target_locked[0]:
1890 locked_by = 'user:{}'.format(target_locked[0])
1889 locked_by = 'user:{}'.format(target_locked[0])
1891 log.debug("The target repository is locked by %s.", locked_by)
1890 log.debug("The target repository is locked by %s.", locked_by)
1892 merge_state = MergeResponse(
1891 merge_state = MergeResponse(
1893 False, False, None, MergeFailureReason.TARGET_IS_LOCKED,
1892 False, False, None, MergeFailureReason.TARGET_IS_LOCKED,
1894 metadata={'locked_by': locked_by})
1893 metadata={'locked_by': locked_by})
1895 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1894 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1896 pull_request, target_ref):
1895 pull_request, target_ref):
1897 log.debug("Refreshing the merge status of the repository.")
1896 log.debug("Refreshing the merge status of the repository.")
1898 merge_state = self._refresh_merge_state(
1897 merge_state = self._refresh_merge_state(
1899 pull_request, target_vcs, target_ref)
1898 pull_request, target_vcs, target_ref)
1900 else:
1899 else:
1901 possible = pull_request.last_merge_status == MergeFailureReason.NONE
1900 possible = pull_request.last_merge_status == MergeFailureReason.NONE
1902 metadata = {
1901 metadata = {
1903 'unresolved_files': '',
1902 'unresolved_files': '',
1904 'target_ref': pull_request.target_ref_parts,
1903 'target_ref': pull_request.target_ref_parts,
1905 'source_ref': pull_request.source_ref_parts,
1904 'source_ref': pull_request.source_ref_parts,
1906 }
1905 }
1907 if pull_request.last_merge_metadata:
1906 if pull_request.last_merge_metadata:
1908 metadata.update(pull_request.last_merge_metadata_parsed)
1907 metadata.update(pull_request.last_merge_metadata_parsed)
1909
1908
1910 if not possible and target_ref.type == 'branch':
1909 if not possible and target_ref.type == 'branch':
1911 # NOTE(marcink): case for mercurial multiple heads on branch
1910 # NOTE(marcink): case for mercurial multiple heads on branch
1912 heads = target_vcs._heads(target_ref.name)
1911 heads = target_vcs._heads(target_ref.name)
1913 if len(heads) != 1:
1912 if len(heads) != 1:
1914 heads = '\n,'.join(target_vcs._heads(target_ref.name))
1913 heads = '\n,'.join(target_vcs._heads(target_ref.name))
1915 metadata.update({
1914 metadata.update({
1916 'heads': heads
1915 'heads': heads
1917 })
1916 })
1918
1917
1919 merge_state = MergeResponse(
1918 merge_state = MergeResponse(
1920 possible, False, None, pull_request.last_merge_status, metadata=metadata)
1919 possible, False, None, pull_request.last_merge_status, metadata=metadata)
1921
1920
1922 return merge_state
1921 return merge_state
1923
1922
1924 def _refresh_reference(self, reference, vcs_repository):
1923 def _refresh_reference(self, reference, vcs_repository):
1925 if reference.type in self.UPDATABLE_REF_TYPES:
1924 if reference.type in self.UPDATABLE_REF_TYPES:
1926 name_or_id = reference.name
1925 name_or_id = reference.name
1927 else:
1926 else:
1928 name_or_id = reference.commit_id
1927 name_or_id = reference.commit_id
1929
1928
1930 refreshed_commit = vcs_repository.get_commit(name_or_id)
1929 refreshed_commit = vcs_repository.get_commit(name_or_id)
1931 refreshed_reference = Reference(
1930 refreshed_reference = Reference(
1932 reference.type, reference.name, refreshed_commit.raw_id)
1931 reference.type, reference.name, refreshed_commit.raw_id)
1933 return refreshed_reference
1932 return refreshed_reference
1934
1933
1935 def _needs_merge_state_refresh(self, pull_request, target_reference):
1934 def _needs_merge_state_refresh(self, pull_request, target_reference):
1936 return not(
1935 return not(
1937 pull_request.revisions and
1936 pull_request.revisions and
1938 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1937 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1939 target_reference.commit_id == pull_request._last_merge_target_rev)
1938 target_reference.commit_id == pull_request._last_merge_target_rev)
1940
1939
1941 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1940 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1942 workspace_id = self._workspace_id(pull_request)
1941 workspace_id = self._workspace_id(pull_request)
1943 source_vcs = pull_request.source_repo.scm_instance()
1942 source_vcs = pull_request.source_repo.scm_instance()
1944 repo_id = pull_request.target_repo.repo_id
1943 repo_id = pull_request.target_repo.repo_id
1945 use_rebase = self._use_rebase_for_merging(pull_request)
1944 use_rebase = self._use_rebase_for_merging(pull_request)
1946 close_branch = self._close_branch_before_merging(pull_request)
1945 close_branch = self._close_branch_before_merging(pull_request)
1947 merge_state = target_vcs.merge(
1946 merge_state = target_vcs.merge(
1948 repo_id, workspace_id,
1947 repo_id, workspace_id,
1949 target_reference, source_vcs, pull_request.source_ref_parts,
1948 target_reference, source_vcs, pull_request.source_ref_parts,
1950 dry_run=True, use_rebase=use_rebase,
1949 dry_run=True, use_rebase=use_rebase,
1951 close_branch=close_branch)
1950 close_branch=close_branch)
1952
1951
1953 # Do not store the response if there was an unknown error.
1952 # Do not store the response if there was an unknown error.
1954 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1953 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1955 pull_request._last_merge_source_rev = \
1954 pull_request._last_merge_source_rev = \
1956 pull_request.source_ref_parts.commit_id
1955 pull_request.source_ref_parts.commit_id
1957 pull_request._last_merge_target_rev = target_reference.commit_id
1956 pull_request._last_merge_target_rev = target_reference.commit_id
1958 pull_request.last_merge_status = merge_state.failure_reason
1957 pull_request.last_merge_status = merge_state.failure_reason
1959 pull_request.last_merge_metadata = merge_state.metadata
1958 pull_request.last_merge_metadata = merge_state.metadata
1960
1959
1961 pull_request.shadow_merge_ref = merge_state.merge_ref
1960 pull_request.shadow_merge_ref = merge_state.merge_ref
1962 Session().add(pull_request)
1961 Session().add(pull_request)
1963 Session().commit()
1962 Session().commit()
1964
1963
1965 return merge_state
1964 return merge_state
1966
1965
1967 def _workspace_id(self, pull_request):
1966 def _workspace_id(self, pull_request):
1968 workspace_id = 'pr-%s' % pull_request.pull_request_id
1967 workspace_id = 'pr-%s' % pull_request.pull_request_id
1969 return workspace_id
1968 return workspace_id
1970
1969
1971 def generate_repo_data(self, repo, commit_id=None, branch=None,
1970 def generate_repo_data(self, repo, commit_id=None, branch=None,
1972 bookmark=None, translator=None):
1971 bookmark=None, translator=None):
1973 from rhodecode.model.repo import RepoModel
1972 from rhodecode.model.repo import RepoModel
1974
1973
1975 all_refs, selected_ref = \
1974 all_refs, selected_ref = \
1976 self._get_repo_pullrequest_sources(
1975 self._get_repo_pullrequest_sources(
1977 repo.scm_instance(), commit_id=commit_id,
1976 repo.scm_instance(), commit_id=commit_id,
1978 branch=branch, bookmark=bookmark, translator=translator)
1977 branch=branch, bookmark=bookmark, translator=translator)
1979
1978
1980 refs_select2 = []
1979 refs_select2 = []
1981 for element in all_refs:
1980 for element in all_refs:
1982 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1981 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1983 refs_select2.append({'text': element[1], 'children': children})
1982 refs_select2.append({'text': element[1], 'children': children})
1984
1983
1985 return {
1984 return {
1986 'user': {
1985 'user': {
1987 'user_id': repo.user.user_id,
1986 'user_id': repo.user.user_id,
1988 'username': repo.user.username,
1987 'username': repo.user.username,
1989 'firstname': repo.user.first_name,
1988 'firstname': repo.user.first_name,
1990 'lastname': repo.user.last_name,
1989 'lastname': repo.user.last_name,
1991 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1990 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1992 },
1991 },
1993 'name': repo.repo_name,
1992 'name': repo.repo_name,
1994 'link': RepoModel().get_url(repo),
1993 'link': RepoModel().get_url(repo),
1995 'description': h.chop_at_smart(repo.description_safe, '\n'),
1994 'description': h.chop_at_smart(repo.description_safe, '\n'),
1996 'refs': {
1995 'refs': {
1997 'all_refs': all_refs,
1996 'all_refs': all_refs,
1998 'selected_ref': selected_ref,
1997 'selected_ref': selected_ref,
1999 'select2_refs': refs_select2
1998 'select2_refs': refs_select2
2000 }
1999 }
2001 }
2000 }
2002
2001
2003 def generate_pullrequest_title(self, source, source_ref, target):
2002 def generate_pullrequest_title(self, source, source_ref, target):
2004 return u'{source}#{at_ref} to {target}'.format(
2003 return u'{source}#{at_ref} to {target}'.format(
2005 source=source,
2004 source=source,
2006 at_ref=source_ref,
2005 at_ref=source_ref,
2007 target=target,
2006 target=target,
2008 )
2007 )
2009
2008
2010 def _cleanup_merge_workspace(self, pull_request):
2009 def _cleanup_merge_workspace(self, pull_request):
2011 # Merging related cleanup
2010 # Merging related cleanup
2012 repo_id = pull_request.target_repo.repo_id
2011 repo_id = pull_request.target_repo.repo_id
2013 target_scm = pull_request.target_repo.scm_instance()
2012 target_scm = pull_request.target_repo.scm_instance()
2014 workspace_id = self._workspace_id(pull_request)
2013 workspace_id = self._workspace_id(pull_request)
2015
2014
2016 try:
2015 try:
2017 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
2016 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
2018 except NotImplementedError:
2017 except NotImplementedError:
2019 pass
2018 pass
2020
2019
2021 def _get_repo_pullrequest_sources(
2020 def _get_repo_pullrequest_sources(
2022 self, repo, commit_id=None, branch=None, bookmark=None,
2021 self, repo, commit_id=None, branch=None, bookmark=None,
2023 translator=None):
2022 translator=None):
2024 """
2023 """
2025 Return a structure with repo's interesting commits, suitable for
2024 Return a structure with repo's interesting commits, suitable for
2026 the selectors in pullrequest controller
2025 the selectors in pullrequest controller
2027
2026
2028 :param commit_id: a commit that must be in the list somehow
2027 :param commit_id: a commit that must be in the list somehow
2029 and selected by default
2028 and selected by default
2030 :param branch: a branch that must be in the list and selected
2029 :param branch: a branch that must be in the list and selected
2031 by default - even if closed
2030 by default - even if closed
2032 :param bookmark: a bookmark that must be in the list and selected
2031 :param bookmark: a bookmark that must be in the list and selected
2033 """
2032 """
2034 _ = translator or get_current_request().translate
2033 _ = translator or get_current_request().translate
2035
2034
2036 commit_id = safe_str(commit_id) if commit_id else None
2035 commit_id = safe_str(commit_id) if commit_id else None
2037 branch = safe_str(branch) if branch else None
2036 branch = safe_str(branch) if branch else None
2038 bookmark = safe_str(bookmark) if bookmark else None
2037 bookmark = safe_str(bookmark) if bookmark else None
2039
2038
2040 selected = None
2039 selected = None
2041
2040
2042 # order matters: first source that has commit_id in it will be selected
2041 # order matters: first source that has commit_id in it will be selected
2043 sources = []
2042 sources = []
2044 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
2043 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
2045 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
2044 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
2046
2045
2047 if commit_id:
2046 if commit_id:
2048 ref_commit = (h.short_id(commit_id), commit_id)
2047 ref_commit = (h.short_id(commit_id), commit_id)
2049 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
2048 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
2050
2049
2051 sources.append(
2050 sources.append(
2052 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
2051 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
2053 )
2052 )
2054
2053
2055 groups = []
2054 groups = []
2056
2055
2057 for group_key, ref_list, group_name, match in sources:
2056 for group_key, ref_list, group_name, match in sources:
2058 group_refs = []
2057 group_refs = []
2059 for ref_name, ref_id in ref_list:
2058 for ref_name, ref_id in ref_list:
2060 ref_key = u'{}:{}:{}'.format(group_key, ref_name, ref_id)
2059 ref_key = u'{}:{}:{}'.format(group_key, ref_name, ref_id)
2061 group_refs.append((ref_key, ref_name))
2060 group_refs.append((ref_key, ref_name))
2062
2061
2063 if not selected:
2062 if not selected:
2064 if set([commit_id, match]) & set([ref_id, ref_name]):
2063 if set([commit_id, match]) & set([ref_id, ref_name]):
2065 selected = ref_key
2064 selected = ref_key
2066
2065
2067 if group_refs:
2066 if group_refs:
2068 groups.append((group_refs, group_name))
2067 groups.append((group_refs, group_name))
2069
2068
2070 if not selected:
2069 if not selected:
2071 ref = commit_id or branch or bookmark
2070 ref = commit_id or branch or bookmark
2072 if ref:
2071 if ref:
2073 raise CommitDoesNotExistError(
2072 raise CommitDoesNotExistError(
2074 u'No commit refs could be found matching: {}'.format(ref))
2073 u'No commit refs could be found matching: {}'.format(ref))
2075 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
2074 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
2076 selected = u'branch:{}:{}'.format(
2075 selected = u'branch:{}:{}'.format(
2077 safe_str(repo.DEFAULT_BRANCH_NAME),
2076 safe_str(repo.DEFAULT_BRANCH_NAME),
2078 safe_str(repo.branches[repo.DEFAULT_BRANCH_NAME])
2077 safe_str(repo.branches[repo.DEFAULT_BRANCH_NAME])
2079 )
2078 )
2080 elif repo.commit_ids:
2079 elif repo.commit_ids:
2081 # make the user select in this case
2080 # make the user select in this case
2082 selected = None
2081 selected = None
2083 else:
2082 else:
2084 raise EmptyRepositoryError()
2083 raise EmptyRepositoryError()
2085 return groups, selected
2084 return groups, selected
2086
2085
2087 def get_diff(self, source_repo, source_ref_id, target_ref_id,
2086 def get_diff(self, source_repo, source_ref_id, target_ref_id,
2088 hide_whitespace_changes, diff_context):
2087 hide_whitespace_changes, diff_context):
2089
2088
2090 return self._get_diff_from_pr_or_version(
2089 return self._get_diff_from_pr_or_version(
2091 source_repo, source_ref_id, target_ref_id,
2090 source_repo, source_ref_id, target_ref_id,
2092 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
2091 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
2093
2092
2094 def _get_diff_from_pr_or_version(
2093 def _get_diff_from_pr_or_version(
2095 self, source_repo, source_ref_id, target_ref_id,
2094 self, source_repo, source_ref_id, target_ref_id,
2096 hide_whitespace_changes, diff_context):
2095 hide_whitespace_changes, diff_context):
2097
2096
2098 target_commit = source_repo.get_commit(
2097 target_commit = source_repo.get_commit(
2099 commit_id=safe_str(target_ref_id))
2098 commit_id=safe_str(target_ref_id))
2100 source_commit = source_repo.get_commit(
2099 source_commit = source_repo.get_commit(
2101 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
2100 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
2102 if isinstance(source_repo, Repository):
2101 if isinstance(source_repo, Repository):
2103 vcs_repo = source_repo.scm_instance()
2102 vcs_repo = source_repo.scm_instance()
2104 else:
2103 else:
2105 vcs_repo = source_repo
2104 vcs_repo = source_repo
2106
2105
2107 # TODO: johbo: In the context of an update, we cannot reach
2106 # TODO: johbo: In the context of an update, we cannot reach
2108 # the old commit anymore with our normal mechanisms. It needs
2107 # the old commit anymore with our normal mechanisms. It needs
2109 # some sort of special support in the vcs layer to avoid this
2108 # some sort of special support in the vcs layer to avoid this
2110 # workaround.
2109 # workaround.
2111 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
2110 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
2112 vcs_repo.alias == 'git'):
2111 vcs_repo.alias == 'git'):
2113 source_commit.raw_id = safe_str(source_ref_id)
2112 source_commit.raw_id = safe_str(source_ref_id)
2114
2113
2115 log.debug('calculating diff between '
2114 log.debug('calculating diff between '
2116 'source_ref:%s and target_ref:%s for repo `%s`',
2115 'source_ref:%s and target_ref:%s for repo `%s`',
2117 target_ref_id, source_ref_id,
2116 target_ref_id, source_ref_id,
2118 safe_str(vcs_repo.path))
2117 safe_str(vcs_repo.path))
2119
2118
2120 vcs_diff = vcs_repo.get_diff(
2119 vcs_diff = vcs_repo.get_diff(
2121 commit1=target_commit, commit2=source_commit,
2120 commit1=target_commit, commit2=source_commit,
2122 ignore_whitespace=hide_whitespace_changes, context=diff_context)
2121 ignore_whitespace=hide_whitespace_changes, context=diff_context)
2123 return vcs_diff
2122 return vcs_diff
2124
2123
2125 def _is_merge_enabled(self, pull_request):
2124 def _is_merge_enabled(self, pull_request):
2126 return self._get_general_setting(
2125 return self._get_general_setting(
2127 pull_request, 'rhodecode_pr_merge_enabled')
2126 pull_request, 'rhodecode_pr_merge_enabled')
2128
2127
2129 def _use_rebase_for_merging(self, pull_request):
2128 def _use_rebase_for_merging(self, pull_request):
2130 repo_type = pull_request.target_repo.repo_type
2129 repo_type = pull_request.target_repo.repo_type
2131 if repo_type == 'hg':
2130 if repo_type == 'hg':
2132 return self._get_general_setting(
2131 return self._get_general_setting(
2133 pull_request, 'rhodecode_hg_use_rebase_for_merging')
2132 pull_request, 'rhodecode_hg_use_rebase_for_merging')
2134 elif repo_type == 'git':
2133 elif repo_type == 'git':
2135 return self._get_general_setting(
2134 return self._get_general_setting(
2136 pull_request, 'rhodecode_git_use_rebase_for_merging')
2135 pull_request, 'rhodecode_git_use_rebase_for_merging')
2137
2136
2138 return False
2137 return False
2139
2138
2140 def _user_name_for_merging(self, pull_request, user):
2139 def _user_name_for_merging(self, pull_request, user):
2141 env_user_name_attr = os.environ.get('RC_MERGE_USER_NAME_ATTR', '')
2140 env_user_name_attr = os.environ.get('RC_MERGE_USER_NAME_ATTR', '')
2142 if env_user_name_attr and hasattr(user, env_user_name_attr):
2141 if env_user_name_attr and hasattr(user, env_user_name_attr):
2143 user_name_attr = env_user_name_attr
2142 user_name_attr = env_user_name_attr
2144 else:
2143 else:
2145 user_name_attr = 'short_contact'
2144 user_name_attr = 'short_contact'
2146
2145
2147 user_name = getattr(user, user_name_attr)
2146 user_name = getattr(user, user_name_attr)
2148 return user_name
2147 return user_name
2149
2148
2150 def _close_branch_before_merging(self, pull_request):
2149 def _close_branch_before_merging(self, pull_request):
2151 repo_type = pull_request.target_repo.repo_type
2150 repo_type = pull_request.target_repo.repo_type
2152 if repo_type == 'hg':
2151 if repo_type == 'hg':
2153 return self._get_general_setting(
2152 return self._get_general_setting(
2154 pull_request, 'rhodecode_hg_close_branch_before_merging')
2153 pull_request, 'rhodecode_hg_close_branch_before_merging')
2155 elif repo_type == 'git':
2154 elif repo_type == 'git':
2156 return self._get_general_setting(
2155 return self._get_general_setting(
2157 pull_request, 'rhodecode_git_close_branch_before_merging')
2156 pull_request, 'rhodecode_git_close_branch_before_merging')
2158
2157
2159 return False
2158 return False
2160
2159
2161 def _get_general_setting(self, pull_request, settings_key, default=False):
2160 def _get_general_setting(self, pull_request, settings_key, default=False):
2162 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
2161 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
2163 settings = settings_model.get_general_settings()
2162 settings = settings_model.get_general_settings()
2164 return settings.get(settings_key, default)
2163 return settings.get(settings_key, default)
2165
2164
2166 def _log_audit_action(self, action, action_data, user, pull_request):
2165 def _log_audit_action(self, action, action_data, user, pull_request):
2167 audit_logger.store(
2166 audit_logger.store(
2168 action=action,
2167 action=action,
2169 action_data=action_data,
2168 action_data=action_data,
2170 user=user,
2169 user=user,
2171 repo=pull_request.target_repo)
2170 repo=pull_request.target_repo)
2172
2171
2173 def get_reviewer_functions(self):
2172 def get_reviewer_functions(self):
2174 """
2173 """
2175 Fetches functions for validation and fetching default reviewers.
2174 Fetches functions for validation and fetching default reviewers.
2176 If available we use the EE package, else we fallback to CE
2175 If available we use the EE package, else we fallback to CE
2177 package functions
2176 package functions
2178 """
2177 """
2179 try:
2178 try:
2180 from rc_reviewers.utils import get_default_reviewers_data
2179 from rc_reviewers.utils import get_default_reviewers_data
2181 from rc_reviewers.utils import validate_default_reviewers
2180 from rc_reviewers.utils import validate_default_reviewers
2182 from rc_reviewers.utils import validate_observers
2181 from rc_reviewers.utils import validate_observers
2183 except ImportError:
2182 except ImportError:
2184 from rhodecode.apps.repository.utils import get_default_reviewers_data
2183 from rhodecode.apps.repository.utils import get_default_reviewers_data
2185 from rhodecode.apps.repository.utils import validate_default_reviewers
2184 from rhodecode.apps.repository.utils import validate_default_reviewers
2186 from rhodecode.apps.repository.utils import validate_observers
2185 from rhodecode.apps.repository.utils import validate_observers
2187
2186
2188 return get_default_reviewers_data, validate_default_reviewers, validate_observers
2187 return get_default_reviewers_data, validate_default_reviewers, validate_observers
2189
2188
2190
2189
2191 class MergeCheck(object):
2190 class MergeCheck(object):
2192 """
2191 """
2193 Perform Merge Checks and returns a check object which stores information
2192 Perform Merge Checks and returns a check object which stores information
2194 about merge errors, and merge conditions
2193 about merge errors, and merge conditions
2195 """
2194 """
2196 TODO_CHECK = 'todo'
2195 TODO_CHECK = 'todo'
2197 PERM_CHECK = 'perm'
2196 PERM_CHECK = 'perm'
2198 REVIEW_CHECK = 'review'
2197 REVIEW_CHECK = 'review'
2199 MERGE_CHECK = 'merge'
2198 MERGE_CHECK = 'merge'
2200 WIP_CHECK = 'wip'
2199 WIP_CHECK = 'wip'
2201
2200
2202 def __init__(self):
2201 def __init__(self):
2203 self.review_status = None
2202 self.review_status = None
2204 self.merge_possible = None
2203 self.merge_possible = None
2205 self.merge_msg = ''
2204 self.merge_msg = ''
2206 self.merge_response = None
2205 self.merge_response = None
2207 self.failed = None
2206 self.failed = None
2208 self.errors = []
2207 self.errors = []
2209 self.error_details = OrderedDict()
2208 self.error_details = OrderedDict()
2210 self.source_commit = AttributeDict()
2209 self.source_commit = AttributeDict()
2211 self.target_commit = AttributeDict()
2210 self.target_commit = AttributeDict()
2212 self.reviewers_count = 0
2211 self.reviewers_count = 0
2213 self.observers_count = 0
2212 self.observers_count = 0
2214
2213
2215 def __repr__(self):
2214 def __repr__(self):
2216 return '<MergeCheck(possible:{}, failed:{}, errors:{})>'.format(
2215 return '<MergeCheck(possible:{}, failed:{}, errors:{})>'.format(
2217 self.merge_possible, self.failed, self.errors)
2216 self.merge_possible, self.failed, self.errors)
2218
2217
2219 def push_error(self, error_type, message, error_key, details):
2218 def push_error(self, error_type, message, error_key, details):
2220 self.failed = True
2219 self.failed = True
2221 self.errors.append([error_type, message])
2220 self.errors.append([error_type, message])
2222 self.error_details[error_key] = dict(
2221 self.error_details[error_key] = dict(
2223 details=details,
2222 details=details,
2224 error_type=error_type,
2223 error_type=error_type,
2225 message=message
2224 message=message
2226 )
2225 )
2227
2226
2228 @classmethod
2227 @classmethod
2229 def validate(cls, pull_request, auth_user, translator, fail_early=False,
2228 def validate(cls, pull_request, auth_user, translator, fail_early=False,
2230 force_shadow_repo_refresh=False):
2229 force_shadow_repo_refresh=False):
2231 _ = translator
2230 _ = translator
2232 merge_check = cls()
2231 merge_check = cls()
2233
2232
2234 # title has WIP:
2233 # title has WIP:
2235 if pull_request.work_in_progress:
2234 if pull_request.work_in_progress:
2236 log.debug("MergeCheck: cannot merge, title has wip: marker.")
2235 log.debug("MergeCheck: cannot merge, title has wip: marker.")
2237
2236
2238 msg = _('WIP marker in title prevents from accidental merge.')
2237 msg = _('WIP marker in title prevents from accidental merge.')
2239 merge_check.push_error('error', msg, cls.WIP_CHECK, pull_request.title)
2238 merge_check.push_error('error', msg, cls.WIP_CHECK, pull_request.title)
2240 if fail_early:
2239 if fail_early:
2241 return merge_check
2240 return merge_check
2242
2241
2243 # permissions to merge
2242 # permissions to merge
2244 user_allowed_to_merge = PullRequestModel().check_user_merge(pull_request, auth_user)
2243 user_allowed_to_merge = PullRequestModel().check_user_merge(pull_request, auth_user)
2245 if not user_allowed_to_merge:
2244 if not user_allowed_to_merge:
2246 log.debug("MergeCheck: cannot merge, approval is pending.")
2245 log.debug("MergeCheck: cannot merge, approval is pending.")
2247
2246
2248 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
2247 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
2249 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2248 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2250 if fail_early:
2249 if fail_early:
2251 return merge_check
2250 return merge_check
2252
2251
2253 # permission to merge into the target branch
2252 # permission to merge into the target branch
2254 target_commit_id = pull_request.target_ref_parts.commit_id
2253 target_commit_id = pull_request.target_ref_parts.commit_id
2255 if pull_request.target_ref_parts.type == 'branch':
2254 if pull_request.target_ref_parts.type == 'branch':
2256 branch_name = pull_request.target_ref_parts.name
2255 branch_name = pull_request.target_ref_parts.name
2257 else:
2256 else:
2258 # for mercurial we can always figure out the branch from the commit
2257 # for mercurial we can always figure out the branch from the commit
2259 # in case of bookmark
2258 # in case of bookmark
2260 target_commit = pull_request.target_repo.get_commit(target_commit_id)
2259 target_commit = pull_request.target_repo.get_commit(target_commit_id)
2261 branch_name = target_commit.branch
2260 branch_name = target_commit.branch
2262
2261
2263 rule, branch_perm = auth_user.get_rule_and_branch_permission(
2262 rule, branch_perm = auth_user.get_rule_and_branch_permission(
2264 pull_request.target_repo.repo_name, branch_name)
2263 pull_request.target_repo.repo_name, branch_name)
2265 if branch_perm and branch_perm == 'branch.none':
2264 if branch_perm and branch_perm == 'branch.none':
2266 msg = _('Target branch `{}` changes rejected by rule {}.').format(
2265 msg = _('Target branch `{}` changes rejected by rule {}.').format(
2267 branch_name, rule)
2266 branch_name, rule)
2268 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2267 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2269 if fail_early:
2268 if fail_early:
2270 return merge_check
2269 return merge_check
2271
2270
2272 # review status, must be always present
2271 # review status, must be always present
2273 review_status = pull_request.calculated_review_status()
2272 review_status = pull_request.calculated_review_status()
2274 merge_check.review_status = review_status
2273 merge_check.review_status = review_status
2275 merge_check.reviewers_count = pull_request.reviewers_count
2274 merge_check.reviewers_count = pull_request.reviewers_count
2276 merge_check.observers_count = pull_request.observers_count
2275 merge_check.observers_count = pull_request.observers_count
2277
2276
2278 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
2277 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
2279 if not status_approved and merge_check.reviewers_count:
2278 if not status_approved and merge_check.reviewers_count:
2280 log.debug("MergeCheck: cannot merge, approval is pending.")
2279 log.debug("MergeCheck: cannot merge, approval is pending.")
2281 msg = _('Pull request reviewer approval is pending.')
2280 msg = _('Pull request reviewer approval is pending.')
2282
2281
2283 merge_check.push_error('warning', msg, cls.REVIEW_CHECK, review_status)
2282 merge_check.push_error('warning', msg, cls.REVIEW_CHECK, review_status)
2284
2283
2285 if fail_early:
2284 if fail_early:
2286 return merge_check
2285 return merge_check
2287
2286
2288 # left over TODOs
2287 # left over TODOs
2289 todos = CommentsModel().get_pull_request_unresolved_todos(pull_request)
2288 todos = CommentsModel().get_pull_request_unresolved_todos(pull_request)
2290 if todos:
2289 if todos:
2291 log.debug("MergeCheck: cannot merge, {} "
2290 log.debug("MergeCheck: cannot merge, {} "
2292 "unresolved TODOs left.".format(len(todos)))
2291 "unresolved TODOs left.".format(len(todos)))
2293
2292
2294 if len(todos) == 1:
2293 if len(todos) == 1:
2295 msg = _('Cannot merge, {} TODO still not resolved.').format(
2294 msg = _('Cannot merge, {} TODO still not resolved.').format(
2296 len(todos))
2295 len(todos))
2297 else:
2296 else:
2298 msg = _('Cannot merge, {} TODOs still not resolved.').format(
2297 msg = _('Cannot merge, {} TODOs still not resolved.').format(
2299 len(todos))
2298 len(todos))
2300
2299
2301 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
2300 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
2302
2301
2303 if fail_early:
2302 if fail_early:
2304 return merge_check
2303 return merge_check
2305
2304
2306 # merge possible, here is the filesystem simulation + shadow repo
2305 # merge possible, here is the filesystem simulation + shadow repo
2307 merge_response, merge_status, msg = PullRequestModel().merge_status(
2306 merge_response, merge_status, msg = PullRequestModel().merge_status(
2308 pull_request, translator=translator,
2307 pull_request, translator=translator,
2309 force_shadow_repo_refresh=force_shadow_repo_refresh)
2308 force_shadow_repo_refresh=force_shadow_repo_refresh)
2310
2309
2311 merge_check.merge_possible = merge_status
2310 merge_check.merge_possible = merge_status
2312 merge_check.merge_msg = msg
2311 merge_check.merge_msg = msg
2313 merge_check.merge_response = merge_response
2312 merge_check.merge_response = merge_response
2314
2313
2315 source_ref_id = pull_request.source_ref_parts.commit_id
2314 source_ref_id = pull_request.source_ref_parts.commit_id
2316 target_ref_id = pull_request.target_ref_parts.commit_id
2315 target_ref_id = pull_request.target_ref_parts.commit_id
2317
2316
2318 try:
2317 try:
2319 source_commit, target_commit = PullRequestModel().get_flow_commits(pull_request)
2318 source_commit, target_commit = PullRequestModel().get_flow_commits(pull_request)
2320 merge_check.source_commit.changed = source_ref_id != source_commit.raw_id
2319 merge_check.source_commit.changed = source_ref_id != source_commit.raw_id
2321 merge_check.source_commit.ref_spec = pull_request.source_ref_parts
2320 merge_check.source_commit.ref_spec = pull_request.source_ref_parts
2322 merge_check.source_commit.current_raw_id = source_commit.raw_id
2321 merge_check.source_commit.current_raw_id = source_commit.raw_id
2323 merge_check.source_commit.previous_raw_id = source_ref_id
2322 merge_check.source_commit.previous_raw_id = source_ref_id
2324
2323
2325 merge_check.target_commit.changed = target_ref_id != target_commit.raw_id
2324 merge_check.target_commit.changed = target_ref_id != target_commit.raw_id
2326 merge_check.target_commit.ref_spec = pull_request.target_ref_parts
2325 merge_check.target_commit.ref_spec = pull_request.target_ref_parts
2327 merge_check.target_commit.current_raw_id = target_commit.raw_id
2326 merge_check.target_commit.current_raw_id = target_commit.raw_id
2328 merge_check.target_commit.previous_raw_id = target_ref_id
2327 merge_check.target_commit.previous_raw_id = target_ref_id
2329 except (SourceRefMissing, TargetRefMissing):
2328 except (SourceRefMissing, TargetRefMissing):
2330 pass
2329 pass
2331
2330
2332 if not merge_status:
2331 if not merge_status:
2333 log.debug("MergeCheck: cannot merge, pull request merge not possible.")
2332 log.debug("MergeCheck: cannot merge, pull request merge not possible.")
2334 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
2333 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
2335
2334
2336 if fail_early:
2335 if fail_early:
2337 return merge_check
2336 return merge_check
2338
2337
2339 log.debug('MergeCheck: is failed: %s', merge_check.failed)
2338 log.debug('MergeCheck: is failed: %s', merge_check.failed)
2340 return merge_check
2339 return merge_check
2341
2340
2342 @classmethod
2341 @classmethod
2343 def get_merge_conditions(cls, pull_request, translator):
2342 def get_merge_conditions(cls, pull_request, translator):
2344 _ = translator
2343 _ = translator
2345 merge_details = {}
2344 merge_details = {}
2346
2345
2347 model = PullRequestModel()
2346 model = PullRequestModel()
2348 use_rebase = model._use_rebase_for_merging(pull_request)
2347 use_rebase = model._use_rebase_for_merging(pull_request)
2349
2348
2350 if use_rebase:
2349 if use_rebase:
2351 merge_details['merge_strategy'] = dict(
2350 merge_details['merge_strategy'] = dict(
2352 details={},
2351 details={},
2353 message=_('Merge strategy: rebase')
2352 message=_('Merge strategy: rebase')
2354 )
2353 )
2355 else:
2354 else:
2356 merge_details['merge_strategy'] = dict(
2355 merge_details['merge_strategy'] = dict(
2357 details={},
2356 details={},
2358 message=_('Merge strategy: explicit merge commit')
2357 message=_('Merge strategy: explicit merge commit')
2359 )
2358 )
2360
2359
2361 close_branch = model._close_branch_before_merging(pull_request)
2360 close_branch = model._close_branch_before_merging(pull_request)
2362 if close_branch:
2361 if close_branch:
2363 repo_type = pull_request.target_repo.repo_type
2362 repo_type = pull_request.target_repo.repo_type
2364 close_msg = ''
2363 close_msg = ''
2365 if repo_type == 'hg':
2364 if repo_type == 'hg':
2366 close_msg = _('Source branch will be closed before the merge.')
2365 close_msg = _('Source branch will be closed before the merge.')
2367 elif repo_type == 'git':
2366 elif repo_type == 'git':
2368 close_msg = _('Source branch will be deleted after the merge.')
2367 close_msg = _('Source branch will be deleted after the merge.')
2369
2368
2370 merge_details['close_branch'] = dict(
2369 merge_details['close_branch'] = dict(
2371 details={},
2370 details={},
2372 message=close_msg
2371 message=close_msg
2373 )
2372 )
2374
2373
2375 return merge_details
2374 return merge_details
2376
2375
2377
2376
2378 @dataclasses.dataclass
2377 @dataclasses.dataclass
2379 class ChangeTuple:
2378 class ChangeTuple:
2380 added: list
2379 added: list
2381 common: list
2380 common: list
2382 removed: list
2381 removed: list
2383 total: list
2382 total: list
2384
2383
2385
2384
2386 @dataclasses.dataclass
2385 @dataclasses.dataclass
2387 class FileChangeTuple:
2386 class FileChangeTuple:
2388 added: list
2387 added: list
2389 modified: list
2388 modified: list
2390 removed: list
2389 removed: list
@@ -1,487 +1,451 b''
1
1
2 # Copyright (C) 2010-2023 RhodeCode GmbH
2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 import mock
20 import mock
21 import pytest
21 import pytest
22
22
23 from rhodecode.lib.str_utils import base64_to_str
23 from rhodecode.lib.str_utils import base64_to_str
24 from rhodecode.lib.utils2 import AttributeDict
24 from rhodecode.lib.utils2 import AttributeDict
25 from rhodecode.tests.utils import CustomTestApp
25 from rhodecode.tests.utils import CustomTestApp
26
26
27 from rhodecode.lib.caching_query import FromCache
27 from rhodecode.lib.caching_query import FromCache
28 from rhodecode.lib.hooks_daemon import DummyHooksCallbackDaemon
29 from rhodecode.lib.middleware import simplevcs
28 from rhodecode.lib.middleware import simplevcs
30 from rhodecode.lib.middleware.https_fixup import HttpsFixup
29 from rhodecode.lib.middleware.https_fixup import HttpsFixup
31 from rhodecode.lib.middleware.utils import scm_app_http
30 from rhodecode.lib.middleware.utils import scm_app_http
32 from rhodecode.model.db import User, _hash_key
31 from rhodecode.model.db import User, _hash_key
33 from rhodecode.model.meta import Session, cache as db_cache
32 from rhodecode.model.meta import Session, cache as db_cache
34 from rhodecode.tests import (
33 from rhodecode.tests import (
35 HG_REPO, TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS)
34 HG_REPO, TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS)
36 from rhodecode.tests.lib.middleware import mock_scm_app
35 from rhodecode.tests.lib.middleware import mock_scm_app
37
36
38
37
39 class StubVCSController(simplevcs.SimpleVCS):
38 class StubVCSController(simplevcs.SimpleVCS):
40
39
41 SCM = 'hg'
40 SCM = 'hg'
42 stub_response_body = tuple()
41 stub_response_body = tuple()
43
42
44 def __init__(self, *args, **kwargs):
43 def __init__(self, *args, **kwargs):
45 super(StubVCSController, self).__init__(*args, **kwargs)
44 super(StubVCSController, self).__init__(*args, **kwargs)
46 self._action = 'pull'
45 self._action = 'pull'
47 self._is_shadow_repo_dir = True
46 self._is_shadow_repo_dir = True
48 self._name = HG_REPO
47 self._name = HG_REPO
49 self.set_repo_names(None)
48 self.set_repo_names(None)
50
49
51 @property
50 @property
52 def is_shadow_repo_dir(self):
51 def is_shadow_repo_dir(self):
53 return self._is_shadow_repo_dir
52 return self._is_shadow_repo_dir
54
53
55 def _get_repository_name(self, environ):
54 def _get_repository_name(self, environ):
56 return self._name
55 return self._name
57
56
58 def _get_action(self, environ):
57 def _get_action(self, environ):
59 return self._action
58 return self._action
60
59
61 def _create_wsgi_app(self, repo_path, repo_name, config):
60 def _create_wsgi_app(self, repo_path, repo_name, config):
62 def fake_app(environ, start_response):
61 def fake_app(environ, start_response):
63 headers = [
62 headers = [
64 ('Http-Accept', 'application/mercurial')
63 ('Http-Accept', 'application/mercurial')
65 ]
64 ]
66 start_response('200 OK', headers)
65 start_response('200 OK', headers)
67 return self.stub_response_body
66 return self.stub_response_body
68 return fake_app
67 return fake_app
69
68
70 def _create_config(self, extras, repo_name, scheme='http'):
69 def _create_config(self, extras, repo_name, scheme='http'):
71 return None
70 return None
72
71
73
72
74 @pytest.fixture()
73 @pytest.fixture()
75 def vcscontroller(baseapp, config_stub, request_stub):
74 def vcscontroller(baseapp, config_stub, request_stub):
76 from rhodecode.config.middleware import ce_auth_resources
75 from rhodecode.config.middleware import ce_auth_resources
77
76
78 config_stub.testing_securitypolicy()
77 config_stub.testing_securitypolicy()
79 config_stub.include('rhodecode.authentication')
78 config_stub.include('rhodecode.authentication')
80
79
81 for resource in ce_auth_resources:
80 for resource in ce_auth_resources:
82 config_stub.include(resource)
81 config_stub.include(resource)
83
82
84 controller = StubVCSController(
83 controller = StubVCSController(
85 baseapp.config.get_settings(), request_stub.registry)
84 baseapp.config.get_settings(), request_stub.registry)
86 app = HttpsFixup(controller, baseapp.config.get_settings())
85 app = HttpsFixup(controller, baseapp.config.get_settings())
87 app = CustomTestApp(app)
86 app = CustomTestApp(app)
88
87
89 _remove_default_user_from_query_cache()
88 _remove_default_user_from_query_cache()
90
89
91 # Sanity checks that things are set up correctly
90 # Sanity checks that things are set up correctly
92 app.get('/' + HG_REPO, status=200)
91 app.get('/' + HG_REPO, status=200)
93
92
94 app.controller = controller
93 app.controller = controller
95 return app
94 return app
96
95
97
96
98 def _remove_default_user_from_query_cache():
97 def _remove_default_user_from_query_cache():
99 user = User.get_default_user(cache=True)
98 user = User.get_default_user(cache=True)
100 query = Session().query(User).filter(User.username == user.username)
99 query = Session().query(User).filter(User.username == user.username)
101 query = query.options(
100 query = query.options(
102 FromCache("sql_cache_short", f"get_user_{_hash_key(user.username)}"))
101 FromCache("sql_cache_short", f"get_user_{_hash_key(user.username)}"))
103
102
104 db_cache.invalidate(
103 db_cache.invalidate(
105 query, {},
104 query, {},
106 FromCache("sql_cache_short", f"get_user_{_hash_key(user.username)}"))
105 FromCache("sql_cache_short", f"get_user_{_hash_key(user.username)}"))
107
106
108 Session().expire(user)
107 Session().expire(user)
109
108
110
109
111 def test_handles_exceptions_during_permissions_checks(
110 def test_handles_exceptions_during_permissions_checks(
112 vcscontroller, disable_anonymous_user, enable_auth_plugins, test_user_factory):
111 vcscontroller, disable_anonymous_user, enable_auth_plugins, test_user_factory):
113
112
114 test_password = 'qweqwe'
113 test_password = 'qweqwe'
115 test_user = test_user_factory(password=test_password, extern_type='headers', extern_name='headers')
114 test_user = test_user_factory(password=test_password, extern_type='headers', extern_name='headers')
116 test_username = test_user.username
115 test_username = test_user.username
117
116
118 enable_auth_plugins.enable([
117 enable_auth_plugins.enable([
119 'egg:rhodecode-enterprise-ce#headers',
118 'egg:rhodecode-enterprise-ce#headers',
120 'egg:rhodecode-enterprise-ce#token',
119 'egg:rhodecode-enterprise-ce#token',
121 'egg:rhodecode-enterprise-ce#rhodecode'],
120 'egg:rhodecode-enterprise-ce#rhodecode'],
122 override={
121 override={
123 'egg:rhodecode-enterprise-ce#headers': {'auth_headers_header': 'REMOTE_USER'}
122 'egg:rhodecode-enterprise-ce#headers': {'auth_headers_header': 'REMOTE_USER'}
124 })
123 })
125
124
126 user_and_pass = f'{test_username}:{test_password}'
125 user_and_pass = f'{test_username}:{test_password}'
127 auth_password = base64_to_str(user_and_pass)
126 auth_password = base64_to_str(user_and_pass)
128
127
129 extra_environ = {
128 extra_environ = {
130 'AUTH_TYPE': 'Basic',
129 'AUTH_TYPE': 'Basic',
131 'HTTP_AUTHORIZATION': f'Basic {auth_password}',
130 'HTTP_AUTHORIZATION': f'Basic {auth_password}',
132 'REMOTE_USER': test_username,
131 'REMOTE_USER': test_username,
133 }
132 }
134
133
135 # Verify that things are hooked up correctly, we pass user with headers bound auth, and headers filled in
134 # Verify that things are hooked up correctly, we pass user with headers bound auth, and headers filled in
136 vcscontroller.get('/', status=200, extra_environ=extra_environ)
135 vcscontroller.get('/', status=200, extra_environ=extra_environ)
137
136
138 # Simulate trouble during permission checks
137 # Simulate trouble during permission checks
139 with mock.patch('rhodecode.model.db.User.get_by_username',
138 with mock.patch('rhodecode.model.db.User.get_by_username',
140 side_effect=Exception('permission_error_test')) as get_user:
139 side_effect=Exception('permission_error_test')) as get_user:
141 # Verify that a correct 500 is returned and check that the expected
140 # Verify that a correct 500 is returned and check that the expected
142 # code path was hit.
141 # code path was hit.
143 vcscontroller.get('/', status=500, extra_environ=extra_environ)
142 vcscontroller.get('/', status=500, extra_environ=extra_environ)
144 assert get_user.called
143 assert get_user.called
145
144
146
145
147 class StubFailVCSController(simplevcs.SimpleVCS):
146 class StubFailVCSController(simplevcs.SimpleVCS):
148 def _handle_request(self, environ, start_response):
147 def _handle_request(self, environ, start_response):
149 raise Exception("BOOM")
148 raise Exception("BOOM")
150
149
151
150
152 @pytest.fixture(scope='module')
151 @pytest.fixture(scope='module')
153 def fail_controller(baseapp):
152 def fail_controller(baseapp):
154 controller = StubFailVCSController(
153 controller = StubFailVCSController(
155 baseapp.config.get_settings(), baseapp.config)
154 baseapp.config.get_settings(), baseapp.config)
156 controller = HttpsFixup(controller, baseapp.config.get_settings())
155 controller = HttpsFixup(controller, baseapp.config.get_settings())
157 controller = CustomTestApp(controller)
156 controller = CustomTestApp(controller)
158 return controller
157 return controller
159
158
160
159
161 def test_handles_exceptions_as_internal_server_error(fail_controller):
160 def test_handles_exceptions_as_internal_server_error(fail_controller):
162 fail_controller.get('/', status=500)
161 fail_controller.get('/', status=500)
163
162
164
163
165 def test_provides_traceback_for_appenlight(fail_controller):
164 def test_provides_traceback_for_appenlight(fail_controller):
166 response = fail_controller.get(
165 response = fail_controller.get(
167 '/', status=500, extra_environ={'appenlight.client': 'fake'})
166 '/', status=500, extra_environ={'appenlight.client': 'fake'})
168 assert 'appenlight.__traceback' in response.request.environ
167 assert 'appenlight.__traceback' in response.request.environ
169
168
170
169
171 def test_provides_utils_scm_app_as_scm_app_by_default(baseapp, request_stub):
170 def test_provides_utils_scm_app_as_scm_app_by_default(baseapp, request_stub):
172 controller = StubVCSController(baseapp.config.get_settings(), request_stub.registry)
171 controller = StubVCSController(baseapp.config.get_settings(), request_stub.registry)
173 assert controller.scm_app is scm_app_http
172 assert controller.scm_app is scm_app_http
174
173
175
174
176 def test_allows_to_override_scm_app_via_config(baseapp, request_stub):
175 def test_allows_to_override_scm_app_via_config(baseapp, request_stub):
177 config = baseapp.config.get_settings().copy()
176 config = baseapp.config.get_settings().copy()
178 config['vcs.scm_app_implementation'] = (
177 config['vcs.scm_app_implementation'] = (
179 'rhodecode.tests.lib.middleware.mock_scm_app')
178 'rhodecode.tests.lib.middleware.mock_scm_app')
180 controller = StubVCSController(config, request_stub.registry)
179 controller = StubVCSController(config, request_stub.registry)
181 assert controller.scm_app is mock_scm_app
180 assert controller.scm_app is mock_scm_app
182
181
183
182
184 @pytest.mark.parametrize('query_string, expected', [
183 @pytest.mark.parametrize('query_string, expected', [
185 ('cmd=stub_command', True),
184 ('cmd=stub_command', True),
186 ('cmd=listkeys', False),
185 ('cmd=listkeys', False),
187 ])
186 ])
188 def test_should_check_locking(query_string, expected):
187 def test_should_check_locking(query_string, expected):
189 result = simplevcs._should_check_locking(query_string)
188 result = simplevcs._should_check_locking(query_string)
190 assert result == expected
189 assert result == expected
191
190
192
191
193 class TestShadowRepoRegularExpression(object):
192 class TestShadowRepoRegularExpression(object):
194 pr_segment = 'pull-request'
193 pr_segment = 'pull-request'
195 shadow_segment = 'repository'
194 shadow_segment = 'repository'
196
195
197 @pytest.mark.parametrize('url, expected', [
196 @pytest.mark.parametrize('url, expected', [
198 # repo with/without groups
197 # repo with/without groups
199 ('My-Repo/{pr_segment}/1/{shadow_segment}', True),
198 ('My-Repo/{pr_segment}/1/{shadow_segment}', True),
200 ('Group/My-Repo/{pr_segment}/2/{shadow_segment}', True),
199 ('Group/My-Repo/{pr_segment}/2/{shadow_segment}', True),
201 ('Group/Sub-Group/My-Repo/{pr_segment}/3/{shadow_segment}', True),
200 ('Group/Sub-Group/My-Repo/{pr_segment}/3/{shadow_segment}', True),
202 ('Group/Sub-Group1/Sub-Group2/My-Repo/{pr_segment}/3/{shadow_segment}', True),
201 ('Group/Sub-Group1/Sub-Group2/My-Repo/{pr_segment}/3/{shadow_segment}', True),
203
202
204 # pull request ID
203 # pull request ID
205 ('MyRepo/{pr_segment}/1/{shadow_segment}', True),
204 ('MyRepo/{pr_segment}/1/{shadow_segment}', True),
206 ('MyRepo/{pr_segment}/1234567890/{shadow_segment}', True),
205 ('MyRepo/{pr_segment}/1234567890/{shadow_segment}', True),
207 ('MyRepo/{pr_segment}/-1/{shadow_segment}', False),
206 ('MyRepo/{pr_segment}/-1/{shadow_segment}', False),
208 ('MyRepo/{pr_segment}/invalid/{shadow_segment}', False),
207 ('MyRepo/{pr_segment}/invalid/{shadow_segment}', False),
209
208
210 # unicode
209 # unicode
211 (u'Sp€çîál-Repö/{pr_segment}/1/{shadow_segment}', True),
210 (u'Sp€çîál-Repö/{pr_segment}/1/{shadow_segment}', True),
212 (u'Sp€çîál-Gröüp/Sp€çîál-Repö/{pr_segment}/1/{shadow_segment}', True),
211 (u'Sp€çîál-Gröüp/Sp€çîál-Repö/{pr_segment}/1/{shadow_segment}', True),
213
212
214 # trailing/leading slash
213 # trailing/leading slash
215 ('/My-Repo/{pr_segment}/1/{shadow_segment}', False),
214 ('/My-Repo/{pr_segment}/1/{shadow_segment}', False),
216 ('My-Repo/{pr_segment}/1/{shadow_segment}/', False),
215 ('My-Repo/{pr_segment}/1/{shadow_segment}/', False),
217 ('/My-Repo/{pr_segment}/1/{shadow_segment}/', False),
216 ('/My-Repo/{pr_segment}/1/{shadow_segment}/', False),
218
217
219 # misc
218 # misc
220 ('My-Repo/{pr_segment}/1/{shadow_segment}/extra', False),
219 ('My-Repo/{pr_segment}/1/{shadow_segment}/extra', False),
221 ('My-Repo/{pr_segment}/1/{shadow_segment}extra', False),
220 ('My-Repo/{pr_segment}/1/{shadow_segment}extra', False),
222 ])
221 ])
223 def test_shadow_repo_regular_expression(self, url, expected):
222 def test_shadow_repo_regular_expression(self, url, expected):
224 from rhodecode.lib.middleware.simplevcs import SimpleVCS
223 from rhodecode.lib.middleware.simplevcs import SimpleVCS
225 url = url.format(
224 url = url.format(
226 pr_segment=self.pr_segment,
225 pr_segment=self.pr_segment,
227 shadow_segment=self.shadow_segment)
226 shadow_segment=self.shadow_segment)
228 match_obj = SimpleVCS.shadow_repo_re.match(url)
227 match_obj = SimpleVCS.shadow_repo_re.match(url)
229 assert (match_obj is not None) == expected
228 assert (match_obj is not None) == expected
230
229
231
230
232 @pytest.mark.backends('git', 'hg')
231 @pytest.mark.backends('git', 'hg')
233 class TestShadowRepoExposure(object):
232 class TestShadowRepoExposure(object):
234
233
235 def test_pull_on_shadow_repo_propagates_to_wsgi_app(
234 def test_pull_on_shadow_repo_propagates_to_wsgi_app(
236 self, baseapp, request_stub):
235 self, baseapp, request_stub):
237 """
236 """
238 Check that a pull action to a shadow repo is propagated to the
237 Check that a pull action to a shadow repo is propagated to the
239 underlying wsgi app.
238 underlying wsgi app.
240 """
239 """
241 controller = StubVCSController(
240 controller = StubVCSController(
242 baseapp.config.get_settings(), request_stub.registry)
241 baseapp.config.get_settings(), request_stub.registry)
243 controller._check_ssl = mock.Mock()
242 controller._check_ssl = mock.Mock()
244 controller.is_shadow_repo = True
243 controller.is_shadow_repo = True
245 controller._action = 'pull'
244 controller._action = 'pull'
246 controller._is_shadow_repo_dir = True
245 controller._is_shadow_repo_dir = True
247 controller.stub_response_body = (b'dummy body value',)
246 controller.stub_response_body = (b'dummy body value',)
248 controller._get_default_cache_ttl = mock.Mock(
247 controller._get_default_cache_ttl = mock.Mock(
249 return_value=(False, 0))
248 return_value=(False, 0))
250
249
251 environ_stub = {
250 environ_stub = {
252 'HTTP_HOST': 'test.example.com',
251 'HTTP_HOST': 'test.example.com',
253 'HTTP_ACCEPT': 'application/mercurial',
252 'HTTP_ACCEPT': 'application/mercurial',
254 'REQUEST_METHOD': 'GET',
253 'REQUEST_METHOD': 'GET',
255 'wsgi.url_scheme': 'http',
254 'wsgi.url_scheme': 'http',
256 }
255 }
257
256
258 response = controller(environ_stub, mock.Mock())
257 response = controller(environ_stub, mock.Mock())
259 response_body = b''.join(response)
258 response_body = b''.join(response)
260
259
261 # Assert that we got the response from the wsgi app.
260 # Assert that we got the response from the wsgi app.
262 assert response_body == b''.join(controller.stub_response_body)
261 assert response_body == b''.join(controller.stub_response_body)
263
262
264 def test_pull_on_shadow_repo_that_is_missing(self, baseapp, request_stub):
263 def test_pull_on_shadow_repo_that_is_missing(self, baseapp, request_stub):
265 """
264 """
266 Check that a pull action to a shadow repo is propagated to the
265 Check that a pull action to a shadow repo is propagated to the
267 underlying wsgi app.
266 underlying wsgi app.
268 """
267 """
269 controller = StubVCSController(
268 controller = StubVCSController(
270 baseapp.config.get_settings(), request_stub.registry)
269 baseapp.config.get_settings(), request_stub.registry)
271 controller._check_ssl = mock.Mock()
270 controller._check_ssl = mock.Mock()
272 controller.is_shadow_repo = True
271 controller.is_shadow_repo = True
273 controller._action = 'pull'
272 controller._action = 'pull'
274 controller._is_shadow_repo_dir = False
273 controller._is_shadow_repo_dir = False
275 controller.stub_response_body = (b'dummy body value',)
274 controller.stub_response_body = (b'dummy body value',)
276 environ_stub = {
275 environ_stub = {
277 'HTTP_HOST': 'test.example.com',
276 'HTTP_HOST': 'test.example.com',
278 'HTTP_ACCEPT': 'application/mercurial',
277 'HTTP_ACCEPT': 'application/mercurial',
279 'REQUEST_METHOD': 'GET',
278 'REQUEST_METHOD': 'GET',
280 'wsgi.url_scheme': 'http',
279 'wsgi.url_scheme': 'http',
281 }
280 }
282
281
283 response = controller(environ_stub, mock.Mock())
282 response = controller(environ_stub, mock.Mock())
284 response_body = b''.join(response)
283 response_body = b''.join(response)
285
284
286 # Assert that we got the response from the wsgi app.
285 # Assert that we got the response from the wsgi app.
287 assert b'404 Not Found' in response_body
286 assert b'404 Not Found' in response_body
288
287
289 def test_push_on_shadow_repo_raises(self, baseapp, request_stub):
288 def test_push_on_shadow_repo_raises(self, baseapp, request_stub):
290 """
289 """
291 Check that a push action to a shadow repo is aborted.
290 Check that a push action to a shadow repo is aborted.
292 """
291 """
293 controller = StubVCSController(
292 controller = StubVCSController(
294 baseapp.config.get_settings(), request_stub.registry)
293 baseapp.config.get_settings(), request_stub.registry)
295 controller._check_ssl = mock.Mock()
294 controller._check_ssl = mock.Mock()
296 controller.is_shadow_repo = True
295 controller.is_shadow_repo = True
297 controller._action = 'push'
296 controller._action = 'push'
298 controller.stub_response_body = (b'dummy body value',)
297 controller.stub_response_body = (b'dummy body value',)
299 environ_stub = {
298 environ_stub = {
300 'HTTP_HOST': 'test.example.com',
299 'HTTP_HOST': 'test.example.com',
301 'HTTP_ACCEPT': 'application/mercurial',
300 'HTTP_ACCEPT': 'application/mercurial',
302 'REQUEST_METHOD': 'GET',
301 'REQUEST_METHOD': 'GET',
303 'wsgi.url_scheme': 'http',
302 'wsgi.url_scheme': 'http',
304 }
303 }
305
304
306 response = controller(environ_stub, mock.Mock())
305 response = controller(environ_stub, mock.Mock())
307 response_body = b''.join(response)
306 response_body = b''.join(response)
308
307
309 assert response_body != controller.stub_response_body
308 assert response_body != controller.stub_response_body
310 # Assert that a 406 error is returned.
309 # Assert that a 406 error is returned.
311 assert b'406 Not Acceptable' in response_body
310 assert b'406 Not Acceptable' in response_body
312
311
313 def test_set_repo_names_no_shadow(self, baseapp, request_stub):
312 def test_set_repo_names_no_shadow(self, baseapp, request_stub):
314 """
313 """
315 Check that the set_repo_names method sets all names to the one returned
314 Check that the set_repo_names method sets all names to the one returned
316 by the _get_repository_name method on a request to a non shadow repo.
315 by the _get_repository_name method on a request to a non shadow repo.
317 """
316 """
318 environ_stub = {}
317 environ_stub = {}
319 controller = StubVCSController(
318 controller = StubVCSController(
320 baseapp.config.get_settings(), request_stub.registry)
319 baseapp.config.get_settings(), request_stub.registry)
321 controller._name = 'RepoGroup/MyRepo'
320 controller._name = 'RepoGroup/MyRepo'
322 controller.set_repo_names(environ_stub)
321 controller.set_repo_names(environ_stub)
323 assert not controller.is_shadow_repo
322 assert not controller.is_shadow_repo
324 assert (controller.url_repo_name ==
323 assert (controller.url_repo_name ==
325 controller.acl_repo_name ==
324 controller.acl_repo_name ==
326 controller.vcs_repo_name ==
325 controller.vcs_repo_name ==
327 controller._get_repository_name(environ_stub))
326 controller._get_repository_name(environ_stub))
328
327
329 def test_set_repo_names_with_shadow(
328 def test_set_repo_names_with_shadow(
330 self, baseapp, pr_util, config_stub, request_stub):
329 self, baseapp, pr_util, config_stub, request_stub):
331 """
330 """
332 Check that the set_repo_names method sets correct names on a request
331 Check that the set_repo_names method sets correct names on a request
333 to a shadow repo.
332 to a shadow repo.
334 """
333 """
335 from rhodecode.model.pull_request import PullRequestModel
334 from rhodecode.model.pull_request import PullRequestModel
336
335
337 pull_request = pr_util.create_pull_request()
336 pull_request = pr_util.create_pull_request()
338 shadow_url = '{target}/{pr_segment}/{pr_id}/{shadow_segment}'.format(
337 shadow_url = '{target}/{pr_segment}/{pr_id}/{shadow_segment}'.format(
339 target=pull_request.target_repo.repo_name,
338 target=pull_request.target_repo.repo_name,
340 pr_id=pull_request.pull_request_id,
339 pr_id=pull_request.pull_request_id,
341 pr_segment=TestShadowRepoRegularExpression.pr_segment,
340 pr_segment=TestShadowRepoRegularExpression.pr_segment,
342 shadow_segment=TestShadowRepoRegularExpression.shadow_segment)
341 shadow_segment=TestShadowRepoRegularExpression.shadow_segment)
343 controller = StubVCSController(
342 controller = StubVCSController(
344 baseapp.config.get_settings(), request_stub.registry)
343 baseapp.config.get_settings(), request_stub.registry)
345 controller._name = shadow_url
344 controller._name = shadow_url
346 controller.set_repo_names({})
345 controller.set_repo_names({})
347
346
348 # Get file system path to shadow repo for assertions.
347 # Get file system path to shadow repo for assertions.
349 workspace_id = PullRequestModel()._workspace_id(pull_request)
348 workspace_id = PullRequestModel()._workspace_id(pull_request)
350 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
349 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
351
350
352 assert controller.vcs_repo_name == vcs_repo_name
351 assert controller.vcs_repo_name == vcs_repo_name
353 assert controller.url_repo_name == shadow_url
352 assert controller.url_repo_name == shadow_url
354 assert controller.acl_repo_name == pull_request.target_repo.repo_name
353 assert controller.acl_repo_name == pull_request.target_repo.repo_name
355 assert controller.is_shadow_repo
354 assert controller.is_shadow_repo
356
355
357 def test_set_repo_names_with_shadow_but_missing_pr(
356 def test_set_repo_names_with_shadow_but_missing_pr(
358 self, baseapp, pr_util, config_stub, request_stub):
357 self, baseapp, pr_util, config_stub, request_stub):
359 """
358 """
360 Checks that the set_repo_names method enforces matching target repos
359 Checks that the set_repo_names method enforces matching target repos
361 and pull request IDs.
360 and pull request IDs.
362 """
361 """
363 pull_request = pr_util.create_pull_request()
362 pull_request = pr_util.create_pull_request()
364 shadow_url = '{target}/{pr_segment}/{pr_id}/{shadow_segment}'.format(
363 shadow_url = '{target}/{pr_segment}/{pr_id}/{shadow_segment}'.format(
365 target=pull_request.target_repo.repo_name,
364 target=pull_request.target_repo.repo_name,
366 pr_id=999999999,
365 pr_id=999999999,
367 pr_segment=TestShadowRepoRegularExpression.pr_segment,
366 pr_segment=TestShadowRepoRegularExpression.pr_segment,
368 shadow_segment=TestShadowRepoRegularExpression.shadow_segment)
367 shadow_segment=TestShadowRepoRegularExpression.shadow_segment)
369 controller = StubVCSController(
368 controller = StubVCSController(
370 baseapp.config.get_settings(), request_stub.registry)
369 baseapp.config.get_settings(), request_stub.registry)
371 controller._name = shadow_url
370 controller._name = shadow_url
372 controller.set_repo_names({})
371 controller.set_repo_names({})
373
372
374 assert not controller.is_shadow_repo
373 assert not controller.is_shadow_repo
375 assert (controller.url_repo_name ==
374 assert (controller.url_repo_name ==
376 controller.acl_repo_name ==
375 controller.acl_repo_name ==
377 controller.vcs_repo_name)
376 controller.vcs_repo_name)
378
377
379
378
380 @pytest.mark.usefixtures('baseapp')
379 @pytest.mark.usefixtures('baseapp')
381 class TestGenerateVcsResponse(object):
380 class TestGenerateVcsResponse(object):
382
381
383 def test_ensures_that_start_response_is_called_early_enough(self):
382 def test_ensures_that_start_response_is_called_early_enough(self):
384 self.call_controller_with_response_body(iter(['a', 'b']))
383 self.call_controller_with_response_body(iter(['a', 'b']))
385 assert self.start_response.called
384 assert self.start_response.called
386
385
387 def test_invalidates_cache_after_body_is_consumed(self):
386 def test_invalidates_cache_after_body_is_consumed(self):
388 result = self.call_controller_with_response_body(iter(['a', 'b']))
387 result = self.call_controller_with_response_body(iter(['a', 'b']))
389 assert not self.was_cache_invalidated()
388 assert not self.was_cache_invalidated()
390 # Consume the result
389 # Consume the result
391 list(result)
390 list(result)
392 assert self.was_cache_invalidated()
391 assert self.was_cache_invalidated()
393
392
394 def test_raises_unknown_exceptions(self):
393 def test_raises_unknown_exceptions(self):
395 result = self.call_controller_with_response_body(
394 result = self.call_controller_with_response_body(
396 self.raise_result_iter(vcs_kind='unknown'))
395 self.raise_result_iter(vcs_kind='unknown'))
397 with pytest.raises(Exception):
396 with pytest.raises(Exception):
398 list(result)
397 list(result)
399
398
400 def test_prepare_callback_daemon_is_called(self):
401 def side_effect(extras, environ, action, txn_id=None):
402 return DummyHooksCallbackDaemon(), extras
403
404 prepare_patcher = mock.patch.object(
405 StubVCSController, '_prepare_callback_daemon')
406 with prepare_patcher as prepare_mock:
407 prepare_mock.side_effect = side_effect
408 self.call_controller_with_response_body(iter(['a', 'b']))
409 assert prepare_mock.called
410 assert prepare_mock.call_count == 1
411
412 def call_controller_with_response_body(self, response_body):
399 def call_controller_with_response_body(self, response_body):
413 settings = {
400 settings = {
414 'base_path': 'fake_base_path',
401 'base_path': 'fake_base_path',
415 'vcs.hooks.protocol': 'http',
402 'vcs.hooks.protocol': 'http',
416 'vcs.hooks.direct_calls': False,
403 'vcs.hooks.direct_calls': False,
417 }
404 }
418 registry = AttributeDict()
405 registry = AttributeDict()
419 controller = StubVCSController(settings, registry)
406 controller = StubVCSController(settings, registry)
420 controller._invalidate_cache = mock.Mock()
407 controller._invalidate_cache = mock.Mock()
421 controller.stub_response_body = response_body
408 controller.stub_response_body = response_body
422 self.start_response = mock.Mock()
409 self.start_response = mock.Mock()
423 result = controller._generate_vcs_response(
410 result = controller._generate_vcs_response(
424 environ={}, start_response=self.start_response,
411 environ={}, start_response=self.start_response,
425 repo_path='fake_repo_path',
412 repo_path='fake_repo_path',
426 extras={}, action='push')
413 extras={}, action='push')
427 self.controller = controller
414 self.controller = controller
428 return result
415 return result
429
416
430 def raise_result_iter(self, vcs_kind='repo_locked'):
417 def raise_result_iter(self, vcs_kind='repo_locked'):
431 """
418 """
432 Simulates an exception due to a vcs raised exception if kind vcs_kind
419 Simulates an exception due to a vcs raised exception if kind vcs_kind
433 """
420 """
434 raise self.vcs_exception(vcs_kind=vcs_kind)
421 raise self.vcs_exception(vcs_kind=vcs_kind)
435 yield "never_reached"
422 yield "never_reached"
436
423
437 def vcs_exception(self, vcs_kind='repo_locked'):
424 def vcs_exception(self, vcs_kind='repo_locked'):
438 locked_exception = Exception('TEST_MESSAGE')
425 locked_exception = Exception('TEST_MESSAGE')
439 locked_exception._vcs_kind = vcs_kind
426 locked_exception._vcs_kind = vcs_kind
440 return locked_exception
427 return locked_exception
441
428
442 def was_cache_invalidated(self):
429 def was_cache_invalidated(self):
443 return self.controller._invalidate_cache.called
430 return self.controller._invalidate_cache.called
444
431
445
432
446 class TestInitializeGenerator(object):
433 class TestInitializeGenerator(object):
447
434
448 def test_drains_first_element(self):
435 def test_drains_first_element(self):
449 gen = self.factory(['__init__', 1, 2])
436 gen = self.factory(['__init__', 1, 2])
450 result = list(gen)
437 result = list(gen)
451 assert result == [1, 2]
438 assert result == [1, 2]
452
439
453 @pytest.mark.parametrize('values', [
440 @pytest.mark.parametrize('values', [
454 [],
441 [],
455 [1, 2],
442 [1, 2],
456 ])
443 ])
457 def test_raises_value_error(self, values):
444 def test_raises_value_error(self, values):
458 with pytest.raises(ValueError):
445 with pytest.raises(ValueError):
459 self.factory(values)
446 self.factory(values)
460
447
461 @simplevcs.initialize_generator
448 @simplevcs.initialize_generator
462 def factory(self, iterable):
449 def factory(self, iterable):
463 for elem in iterable:
450 for elem in iterable:
464 yield elem
451 yield elem
465
466
467 class TestPrepareHooksDaemon(object):
468 def test_calls_imported_prepare_callback_daemon(self, app_settings, request_stub):
469 expected_extras = {'extra1': 'value1'}
470 daemon = DummyHooksCallbackDaemon()
471
472 controller = StubVCSController(app_settings, request_stub.registry)
473 prepare_patcher = mock.patch.object(
474 simplevcs, 'prepare_callback_daemon',
475 return_value=(daemon, expected_extras))
476 with prepare_patcher as prepare_mock:
477 callback_daemon, extras = controller._prepare_callback_daemon(
478 expected_extras.copy(), {}, 'push')
479 prepare_mock.assert_called_once_with(
480 expected_extras,
481 protocol=app_settings['vcs.hooks.protocol'],
482 host=app_settings['vcs.hooks.host'],
483 txn_id=None,
484 use_direct_calls=app_settings['vcs.hooks.direct_calls'])
485
486 assert callback_daemon == daemon
487 assert extras == extras
@@ -1,376 +1,355 b''
1
1
2 # Copyright (C) 2010-2023 RhodeCode GmbH
2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 import logging
20 import logging
21 import io
21 import io
22
22
23 import mock
23 import mock
24 import msgpack
24 import msgpack
25 import pytest
25 import pytest
26 import tempfile
26
27
27 from rhodecode.lib import hooks_daemon
28 from rhodecode.lib import hooks_daemon
28 from rhodecode.lib.str_utils import safe_bytes
29 from rhodecode.lib.str_utils import safe_bytes
29 from rhodecode.tests.utils import assert_message_in_log
30 from rhodecode.tests.utils import assert_message_in_log
30 from rhodecode.lib.ext_json import json
31 from rhodecode.lib.ext_json import json
31
32
32 test_proto = hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO
33 test_proto = hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO
33
34
34
35
35 class TestDummyHooksCallbackDaemon(object):
36 def test_hooks_module_path_set_properly(self):
37 daemon = hooks_daemon.DummyHooksCallbackDaemon()
38 assert daemon.hooks_module == 'rhodecode.lib.hooks_daemon'
39
40 def test_logs_entering_the_hook(self):
41 daemon = hooks_daemon.DummyHooksCallbackDaemon()
42 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
43 with daemon as return_value:
44 log_mock.assert_called_once_with(
45 'Running `%s` callback daemon', 'DummyHooksCallbackDaemon')
46 assert return_value == daemon
47
48 def test_logs_exiting_the_hook(self):
49 daemon = hooks_daemon.DummyHooksCallbackDaemon()
50 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
51 with daemon:
52 pass
53 log_mock.assert_called_with(
54 'Exiting `%s` callback daemon', 'DummyHooksCallbackDaemon')
55
56
57 class TestHooks(object):
36 class TestHooks(object):
58 def test_hooks_can_be_used_as_a_context_processor(self):
37 def test_hooks_can_be_used_as_a_context_processor(self):
59 hooks = hooks_daemon.Hooks()
38 hooks = hooks_daemon.Hooks()
60 with hooks as return_value:
39 with hooks as return_value:
61 pass
40 pass
62 assert hooks == return_value
41 assert hooks == return_value
63
42
64
43
65 class TestHooksHttpHandler(object):
44 class TestHooksHttpHandler(object):
66 def test_read_request_parses_method_name_and_arguments(self):
45 def test_read_request_parses_method_name_and_arguments(self):
67 data = {
46 data = {
68 'method': 'test',
47 'method': 'test',
69 'extras': {
48 'extras': {
70 'param1': 1,
49 'param1': 1,
71 'param2': 'a'
50 'param2': 'a'
72 }
51 }
73 }
52 }
74 request = self._generate_post_request(data)
53 request = self._generate_post_request(data)
75 hooks_patcher = mock.patch.object(
54 hooks_patcher = mock.patch.object(
76 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
55 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
77
56
78 with hooks_patcher as hooks_mock:
57 with hooks_patcher as hooks_mock:
79 handler = hooks_daemon.HooksHttpHandler
58 handler = hooks_daemon.HooksHttpHandler
80 handler.DEFAULT_HOOKS_PROTO = test_proto
59 handler.DEFAULT_HOOKS_PROTO = test_proto
81 handler.wbufsize = 10240
60 handler.wbufsize = 10240
82 MockServer(handler, request)
61 MockServer(handler, request)
83
62
84 hooks_mock.assert_called_once_with(data['extras'])
63 hooks_mock.assert_called_once_with(data['extras'])
85
64
86 def test_hooks_serialized_result_is_returned(self):
65 def test_hooks_serialized_result_is_returned(self):
87 request = self._generate_post_request({})
66 request = self._generate_post_request({})
88 rpc_method = 'test'
67 rpc_method = 'test'
89 hook_result = {
68 hook_result = {
90 'first': 'one',
69 'first': 'one',
91 'second': 2
70 'second': 2
92 }
71 }
93 extras = {}
72 extras = {}
94
73
95 # patching our _read to return test method and proto used
74 # patching our _read to return test method and proto used
96 read_patcher = mock.patch.object(
75 read_patcher = mock.patch.object(
97 hooks_daemon.HooksHttpHandler, '_read_request',
76 hooks_daemon.HooksHttpHandler, '_read_request',
98 return_value=(test_proto, rpc_method, extras))
77 return_value=(test_proto, rpc_method, extras))
99
78
100 # patch Hooks instance to return hook_result data on 'test' call
79 # patch Hooks instance to return hook_result data on 'test' call
101 hooks_patcher = mock.patch.object(
80 hooks_patcher = mock.patch.object(
102 hooks_daemon.Hooks, rpc_method, create=True,
81 hooks_daemon.Hooks, rpc_method, create=True,
103 return_value=hook_result)
82 return_value=hook_result)
104
83
105 with read_patcher, hooks_patcher:
84 with read_patcher, hooks_patcher:
106 handler = hooks_daemon.HooksHttpHandler
85 handler = hooks_daemon.HooksHttpHandler
107 handler.DEFAULT_HOOKS_PROTO = test_proto
86 handler.DEFAULT_HOOKS_PROTO = test_proto
108 handler.wbufsize = 10240
87 handler.wbufsize = 10240
109 server = MockServer(handler, request)
88 server = MockServer(handler, request)
110
89
111 expected_result = hooks_daemon.HooksHttpHandler.serialize_data(hook_result)
90 expected_result = hooks_daemon.HooksHttpHandler.serialize_data(hook_result)
112
91
113 server.request.output_stream.seek(0)
92 server.request.output_stream.seek(0)
114 assert server.request.output_stream.readlines()[-1] == expected_result
93 assert server.request.output_stream.readlines()[-1] == expected_result
115
94
116 def test_exception_is_returned_in_response(self):
95 def test_exception_is_returned_in_response(self):
117 request = self._generate_post_request({})
96 request = self._generate_post_request({})
118 rpc_method = 'test'
97 rpc_method = 'test'
119
98
120 read_patcher = mock.patch.object(
99 read_patcher = mock.patch.object(
121 hooks_daemon.HooksHttpHandler, '_read_request',
100 hooks_daemon.HooksHttpHandler, '_read_request',
122 return_value=(test_proto, rpc_method, {}))
101 return_value=(test_proto, rpc_method, {}))
123
102
124 hooks_patcher = mock.patch.object(
103 hooks_patcher = mock.patch.object(
125 hooks_daemon.Hooks, rpc_method, create=True,
104 hooks_daemon.Hooks, rpc_method, create=True,
126 side_effect=Exception('Test exception'))
105 side_effect=Exception('Test exception'))
127
106
128 with read_patcher, hooks_patcher:
107 with read_patcher, hooks_patcher:
129 handler = hooks_daemon.HooksHttpHandler
108 handler = hooks_daemon.HooksHttpHandler
130 handler.DEFAULT_HOOKS_PROTO = test_proto
109 handler.DEFAULT_HOOKS_PROTO = test_proto
131 handler.wbufsize = 10240
110 handler.wbufsize = 10240
132 server = MockServer(handler, request)
111 server = MockServer(handler, request)
133
112
134 server.request.output_stream.seek(0)
113 server.request.output_stream.seek(0)
135 data = server.request.output_stream.readlines()
114 data = server.request.output_stream.readlines()
136 msgpack_data = b''.join(data[5:])
115 msgpack_data = b''.join(data[5:])
137 org_exc = hooks_daemon.HooksHttpHandler.deserialize_data(msgpack_data)
116 org_exc = hooks_daemon.HooksHttpHandler.deserialize_data(msgpack_data)
138 expected_result = {
117 expected_result = {
139 'exception': 'Exception',
118 'exception': 'Exception',
140 'exception_traceback': org_exc['exception_traceback'],
119 'exception_traceback': org_exc['exception_traceback'],
141 'exception_args': ['Test exception']
120 'exception_args': ['Test exception']
142 }
121 }
143 assert org_exc == expected_result
122 assert org_exc == expected_result
144
123
145 def test_log_message_writes_to_debug_log(self, caplog):
124 def test_log_message_writes_to_debug_log(self, caplog):
146 ip_port = ('0.0.0.0', 8888)
125 ip_port = ('0.0.0.0', 8888)
147 handler = hooks_daemon.HooksHttpHandler(
126 handler = hooks_daemon.HooksHttpHandler(
148 MockRequest('POST /'), ip_port, mock.Mock())
127 MockRequest('POST /'), ip_port, mock.Mock())
149 fake_date = '1/Nov/2015 00:00:00'
128 fake_date = '1/Nov/2015 00:00:00'
150 date_patcher = mock.patch.object(
129 date_patcher = mock.patch.object(
151 handler, 'log_date_time_string', return_value=fake_date)
130 handler, 'log_date_time_string', return_value=fake_date)
152
131
153 with date_patcher, caplog.at_level(logging.DEBUG):
132 with date_patcher, caplog.at_level(logging.DEBUG):
154 handler.log_message('Some message %d, %s', 123, 'string')
133 handler.log_message('Some message %d, %s', 123, 'string')
155
134
156 expected_message = f"HOOKS: client={ip_port} - - [{fake_date}] Some message 123, string"
135 expected_message = f"HOOKS: client={ip_port} - - [{fake_date}] Some message 123, string"
157
136
158 assert_message_in_log(
137 assert_message_in_log(
159 caplog.records, expected_message,
138 caplog.records, expected_message,
160 levelno=logging.DEBUG, module='hooks_daemon')
139 levelno=logging.DEBUG, module='hooks_daemon')
161
140
162 def _generate_post_request(self, data, proto=test_proto):
141 def _generate_post_request(self, data, proto=test_proto):
163 if proto == hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO:
142 if proto == hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO:
164 payload = msgpack.packb(data)
143 payload = msgpack.packb(data)
165 else:
144 else:
166 payload = json.dumps(data)
145 payload = json.dumps(data)
167
146
168 return b'POST / HTTP/1.0\nContent-Length: %d\n\n%b' % (
147 return b'POST / HTTP/1.0\nContent-Length: %d\n\n%b' % (
169 len(payload), payload)
148 len(payload), payload)
170
149
171
150
172 class ThreadedHookCallbackDaemon(object):
151 class ThreadedHookCallbackDaemon(object):
173 def test_constructor_calls_prepare(self):
152 def test_constructor_calls_prepare(self):
174 prepare_daemon_patcher = mock.patch.object(
153 prepare_daemon_patcher = mock.patch.object(
175 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
154 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
176 with prepare_daemon_patcher as prepare_daemon_mock:
155 with prepare_daemon_patcher as prepare_daemon_mock:
177 hooks_daemon.ThreadedHookCallbackDaemon()
156 hooks_daemon.ThreadedHookCallbackDaemon()
178 prepare_daemon_mock.assert_called_once_with()
157 prepare_daemon_mock.assert_called_once_with()
179
158
180 def test_run_is_called_on_context_start(self):
159 def test_run_is_called_on_context_start(self):
181 patchers = mock.patch.multiple(
160 patchers = mock.patch.multiple(
182 hooks_daemon.ThreadedHookCallbackDaemon,
161 hooks_daemon.ThreadedHookCallbackDaemon,
183 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
162 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
184
163
185 with patchers as mocks:
164 with patchers as mocks:
186 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
165 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
187 with daemon as daemon_context:
166 with daemon as daemon_context:
188 pass
167 pass
189 mocks['_run'].assert_called_once_with()
168 mocks['_run'].assert_called_once_with()
190 assert daemon_context == daemon
169 assert daemon_context == daemon
191
170
192 def test_stop_is_called_on_context_exit(self):
171 def test_stop_is_called_on_context_exit(self):
193 patchers = mock.patch.multiple(
172 patchers = mock.patch.multiple(
194 hooks_daemon.ThreadedHookCallbackDaemon,
173 hooks_daemon.ThreadedHookCallbackDaemon,
195 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
174 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
196
175
197 with patchers as mocks:
176 with patchers as mocks:
198 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
177 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
199 with daemon as daemon_context:
178 with daemon as daemon_context:
200 assert mocks['_stop'].call_count == 0
179 assert mocks['_stop'].call_count == 0
201
180
202 mocks['_stop'].assert_called_once_with()
181 mocks['_stop'].assert_called_once_with()
203 assert daemon_context == daemon
182 assert daemon_context == daemon
204
183
205
184
206 class TestHttpHooksCallbackDaemon(object):
185 class TestHttpHooksCallbackDaemon(object):
207 def test_hooks_callback_generates_new_port(self, caplog):
186 def test_hooks_callback_generates_new_port(self, caplog):
208 with caplog.at_level(logging.DEBUG):
187 with caplog.at_level(logging.DEBUG):
209 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
188 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
210 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
189 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
211
190
212 with caplog.at_level(logging.DEBUG):
191 with caplog.at_level(logging.DEBUG):
213 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
192 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
214 assert daemon._daemon.server_address[1] in range(0, 66000)
193 assert daemon._daemon.server_address[1] in range(0, 66000)
215 assert daemon._daemon.server_address[0] != '127.0.0.1'
194 assert daemon._daemon.server_address[0] != '127.0.0.1'
216
195
217 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
196 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
218 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
197 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
219 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
198 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
220 assert daemon._daemon == tcp_server
199 assert daemon._daemon == tcp_server
221
200
222 _, port = tcp_server.server_address
201 _, port = tcp_server.server_address
223
202
224 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
203 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
225 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
204 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
226 assert_message_in_log(
205 assert_message_in_log(
227 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
206 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
228
207
229 def test_prepare_inits_hooks_uri_and_logs_it(
208 def test_prepare_inits_hooks_uri_and_logs_it(
230 self, tcp_server, caplog):
209 self, tcp_server, caplog):
231 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
210 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
232 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
211 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
233
212
234 _, port = tcp_server.server_address
213 _, port = tcp_server.server_address
235 expected_uri = '{}:{}'.format('127.0.0.1', port)
214 expected_uri = '{}:{}'.format('127.0.0.1', port)
236 assert daemon.hooks_uri == expected_uri
215 assert daemon.hooks_uri == expected_uri
237
216
238 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
217 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
239 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
218 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
240 assert_message_in_log(
219 assert_message_in_log(
241 caplog.records, msg,
220 caplog.records, msg,
242 levelno=logging.DEBUG, module='hooks_daemon')
221 levelno=logging.DEBUG, module='hooks_daemon')
243
222
244 def test_run_creates_a_thread(self, tcp_server):
223 def test_run_creates_a_thread(self, tcp_server):
245 thread = mock.Mock()
224 thread = mock.Mock()
246
225
247 with self._tcp_patcher(tcp_server):
226 with self._tcp_patcher(tcp_server):
248 daemon = hooks_daemon.HttpHooksCallbackDaemon()
227 daemon = hooks_daemon.HttpHooksCallbackDaemon()
249
228
250 with self._thread_patcher(thread) as thread_mock:
229 with self._thread_patcher(thread) as thread_mock:
251 daemon._run()
230 daemon._run()
252
231
253 thread_mock.assert_called_once_with(
232 thread_mock.assert_called_once_with(
254 target=tcp_server.serve_forever,
233 target=tcp_server.serve_forever,
255 kwargs={'poll_interval': daemon.POLL_INTERVAL})
234 kwargs={'poll_interval': daemon.POLL_INTERVAL})
256 assert thread.daemon is True
235 assert thread.daemon is True
257 thread.start.assert_called_once_with()
236 thread.start.assert_called_once_with()
258
237
259 def test_run_logs(self, tcp_server, caplog):
238 def test_run_logs(self, tcp_server, caplog):
260
239
261 with self._tcp_patcher(tcp_server):
240 with self._tcp_patcher(tcp_server):
262 daemon = hooks_daemon.HttpHooksCallbackDaemon()
241 daemon = hooks_daemon.HttpHooksCallbackDaemon()
263
242
264 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
243 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
265 daemon._run()
244 daemon._run()
266
245
267 assert_message_in_log(
246 assert_message_in_log(
268 caplog.records,
247 caplog.records,
269 'Running thread-based loop of callback daemon in background',
248 'Running thread-based loop of callback daemon in background',
270 levelno=logging.DEBUG, module='hooks_daemon')
249 levelno=logging.DEBUG, module='hooks_daemon')
271
250
272 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
251 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
273 thread = mock.Mock()
252 thread = mock.Mock()
274
253
275 with self._tcp_patcher(tcp_server):
254 with self._tcp_patcher(tcp_server):
276 daemon = hooks_daemon.HttpHooksCallbackDaemon()
255 daemon = hooks_daemon.HttpHooksCallbackDaemon()
277
256
278 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
257 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
279 with daemon:
258 with daemon:
280 assert daemon._daemon == tcp_server
259 assert daemon._daemon == tcp_server
281 assert daemon._callback_thread == thread
260 assert daemon._callback_thread == thread
282
261
283 assert daemon._daemon is None
262 assert daemon._daemon is None
284 assert daemon._callback_thread is None
263 assert daemon._callback_thread is None
285 tcp_server.shutdown.assert_called_with()
264 tcp_server.shutdown.assert_called_with()
286 thread.join.assert_called_once_with()
265 thread.join.assert_called_once_with()
287
266
288 assert_message_in_log(
267 assert_message_in_log(
289 caplog.records, 'Waiting for background thread to finish.',
268 caplog.records, 'Waiting for background thread to finish.',
290 levelno=logging.DEBUG, module='hooks_daemon')
269 levelno=logging.DEBUG, module='hooks_daemon')
291
270
292 def _tcp_patcher(self, tcp_server):
271 def _tcp_patcher(self, tcp_server):
293 return mock.patch.object(
272 return mock.patch.object(
294 hooks_daemon, 'TCPServer', return_value=tcp_server)
273 hooks_daemon, 'TCPServer', return_value=tcp_server)
295
274
296 def _thread_patcher(self, thread):
275 def _thread_patcher(self, thread):
297 return mock.patch.object(
276 return mock.patch.object(
298 hooks_daemon.threading, 'Thread', return_value=thread)
277 hooks_daemon.threading, 'Thread', return_value=thread)
299
278
300
279
301 class TestPrepareHooksDaemon(object):
280 class TestPrepareHooksDaemon(object):
302 @pytest.mark.parametrize('protocol', ('http',))
281 @pytest.mark.parametrize('protocol', ('celery',))
303 def test_returns_dummy_hooks_callback_daemon_when_using_direct_calls(
282 def test_returns_celery_hooks_callback_daemon_when_celery_protocol_specified(
304 self, protocol):
283 self, protocol):
305 expected_extras = {'extra1': 'value1'}
284 with tempfile.NamedTemporaryFile(mode='w') as temp_file:
306 callback, extras = hooks_daemon.prepare_callback_daemon(
285 temp_file.write("[app:main]\ncelery.broker_url = redis://redis/0\n"
307 expected_extras.copy(), protocol=protocol,
286 "celery.result_backend = redis://redis/0")
308 host='127.0.0.1', use_direct_calls=True)
287 temp_file.flush()
309 assert isinstance(callback, hooks_daemon.DummyHooksCallbackDaemon)
288 expected_extras = {'config': temp_file.name}
310 expected_extras['hooks_module'] = 'rhodecode.lib.hooks_daemon'
289 callback, extras = hooks_daemon.prepare_callback_daemon(
311 expected_extras['time'] = extras['time']
290 expected_extras, protocol=protocol, host='')
312 assert 'extra1' in extras
291 assert isinstance(callback, hooks_daemon.CeleryHooksCallbackDaemon)
313
292
314 @pytest.mark.parametrize('protocol, expected_class', (
293 @pytest.mark.parametrize('protocol, expected_class', (
315 ('http', hooks_daemon.HttpHooksCallbackDaemon),
294 ('http', hooks_daemon.HttpHooksCallbackDaemon),
316 ))
295 ))
317 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
296 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
318 self, protocol, expected_class):
297 self, protocol, expected_class):
319 expected_extras = {
298 expected_extras = {
320 'extra1': 'value1',
299 'extra1': 'value1',
321 'txn_id': 'txnid2',
300 'txn_id': 'txnid2',
322 'hooks_protocol': protocol.lower()
301 'hooks_protocol': protocol.lower(),
302 'task_backend': '',
303 'task_queue': ''
323 }
304 }
324 callback, extras = hooks_daemon.prepare_callback_daemon(
305 callback, extras = hooks_daemon.prepare_callback_daemon(
325 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
306 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
326 use_direct_calls=False,
327 txn_id='txnid2')
307 txn_id='txnid2')
328 assert isinstance(callback, expected_class)
308 assert isinstance(callback, expected_class)
329 extras.pop('hooks_uri')
309 extras.pop('hooks_uri')
330 expected_extras['time'] = extras['time']
310 expected_extras['time'] = extras['time']
331 assert extras == expected_extras
311 assert extras == expected_extras
332
312
333 @pytest.mark.parametrize('protocol', (
313 @pytest.mark.parametrize('protocol', (
334 'invalid',
314 'invalid',
335 'Http',
315 'Http',
336 'HTTP',
316 'HTTP',
337 ))
317 ))
338 def test_raises_on_invalid_protocol(self, protocol):
318 def test_raises_on_invalid_protocol(self, protocol):
339 expected_extras = {
319 expected_extras = {
340 'extra1': 'value1',
320 'extra1': 'value1',
341 'hooks_protocol': protocol.lower()
321 'hooks_protocol': protocol.lower()
342 }
322 }
343 with pytest.raises(Exception):
323 with pytest.raises(Exception):
344 callback, extras = hooks_daemon.prepare_callback_daemon(
324 callback, extras = hooks_daemon.prepare_callback_daemon(
345 expected_extras.copy(),
325 expected_extras.copy(),
346 protocol=protocol, host='127.0.0.1',
326 protocol=protocol, host='127.0.0.1')
347 use_direct_calls=False)
348
327
349
328
350 class MockRequest(object):
329 class MockRequest(object):
351
330
352 def __init__(self, request):
331 def __init__(self, request):
353 self.request = request
332 self.request = request
354 self.input_stream = io.BytesIO(safe_bytes(self.request))
333 self.input_stream = io.BytesIO(safe_bytes(self.request))
355 self.output_stream = io.BytesIO() # make it un-closable for testing invesitagion
334 self.output_stream = io.BytesIO() # make it un-closable for testing invesitagion
356 self.output_stream.close = lambda: None
335 self.output_stream.close = lambda: None
357
336
358 def makefile(self, mode, *args, **kwargs):
337 def makefile(self, mode, *args, **kwargs):
359 return self.output_stream if mode == 'wb' else self.input_stream
338 return self.output_stream if mode == 'wb' else self.input_stream
360
339
361
340
362 class MockServer(object):
341 class MockServer(object):
363
342
364 def __init__(self, handler_cls, request):
343 def __init__(self, handler_cls, request):
365 ip_port = ('0.0.0.0', 8888)
344 ip_port = ('0.0.0.0', 8888)
366 self.request = MockRequest(request)
345 self.request = MockRequest(request)
367 self.server_address = ip_port
346 self.server_address = ip_port
368 self.handler = handler_cls(self.request, ip_port, self)
347 self.handler = handler_cls(self.request, ip_port, self)
369
348
370
349
371 @pytest.fixture()
350 @pytest.fixture()
372 def tcp_server():
351 def tcp_server():
373 server = mock.Mock()
352 server = mock.Mock()
374 server.server_address = ('127.0.0.1', 8881)
353 server.server_address = ('127.0.0.1', 8881)
375 server.wbufsize = 1024
354 server.wbufsize = 1024
376 return server
355 return server
General Comments 0
You need to be logged in to leave comments. Login now