##// END OF EJS Templates
feat(ssh-wrapper-speedup): major rewrite of code to address imports problem with ssh-wrapper-v2...
super-admin -
r5325:359b5cac default
parent child Browse files
Show More
@@ -0,0 +1,198 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
19 import os
20 import tempfile
21 import logging
22
23 from pyramid.settings import asbool
24
25 from rhodecode.config.settings_maker import SettingsMaker
26 from rhodecode.config import utils as config_utils
27
28 log = logging.getLogger(__name__)
29
30
31 def sanitize_settings_and_apply_defaults(global_config, settings):
32 """
33 Applies settings defaults and does all type conversion.
34
35 We would move all settings parsing and preparation into this place, so that
36 we have only one place left which deals with this part. The remaining parts
37 of the application would start to rely fully on well-prepared settings.
38
39 This piece would later be split up per topic to avoid a big fat monster
40 function.
41 """
42 jn = os.path.join
43
44 global_settings_maker = SettingsMaker(global_config)
45 global_settings_maker.make_setting('debug', default=False, parser='bool')
46 debug_enabled = asbool(global_config.get('debug'))
47
48 settings_maker = SettingsMaker(settings)
49
50 settings_maker.make_setting(
51 'logging.autoconfigure',
52 default=False,
53 parser='bool')
54
55 logging_conf = jn(os.path.dirname(global_config.get('__file__')), 'logging.ini')
56 settings_maker.enable_logging(logging_conf, level='INFO' if debug_enabled else 'DEBUG')
57
58 # Default includes, possible to change as a user
59 pyramid_includes = settings_maker.make_setting('pyramid.includes', [], parser='list:newline')
60 log.debug(
61 "Using the following pyramid.includes: %s",
62 pyramid_includes)
63
64 settings_maker.make_setting('rhodecode.edition', 'Community Edition')
65 settings_maker.make_setting('rhodecode.edition_id', 'CE')
66
67 if 'mako.default_filters' not in settings:
68 # set custom default filters if we don't have it defined
69 settings['mako.imports'] = 'from rhodecode.lib.base import h_filter'
70 settings['mako.default_filters'] = 'h_filter'
71
72 if 'mako.directories' not in settings:
73 mako_directories = settings.setdefault('mako.directories', [
74 # Base templates of the original application
75 'rhodecode:templates',
76 ])
77 log.debug(
78 "Using the following Mako template directories: %s",
79 mako_directories)
80
81 # NOTE(marcink): fix redis requirement for schema of connection since 3.X
82 if 'beaker.session.type' in settings and settings['beaker.session.type'] == 'ext:redis':
83 raw_url = settings['beaker.session.url']
84 if not raw_url.startswith(('redis://', 'rediss://', 'unix://')):
85 settings['beaker.session.url'] = 'redis://' + raw_url
86
87 settings_maker.make_setting('__file__', global_config.get('__file__'))
88
89 # TODO: johbo: Re-think this, usually the call to config.include
90 # should allow to pass in a prefix.
91 settings_maker.make_setting('rhodecode.api.url', '/_admin/api')
92
93 # Sanitize generic settings.
94 settings_maker.make_setting('default_encoding', 'UTF-8', parser='list')
95 settings_maker.make_setting('is_test', False, parser='bool')
96 settings_maker.make_setting('gzip_responses', False, parser='bool')
97
98 # statsd
99 settings_maker.make_setting('statsd.enabled', False, parser='bool')
100 settings_maker.make_setting('statsd.statsd_host', 'statsd-exporter', parser='string')
101 settings_maker.make_setting('statsd.statsd_port', 9125, parser='int')
102 settings_maker.make_setting('statsd.statsd_prefix', '')
103 settings_maker.make_setting('statsd.statsd_ipv6', False, parser='bool')
104
105 settings_maker.make_setting('vcs.svn.compatible_version', '')
106 settings_maker.make_setting('vcs.hooks.protocol', 'http')
107 settings_maker.make_setting('vcs.hooks.host', '*')
108 settings_maker.make_setting('vcs.scm_app_implementation', 'http')
109 settings_maker.make_setting('vcs.server', '')
110 settings_maker.make_setting('vcs.server.protocol', 'http')
111 settings_maker.make_setting('vcs.server.enable', 'true', parser='bool')
112 settings_maker.make_setting('startup.import_repos', 'false', parser='bool')
113 settings_maker.make_setting('vcs.hooks.direct_calls', 'false', parser='bool')
114 settings_maker.make_setting('vcs.start_server', 'false', parser='bool')
115 settings_maker.make_setting('vcs.backends', 'hg, git, svn', parser='list')
116 settings_maker.make_setting('vcs.connection_timeout', 3600, parser='int')
117
118 settings_maker.make_setting('vcs.methods.cache', True, parser='bool')
119
120 # Support legacy values of vcs.scm_app_implementation. Legacy
121 # configurations may use 'rhodecode.lib.middleware.utils.scm_app_http', or
122 # disabled since 4.13 'vcsserver.scm_app' which is now mapped to 'http'.
123 scm_app_impl = settings['vcs.scm_app_implementation']
124 if scm_app_impl in ['rhodecode.lib.middleware.utils.scm_app_http', 'vcsserver.scm_app']:
125 settings['vcs.scm_app_implementation'] = 'http'
126
127 settings_maker.make_setting('appenlight', False, parser='bool')
128
129 temp_store = tempfile.gettempdir()
130 tmp_cache_dir = jn(temp_store, 'rc_cache')
131
132 # save default, cache dir, and use it for all backends later.
133 default_cache_dir = settings_maker.make_setting(
134 'cache_dir',
135 default=tmp_cache_dir, default_when_empty=True,
136 parser='dir:ensured')
137
138 # exception store cache
139 settings_maker.make_setting(
140 'exception_tracker.store_path',
141 default=jn(default_cache_dir, 'exc_store'), default_when_empty=True,
142 parser='dir:ensured'
143 )
144
145 settings_maker.make_setting(
146 'celerybeat-schedule.path',
147 default=jn(default_cache_dir, 'celerybeat_schedule', 'celerybeat-schedule.db'), default_when_empty=True,
148 parser='file:ensured'
149 )
150
151 settings_maker.make_setting('exception_tracker.send_email', False, parser='bool')
152 settings_maker.make_setting('exception_tracker.email_prefix', '[RHODECODE ERROR]', default_when_empty=True)
153
154 # sessions, ensure file since no-value is memory
155 settings_maker.make_setting('beaker.session.type', 'file')
156 settings_maker.make_setting('beaker.session.data_dir', jn(default_cache_dir, 'session_data'))
157
158 # cache_general
159 settings_maker.make_setting('rc_cache.cache_general.backend', 'dogpile.cache.rc.file_namespace')
160 settings_maker.make_setting('rc_cache.cache_general.expiration_time', 60 * 60 * 12, parser='int')
161 settings_maker.make_setting('rc_cache.cache_general.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_general.db'))
162
163 # cache_perms
164 settings_maker.make_setting('rc_cache.cache_perms.backend', 'dogpile.cache.rc.file_namespace')
165 settings_maker.make_setting('rc_cache.cache_perms.expiration_time', 60 * 60, parser='int')
166 settings_maker.make_setting('rc_cache.cache_perms.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_perms_db'))
167
168 # cache_repo
169 settings_maker.make_setting('rc_cache.cache_repo.backend', 'dogpile.cache.rc.file_namespace')
170 settings_maker.make_setting('rc_cache.cache_repo.expiration_time', 60 * 60 * 24 * 30, parser='int')
171 settings_maker.make_setting('rc_cache.cache_repo.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_repo_db'))
172
173 # cache_license
174 settings_maker.make_setting('rc_cache.cache_license.backend', 'dogpile.cache.rc.file_namespace')
175 settings_maker.make_setting('rc_cache.cache_license.expiration_time', 60 * 5, parser='int')
176 settings_maker.make_setting('rc_cache.cache_license.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_license_db'))
177
178 # cache_repo_longterm memory, 96H
179 settings_maker.make_setting('rc_cache.cache_repo_longterm.backend', 'dogpile.cache.rc.memory_lru')
180 settings_maker.make_setting('rc_cache.cache_repo_longterm.expiration_time', 345600, parser='int')
181 settings_maker.make_setting('rc_cache.cache_repo_longterm.max_size', 10000, parser='int')
182
183 # sql_cache_short
184 settings_maker.make_setting('rc_cache.sql_cache_short.backend', 'dogpile.cache.rc.memory_lru')
185 settings_maker.make_setting('rc_cache.sql_cache_short.expiration_time', 30, parser='int')
186 settings_maker.make_setting('rc_cache.sql_cache_short.max_size', 10000, parser='int')
187
188 # archive_cache
189 settings_maker.make_setting('archive_cache.store_dir', jn(default_cache_dir, 'archive_cache'), default_when_empty=True,)
190 settings_maker.make_setting('archive_cache.cache_size_gb', 10, parser='float')
191 settings_maker.make_setting('archive_cache.cache_shards', 10, parser='int')
192
193 settings_maker.env_expand()
194
195 # configure instance id
196 config_utils.set_instance_id(settings)
197
198 return settings
@@ -0,0 +1,38 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
19 import urllib.parse
20
21 from rhodecode.lib.vcs import CurlSession
22 from rhodecode.lib.ext_json import json
23
24
25 def call_service_api(service_api_host, service_api_token, api_url, payload):
26
27 payload.update({
28 'id': 'service',
29 'auth_token': service_api_token
30 })
31
32 service_api_url = urllib.parse.urljoin(service_api_host, api_url)
33 response = CurlSession().post(service_api_url, json.dumps(payload))
34
35 if response.status_code != 200:
36 raise Exception(f"Service API at {service_api_url} responded with error: {response.status_code}")
37
38 return json.loads(response.content)['result']
@@ -0,0 +1,40 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 import os
19
20
21 def get_config(ini_path, **kwargs):
22 import configparser
23 parser = configparser.ConfigParser(**kwargs)
24 parser.read(ini_path)
25 return parser
26
27
28 def get_app_config_lightweight(ini_path):
29 parser = get_config(ini_path)
30 parser.set('app:main', 'here', os.getcwd())
31 parser.set('app:main', '__file__', ini_path)
32 return dict(parser.items('app:main'))
33
34
35 def get_app_config(ini_path):
36 """
37 This loads the app context and provides a heavy type iniliaziation of config
38 """
39 from paste.deploy.loadwsgi import appconfig
40 return appconfig(f'config:{ini_path}', relative_to=os.getcwd())
@@ -0,0 +1,17 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
@@ -0,0 +1,115 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 import os
19 import time
20 import logging
21 import tempfile
22
23 from rhodecode.lib.config_utils import get_config
24 from rhodecode.lib.ext_json import json
25
26 log = logging.getLogger(__name__)
27
28
29 class BaseHooksCallbackDaemon:
30 """
31 Basic context manager for actions that don't require some extra
32 """
33 def __init__(self):
34 pass
35
36 def __enter__(self):
37 log.debug('Running `%s` callback daemon', self.__class__.__name__)
38 return self
39
40 def __exit__(self, exc_type, exc_val, exc_tb):
41 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
42
43
44 class HooksModuleCallbackDaemon(BaseHooksCallbackDaemon):
45
46 def __init__(self, module):
47 super().__init__()
48 self.hooks_module = module
49
50
51 def get_txn_id_data_path(txn_id):
52 import rhodecode
53
54 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
55 final_dir = os.path.join(root, 'svn_txn_id')
56
57 if not os.path.isdir(final_dir):
58 os.makedirs(final_dir)
59 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
60
61
62 def store_txn_id_data(txn_id, data_dict):
63 if not txn_id:
64 log.warning('Cannot store txn_id because it is empty')
65 return
66
67 path = get_txn_id_data_path(txn_id)
68 try:
69 with open(path, 'wb') as f:
70 f.write(json.dumps(data_dict))
71 except Exception:
72 log.exception('Failed to write txn_id metadata')
73
74
75 def get_txn_id_from_store(txn_id):
76 """
77 Reads txn_id from store and if present returns the data for callback manager
78 """
79 path = get_txn_id_data_path(txn_id)
80 try:
81 with open(path, 'rb') as f:
82 return json.loads(f.read())
83 except Exception:
84 return {}
85
86
87 def prepare_callback_daemon(extras, protocol, host, txn_id=None):
88 txn_details = get_txn_id_from_store(txn_id)
89 port = txn_details.get('port', 0)
90 match protocol:
91 case 'http':
92 from rhodecode.lib.hook_daemon.http_hooks_deamon import HttpHooksCallbackDaemon
93 callback_daemon = HttpHooksCallbackDaemon(
94 txn_id=txn_id, host=host, port=port)
95 case 'celery':
96 from rhodecode.lib.hook_daemon.celery_hooks_deamon import CeleryHooksCallbackDaemon
97 callback_daemon = CeleryHooksCallbackDaemon(get_config(extras['config']))
98 case 'local':
99 from rhodecode.lib.hook_daemon.hook_module import Hooks
100 callback_daemon = HooksModuleCallbackDaemon(Hooks.__module__)
101 case _:
102 log.error('Unsupported callback daemon protocol "%s"', protocol)
103 raise Exception('Unsupported callback daemon protocol.')
104
105 extras['hooks_uri'] = getattr(callback_daemon, 'hooks_uri', '')
106 extras['task_queue'] = getattr(callback_daemon, 'task_queue', '')
107 extras['task_backend'] = getattr(callback_daemon, 'task_backend', '')
108 extras['hooks_protocol'] = protocol
109 extras['time'] = time.time()
110
111 # register txn_id
112 extras['txn_id'] = txn_id
113 log.debug('Prepared a callback daemon: %s',
114 callback_daemon.__class__.__name__)
115 return callback_daemon, extras
@@ -0,0 +1,30 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
19 from rhodecode.lib.hook_daemon.base import BaseHooksCallbackDaemon
20
21
22 class CeleryHooksCallbackDaemon(BaseHooksCallbackDaemon):
23 """
24 Context manger for achieving a compatibility with celery backend
25 """
26
27 def __init__(self, config):
28 # TODO: replace this with settings bootstrapped...
29 self.task_queue = config.get('app:main', 'celery.broker_url')
30 self.task_backend = config.get('app:main', 'celery.result_backend')
@@ -0,0 +1,104 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
19 import logging
20 import traceback
21
22 from rhodecode.model import meta
23
24 from rhodecode.lib import hooks_base
25 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
26 from rhodecode.lib.utils2 import AttributeDict
27
28 log = logging.getLogger(__name__)
29
30
31 class Hooks(object):
32 """
33 Exposes the hooks for remote callbacks
34 """
35 def __init__(self, request=None, log_prefix=''):
36 self.log_prefix = log_prefix
37 self.request = request
38
39 def repo_size(self, extras):
40 log.debug("%sCalled repo_size of %s object", self.log_prefix, self)
41 return self._call_hook(hooks_base.repo_size, extras)
42
43 def pre_pull(self, extras):
44 log.debug("%sCalled pre_pull of %s object", self.log_prefix, self)
45 return self._call_hook(hooks_base.pre_pull, extras)
46
47 def post_pull(self, extras):
48 log.debug("%sCalled post_pull of %s object", self.log_prefix, self)
49 return self._call_hook(hooks_base.post_pull, extras)
50
51 def pre_push(self, extras):
52 log.debug("%sCalled pre_push of %s object", self.log_prefix, self)
53 return self._call_hook(hooks_base.pre_push, extras)
54
55 def post_push(self, extras):
56 log.debug("%sCalled post_push of %s object", self.log_prefix, self)
57 return self._call_hook(hooks_base.post_push, extras)
58
59 def _call_hook(self, hook, extras):
60 extras = AttributeDict(extras)
61 _server_url = extras['server_url']
62
63 extras.request = self.request
64
65 try:
66 result = hook(extras)
67 if result is None:
68 raise Exception(f'Failed to obtain hook result from func: {hook}')
69 except HTTPBranchProtected as handled_error:
70 # Those special cases don't need error reporting. It's a case of
71 # locked repo or protected branch
72 result = AttributeDict({
73 'status': handled_error.code,
74 'output': handled_error.explanation
75 })
76 except (HTTPLockedRC, Exception) as error:
77 # locked needs different handling since we need to also
78 # handle PULL operations
79 exc_tb = ''
80 if not isinstance(error, HTTPLockedRC):
81 exc_tb = traceback.format_exc()
82 log.exception('%sException when handling hook %s', self.log_prefix, hook)
83 error_args = error.args
84 return {
85 'status': 128,
86 'output': '',
87 'exception': type(error).__name__,
88 'exception_traceback': exc_tb,
89 'exception_args': error_args,
90 }
91 finally:
92 meta.Session.remove()
93
94 log.debug('%sGot hook call response %s', self.log_prefix, result)
95 return {
96 'status': result.status,
97 'output': result.output,
98 }
99
100 def __enter__(self):
101 return self
102
103 def __exit__(self, exc_type, exc_val, exc_tb):
104 pass
@@ -0,0 +1,280 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
6 #
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
11 #
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
19 import os
20 import logging
21 import traceback
22 import threading
23 import socket
24 import msgpack
25 import gevent
26
27 from http.server import BaseHTTPRequestHandler
28 from socketserver import TCPServer
29
30 from rhodecode.model import meta
31 from rhodecode.lib.ext_json import json
32 from rhodecode.lib import rc_cache
33 from rhodecode.lib.hook_daemon.base import get_txn_id_data_path
34 from rhodecode.lib.hook_daemon.hook_module import Hooks
35
36 log = logging.getLogger(__name__)
37
38
39 class HooksHttpHandler(BaseHTTPRequestHandler):
40
41 JSON_HOOKS_PROTO = 'json.v1'
42 MSGPACK_HOOKS_PROTO = 'msgpack.v1'
43 # starting with RhodeCode 5.0.0 MsgPack is the default, prior it used json
44 DEFAULT_HOOKS_PROTO = MSGPACK_HOOKS_PROTO
45
46 @classmethod
47 def serialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
48 if proto == cls.MSGPACK_HOOKS_PROTO:
49 return msgpack.packb(data)
50 return json.dumps(data)
51
52 @classmethod
53 def deserialize_data(cls, data, proto=DEFAULT_HOOKS_PROTO):
54 if proto == cls.MSGPACK_HOOKS_PROTO:
55 return msgpack.unpackb(data)
56 return json.loads(data)
57
58 def do_POST(self):
59 hooks_proto, method, extras = self._read_request()
60 log.debug('Handling HooksHttpHandler %s with %s proto', method, hooks_proto)
61
62 txn_id = getattr(self.server, 'txn_id', None)
63 if txn_id:
64 log.debug('Computing TXN_ID based on `%s`:`%s`',
65 extras['repository'], extras['txn_id'])
66 computed_txn_id = rc_cache.utils.compute_key_from_params(
67 extras['repository'], extras['txn_id'])
68 if txn_id != computed_txn_id:
69 raise Exception(
70 'TXN ID fail: expected {} got {} instead'.format(
71 txn_id, computed_txn_id))
72
73 request = getattr(self.server, 'request', None)
74 try:
75 hooks = Hooks(request=request, log_prefix='HOOKS: {} '.format(self.server.server_address))
76 result = self._call_hook_method(hooks, method, extras)
77
78 except Exception as e:
79 exc_tb = traceback.format_exc()
80 result = {
81 'exception': e.__class__.__name__,
82 'exception_traceback': exc_tb,
83 'exception_args': e.args
84 }
85 self._write_response(hooks_proto, result)
86
87 def _read_request(self):
88 length = int(self.headers['Content-Length'])
89 # respect sent headers, fallback to OLD proto for compatability
90 hooks_proto = self.headers.get('rc-hooks-protocol') or self.JSON_HOOKS_PROTO
91 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
92 # support for new vcsserver msgpack based protocol hooks
93 body = self.rfile.read(length)
94 data = self.deserialize_data(body)
95 else:
96 body = self.rfile.read(length)
97 data = self.deserialize_data(body)
98
99 return hooks_proto, data['method'], data['extras']
100
101 def _write_response(self, hooks_proto, result):
102 self.send_response(200)
103 if hooks_proto == self.MSGPACK_HOOKS_PROTO:
104 self.send_header("Content-type", "application/msgpack")
105 self.end_headers()
106 data = self.serialize_data(result)
107 self.wfile.write(data)
108 else:
109 self.send_header("Content-type", "text/json")
110 self.end_headers()
111 data = self.serialize_data(result)
112 self.wfile.write(data)
113
114 def _call_hook_method(self, hooks, method, extras):
115 try:
116 result = getattr(hooks, method)(extras)
117 finally:
118 meta.Session.remove()
119 return result
120
121 def log_message(self, format, *args):
122 """
123 This is an overridden method of BaseHTTPRequestHandler which logs using
124 a logging library instead of writing directly to stderr.
125 """
126
127 message = format % args
128
129 log.debug(
130 "HOOKS: client=%s - - [%s] %s", self.client_address,
131 self.log_date_time_string(), message)
132
133
134 class ThreadedHookCallbackDaemon(object):
135
136 _callback_thread = None
137 _daemon = None
138 _done = False
139 use_gevent = False
140
141 def __init__(self, txn_id=None, host=None, port=None):
142 self._prepare(txn_id=txn_id, host=host, port=port)
143 if self.use_gevent:
144 self._run_func = self._run_gevent
145 self._stop_func = self._stop_gevent
146 else:
147 self._run_func = self._run
148 self._stop_func = self._stop
149
150 def __enter__(self):
151 log.debug('Running `%s` callback daemon', self.__class__.__name__)
152 self._run_func()
153 return self
154
155 def __exit__(self, exc_type, exc_val, exc_tb):
156 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
157 self._stop_func()
158
159 def _prepare(self, txn_id=None, host=None, port=None):
160 raise NotImplementedError()
161
162 def _run(self):
163 raise NotImplementedError()
164
165 def _stop(self):
166 raise NotImplementedError()
167
168 def _run_gevent(self):
169 raise NotImplementedError()
170
171 def _stop_gevent(self):
172 raise NotImplementedError()
173
174
175 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
176 """
177 Context manager which will run a callback daemon in a background thread.
178 """
179
180 hooks_uri = None
181
182 # From Python docs: Polling reduces our responsiveness to a shutdown
183 # request and wastes cpu at all other times.
184 POLL_INTERVAL = 0.01
185
186 use_gevent = False
187
188 @property
189 def _hook_prefix(self):
190 return 'HOOKS: {} '.format(self.hooks_uri)
191
192 def get_hostname(self):
193 return socket.gethostname() or '127.0.0.1'
194
195 def get_available_port(self, min_port=20000, max_port=65535):
196 from rhodecode.lib.utils2 import get_available_port as _get_port
197 return _get_port(min_port, max_port)
198
199 def _prepare(self, txn_id=None, host=None, port=None):
200 from pyramid.threadlocal import get_current_request
201
202 if not host or host == "*":
203 host = self.get_hostname()
204 if not port:
205 port = self.get_available_port()
206
207 server_address = (host, port)
208 self.hooks_uri = '{}:{}'.format(host, port)
209 self.txn_id = txn_id
210 self._done = False
211
212 log.debug(
213 "%s Preparing HTTP callback daemon registering hook object: %s",
214 self._hook_prefix, HooksHttpHandler)
215
216 self._daemon = TCPServer(server_address, HooksHttpHandler)
217 # inject transaction_id for later verification
218 self._daemon.txn_id = self.txn_id
219
220 # pass the WEB app request into daemon
221 self._daemon.request = get_current_request()
222
223 def _run(self):
224 log.debug("Running thread-based loop of callback daemon in background")
225 callback_thread = threading.Thread(
226 target=self._daemon.serve_forever,
227 kwargs={'poll_interval': self.POLL_INTERVAL})
228 callback_thread.daemon = True
229 callback_thread.start()
230 self._callback_thread = callback_thread
231
232 def _run_gevent(self):
233 log.debug("Running gevent-based loop of callback daemon in background")
234 # create a new greenlet for the daemon's serve_forever method
235 callback_greenlet = gevent.spawn(
236 self._daemon.serve_forever,
237 poll_interval=self.POLL_INTERVAL)
238
239 # store reference to greenlet
240 self._callback_greenlet = callback_greenlet
241
242 # switch to this greenlet
243 gevent.sleep(0.01)
244
245 def _stop(self):
246 log.debug("Waiting for background thread to finish.")
247 self._daemon.shutdown()
248 self._callback_thread.join()
249 self._daemon = None
250 self._callback_thread = None
251 if self.txn_id:
252 txn_id_file = get_txn_id_data_path(self.txn_id)
253 log.debug('Cleaning up TXN ID %s', txn_id_file)
254 if os.path.isfile(txn_id_file):
255 os.remove(txn_id_file)
256
257 log.debug("Background thread done.")
258
259 def _stop_gevent(self):
260 log.debug("Waiting for background greenlet to finish.")
261
262 # if greenlet exists and is running
263 if self._callback_greenlet and not self._callback_greenlet.dead:
264 # shutdown daemon if it exists
265 if self._daemon:
266 self._daemon.shutdown()
267
268 # kill the greenlet
269 self._callback_greenlet.kill()
270
271 self._daemon = None
272 self._callback_greenlet = None
273
274 if self.txn_id:
275 txn_id_file = get_txn_id_data_path(self.txn_id)
276 log.debug('Cleaning up TXN ID %s', txn_id_file)
277 if os.path.isfile(txn_id_file):
278 os.remove(txn_id_file)
279
280 log.debug("Background greenlet done.")
@@ -1,392 +1,417 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import re
21 21 import logging
22 22 import datetime
23 import configparser
24 23 from sqlalchemy import Table
25 24
26 from rhodecode.lib.utils import call_service_api
25 from rhodecode.lib.api_utils import call_service_api
27 26 from rhodecode.lib.utils2 import AttributeDict
28 from rhodecode.model.scm import ScmModel
27 from rhodecode.lib.vcs.exceptions import ImproperlyConfiguredError
29 28
30 29 from .hg import MercurialServer
31 30 from .git import GitServer
32 31 from .svn import SubversionServer
33 32 log = logging.getLogger(__name__)
34 33
35 34
36 35 class SshWrapper(object):
37 36 hg_cmd_pat = re.compile(r'^hg\s+\-R\s+(\S+)\s+serve\s+\-\-stdio$')
38 37 git_cmd_pat = re.compile(r'^git-(receive-pack|upload-pack)\s\'[/]?(\S+?)(|\.git)\'$')
39 38 svn_cmd_pat = re.compile(r'^svnserve -t')
40 39
41 40 def __init__(self, command, connection_info, mode,
42 user, user_id, key_id: int, shell, ini_path: str, env):
41 user, user_id, key_id: int, shell, ini_path: str, settings, env):
43 42 self.command = command
44 43 self.connection_info = connection_info
45 44 self.mode = mode
46 45 self.username = user
47 46 self.user_id = user_id
48 47 self.key_id = key_id
49 48 self.shell = shell
50 49 self.ini_path = ini_path
51 50 self.env = env
52
53 self.config = self.parse_config(ini_path)
51 self.settings = settings
54 52 self.server_impl = None
55 53
56 def parse_config(self, config_path):
57 parser = configparser.ConfigParser()
58 parser.read(config_path)
59 return parser
60
61 54 def update_key_access_time(self, key_id):
62 55 from rhodecode.model.meta import raw_query_executor, Base
63 56
64 57 table = Table('user_ssh_keys', Base.metadata, autoload=False)
65 58 atime = datetime.datetime.utcnow()
66 59 stmt = (
67 60 table.update()
68 61 .where(table.c.ssh_key_id == key_id)
69 62 .values(accessed_on=atime)
70 63 # no MySQL Support for .returning :((
71 64 #.returning(table.c.accessed_on, table.c.ssh_key_fingerprint)
72 65 )
73 66
74 67 res_count = None
75 68 with raw_query_executor() as session:
76 69 result = session.execute(stmt)
77 70 if result.rowcount:
78 71 res_count = result.rowcount
79 72
80 73 if res_count:
81 74 log.debug('Update key id:`%s` access time', key_id)
82 75
83 76 def get_user(self, user_id):
84 77 user = AttributeDict()
85 78 # lazy load db imports
86 79 from rhodecode.model.db import User
87 80 dbuser = User.get(user_id)
88 81 if not dbuser:
89 82 return None
90 83 user.user_id = dbuser.user_id
91 84 user.username = dbuser.username
92 85 user.auth_user = dbuser.AuthUser()
93 86 return user
94 87
95 88 def get_connection_info(self):
96 89 """
97 90 connection_info
98 91
99 92 Identifies the client and server ends of the connection.
100 93 The variable contains four space-separated values: client IP address,
101 94 client port number, server IP address, and server port number.
102 95 """
103 96 conn = dict(
104 97 client_ip=None,
105 98 client_port=None,
106 99 server_ip=None,
107 100 server_port=None,
108 101 )
109 102
110 103 info = self.connection_info.split(' ')
111 104 if len(info) == 4:
112 105 conn['client_ip'] = info[0]
113 106 conn['client_port'] = info[1]
114 107 conn['server_ip'] = info[2]
115 108 conn['server_port'] = info[3]
116 109
117 110 return conn
118 111
119 112 def maybe_translate_repo_uid(self, repo_name):
120 113 _org_name = repo_name
121 114 if _org_name.startswith('_'):
122 115 # remove format of _ID/subrepo
123 116 _org_name = _org_name.split('/', 1)[0]
124 117
125 118 if repo_name.startswith('_'):
126 119 from rhodecode.model.repo import RepoModel
127 120 org_repo_name = repo_name
128 121 log.debug('translating UID repo %s', org_repo_name)
129 122 by_id_match = RepoModel().get_repo_by_id(repo_name)
130 123 if by_id_match:
131 124 repo_name = by_id_match.repo_name
132 125 log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
133 126
134 127 return repo_name, _org_name
135 128
136 129 def get_repo_details(self, mode):
137 130 vcs_type = mode if mode in ['svn', 'hg', 'git'] else None
138 131 repo_name = None
139 132
140 133 hg_match = self.hg_cmd_pat.match(self.command)
141 134 if hg_match is not None:
142 135 vcs_type = 'hg'
143 136 repo_id = hg_match.group(1).strip('/')
144 137 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
145 138 return vcs_type, repo_name, mode
146 139
147 140 git_match = self.git_cmd_pat.match(self.command)
148 141 if git_match is not None:
149 142 mode = git_match.group(1)
150 143 vcs_type = 'git'
151 144 repo_id = git_match.group(2).strip('/')
152 145 repo_name, org_name = self.maybe_translate_repo_uid(repo_id)
153 146 return vcs_type, repo_name, mode
154 147
155 148 svn_match = self.svn_cmd_pat.match(self.command)
156 149 if svn_match is not None:
157 150 vcs_type = 'svn'
158 151 # Repo name should be extracted from the input stream, we're unable to
159 152 # extract it at this point in execution
160 153 return vcs_type, repo_name, mode
161 154
162 155 return vcs_type, repo_name, mode
163 156
164 157 def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
158 # TODO: remove this once we have .ini defined access path...
159 from rhodecode.model.scm import ScmModel
160
165 161 store = ScmModel().repos_path
166 162
167 163 check_branch_perms = False
168 164 detect_force_push = False
169 165
170 166 if branch_permissions:
171 167 check_branch_perms = True
172 168 detect_force_push = True
173 169
174 170 log.debug(
175 171 'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
176 172 vcs, mode, repo, check_branch_perms)
177 173
178 174 # detect if we have to check branch permissions
179 175 extras = {
180 176 'detect_force_push': detect_force_push,
181 177 'check_branch_perms': check_branch_perms,
182 178 'config': self.ini_path
183 179 }
184 180
185 181 if vcs == 'hg':
186 182 server = MercurialServer(
187 183 store=store, ini_path=self.ini_path,
188 184 repo_name=repo, user=user,
189 user_permissions=permissions, config=self.config, env=self.env)
185 user_permissions=permissions, settings=self.settings, env=self.env)
190 186 self.server_impl = server
191 187 return server.run(tunnel_extras=extras)
192 188
193 189 elif vcs == 'git':
194 190 server = GitServer(
195 191 store=store, ini_path=self.ini_path,
196 192 repo_name=repo, repo_mode=mode, user=user,
197 user_permissions=permissions, config=self.config, env=self.env)
193 user_permissions=permissions, settings=self.settings, env=self.env)
198 194 self.server_impl = server
199 195 return server.run(tunnel_extras=extras)
200 196
201 197 elif vcs == 'svn':
202 198 server = SubversionServer(
203 199 store=store, ini_path=self.ini_path,
204 200 repo_name=None, user=user,
205 user_permissions=permissions, config=self.config, env=self.env)
201 user_permissions=permissions, settings=self.settings, env=self.env)
206 202 self.server_impl = server
207 203 return server.run(tunnel_extras=extras)
208 204
209 205 else:
210 206 raise Exception(f'Unrecognised VCS: {vcs}')
211 207
212 208 def wrap(self):
213 209 mode = self.mode
214 210 username = self.username
215 211 user_id = self.user_id
216 212 key_id = self.key_id
217 213 shell = self.shell
218 214
219 215 scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
220 216
221 217 log.debug(
222 218 'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
223 219 'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
224 220 mode, username, user_id, shell, self.command,
225 221 scm_detected, scm_mode, scm_repo)
226 222
227 223 log.debug('SSH Connection info %s', self.get_connection_info())
228 224
229 225 # update last access time for this key
230 226 if key_id:
231 227 self.update_key_access_time(key_id)
232 228
233 229 if shell and self.command is None:
234 230 log.info('Dropping to shell, no command given and shell is allowed')
235 231 os.execl('/bin/bash', '-l')
236 232 exit_code = 1
237 233
238 234 elif scm_detected:
239 235 user = self.get_user(user_id)
240 236 if not user:
241 237 log.warning('User with id %s not found', user_id)
242 238 exit_code = -1
243 239 return exit_code
244 240
245 241 auth_user = user.auth_user
246 242 permissions = auth_user.permissions['repositories']
247 243 repo_branch_permissions = auth_user.get_branch_permissions(scm_repo)
248 244 try:
249 245 exit_code, is_updated = self.serve(
250 246 scm_detected, scm_repo, scm_mode, user, permissions,
251 247 repo_branch_permissions)
252 248 except Exception:
253 249 log.exception('Error occurred during execution of SshWrapper')
254 250 exit_code = -1
255 251
256 252 elif self.command is None and shell is False:
257 253 log.error('No Command given.')
258 254 exit_code = -1
259 255
260 256 else:
261 257 log.error('Unhandled Command: "%s" Aborting.', self.command)
262 258 exit_code = -1
263 259
264 260 return exit_code
265 261
266 262
267 263 class SshWrapperStandalone(SshWrapper):
268 264 """
269 265 New version of SshWrapper designed to be depended only on service API
270 266 """
271 267 repos_path = None
268 service_api_host: str
269 service_api_token: str
270 api_url: str
271
272 def __init__(self, command, connection_info, mode,
273 user, user_id, key_id: int, shell, ini_path: str, settings, env):
274
275 # validate our settings for making a standalone calls
276 try:
277 self.service_api_host = settings['app.service_api.host']
278 self.service_api_token = settings['app.service_api.token']
279 except KeyError:
280 raise ImproperlyConfiguredError(
281 "app.service_api.host or app.service_api.token are missing. "
282 "Please ensure that app.service_api.host and app.service_api.token are "
283 "defined inside of .ini configuration file."
284 )
285
286 try:
287 self.api_url = settings['rhodecode.api.url']
288 except KeyError:
289 raise ImproperlyConfiguredError(
290 "rhodecode.api.url is missing. "
291 "Please ensure that rhodecode.api.url is "
292 "defined inside of .ini configuration file."
293 )
294
295 super(SshWrapperStandalone, self).__init__(
296 command, connection_info, mode, user, user_id, key_id, shell, ini_path, settings, env)
272 297
273 298 @staticmethod
274 299 def parse_user_related_data(user_data):
275 300 user = AttributeDict()
276 301 user.user_id = user_data['user_id']
277 302 user.username = user_data['username']
278 303 user.repo_permissions = user_data['repo_permissions']
279 304 user.branch_permissions = user_data['branch_permissions']
280 305 return user
281 306
282 307 def wrap(self):
283 308 mode = self.mode
284 309 username = self.username
285 310 user_id = self.user_id
286 311 shell = self.shell
287 312
288 313 scm_detected, scm_repo, scm_mode = self.get_repo_details(mode)
289 314
290 315 log.debug(
291 316 'Mode: `%s` User: `name:%s : id:%s` Shell: `%s` SSH Command: `\"%s\"` '
292 317 'SCM_DETECTED: `%s` SCM Mode: `%s` SCM Repo: `%s`',
293 318 mode, username, user_id, shell, self.command,
294 319 scm_detected, scm_mode, scm_repo)
295 320
296 321 log.debug('SSH Connection info %s', self.get_connection_info())
297 322
298 323 if shell and self.command is None:
299 324 log.info('Dropping to shell, no command given and shell is allowed')
300 325 os.execl('/bin/bash', '-l')
301 326 exit_code = 1
302 327
303 328 elif scm_detected:
304 data = call_service_api(self.ini_path, {
329 data = call_service_api(self.service_api_host, self.service_api_token, self.api_url, {
305 330 "method": "service_get_data_for_ssh_wrapper",
306 331 "args": {"user_id": user_id, "repo_name": scm_repo, "key_id": self.key_id}
307 332 })
308 333 user = self.parse_user_related_data(data)
309 334 if not user:
310 335 log.warning('User with id %s not found', user_id)
311 336 exit_code = -1
312 337 return exit_code
313 338 self.repos_path = data['repos_path']
314 339 permissions = user.repo_permissions
315 340 repo_branch_permissions = user.branch_permissions
316 341 try:
317 342 exit_code, is_updated = self.serve(
318 343 scm_detected, scm_repo, scm_mode, user, permissions,
319 344 repo_branch_permissions)
320 345 except Exception:
321 346 log.exception('Error occurred during execution of SshWrapper')
322 347 exit_code = -1
323 348
324 349 elif self.command is None and shell is False:
325 350 log.error('No Command given.')
326 351 exit_code = -1
327 352
328 353 else:
329 354 log.error('Unhandled Command: "%s" Aborting.', self.command)
330 355 exit_code = -1
331 356
332 357 return exit_code
333 358
334 359 def maybe_translate_repo_uid(self, repo_name):
335 360 _org_name = repo_name
336 361 if _org_name.startswith('_'):
337 362 _org_name = _org_name.split('/', 1)[0]
338 363
339 364 if repo_name.startswith('_'):
340 365 org_repo_name = repo_name
341 366 log.debug('translating UID repo %s', org_repo_name)
342 by_id_match = call_service_api(self.ini_path, {
367 by_id_match = call_service_api(self.service_api_host, self.service_api_token, self.api_url, {
343 368 'method': 'service_get_repo_name_by_id',
344 369 "args": {"repo_id": repo_name}
345 370 })
346 371 if by_id_match:
347 372 repo_name = by_id_match['repo_name']
348 373 log.debug('translation of UID repo %s got `%s`', org_repo_name, repo_name)
349 374
350 375 return repo_name, _org_name
351 376
352 377 def serve(self, vcs, repo, mode, user, permissions, branch_permissions):
353 378 store = self.repos_path
354 379
355 380 check_branch_perms = False
356 381 detect_force_push = False
357 382
358 383 if branch_permissions:
359 384 check_branch_perms = True
360 385 detect_force_push = True
361 386
362 387 log.debug(
363 388 'VCS detected:`%s` mode: `%s` repo_name: %s, branch_permission_checks:%s',
364 389 vcs, mode, repo, check_branch_perms)
365 390
366 391 # detect if we have to check branch permissions
367 392 extras = {
368 393 'detect_force_push': detect_force_push,
369 394 'check_branch_perms': check_branch_perms,
370 395 'config': self.ini_path
371 396 }
372 397
373 398 match vcs:
374 399 case 'hg':
375 400 server = MercurialServer(
376 401 store=store, ini_path=self.ini_path,
377 402 repo_name=repo, user=user,
378 user_permissions=permissions, config=self.config, env=self.env)
403 user_permissions=permissions, settings=self.settings, env=self.env)
379 404 case 'git':
380 405 server = GitServer(
381 406 store=store, ini_path=self.ini_path,
382 407 repo_name=repo, repo_mode=mode, user=user,
383 user_permissions=permissions, config=self.config, env=self.env)
408 user_permissions=permissions, settings=self.settings, env=self.env)
384 409 case 'svn':
385 410 server = SubversionServer(
386 411 store=store, ini_path=self.ini_path,
387 412 repo_name=None, user=user,
388 user_permissions=permissions, config=self.config, env=self.env)
413 user_permissions=permissions, settings=self.settings, env=self.env)
389 414 case _:
390 415 raise Exception(f'Unrecognised VCS: {vcs}')
391 416 self.server_impl = server
392 417 return server.run(tunnel_extras=extras)
@@ -1,174 +1,179 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import sys
21 21 import logging
22 22
23 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
23 from rhodecode.lib.hook_daemon.base import prepare_callback_daemon
24 24 from rhodecode.lib.ext_json import sjson as json
25 25 from rhodecode.lib.vcs.conf import settings as vcs_settings
26 from rhodecode.lib.utils import call_service_api
27 from rhodecode.model.scm import ScmModel
26 from rhodecode.lib.api_utils import call_service_api
28 27
29 28 log = logging.getLogger(__name__)
30 29
31 30
32 class VcsServer(object):
31 class SSHVcsServer(object):
33 32 repo_user_agent = None # set in child classes
34 33 _path = None # set executable path for hg/git/svn binary
35 34 backend = None # set in child classes
36 35 tunnel = None # subprocess handling tunnel
36 settings = None # parsed settings module
37 37 write_perms = ['repository.admin', 'repository.write']
38 38 read_perms = ['repository.read', 'repository.admin', 'repository.write']
39 39
40 def __init__(self, user, user_permissions, config, env):
40 def __init__(self, user, user_permissions, settings, env):
41 41 self.user = user
42 42 self.user_permissions = user_permissions
43 self.config = config
43 self.settings = settings
44 44 self.env = env
45 45 self.stdin = sys.stdin
46 46
47 47 self.repo_name = None
48 48 self.repo_mode = None
49 49 self.store = ''
50 50 self.ini_path = ''
51 51 self.hooks_protocol = None
52 52
53 53 def _invalidate_cache(self, repo_name):
54 54 """
55 55 Set's cache for this repository for invalidation on next access
56 56
57 57 :param repo_name: full repo name, also a cache key
58 58 """
59 59 # Todo: Leave only "celery" case after transition.
60 60 match self.hooks_protocol:
61 61 case 'http':
62 from rhodecode.model.scm import ScmModel
62 63 ScmModel().mark_for_invalidation(repo_name)
63 64 case 'celery':
64 call_service_api(self.ini_path, {
65 service_api_host = self.settings['app.service_api.host']
66 service_api_token = self.settings['app.service_api.token']
67 api_url = self.settings['rhodecode.api.url']
68
69 call_service_api(service_api_host, service_api_token, api_url, {
65 70 "method": "service_mark_for_invalidation",
66 71 "args": {"repo_name": repo_name}
67 72 })
68 73
69 74 def has_write_perm(self):
70 75 permission = self.user_permissions.get(self.repo_name)
71 76 if permission in ['repository.write', 'repository.admin']:
72 77 return True
73 78
74 79 return False
75 80
76 81 def _check_permissions(self, action):
77 82 permission = self.user_permissions.get(self.repo_name)
78 83 user_info = f'{self.user["user_id"]}:{self.user["username"]}'
79 84 log.debug('permission for %s on %s are: %s',
80 85 user_info, self.repo_name, permission)
81 86
82 87 if not permission:
83 88 log.error('user `%s` permissions to repo:%s are empty. Forbidding access.',
84 89 user_info, self.repo_name)
85 90 return -2
86 91
87 92 if action == 'pull':
88 93 if permission in self.read_perms:
89 94 log.info(
90 95 'READ Permissions for User "%s" detected to repo "%s"!',
91 96 user_info, self.repo_name)
92 97 return 0
93 98 else:
94 99 if permission in self.write_perms:
95 100 log.info(
96 101 'WRITE, or Higher Permissions for User "%s" detected to repo "%s"!',
97 102 user_info, self.repo_name)
98 103 return 0
99 104
100 105 log.error('Cannot properly fetch or verify user `%s` permissions. '
101 106 'Permissions: %s, vcs action: %s',
102 107 user_info, permission, action)
103 108 return -2
104 109
105 110 def update_environment(self, action, extras=None):
106 111
107 112 scm_data = {
108 113 'ip': os.environ['SSH_CLIENT'].split()[0],
109 114 'username': self.user.username,
110 115 'user_id': self.user.user_id,
111 116 'action': action,
112 117 'repository': self.repo_name,
113 118 'scm': self.backend,
114 119 'config': self.ini_path,
115 120 'repo_store': self.store,
116 121 'make_lock': None,
117 122 'locked_by': [None, None],
118 123 'server_url': None,
119 124 'user_agent': f'{self.repo_user_agent}/ssh-user-agent',
120 125 'hooks': ['push', 'pull'],
121 'hooks_module': 'rhodecode.lib.hooks_daemon',
126 'hooks_module': 'rhodecode.lib.hook_daemon.hook_module',
122 127 'is_shadow_repo': False,
123 128 'detect_force_push': False,
124 129 'check_branch_perms': False,
125 130
126 131 'SSH': True,
127 132 'SSH_PERMISSIONS': self.user_permissions.get(self.repo_name),
128 133 }
129 134 if extras:
130 135 scm_data.update(extras)
131 136 os.putenv("RC_SCM_DATA", json.dumps(scm_data))
132 137 return scm_data
133 138
134 139 def get_root_store(self):
135 140 root_store = self.store
136 141 if not root_store.endswith('/'):
137 142 # always append trailing slash
138 143 root_store = root_store + '/'
139 144 return root_store
140 145
141 146 def _handle_tunnel(self, extras):
142 147 # pre-auth
143 148 action = 'pull'
144 149 exit_code = self._check_permissions(action)
145 150 if exit_code:
146 151 return exit_code, False
147 152
148 153 req = self.env.get('request')
149 154 if req:
150 155 server_url = req.host_url + req.script_name
151 156 extras['server_url'] = server_url
152 157
153 158 log.debug('Using %s binaries from path %s', self.backend, self._path)
154 159 exit_code = self.tunnel.run(extras)
155 160
156 161 return exit_code, action == "push"
157 162
158 163 def run(self, tunnel_extras=None):
159 self.hooks_protocol = self.config.get('app:main', 'vcs.hooks.protocol')
164 self.hooks_protocol = self.settings['vcs.hooks.protocol']
160 165 tunnel_extras = tunnel_extras or {}
161 166 extras = {}
162 167 extras.update(tunnel_extras)
163 168
164 169 callback_daemon, extras = prepare_callback_daemon(
165 170 extras, protocol=self.hooks_protocol,
166 171 host=vcs_settings.HOOKS_HOST)
167 172
168 173 with callback_daemon:
169 174 try:
170 175 return self._handle_tunnel(extras)
171 176 finally:
172 177 log.debug('Running cleanup with cache invalidation')
173 178 if self.repo_name:
174 179 self._invalidate_cache(self.repo_name)
@@ -1,88 +1,86 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import sys
20 20 import logging
21 21 import subprocess
22 22
23 23 from vcsserver import hooks
24 from .base import VcsServer
24 from .base import SSHVcsServer
25 25
26 26 log = logging.getLogger(__name__)
27 27
28 28
29 29 class GitTunnelWrapper(object):
30 30 process = None
31 31
32 32 def __init__(self, server):
33 33 self.server = server
34 34 self.stdin = sys.stdin
35 35 self.stdout = sys.stdout
36 36
37 37 def create_hooks_env(self):
38 38 pass
39 39
40 40 def command(self):
41 41 root = self.server.get_root_store()
42 42 command = "cd {root}; {git_path} {mode} '{root}{repo_name}'".format(
43 43 root=root, git_path=self.server.git_path,
44 44 mode=self.server.repo_mode, repo_name=self.server.repo_name)
45 45 log.debug("Final CMD: %s", command)
46 46 return command
47 47
48 48 def run(self, extras):
49 49 action = "push" if self.server.repo_mode == "receive-pack" else "pull"
50 50 exit_code = self.server._check_permissions(action)
51 51 if exit_code:
52 52 return exit_code
53 53
54 54 scm_extras = self.server.update_environment(action=action, extras=extras)
55 55
56 56 if action == "pull":
57 57 hook_response = hooks.git_pre_pull(scm_extras)
58 58 pre_pull_messages = hook_response.output
59 59 sys.stdout.write(pre_pull_messages)
60 60
61 61 self.create_hooks_env()
62 62 result = subprocess.run(self.command(), shell=True)
63 63 result = result.returncode
64 64
65 65 # Upload-pack == clone
66 66 if action == "pull":
67 67 hook_response = hooks.git_post_pull(scm_extras)
68 68 post_pull_messages = hook_response.output
69 69 sys.stderr.write(post_pull_messages)
70 70 return result
71 71
72 72
73 class GitServer(VcsServer):
73 class GitServer(SSHVcsServer):
74 74 backend = 'git'
75 75 repo_user_agent = 'git'
76 76
77 def __init__(self, store, ini_path, repo_name, repo_mode,
78 user, user_permissions, config, env):
79 super().\
80 __init__(user, user_permissions, config, env)
77 def __init__(self, store, ini_path, repo_name, repo_mode, user, user_permissions, settings, env):
78 super().__init__(user, user_permissions, settings, env)
81 79
82 80 self.store = store
83 81 self.ini_path = ini_path
84 82 self.repo_name = repo_name
85 self._path = self.git_path = config.get('app:main', 'ssh.executable.git')
83 self._path = self.git_path = settings['ssh.executable.git']
86 84
87 85 self.repo_mode = repo_mode
88 86 self.tunnel = GitTunnelWrapper(server=self)
@@ -1,155 +1,160 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import sys
21 21 import logging
22 22 import tempfile
23 23 import textwrap
24 24 import collections
25 from .base import VcsServer
26 from rhodecode.lib.utils import call_service_api
27 from rhodecode.model.db import RhodeCodeUi
28 from rhodecode.model.settings import VcsSettingsModel
25
26 from .base import SSHVcsServer
27
28 from rhodecode.lib.api_utils import call_service_api
29 29
30 30 log = logging.getLogger(__name__)
31 31
32 32
33 33 class MercurialTunnelWrapper(object):
34 34 process = None
35 35
36 36 def __init__(self, server):
37 37 self.server = server
38 38 self.stdin = sys.stdin
39 39 self.stdout = sys.stdout
40 40 self.hooks_env_fd, self.hooks_env_path = tempfile.mkstemp(prefix='hgrc_rhodecode_')
41 41
42 42 def create_hooks_env(self):
43 43 repo_name = self.server.repo_name
44 44 hg_flags = self.server.config_to_hgrc(repo_name)
45 45
46 46 content = textwrap.dedent(
47 47 '''
48 48 # RhodeCode SSH hooks version=2.0.0
49 49 {custom}
50 50 '''
51 51 ).format(custom='\n'.join(hg_flags))
52 52
53 53 root = self.server.get_root_store()
54 54 hgrc_custom = os.path.join(root, repo_name, '.hg', 'hgrc_rhodecode')
55 55 hgrc_main = os.path.join(root, repo_name, '.hg', 'hgrc')
56 56
57 57 # cleanup custom hgrc file
58 58 if os.path.isfile(hgrc_custom):
59 59 with open(hgrc_custom, 'wb') as f:
60 f.write('')
60 f.write(b'')
61 61 log.debug('Cleanup custom hgrc file under %s', hgrc_custom)
62 62
63 63 # write temp
64 64 with os.fdopen(self.hooks_env_fd, 'w') as hooks_env_file:
65 65 hooks_env_file.write(content)
66 66
67 67 return self.hooks_env_path
68 68
69 69 def remove_configs(self):
70 70 os.remove(self.hooks_env_path)
71 71
72 72 def command(self, hgrc_path):
73 73 root = self.server.get_root_store()
74 74
75 75 command = (
76 76 "cd {root}; HGRCPATH={hgrc} {hg_path} -R {root}{repo_name} "
77 77 "serve --stdio".format(
78 78 root=root, hg_path=self.server.hg_path,
79 79 repo_name=self.server.repo_name, hgrc=hgrc_path))
80 80 log.debug("Final CMD: %s", command)
81 81 return command
82 82
83 83 def run(self, extras):
84 84 # at this point we cannot tell, we do further ACL checks
85 85 # inside the hooks
86 86 action = '?'
87 87 # permissions are check via `pre_push_ssh_auth` hook
88 88 self.server.update_environment(action=action, extras=extras)
89 89 custom_hgrc_file = self.create_hooks_env()
90 90
91 91 try:
92 92 return os.system(self.command(custom_hgrc_file))
93 93 finally:
94 94 self.remove_configs()
95 95
96 96
97 class MercurialServer(VcsServer):
97 class MercurialServer(SSHVcsServer):
98 98 backend = 'hg'
99 99 repo_user_agent = 'mercurial'
100 100 cli_flags = ['phases', 'largefiles', 'extensions', 'experimental', 'hooks']
101 101
102 def __init__(self, store, ini_path, repo_name, user, user_permissions, config, env):
103 super().__init__(user, user_permissions, config, env)
102 def __init__(self, store, ini_path, repo_name, user, user_permissions, settings, env):
103 super().__init__(user, user_permissions, settings, env)
104 104
105 105 self.store = store
106 106 self.ini_path = ini_path
107 107 self.repo_name = repo_name
108 self._path = self.hg_path = config.get('app:main', 'ssh.executable.hg')
108 self._path = self.hg_path = settings['ssh.executable.hg']
109 109 self.tunnel = MercurialTunnelWrapper(server=self)
110 110
111 111 def config_to_hgrc(self, repo_name):
112 112 # Todo: once transition is done only call to service api should exist
113 113 if self.hooks_protocol == 'celery':
114 data = call_service_api(self.ini_path, {
114 service_api_host = self.settings['app.service_api.host']
115 service_api_token = self.settings['app.service_api.token']
116 api_url = self.settings['rhodecode.api.url']
117 data = call_service_api(service_api_host, service_api_token, api_url, {
115 118 "method": "service_config_to_hgrc",
116 119 "args": {"cli_flags": self.cli_flags, "repo_name": repo_name}
117 120 })
118 121 return data['flags']
119
120 ui_sections = collections.defaultdict(list)
121 ui = VcsSettingsModel(repo=repo_name).get_ui_settings(section=None, key=None)
122 else:
123 from rhodecode.model.db import RhodeCodeUi
124 from rhodecode.model.settings import VcsSettingsModel
125 ui_sections = collections.defaultdict(list)
126 ui = VcsSettingsModel(repo=repo_name).get_ui_settings(section=None, key=None)
122 127
123 # write default hooks
124 default_hooks = [
125 ('pretxnchangegroup.ssh_auth', 'python:vcsserver.hooks.pre_push_ssh_auth'),
126 ('pretxnchangegroup.ssh', 'python:vcsserver.hooks.pre_push_ssh'),
127 ('changegroup.ssh', 'python:vcsserver.hooks.post_push_ssh'),
128 # write default hooks
129 default_hooks = [
130 ('pretxnchangegroup.ssh_auth', 'python:vcsserver.hooks.pre_push_ssh_auth'),
131 ('pretxnchangegroup.ssh', 'python:vcsserver.hooks.pre_push_ssh'),
132 ('changegroup.ssh', 'python:vcsserver.hooks.post_push_ssh'),
128 133
129 ('preoutgoing.ssh', 'python:vcsserver.hooks.pre_pull_ssh'),
130 ('outgoing.ssh', 'python:vcsserver.hooks.post_pull_ssh'),
131 ]
134 ('preoutgoing.ssh', 'python:vcsserver.hooks.pre_pull_ssh'),
135 ('outgoing.ssh', 'python:vcsserver.hooks.post_pull_ssh'),
136 ]
132 137
133 for k, v in default_hooks:
134 ui_sections['hooks'].append((k, v))
138 for k, v in default_hooks:
139 ui_sections['hooks'].append((k, v))
135 140
136 for entry in ui:
137 if not entry.active:
138 continue
139 sec = entry.section
140 key = entry.key
141
142 if sec in self.cli_flags:
143 # we want only custom hooks, so we skip builtins
144 if sec == 'hooks' and key in RhodeCodeUi.HOOKS_BUILTIN:
141 for entry in ui:
142 if not entry.active:
145 143 continue
144 sec = entry.section
145 key = entry.key
146 146
147 ui_sections[sec].append([key, entry.value])
147 if sec in self.cli_flags:
148 # we want only custom hooks, so we skip builtins
149 if sec == 'hooks' and key in RhodeCodeUi.HOOKS_BUILTIN:
150 continue
148 151
149 flags = []
150 for _sec, key_val in ui_sections.items():
151 flags.append(' ')
152 flags.append(f'[{_sec}]')
153 for key, val in key_val:
154 flags.append(f'{key}= {val}')
155 return flags
152 ui_sections[sec].append([key, entry.value])
153
154 flags = []
155 for _sec, key_val in ui_sections.items():
156 flags.append(' ')
157 flags.append(f'[{_sec}]')
158 for key, val in key_val:
159 flags.append(f'{key}= {val}')
160 return flags
@@ -1,256 +1,254 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import re
21 21 import sys
22 22 import logging
23 23 import signal
24 24 import tempfile
25 25 from subprocess import Popen, PIPE
26 26 import urllib.parse
27 27
28 from .base import VcsServer
28 from .base import SSHVcsServer
29 29
30 30 log = logging.getLogger(__name__)
31 31
32 32
33 33 class SubversionTunnelWrapper(object):
34 34 process = None
35 35
36 36 def __init__(self, server):
37 37 self.server = server
38 38 self.timeout = 30
39 39 self.stdin = sys.stdin
40 40 self.stdout = sys.stdout
41 41 self.svn_conf_fd, self.svn_conf_path = tempfile.mkstemp()
42 42 self.hooks_env_fd, self.hooks_env_path = tempfile.mkstemp()
43 43
44 44 self.read_only = True # flag that we set to make the hooks readonly
45 45
46 46 def create_svn_config(self):
47 47 content = (
48 48 '[general]\n'
49 49 'hooks-env = {}\n').format(self.hooks_env_path)
50 50 with os.fdopen(self.svn_conf_fd, 'w') as config_file:
51 51 config_file.write(content)
52 52
53 53 def create_hooks_env(self):
54 54 content = (
55 55 '[default]\n'
56 56 'LANG = en_US.UTF-8\n')
57 57 if self.read_only:
58 58 content += 'SSH_READ_ONLY = 1\n'
59 59 with os.fdopen(self.hooks_env_fd, 'w') as hooks_env_file:
60 60 hooks_env_file.write(content)
61 61
62 62 def remove_configs(self):
63 63 os.remove(self.svn_conf_path)
64 64 os.remove(self.hooks_env_path)
65 65
66 66 def command(self):
67 67 root = self.server.get_root_store()
68 68 username = self.server.user.username
69 69
70 70 command = [
71 71 self.server.svn_path, '-t',
72 72 '--config-file', self.svn_conf_path,
73 73 '--tunnel-user', username,
74 74 '-r', root]
75 75 log.debug("Final CMD: %s", ' '.join(command))
76 76 return command
77 77
78 78 def start(self):
79 79 command = self.command()
80 80 self.process = Popen(' '.join(command), stdin=PIPE, shell=True)
81 81
82 82 def sync(self):
83 83 while self.process.poll() is None:
84 84 next_byte = self.stdin.read(1)
85 85 if not next_byte:
86 86 break
87 87 self.process.stdin.write(next_byte)
88 88 self.remove_configs()
89 89
90 90 @property
91 91 def return_code(self):
92 92 return self.process.returncode
93 93
94 94 def get_first_client_response(self):
95 95 signal.signal(signal.SIGALRM, self.interrupt)
96 96 signal.alarm(self.timeout)
97 97 first_response = self._read_first_client_response()
98 98 signal.alarm(0)
99 99 return (self._parse_first_client_response(first_response)
100 100 if first_response else None)
101 101
102 102 def patch_first_client_response(self, response, **kwargs):
103 103 self.create_hooks_env()
104 104 data = response.copy()
105 105 data.update(kwargs)
106 106 data['url'] = self._svn_string(data['url'])
107 107 data['ra_client'] = self._svn_string(data['ra_client'])
108 108 data['client'] = data['client'] or ''
109 109 buffer_ = (
110 110 "( {version} ( {capabilities} ) {url}{ra_client}"
111 111 "( {client}) ) ".format(**data))
112 112 self.process.stdin.write(buffer_)
113 113
114 114 def fail(self, message):
115 115 print("( failure ( ( 210005 {message} 0: 0 ) ) )".format(
116 116 message=self._svn_string(message)))
117 117 self.remove_configs()
118 118 self.process.kill()
119 119 return 1
120 120
121 121 def interrupt(self, signum, frame):
122 122 self.fail("Exited by timeout")
123 123
124 124 def _svn_string(self, str_):
125 125 if not str_:
126 126 return ''
127 127 return f'{len(str_)}:{str_} '
128 128
129 129 def _read_first_client_response(self):
130 130 buffer_ = ""
131 131 brackets_stack = []
132 132 while True:
133 133 next_byte = self.stdin.read(1)
134 134 buffer_ += next_byte
135 135 if next_byte == "(":
136 136 brackets_stack.append(next_byte)
137 137 elif next_byte == ")":
138 138 brackets_stack.pop()
139 139 elif next_byte == " " and not brackets_stack:
140 140 break
141 141
142 142 return buffer_
143 143
144 144 def _parse_first_client_response(self, buffer_):
145 145 """
146 146 According to the Subversion RA protocol, the first request
147 147 should look like:
148 148
149 149 ( version:number ( cap:word ... ) url:string ? ra-client:string
150 150 ( ? client:string ) )
151 151
152 152 Please check https://svn.apache.org/repos/asf/subversion/trunk/subversion/libsvn_ra_svn/protocol
153 153 """
154 154 version_re = r'(?P<version>\d+)'
155 155 capabilities_re = r'\(\s(?P<capabilities>[\w\d\-\ ]+)\s\)'
156 156 url_re = r'\d+\:(?P<url>[\W\w]+)'
157 157 ra_client_re = r'(\d+\:(?P<ra_client>[\W\w]+)\s)'
158 158 client_re = r'(\d+\:(?P<client>[\W\w]+)\s)*'
159 159 regex = re.compile(
160 160 r'^\(\s{version}\s{capabilities}\s{url}\s{ra_client}'
161 161 r'\(\s{client}\)\s\)\s*$'.format(
162 162 version=version_re, capabilities=capabilities_re,
163 163 url=url_re, ra_client=ra_client_re, client=client_re))
164 164 matcher = regex.match(buffer_)
165 165
166 166 return matcher.groupdict() if matcher else None
167 167
168 168 def _match_repo_name(self, url):
169 169 """
170 170 Given an server url, try to match it against ALL known repository names.
171 171 This handles a tricky SVN case for SSH and subdir commits.
172 172 E.g if our repo name is my-svn-repo, a svn commit on file in a subdir would
173 173 result in the url with this subdir added.
174 174 """
175 175 # case 1 direct match, we don't do any "heavy" lookups
176 176 if url in self.server.user_permissions:
177 177 return url
178 178
179 179 log.debug('Extracting repository name from subdir path %s', url)
180 180 # case 2 we check all permissions, and match closes possible case...
181 181 # NOTE(dan): In this case we only know that url has a subdir parts, it's safe
182 182 # to assume that it will have the repo name as prefix, we ensure the prefix
183 183 # for similar repositories isn't matched by adding a /
184 184 # e.g subgroup/repo-name/ and subgroup/repo-name-1/ would work correct.
185 185 for repo_name in self.server.user_permissions:
186 186 repo_name_prefix = repo_name + '/'
187 187 if url.startswith(repo_name_prefix):
188 188 log.debug('Found prefix %s match, returning proper repository name',
189 189 repo_name_prefix)
190 190 return repo_name
191 191
192 192 return
193 193
194 194 def run(self, extras):
195 195 action = 'pull'
196 196 self.create_svn_config()
197 197 self.start()
198 198
199 199 first_response = self.get_first_client_response()
200 200 if not first_response:
201 201 return self.fail("Repository name cannot be extracted")
202 202
203 203 url_parts = urllib.parse.urlparse(first_response['url'])
204 204
205 205 self.server.repo_name = self._match_repo_name(url_parts.path.strip('/'))
206 206
207 207 exit_code = self.server._check_permissions(action)
208 208 if exit_code:
209 209 return exit_code
210 210
211 211 # set the readonly flag to False if we have proper permissions
212 212 if self.server.has_write_perm():
213 213 self.read_only = False
214 214 self.server.update_environment(action=action, extras=extras)
215 215
216 216 self.patch_first_client_response(first_response)
217 217 self.sync()
218 218 return self.return_code
219 219
220 220
221 class SubversionServer(VcsServer):
221 class SubversionServer(SSHVcsServer):
222 222 backend = 'svn'
223 223 repo_user_agent = 'svn'
224 224
225 def __init__(self, store, ini_path, repo_name,
226 user, user_permissions, config, env):
227 super()\
228 .__init__(user, user_permissions, config, env)
225 def __init__(self, store, ini_path, repo_name, user, user_permissions, settings, env):
226 super().__init__(user, user_permissions, settings, env)
229 227 self.store = store
230 228 self.ini_path = ini_path
231 229 # NOTE(dan): repo_name at this point is empty,
232 230 # this is set later in .run() based from parsed input stream
233 231 self.repo_name = repo_name
234 self._path = self.svn_path = config.get('app:main', 'ssh.executable.svn')
232 self._path = self.svn_path = settings['ssh.executable.svn']
235 233
236 234 self.tunnel = SubversionTunnelWrapper(server=self)
237 235
238 236 def _handle_tunnel(self, extras):
239 237
240 238 # pre-auth
241 239 action = 'pull'
242 240 # Special case for SVN, we extract repo name at later stage
243 241 # exit_code = self._check_permissions(action)
244 242 # if exit_code:
245 243 # return exit_code, False
246 244
247 245 req = self.env['request']
248 246 server_url = req.host_url + req.script_name
249 247 extras['server_url'] = server_url
250 248
251 249 log.debug('Using %s binaries from path %s', self.backend, self._path)
252 250 exit_code = self.tunnel.run(extras)
253 251
254 252 return exit_code, action == "push"
255 253
256 254
@@ -1,72 +1,72 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import sys
21 21 import time
22 22 import logging
23 23
24 24 import click
25 25
26 26 from rhodecode.lib.pyramid_utils import bootstrap
27 27 from rhodecode.lib.statsd_client import StatsdClient
28 28 from .backends import SshWrapper
29 29 from .utils import setup_custom_logging
30 30
31 31 log = logging.getLogger(__name__)
32 32
33 33
34
35 34 @click.command()
36 35 @click.argument('ini_path', type=click.Path(exists=True))
37 36 @click.option(
38 37 '--mode', '-m', required=False, default='auto',
39 38 type=click.Choice(['auto', 'vcs', 'git', 'hg', 'svn', 'test']),
40 39 help='mode of operation')
41 40 @click.option('--user', help='Username for which the command will be executed')
42 41 @click.option('--user-id', help='User ID for which the command will be executed')
43 42 @click.option('--key-id', help='ID of the key from the database')
44 43 @click.option('--shell', '-s', is_flag=True, help='Allow Shell')
45 44 @click.option('--debug', is_flag=True, help='Enabled detailed output logging')
46 45 def main(ini_path, mode, user, user_id, key_id, shell, debug):
47 46 setup_custom_logging(ini_path, debug)
48 47
49 48 command = os.environ.get('SSH_ORIGINAL_COMMAND', '')
50 49 if not command and mode not in ['test']:
51 50 raise ValueError(
52 51 'Unable to fetch SSH_ORIGINAL_COMMAND from environment.'
53 52 'Please make sure this is set and available during execution '
54 53 'of this script.')
55 54 connection_info = os.environ.get('SSH_CONNECTION', '')
56 55 time_start = time.time()
57 56 with bootstrap(ini_path, env={'RC_CMD_SSH_WRAPPER': '1'}) as env:
57 settings = env['registry'].settings
58 58 statsd = StatsdClient.statsd
59 59 try:
60 60 ssh_wrapper = SshWrapper(
61 61 command, connection_info, mode,
62 user, user_id, key_id, shell, ini_path, env)
62 user, user_id, key_id, shell, ini_path, settings, env)
63 63 except Exception:
64 64 log.exception('Failed to execute SshWrapper')
65 65 sys.exit(-5)
66 66 return_code = ssh_wrapper.wrap()
67 67 operation_took = time.time() - time_start
68 68 if statsd:
69 69 operation_took_ms = round(1000.0 * operation_took)
70 70 statsd.timing("rhodecode_ssh_wrapper_timing.histogram", operation_took_ms,
71 71 use_decimals=False)
72 72 sys.exit(return_code)
@@ -1,70 +1,92 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 """
20 WARNING: be really carefully with changing ANY imports in this file
21 # This script is to mean as really fast executable, doing some imports here that would yield an import chain change
22 # can affect execution times...
23 # This can be easily debugged using such command::
24 # time PYTHONPROFILEIMPORTTIME=1 rc-ssh-wrapper-v2 --debug --mode=test .dev/dev.ini
25 """
26
19 27 import os
20 28 import sys
21 29 import time
22 30 import logging
23 31
24 32 import click
25 33
34 from rhodecode.config.config_maker import sanitize_settings_and_apply_defaults
26 35 from rhodecode.lib.statsd_client import StatsdClient
36 from rhodecode.lib.config_utils import get_app_config_lightweight
37
38 from .utils import setup_custom_logging
27 39 from .backends import SshWrapperStandalone
28 from .utils import setup_custom_logging
29 40
30 41 log = logging.getLogger(__name__)
31 42
32 43
33 44 @click.command()
34 45 @click.argument('ini_path', type=click.Path(exists=True))
35 46 @click.option(
36 47 '--mode', '-m', required=False, default='auto',
37 48 type=click.Choice(['auto', 'vcs', 'git', 'hg', 'svn', 'test']),
38 49 help='mode of operation')
39 50 @click.option('--user', help='Username for which the command will be executed')
40 51 @click.option('--user-id', help='User ID for which the command will be executed')
41 52 @click.option('--key-id', help='ID of the key from the database')
42 53 @click.option('--shell', '-s', is_flag=True, help='Allow Shell')
43 54 @click.option('--debug', is_flag=True, help='Enabled detailed output logging')
44 55 def main(ini_path, mode, user, user_id, key_id, shell, debug):
56
57 time_start = time.time()
45 58 setup_custom_logging(ini_path, debug)
46 59
47 60 command = os.environ.get('SSH_ORIGINAL_COMMAND', '')
48 61 if not command and mode not in ['test']:
49 62 raise ValueError(
50 63 'Unable to fetch SSH_ORIGINAL_COMMAND from environment.'
51 64 'Please make sure this is set and available during execution '
52 65 'of this script.')
53 connection_info = os.environ.get('SSH_CONNECTION', '')
54 time_start = time.time()
55 env = {'RC_CMD_SSH_WRAPPER': '1'}
66
67 # initialize settings and get defaults
68 settings = get_app_config_lightweight(ini_path)
69 settings = sanitize_settings_and_apply_defaults({'__file__': ini_path}, settings)
70
71 # init and bootstrap StatsdClient
72 StatsdClient.setup(settings)
56 73 statsd = StatsdClient.statsd
74
57 75 try:
76 connection_info = os.environ.get('SSH_CONNECTION', '')
77 env = {'RC_CMD_SSH_WRAPPER': '1'}
58 78 ssh_wrapper = SshWrapperStandalone(
59 79 command, connection_info, mode,
60 user, user_id, key_id, shell, ini_path, env)
80 user, user_id, key_id, shell, ini_path, settings, env)
61 81 except Exception:
62 82 log.exception('Failed to execute SshWrapper')
63 83 sys.exit(-5)
84
64 85 return_code = ssh_wrapper.wrap()
65 86 operation_took = time.time() - time_start
66 87 if statsd:
67 88 operation_took_ms = round(1000.0 * operation_took)
68 89 statsd.timing("rhodecode_ssh_wrapper_timing.histogram", operation_took_ms,
69 90 use_decimals=False)
91
70 92 sys.exit(return_code)
@@ -1,34 +1,34 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import logging
20 from pyramid.paster import setup_logging
21 20
22 21
23 22 def setup_custom_logging(ini_path, debug):
24 23 if debug:
24 from pyramid.paster import setup_logging # Lazy import
25 25 # enabled rhodecode.ini controlled logging setup
26 26 setup_logging(ini_path)
27 27 else:
28 28 # configure logging in a mode that doesn't print anything.
29 29 # in case of regularly configured logging it gets printed out back
30 30 # to the client doing an SSH command.
31 31 logger = logging.getLogger('')
32 32 null = logging.NullHandler()
33 33 # add the handler to the root logger
34 34 logger.handlers = [null]
@@ -1,68 +1,71 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import pytest
21 21 import configparser
22 22
23 23 from rhodecode.apps.ssh_support.lib.ssh_wrapper_v1 import SshWrapper
24 24 from rhodecode.lib.utils2 import AttributeDict
25 25
26 26
27 27 @pytest.fixture()
28 28 def dummy_conf_file(tmpdir):
29 29 conf = configparser.ConfigParser()
30 30 conf.add_section('app:main')
31 31 conf.set('app:main', 'ssh.executable.hg', '/usr/bin/hg')
32 32 conf.set('app:main', 'ssh.executable.git', '/usr/bin/git')
33 33 conf.set('app:main', 'ssh.executable.svn', '/usr/bin/svnserve')
34 34
35 35 f_path = os.path.join(str(tmpdir), 'ssh_wrapper_test.ini')
36 36 with open(f_path, 'wt') as f:
37 37 conf.write(f)
38 38
39 39 return os.path.join(f_path)
40 40
41 41
42 42 def plain_dummy_env():
43 43 return {
44 44 'request':
45 45 AttributeDict(host_url='http://localhost', script_name='/')
46 46 }
47 47
48 48
49 49 @pytest.fixture()
50 50 def dummy_env():
51 51 return plain_dummy_env()
52 52
53 53
54 54 def plain_dummy_user():
55 return AttributeDict(username='test_user')
55 return AttributeDict(
56 user_id=1,
57 username='test_user'
58 )
56 59
57 60
58 61 @pytest.fixture()
59 62 def dummy_user():
60 63 return plain_dummy_user()
61 64
62 65
63 66 @pytest.fixture()
64 67 def ssh_wrapper(app, dummy_conf_file, dummy_env):
65 68 conn_info = '127.0.0.1 22 10.0.0.1 443'
66 69 return SshWrapper(
67 70 'random command', conn_info, 'auto', 'admin', '1', key_id='1',
68 shell=False, ini_path=dummy_conf_file, env=dummy_env)
71 shell=False, ini_path=dummy_conf_file, settings={}, env=dummy_env)
@@ -1,153 +1,151 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20
21 21 import mock
22 22 import pytest
23 23
24 24 from rhodecode.apps.ssh_support.lib.backends.git import GitServer
25 25 from rhodecode.apps.ssh_support.tests.conftest import plain_dummy_env, plain_dummy_user
26 26 from rhodecode.lib.ext_json import json
27 27
28
28 29 class GitServerCreator(object):
29 30 root = '/tmp/repo/path/'
30 31 git_path = '/usr/local/bin/git'
31 32 config_data = {
32 33 'app:main': {
33 34 'ssh.executable.git': git_path,
34 35 'vcs.hooks.protocol': 'http',
35 36 }
36 37 }
37 38 repo_name = 'test_git'
38 39 repo_mode = 'receive-pack'
39 40 user = plain_dummy_user()
40 41
41 42 def __init__(self):
42 def config_get(part, key):
43 return self.config_data.get(part, {}).get(key)
44 self.config_mock = mock.Mock()
45 self.config_mock.get = mock.Mock(side_effect=config_get)
43 pass
46 44
47 45 def create(self, **kwargs):
48 46 parameters = {
49 47 'store': self.root,
50 48 'ini_path': '',
51 49 'user': self.user,
52 50 'repo_name': self.repo_name,
53 51 'repo_mode': self.repo_mode,
54 52 'user_permissions': {
55 53 self.repo_name: 'repository.admin'
56 54 },
57 'config': self.config_mock,
55 'settings': self.config_data['app:main'],
58 56 'env': plain_dummy_env()
59 57 }
60 58 parameters.update(kwargs)
61 59 server = GitServer(**parameters)
62 60 return server
63 61
64 62
65 63 @pytest.fixture()
66 64 def git_server(app):
67 65 return GitServerCreator()
68 66
69 67
70 68 class TestGitServer(object):
71 69
72 70 def test_command(self, git_server):
73 71 server = git_server.create()
74 72 expected_command = (
75 73 'cd {root}; {git_path} {repo_mode} \'{root}{repo_name}\''.format(
76 74 root=git_server.root, git_path=git_server.git_path,
77 75 repo_mode=git_server.repo_mode, repo_name=git_server.repo_name)
78 76 )
79 77 assert expected_command == server.tunnel.command()
80 78
81 79 @pytest.mark.parametrize('permissions, action, code', [
82 80 ({}, 'pull', -2),
83 81 ({'test_git': 'repository.read'}, 'pull', 0),
84 82 ({'test_git': 'repository.read'}, 'push', -2),
85 83 ({'test_git': 'repository.write'}, 'push', 0),
86 84 ({'test_git': 'repository.admin'}, 'push', 0),
87 85
88 86 ])
89 87 def test_permission_checks(self, git_server, permissions, action, code):
90 88 server = git_server.create(user_permissions=permissions)
91 89 result = server._check_permissions(action)
92 90 assert result is code
93 91
94 92 @pytest.mark.parametrize('permissions, value', [
95 93 ({}, False),
96 94 ({'test_git': 'repository.read'}, False),
97 95 ({'test_git': 'repository.write'}, True),
98 96 ({'test_git': 'repository.admin'}, True),
99 97
100 98 ])
101 99 def test_has_write_permissions(self, git_server, permissions, value):
102 100 server = git_server.create(user_permissions=permissions)
103 101 result = server.has_write_perm()
104 102 assert result is value
105 103
106 104 def test_run_returns_executes_command(self, git_server):
107 105 server = git_server.create()
108 106 from rhodecode.apps.ssh_support.lib.backends.git import GitTunnelWrapper
109 107
110 108 os.environ['SSH_CLIENT'] = '127.0.0.1'
111 109 with mock.patch.object(GitTunnelWrapper, 'create_hooks_env') as _patch:
112 110 _patch.return_value = 0
113 111 with mock.patch.object(GitTunnelWrapper, 'command', return_value='date'):
114 112 exit_code = server.run()
115 113
116 114 assert exit_code == (0, False)
117 115
118 116 @pytest.mark.parametrize(
119 117 'repo_mode, action', [
120 118 ['receive-pack', 'push'],
121 119 ['upload-pack', 'pull']
122 120 ])
123 121 def test_update_environment(self, git_server, repo_mode, action):
124 122 server = git_server.create(repo_mode=repo_mode)
125 123 store = server.store
126 124
127 125 with mock.patch('os.environ', {'SSH_CLIENT': '10.10.10.10 b'}):
128 126 with mock.patch('os.putenv') as putenv_mock:
129 127 server.update_environment(action)
130 128
131 129 expected_data = {
132 130 'username': git_server.user.username,
133 131 'user_id': git_server.user.user_id,
134 132 'scm': 'git',
135 133 'repository': git_server.repo_name,
136 134 'make_lock': None,
137 135 'action': action,
138 136 'ip': '10.10.10.10',
139 137 'locked_by': [None, None],
140 138 'config': '',
141 139 'repo_store': store,
142 140 'server_url': None,
143 141 'hooks': ['push', 'pull'],
144 142 'is_shadow_repo': False,
145 'hooks_module': 'rhodecode.lib.hooks_daemon',
143 'hooks_module': 'rhodecode.lib.hook_daemon.hook_module',
146 144 'check_branch_perms': False,
147 145 'detect_force_push': False,
148 146 'user_agent': u'git/ssh-user-agent',
149 147 'SSH': True,
150 148 'SSH_PERMISSIONS': 'repository.admin',
151 149 }
152 150 args, kwargs = putenv_mock.call_args
153 151 assert json.loads(args[1]) == expected_data
@@ -1,118 +1,115 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import mock
21 21 import pytest
22 22
23 23 from rhodecode.apps.ssh_support.lib.backends.hg import MercurialServer
24 24 from rhodecode.apps.ssh_support.tests.conftest import plain_dummy_env, plain_dummy_user
25 25
26 26
27 27 class MercurialServerCreator(object):
28 28 root = '/tmp/repo/path/'
29 29 hg_path = '/usr/local/bin/hg'
30 30
31 31 config_data = {
32 32 'app:main': {
33 33 'ssh.executable.hg': hg_path,
34 34 'vcs.hooks.protocol': 'http',
35 35 }
36 36 }
37 37 repo_name = 'test_hg'
38 38 user = plain_dummy_user()
39 39
40 40 def __init__(self):
41 def config_get(part, key):
42 return self.config_data.get(part, {}).get(key)
43 self.config_mock = mock.Mock()
44 self.config_mock.get = mock.Mock(side_effect=config_get)
41 pass
45 42
46 43 def create(self, **kwargs):
47 44 parameters = {
48 45 'store': self.root,
49 46 'ini_path': '',
50 47 'user': self.user,
51 48 'repo_name': self.repo_name,
52 49 'user_permissions': {
53 50 'test_hg': 'repository.admin'
54 51 },
55 'config': self.config_mock,
52 'settings': self.config_data['app:main'],
56 53 'env': plain_dummy_env()
57 54 }
58 55 parameters.update(kwargs)
59 56 server = MercurialServer(**parameters)
60 57 return server
61 58
62 59
63 60 @pytest.fixture()
64 61 def hg_server(app):
65 62 return MercurialServerCreator()
66 63
67 64
68 65 class TestMercurialServer(object):
69 66
70 67 def test_command(self, hg_server, tmpdir):
71 68 server = hg_server.create()
72 69 custom_hgrc = os.path.join(str(tmpdir), 'hgrc')
73 70 expected_command = (
74 71 'cd {root}; HGRCPATH={custom_hgrc} {hg_path} -R {root}{repo_name} serve --stdio'.format(
75 72 root=hg_server.root, custom_hgrc=custom_hgrc, hg_path=hg_server.hg_path,
76 73 repo_name=hg_server.repo_name)
77 74 )
78 75 server_command = server.tunnel.command(custom_hgrc)
79 76 assert expected_command == server_command
80 77
81 78 @pytest.mark.parametrize('permissions, action, code', [
82 79 ({}, 'pull', -2),
83 80 ({'test_hg': 'repository.read'}, 'pull', 0),
84 81 ({'test_hg': 'repository.read'}, 'push', -2),
85 82 ({'test_hg': 'repository.write'}, 'push', 0),
86 83 ({'test_hg': 'repository.admin'}, 'push', 0),
87 84
88 85 ])
89 86 def test_permission_checks(self, hg_server, permissions, action, code):
90 87 server = hg_server.create(user_permissions=permissions)
91 88 result = server._check_permissions(action)
92 89 assert result is code
93 90
94 91 @pytest.mark.parametrize('permissions, value', [
95 92 ({}, False),
96 93 ({'test_hg': 'repository.read'}, False),
97 94 ({'test_hg': 'repository.write'}, True),
98 95 ({'test_hg': 'repository.admin'}, True),
99 96
100 97 ])
101 98 def test_has_write_permissions(self, hg_server, permissions, value):
102 99 server = hg_server.create(user_permissions=permissions)
103 100 result = server.has_write_perm()
104 101 assert result is value
105 102
106 103 def test_run_returns_executes_command(self, hg_server):
107 104 server = hg_server.create()
108 105 from rhodecode.apps.ssh_support.lib.backends.hg import MercurialTunnelWrapper
109 106 os.environ['SSH_CLIENT'] = '127.0.0.1'
110 107 with mock.patch.object(MercurialTunnelWrapper, 'create_hooks_env') as _patch:
111 108 _patch.return_value = 0
112 109 with mock.patch.object(MercurialTunnelWrapper, 'command', return_value='date'):
113 110 exit_code = server.run()
114 111
115 112 assert exit_code == (0, False)
116 113
117 114
118 115
@@ -1,205 +1,202 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18 import os
19 19 import mock
20 20 import pytest
21 21
22 22 from rhodecode.apps.ssh_support.lib.backends.svn import SubversionServer
23 23 from rhodecode.apps.ssh_support.tests.conftest import plain_dummy_env, plain_dummy_user
24 24
25 25
26 26 class SubversionServerCreator(object):
27 27 root = '/tmp/repo/path/'
28 28 svn_path = '/usr/local/bin/svnserve'
29 29 config_data = {
30 30 'app:main': {
31 31 'ssh.executable.svn': svn_path,
32 32 'vcs.hooks.protocol': 'http',
33 33 }
34 34 }
35 35 repo_name = 'test-svn'
36 36 user = plain_dummy_user()
37 37
38 38 def __init__(self):
39 def config_get(part, key):
40 return self.config_data.get(part, {}).get(key)
41 self.config_mock = mock.Mock()
42 self.config_mock.get = mock.Mock(side_effect=config_get)
39 pass
43 40
44 41 def create(self, **kwargs):
45 42 parameters = {
46 43 'store': self.root,
47 44 'repo_name': self.repo_name,
48 45 'ini_path': '',
49 46 'user': self.user,
50 47 'user_permissions': {
51 48 self.repo_name: 'repository.admin'
52 49 },
53 'config': self.config_mock,
50 'settings': self.config_data['app:main'],
54 51 'env': plain_dummy_env()
55 52 }
56 53
57 54 parameters.update(kwargs)
58 55 server = SubversionServer(**parameters)
59 56 return server
60 57
61 58
62 59 @pytest.fixture()
63 60 def svn_server(app):
64 61 return SubversionServerCreator()
65 62
66 63
67 64 class TestSubversionServer(object):
68 65 def test_command(self, svn_server):
69 66 server = svn_server.create()
70 67 expected_command = [
71 68 svn_server.svn_path, '-t',
72 69 '--config-file', server.tunnel.svn_conf_path,
73 70 '--tunnel-user', svn_server.user.username,
74 71 '-r', svn_server.root
75 72 ]
76 73
77 74 assert expected_command == server.tunnel.command()
78 75
79 76 @pytest.mark.parametrize('permissions, action, code', [
80 77 ({}, 'pull', -2),
81 78 ({'test-svn': 'repository.read'}, 'pull', 0),
82 79 ({'test-svn': 'repository.read'}, 'push', -2),
83 80 ({'test-svn': 'repository.write'}, 'push', 0),
84 81 ({'test-svn': 'repository.admin'}, 'push', 0),
85 82
86 83 ])
87 84 def test_permission_checks(self, svn_server, permissions, action, code):
88 85 server = svn_server.create(user_permissions=permissions)
89 86 result = server._check_permissions(action)
90 87 assert result is code
91 88
92 89 @pytest.mark.parametrize('permissions, access_paths, expected_match', [
93 90 # not matched repository name
94 91 ({
95 92 'test-svn': ''
96 93 }, ['test-svn-1', 'test-svn-1/subpath'],
97 94 None),
98 95
99 96 # exact match
100 97 ({
101 98 'test-svn': ''
102 99 },
103 100 ['test-svn'],
104 101 'test-svn'),
105 102
106 103 # subdir commits
107 104 ({
108 105 'test-svn': ''
109 106 },
110 107 ['test-svn/foo',
111 108 'test-svn/foo/test-svn',
112 109 'test-svn/trunk/development.txt',
113 110 ],
114 111 'test-svn'),
115 112
116 113 # subgroups + similar patterns
117 114 ({
118 115 'test-svn': '',
119 116 'test-svn-1': '',
120 117 'test-svn-subgroup/test-svn': '',
121 118
122 119 },
123 120 ['test-svn-1',
124 121 'test-svn-1/foo/test-svn',
125 122 'test-svn-1/test-svn',
126 123 ],
127 124 'test-svn-1'),
128 125
129 126 # subgroups + similar patterns
130 127 ({
131 128 'test-svn-1': '',
132 129 'test-svn-10': '',
133 130 'test-svn-100': '',
134 131 },
135 132 ['test-svn-10',
136 133 'test-svn-10/foo/test-svn',
137 134 'test-svn-10/test-svn',
138 135 ],
139 136 'test-svn-10'),
140 137
141 138 # subgroups + similar patterns
142 139 ({
143 140 'name': '',
144 141 'nameContains': '',
145 142 'nameContainsThis': '',
146 143 },
147 144 ['nameContains',
148 145 'nameContains/This',
149 146 'nameContains/This/test-svn',
150 147 ],
151 148 'nameContains'),
152 149
153 150 # subgroups + similar patterns
154 151 ({
155 152 'test-svn': '',
156 153 'test-svn-1': '',
157 154 'test-svn-subgroup/test-svn': '',
158 155
159 156 },
160 157 ['test-svn-subgroup/test-svn',
161 158 'test-svn-subgroup/test-svn/foo/test-svn',
162 159 'test-svn-subgroup/test-svn/trunk/example.txt',
163 160 ],
164 161 'test-svn-subgroup/test-svn'),
165 162 ])
166 163 def test_repo_extraction_on_subdir(self, svn_server, permissions, access_paths, expected_match):
167 164 server = svn_server.create(user_permissions=permissions)
168 165 for path in access_paths:
169 166 repo_name = server.tunnel._match_repo_name(path)
170 167 assert repo_name == expected_match
171 168
172 169 def test_run_returns_executes_command(self, svn_server):
173 170 server = svn_server.create()
174 171 from rhodecode.apps.ssh_support.lib.backends.svn import SubversionTunnelWrapper
175 172 os.environ['SSH_CLIENT'] = '127.0.0.1'
176 173 with mock.patch.object(
177 174 SubversionTunnelWrapper, 'get_first_client_response',
178 175 return_value={'url': 'http://server/test-svn'}):
179 176 with mock.patch.object(
180 177 SubversionTunnelWrapper, 'patch_first_client_response',
181 178 return_value=0):
182 179 with mock.patch.object(
183 180 SubversionTunnelWrapper, 'sync',
184 181 return_value=0):
185 182 with mock.patch.object(
186 183 SubversionTunnelWrapper, 'command',
187 184 return_value=['date']):
188 185
189 186 exit_code = server.run()
190 187 # SVN has this differently configured, and we get in our mock env
191 188 # None as return code
192 189 assert exit_code == (None, False)
193 190
194 191 def test_run_returns_executes_command_that_cannot_extract_repo_name(self, svn_server):
195 192 server = svn_server.create()
196 193 from rhodecode.apps.ssh_support.lib.backends.svn import SubversionTunnelWrapper
197 194 with mock.patch.object(
198 195 SubversionTunnelWrapper, 'command',
199 196 return_value=['date']):
200 197 with mock.patch.object(
201 198 SubversionTunnelWrapper, 'get_first_client_response',
202 199 return_value=None):
203 200 exit_code = server.run()
204 201
205 202 assert exit_code == (1, False)
@@ -1,52 +1,48 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import pytest
20 20
21 21
22 22 class TestSSHWrapper(object):
23 23
24 24 def test_serve_raises_an_exception_when_vcs_is_not_recognized(self, ssh_wrapper):
25 25 with pytest.raises(Exception) as exc_info:
26 26 ssh_wrapper.serve(
27 27 vcs='microsoft-tfs', repo='test-repo', mode=None, user='test',
28 28 permissions={}, branch_permissions={})
29 29 assert str(exc_info.value) == 'Unrecognised VCS: microsoft-tfs'
30 30
31 def test_parse_config(self, ssh_wrapper):
32 config = ssh_wrapper.parse_config(ssh_wrapper.ini_path)
33 assert config
34
35 31 def test_get_connection_info(self, ssh_wrapper):
36 32 conn_info = ssh_wrapper.get_connection_info()
37 33 assert {'client_ip': '127.0.0.1',
38 34 'client_port': '22',
39 35 'server_ip': '10.0.0.1',
40 36 'server_port': '443'} == conn_info
41 37
42 38 @pytest.mark.parametrize('command, vcs', [
43 39 ('xxx', None),
44 40 ('svnserve -t', 'svn'),
45 41 ('hg -R repo serve --stdio', 'hg'),
46 42 ('git-receive-pack \'repo.git\'', 'git'),
47 43
48 44 ])
49 45 def test_get_repo_details(self, ssh_wrapper, command, vcs):
50 46 ssh_wrapper.command = command
51 47 vcs_type, repo_name, mode = ssh_wrapper.get_repo_details(mode='auto')
52 48 assert vcs_type == vcs
@@ -1,637 +1,466 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import sys
21 21 import collections
22 import tempfile
22
23 23 import time
24 24 import logging.config
25 25
26 26 from paste.gzipper import make_gzip_middleware
27 27 import pyramid.events
28 28 from pyramid.wsgi import wsgiapp
29 29 from pyramid.config import Configurator
30 30 from pyramid.settings import asbool, aslist
31 31 from pyramid.httpexceptions import (
32 32 HTTPException, HTTPError, HTTPInternalServerError, HTTPFound, HTTPNotFound)
33 33 from pyramid.renderers import render_to_response
34 34
35 from rhodecode import api
36 35 from rhodecode.model import meta
37 36 from rhodecode.config import patches
38 from rhodecode.config import utils as config_utils
39 from rhodecode.config.settings_maker import SettingsMaker
37
40 38 from rhodecode.config.environment import load_pyramid_environment
41 39
42 40 import rhodecode.events
41 from rhodecode.config.config_maker import sanitize_settings_and_apply_defaults
43 42 from rhodecode.lib.middleware.vcs import VCSMiddleware
44 43 from rhodecode.lib.request import Request
45 44 from rhodecode.lib.vcs import VCSCommunicationError
46 45 from rhodecode.lib.exceptions import VCSServerUnavailable
47 46 from rhodecode.lib.middleware.appenlight import wrap_in_appenlight_if_enabled
48 47 from rhodecode.lib.middleware.https_fixup import HttpsFixup
49 48 from rhodecode.lib.plugins.utils import register_rhodecode_plugin
50 49 from rhodecode.lib.utils2 import AttributeDict
51 50 from rhodecode.lib.exc_tracking import store_exception, format_exc
52 51 from rhodecode.subscribers import (
53 52 scan_repositories_if_enabled, write_js_routes_if_enabled,
54 53 write_metadata_if_needed, write_usage_data)
55 54 from rhodecode.lib.statsd_client import StatsdClient
56 55
57 56 log = logging.getLogger(__name__)
58 57
59 58
60 59 def is_http_error(response):
61 60 # error which should have traceback
62 61 return response.status_code > 499
63 62
64 63
65 64 def should_load_all():
66 65 """
67 66 Returns if all application components should be loaded. In some cases it's
68 67 desired to skip apps loading for faster shell script execution
69 68 """
70 69 ssh_cmd = os.environ.get('RC_CMD_SSH_WRAPPER')
71 70 if ssh_cmd:
72 71 return False
73 72
74 73 return True
75 74
76 75
77 76 def make_pyramid_app(global_config, **settings):
78 77 """
79 78 Constructs the WSGI application based on Pyramid.
80 79
81 80 Specials:
82 81
83 82 * The application can also be integrated like a plugin via the call to
84 83 `includeme`. This is accompanied with the other utility functions which
85 84 are called. Changing this should be done with great care to not break
86 85 cases when these fragments are assembled from another place.
87 86
88 87 """
89 88 start_time = time.time()
90 89 log.info('Pyramid app config starting')
91 90
92 91 sanitize_settings_and_apply_defaults(global_config, settings)
93 92
94 93 # init and bootstrap StatsdClient
95 94 StatsdClient.setup(settings)
96 95
97 96 config = Configurator(settings=settings)
98 97 # Init our statsd at very start
99 98 config.registry.statsd = StatsdClient.statsd
100 99
101 100 # Apply compatibility patches
102 101 patches.inspect_getargspec()
103 102
104 103 load_pyramid_environment(global_config, settings)
105 104
106 105 # Static file view comes first
107 106 includeme_first(config)
108 107
109 108 includeme(config)
110 109
111 110 pyramid_app = config.make_wsgi_app()
112 111 pyramid_app = wrap_app_in_wsgi_middlewares(pyramid_app, config)
113 112 pyramid_app.config = config
114 113
115 114 celery_settings = get_celery_config(settings)
116 115 config.configure_celery(celery_settings)
117 116
118 117 # creating the app uses a connection - return it after we are done
119 118 meta.Session.remove()
120 119
121 120 total_time = time.time() - start_time
122 121 log.info('Pyramid app created and configured in %.2fs', total_time)
123 122 return pyramid_app
124 123
125 124
126 125 def get_celery_config(settings):
127 126 """
128 127 Converts basic ini configuration into celery 4.X options
129 128 """
130 129
131 130 def key_converter(key_name):
132 131 pref = 'celery.'
133 132 if key_name.startswith(pref):
134 133 return key_name[len(pref):].replace('.', '_').lower()
135 134
136 135 def type_converter(parsed_key, value):
137 136 # cast to int
138 137 if value.isdigit():
139 138 return int(value)
140 139
141 140 # cast to bool
142 141 if value.lower() in ['true', 'false', 'True', 'False']:
143 142 return value.lower() == 'true'
144 143 return value
145 144
146 145 celery_config = {}
147 146 for k, v in settings.items():
148 147 pref = 'celery.'
149 148 if k.startswith(pref):
150 149 celery_config[key_converter(k)] = type_converter(key_converter(k), v)
151 150
152 151 # TODO:rethink if we want to support celerybeat based file config, probably NOT
153 152 # beat_config = {}
154 153 # for section in parser.sections():
155 154 # if section.startswith('celerybeat:'):
156 155 # name = section.split(':', 1)[1]
157 156 # beat_config[name] = get_beat_config(parser, section)
158 157
159 158 # final compose of settings
160 159 celery_settings = {}
161 160
162 161 if celery_config:
163 162 celery_settings.update(celery_config)
164 163 # if beat_config:
165 164 # celery_settings.update({'beat_schedule': beat_config})
166 165
167 166 return celery_settings
168 167
169 168
170 169 def not_found_view(request):
171 170 """
172 171 This creates the view which should be registered as not-found-view to
173 172 pyramid.
174 173 """
175 174
176 175 if not getattr(request, 'vcs_call', None):
177 176 # handle like regular case with our error_handler
178 177 return error_handler(HTTPNotFound(), request)
179 178
180 179 # handle not found view as a vcs call
181 180 settings = request.registry.settings
182 181 ae_client = getattr(request, 'ae_client', None)
183 182 vcs_app = VCSMiddleware(
184 183 HTTPNotFound(), request.registry, settings,
185 184 appenlight_client=ae_client)
186 185
187 186 return wsgiapp(vcs_app)(None, request)
188 187
189 188
190 189 def error_handler(exception, request):
191 190 import rhodecode
192 191 from rhodecode.lib import helpers
193 192
194 193 rhodecode_title = rhodecode.CONFIG.get('rhodecode_title') or 'RhodeCode'
195 194
196 195 base_response = HTTPInternalServerError()
197 196 # prefer original exception for the response since it may have headers set
198 197 if isinstance(exception, HTTPException):
199 198 base_response = exception
200 199 elif isinstance(exception, VCSCommunicationError):
201 200 base_response = VCSServerUnavailable()
202 201
203 202 if is_http_error(base_response):
204 203 traceback_info = format_exc(request.exc_info)
205 204 log.error(
206 205 'error occurred handling this request for path: %s, \n%s',
207 206 request.path, traceback_info)
208 207
209 208 error_explanation = base_response.explanation or str(base_response)
210 209 if base_response.status_code == 404:
211 210 error_explanation += " Optionally you don't have permission to access this page."
212 211 c = AttributeDict()
213 212 c.error_message = base_response.status
214 213 c.error_explanation = error_explanation
215 214 c.visual = AttributeDict()
216 215
217 216 c.visual.rhodecode_support_url = (
218 217 request.registry.settings.get('rhodecode_support_url') or
219 218 request.route_url('rhodecode_support')
220 219 )
221 220 c.redirect_time = 0
222 221 c.rhodecode_name = rhodecode_title
223 222 if not c.rhodecode_name:
224 223 c.rhodecode_name = 'Rhodecode'
225 224
226 225 c.causes = []
227 226 if is_http_error(base_response):
228 227 c.causes.append('Server is overloaded.')
229 228 c.causes.append('Server database connection is lost.')
230 229 c.causes.append('Server expected unhandled error.')
231 230
232 231 if hasattr(base_response, 'causes'):
233 232 c.causes = base_response.causes
234 233
235 234 c.messages = helpers.flash.pop_messages(request=request)
236 235 exc_info = sys.exc_info()
237 236 c.exception_id = id(exc_info)
238 237 c.show_exception_id = isinstance(base_response, VCSServerUnavailable) \
239 238 or base_response.status_code > 499
240 239 c.exception_id_url = request.route_url(
241 240 'admin_settings_exception_tracker_show', exception_id=c.exception_id)
242 241
243 242 debug_mode = rhodecode.ConfigGet().get_bool('debug')
244 243 if c.show_exception_id:
245 244 store_exception(c.exception_id, exc_info)
246 245 c.exception_debug = debug_mode
247 246 c.exception_config_ini = rhodecode.CONFIG.get('__file__')
248 247
249 248 if debug_mode:
250 249 try:
251 250 from rich.traceback import install
252 251 install(show_locals=True)
253 252 log.debug('Installing rich tracebacks...')
254 253 except ImportError:
255 254 pass
256 255
257 256 response = render_to_response(
258 257 '/errors/error_document.mako', {'c': c, 'h': helpers}, request=request,
259 258 response=base_response)
260 259
261 260 response.headers["X-RC-Exception-Id"] = str(c.exception_id)
262 261
263 262 statsd = request.registry.statsd
264 263 if statsd and base_response.status_code > 499:
265 264 exc_type = f"{exception.__class__.__module__}.{exception.__class__.__name__}"
266 265 statsd.incr('rhodecode_exception_total',
267 266 tags=["exc_source:web",
268 267 f"http_code:{base_response.status_code}",
269 268 f"type:{exc_type}"])
270 269
271 270 return response
272 271
273 272
274 273 def includeme_first(config):
275 274 # redirect automatic browser favicon.ico requests to correct place
276 275 def favicon_redirect(context, request):
277 276 return HTTPFound(
278 277 request.static_path('rhodecode:public/images/favicon.ico'))
279 278
280 279 config.add_view(favicon_redirect, route_name='favicon')
281 280 config.add_route('favicon', '/favicon.ico')
282 281
283 282 def robots_redirect(context, request):
284 283 return HTTPFound(
285 284 request.static_path('rhodecode:public/robots.txt'))
286 285
287 286 config.add_view(robots_redirect, route_name='robots')
288 287 config.add_route('robots', '/robots.txt')
289 288
290 289 config.add_static_view(
291 290 '_static/deform', 'deform:static')
292 291 config.add_static_view(
293 292 '_static/rhodecode', path='rhodecode:public', cache_max_age=3600 * 24)
294 293
295 294
296 295 ce_auth_resources = [
297 296 'rhodecode.authentication.plugins.auth_crowd',
298 297 'rhodecode.authentication.plugins.auth_headers',
299 298 'rhodecode.authentication.plugins.auth_jasig_cas',
300 299 'rhodecode.authentication.plugins.auth_ldap',
301 300 'rhodecode.authentication.plugins.auth_pam',
302 301 'rhodecode.authentication.plugins.auth_rhodecode',
303 302 'rhodecode.authentication.plugins.auth_token',
304 303 ]
305 304
306 305
307 306 def includeme(config, auth_resources=None):
308 307 from rhodecode.lib.celerylib.loader import configure_celery
309 308 log.debug('Initializing main includeme from %s', os.path.basename(__file__))
310 309 settings = config.registry.settings
311 310 config.set_request_factory(Request)
312 311
313 312 # plugin information
314 313 config.registry.rhodecode_plugins = collections.OrderedDict()
315 314
316 315 config.add_directive(
317 316 'register_rhodecode_plugin', register_rhodecode_plugin)
318 317
319 318 config.add_directive('configure_celery', configure_celery)
320 319
321 320 if settings.get('appenlight', False):
322 321 config.include('appenlight_client.ext.pyramid_tween')
323 322
324 323 load_all = should_load_all()
325 324
326 325 # Includes which are required. The application would fail without them.
327 326 config.include('pyramid_mako')
328 327 config.include('rhodecode.lib.rc_beaker')
329 328 config.include('rhodecode.lib.rc_cache')
330 329 config.include('rhodecode.lib.rc_cache.archive_cache')
331 330
332 331 config.include('rhodecode.apps._base.navigation')
333 332 config.include('rhodecode.apps._base.subscribers')
334 333 config.include('rhodecode.tweens')
335 334 config.include('rhodecode.authentication')
336 335
337 336 if load_all:
338 337
339 338 # load CE authentication plugins
340 339
341 340 if auth_resources:
342 341 ce_auth_resources.extend(auth_resources)
343 342
344 343 for resource in ce_auth_resources:
345 344 config.include(resource)
346 345
347 346 # Auto discover authentication plugins and include their configuration.
348 347 if asbool(settings.get('auth_plugin.import_legacy_plugins', 'true')):
349 348 from rhodecode.authentication import discover_legacy_plugins
350 349 discover_legacy_plugins(config)
351 350
352 351 # apps
353 352 if load_all:
354 353 log.debug('Starting config.include() calls')
355 354 config.include('rhodecode.api.includeme')
356 355 config.include('rhodecode.apps._base.includeme')
357 356 config.include('rhodecode.apps._base.navigation.includeme')
358 357 config.include('rhodecode.apps._base.subscribers.includeme')
359 358 config.include('rhodecode.apps.hovercards.includeme')
360 359 config.include('rhodecode.apps.ops.includeme')
361 360 config.include('rhodecode.apps.channelstream.includeme')
362 361 config.include('rhodecode.apps.file_store.includeme')
363 362 config.include('rhodecode.apps.admin.includeme')
364 363 config.include('rhodecode.apps.login.includeme')
365 364 config.include('rhodecode.apps.home.includeme')
366 365 config.include('rhodecode.apps.journal.includeme')
367 366
368 367 config.include('rhodecode.apps.repository.includeme')
369 368 config.include('rhodecode.apps.repo_group.includeme')
370 369 config.include('rhodecode.apps.user_group.includeme')
371 370 config.include('rhodecode.apps.search.includeme')
372 371 config.include('rhodecode.apps.user_profile.includeme')
373 372 config.include('rhodecode.apps.user_group_profile.includeme')
374 373 config.include('rhodecode.apps.my_account.includeme')
375 374 config.include('rhodecode.apps.gist.includeme')
376 375
377 376 config.include('rhodecode.apps.svn_support.includeme')
378 377 config.include('rhodecode.apps.ssh_support.includeme')
379 378 config.include('rhodecode.apps.debug_style')
380 379
381 380 if load_all:
382 381 config.include('rhodecode.integrations.includeme')
383 382 config.include('rhodecode.integrations.routes.includeme')
384 383
385 384 config.add_route('rhodecode_support', 'https://rhodecode.com/help/', static=True)
386 385 settings['default_locale_name'] = settings.get('lang', 'en')
387 386 config.add_translation_dirs('rhodecode:i18n/')
388 387
389 388 # Add subscribers.
390 389 if load_all:
391 390 log.debug('Adding subscribers...')
392 391 config.add_subscriber(scan_repositories_if_enabled,
393 392 pyramid.events.ApplicationCreated)
394 393 config.add_subscriber(write_metadata_if_needed,
395 394 pyramid.events.ApplicationCreated)
396 395 config.add_subscriber(write_usage_data,
397 396 pyramid.events.ApplicationCreated)
398 397 config.add_subscriber(write_js_routes_if_enabled,
399 398 pyramid.events.ApplicationCreated)
400 399
401 400
402 401 # Set the default renderer for HTML templates to mako.
403 402 config.add_mako_renderer('.html')
404 403
405 404 config.add_renderer(
406 405 name='json_ext',
407 406 factory='rhodecode.lib.ext_json_renderer.pyramid_ext_json')
408 407
409 408 config.add_renderer(
410 409 name='string_html',
411 410 factory='rhodecode.lib.string_renderer.html')
412 411
413 412 # include RhodeCode plugins
414 413 includes = aslist(settings.get('rhodecode.includes', []))
415 414 log.debug('processing rhodecode.includes data...')
416 415 for inc in includes:
417 416 config.include(inc)
418 417
419 418 # custom not found view, if our pyramid app doesn't know how to handle
420 419 # the request pass it to potential VCS handling ap
421 420 config.add_notfound_view(not_found_view)
422 421 if not settings.get('debugtoolbar.enabled', False):
423 422 # disabled debugtoolbar handle all exceptions via the error_handlers
424 423 config.add_view(error_handler, context=Exception)
425 424
426 425 # all errors including 403/404/50X
427 426 config.add_view(error_handler, context=HTTPError)
428 427
429 428
430 429 def wrap_app_in_wsgi_middlewares(pyramid_app, config):
431 430 """
432 431 Apply outer WSGI middlewares around the application.
433 432 """
434 433 registry = config.registry
435 434 settings = registry.settings
436 435
437 436 # enable https redirects based on HTTP_X_URL_SCHEME set by proxy
438 437 pyramid_app = HttpsFixup(pyramid_app, settings)
439 438
440 439 pyramid_app, _ae_client = wrap_in_appenlight_if_enabled(
441 440 pyramid_app, settings)
442 441 registry.ae_client = _ae_client
443 442
444 443 if settings['gzip_responses']:
445 444 pyramid_app = make_gzip_middleware(
446 445 pyramid_app, settings, compress_level=1)
447 446
448 447 # this should be the outer most middleware in the wsgi stack since
449 448 # middleware like Routes make database calls
450 449 def pyramid_app_with_cleanup(environ, start_response):
451 450 start = time.time()
452 451 try:
453 452 return pyramid_app(environ, start_response)
454 453 finally:
455 454 # Dispose current database session and rollback uncommitted
456 455 # transactions.
457 456 meta.Session.remove()
458 457
459 458 # In a single threaded mode server, on non sqlite db we should have
460 459 # '0 Current Checked out connections' at the end of a request,
461 460 # if not, then something, somewhere is leaving a connection open
462 461 pool = meta.get_engine().pool
463 462 log.debug('sa pool status: %s', pool.status())
464 463 total = time.time() - start
465 464 log.debug('Request processing finalized: %.4fs', total)
466 465
467 466 return pyramid_app_with_cleanup
468
469
470 def sanitize_settings_and_apply_defaults(global_config, settings):
471 """
472 Applies settings defaults and does all type conversion.
473
474 We would move all settings parsing and preparation into this place, so that
475 we have only one place left which deals with this part. The remaining parts
476 of the application would start to rely fully on well prepared settings.
477
478 This piece would later be split up per topic to avoid a big fat monster
479 function.
480 """
481 jn = os.path.join
482
483 global_settings_maker = SettingsMaker(global_config)
484 global_settings_maker.make_setting('debug', default=False, parser='bool')
485 debug_enabled = asbool(global_config.get('debug'))
486
487 settings_maker = SettingsMaker(settings)
488
489 settings_maker.make_setting(
490 'logging.autoconfigure',
491 default=False,
492 parser='bool')
493
494 logging_conf = jn(os.path.dirname(global_config.get('__file__')), 'logging.ini')
495 settings_maker.enable_logging(logging_conf, level='INFO' if debug_enabled else 'DEBUG')
496
497 # Default includes, possible to change as a user
498 pyramid_includes = settings_maker.make_setting('pyramid.includes', [], parser='list:newline')
499 log.debug(
500 "Using the following pyramid.includes: %s",
501 pyramid_includes)
502
503 settings_maker.make_setting('rhodecode.edition', 'Community Edition')
504 settings_maker.make_setting('rhodecode.edition_id', 'CE')
505
506 if 'mako.default_filters' not in settings:
507 # set custom default filters if we don't have it defined
508 settings['mako.imports'] = 'from rhodecode.lib.base import h_filter'
509 settings['mako.default_filters'] = 'h_filter'
510
511 if 'mako.directories' not in settings:
512 mako_directories = settings.setdefault('mako.directories', [
513 # Base templates of the original application
514 'rhodecode:templates',
515 ])
516 log.debug(
517 "Using the following Mako template directories: %s",
518 mako_directories)
519
520 # NOTE(marcink): fix redis requirement for schema of connection since 3.X
521 if 'beaker.session.type' in settings and settings['beaker.session.type'] == 'ext:redis':
522 raw_url = settings['beaker.session.url']
523 if not raw_url.startswith(('redis://', 'rediss://', 'unix://')):
524 settings['beaker.session.url'] = 'redis://' + raw_url
525
526 settings_maker.make_setting('__file__', global_config.get('__file__'))
527
528 # TODO: johbo: Re-think this, usually the call to config.include
529 # should allow to pass in a prefix.
530 settings_maker.make_setting('rhodecode.api.url', api.DEFAULT_URL)
531
532 # Sanitize generic settings.
533 settings_maker.make_setting('default_encoding', 'UTF-8', parser='list')
534 settings_maker.make_setting('is_test', False, parser='bool')
535 settings_maker.make_setting('gzip_responses', False, parser='bool')
536
537 # statsd
538 settings_maker.make_setting('statsd.enabled', False, parser='bool')
539 settings_maker.make_setting('statsd.statsd_host', 'statsd-exporter', parser='string')
540 settings_maker.make_setting('statsd.statsd_port', 9125, parser='int')
541 settings_maker.make_setting('statsd.statsd_prefix', '')
542 settings_maker.make_setting('statsd.statsd_ipv6', False, parser='bool')
543
544 settings_maker.make_setting('vcs.svn.compatible_version', '')
545 settings_maker.make_setting('vcs.hooks.protocol', 'http')
546 settings_maker.make_setting('vcs.hooks.host', '*')
547 settings_maker.make_setting('vcs.scm_app_implementation', 'http')
548 settings_maker.make_setting('vcs.server', '')
549 settings_maker.make_setting('vcs.server.protocol', 'http')
550 settings_maker.make_setting('vcs.server.enable', 'true', parser='bool')
551 settings_maker.make_setting('startup.import_repos', 'false', parser='bool')
552 settings_maker.make_setting('vcs.hooks.direct_calls', 'false', parser='bool')
553 settings_maker.make_setting('vcs.start_server', 'false', parser='bool')
554 settings_maker.make_setting('vcs.backends', 'hg, git, svn', parser='list')
555 settings_maker.make_setting('vcs.connection_timeout', 3600, parser='int')
556
557 settings_maker.make_setting('vcs.methods.cache', True, parser='bool')
558
559 # Support legacy values of vcs.scm_app_implementation. Legacy
560 # configurations may use 'rhodecode.lib.middleware.utils.scm_app_http', or
561 # disabled since 4.13 'vcsserver.scm_app' which is now mapped to 'http'.
562 scm_app_impl = settings['vcs.scm_app_implementation']
563 if scm_app_impl in ['rhodecode.lib.middleware.utils.scm_app_http', 'vcsserver.scm_app']:
564 settings['vcs.scm_app_implementation'] = 'http'
565
566 settings_maker.make_setting('appenlight', False, parser='bool')
567
568 temp_store = tempfile.gettempdir()
569 tmp_cache_dir = jn(temp_store, 'rc_cache')
570
571 # save default, cache dir, and use it for all backends later.
572 default_cache_dir = settings_maker.make_setting(
573 'cache_dir',
574 default=tmp_cache_dir, default_when_empty=True,
575 parser='dir:ensured')
576
577 # exception store cache
578 settings_maker.make_setting(
579 'exception_tracker.store_path',
580 default=jn(default_cache_dir, 'exc_store'), default_when_empty=True,
581 parser='dir:ensured'
582 )
583
584 settings_maker.make_setting(
585 'celerybeat-schedule.path',
586 default=jn(default_cache_dir, 'celerybeat_schedule', 'celerybeat-schedule.db'), default_when_empty=True,
587 parser='file:ensured'
588 )
589
590 settings_maker.make_setting('exception_tracker.send_email', False, parser='bool')
591 settings_maker.make_setting('exception_tracker.email_prefix', '[RHODECODE ERROR]', default_when_empty=True)
592
593 # sessions, ensure file since no-value is memory
594 settings_maker.make_setting('beaker.session.type', 'file')
595 settings_maker.make_setting('beaker.session.data_dir', jn(default_cache_dir, 'session_data'))
596
597 # cache_general
598 settings_maker.make_setting('rc_cache.cache_general.backend', 'dogpile.cache.rc.file_namespace')
599 settings_maker.make_setting('rc_cache.cache_general.expiration_time', 60 * 60 * 12, parser='int')
600 settings_maker.make_setting('rc_cache.cache_general.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_general.db'))
601
602 # cache_perms
603 settings_maker.make_setting('rc_cache.cache_perms.backend', 'dogpile.cache.rc.file_namespace')
604 settings_maker.make_setting('rc_cache.cache_perms.expiration_time', 60 * 60, parser='int')
605 settings_maker.make_setting('rc_cache.cache_perms.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_perms_db'))
606
607 # cache_repo
608 settings_maker.make_setting('rc_cache.cache_repo.backend', 'dogpile.cache.rc.file_namespace')
609 settings_maker.make_setting('rc_cache.cache_repo.expiration_time', 60 * 60 * 24 * 30, parser='int')
610 settings_maker.make_setting('rc_cache.cache_repo.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_repo_db'))
611
612 # cache_license
613 settings_maker.make_setting('rc_cache.cache_license.backend', 'dogpile.cache.rc.file_namespace')
614 settings_maker.make_setting('rc_cache.cache_license.expiration_time', 60 * 5, parser='int')
615 settings_maker.make_setting('rc_cache.cache_license.arguments.filename', jn(default_cache_dir, 'rhodecode_cache_license_db'))
616
617 # cache_repo_longterm memory, 96H
618 settings_maker.make_setting('rc_cache.cache_repo_longterm.backend', 'dogpile.cache.rc.memory_lru')
619 settings_maker.make_setting('rc_cache.cache_repo_longterm.expiration_time', 345600, parser='int')
620 settings_maker.make_setting('rc_cache.cache_repo_longterm.max_size', 10000, parser='int')
621
622 # sql_cache_short
623 settings_maker.make_setting('rc_cache.sql_cache_short.backend', 'dogpile.cache.rc.memory_lru')
624 settings_maker.make_setting('rc_cache.sql_cache_short.expiration_time', 30, parser='int')
625 settings_maker.make_setting('rc_cache.sql_cache_short.max_size', 10000, parser='int')
626
627 # archive_cache
628 settings_maker.make_setting('archive_cache.store_dir', jn(default_cache_dir, 'archive_cache'), default_when_empty=True,)
629 settings_maker.make_setting('archive_cache.cache_size_gb', 10, parser='float')
630 settings_maker.make_setting('archive_cache.cache_shards', 10, parser='int')
631
632 settings_maker.env_expand()
633
634 # configure instance id
635 config_utils.set_instance_id(settings)
636
637 return settings
@@ -1,117 +1,116 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import platform
21 21
22 from rhodecode.model import init_model
23
24 22
25 23 def configure_vcs(config):
26 24 """
27 25 Patch VCS config with some RhodeCode specific stuff
28 26 """
29 27 from rhodecode.lib.vcs import conf
30 28 import rhodecode.lib.vcs.conf.settings
31 29
32 30 conf.settings.BACKENDS = {
33 31 'hg': 'rhodecode.lib.vcs.backends.hg.MercurialRepository',
34 32 'git': 'rhodecode.lib.vcs.backends.git.GitRepository',
35 33 'svn': 'rhodecode.lib.vcs.backends.svn.SubversionRepository',
36 34 }
37 35
38 36 conf.settings.HOOKS_PROTOCOL = config['vcs.hooks.protocol']
39 37 conf.settings.HOOKS_HOST = config['vcs.hooks.host']
40 38 conf.settings.DEFAULT_ENCODINGS = config['default_encoding']
41 39 conf.settings.ALIASES[:] = config['vcs.backends']
42 40 conf.settings.SVN_COMPATIBLE_VERSION = config['vcs.svn.compatible_version']
43 41
44 42
45 43 def initialize_database(config):
46 44 from rhodecode.lib.utils2 import engine_from_config, get_encryption_key
45 from rhodecode.model import init_model
47 46 engine = engine_from_config(config, 'sqlalchemy.db1.')
48 47 init_model(engine, encryption_key=get_encryption_key(config))
49 48
50 49
51 50 def initialize_test_environment(settings, test_env=None):
52 51 if test_env is None:
53 52 test_env = not int(os.environ.get('RC_NO_TMP_PATH', 0))
54 53
55 54 from rhodecode.lib.utils import (
56 55 create_test_directory, create_test_database, create_test_repositories,
57 56 create_test_index)
58 57 from rhodecode.tests import TESTS_TMP_PATH
59 58 from rhodecode.lib.vcs.backends.hg import largefiles_store
60 59 from rhodecode.lib.vcs.backends.git import lfs_store
61 60
62 61 # test repos
63 62 if test_env:
64 63 create_test_directory(TESTS_TMP_PATH)
65 64 # large object stores
66 65 create_test_directory(largefiles_store(TESTS_TMP_PATH))
67 66 create_test_directory(lfs_store(TESTS_TMP_PATH))
68 67
69 68 create_test_database(TESTS_TMP_PATH, settings)
70 69 create_test_repositories(TESTS_TMP_PATH, settings)
71 70 create_test_index(TESTS_TMP_PATH, settings)
72 71
73 72
74 73 def get_vcs_server_protocol(config):
75 74 return config['vcs.server.protocol']
76 75
77 76
78 77 def set_instance_id(config):
79 78 """
80 79 Sets a dynamic generated config['instance_id'] if missing or '*'
81 80 E.g instance_id = *cluster-1 or instance_id = *
82 81 """
83 82
84 83 config['instance_id'] = config.get('instance_id') or ''
85 84 instance_id = config['instance_id']
86 85 if instance_id.startswith('*') or not instance_id:
87 86 prefix = instance_id.lstrip('*')
88 87 _platform_id = platform.uname()[1] or 'instance'
89 88 config['instance_id'] = '{prefix}uname:{platform}-pid:{pid}'.format(
90 89 prefix=prefix,
91 90 platform=_platform_id,
92 91 pid=os.getpid())
93 92
94 93
95 94 def get_default_user_id():
96 95 DEFAULT_USER = 'default'
97 96 from sqlalchemy import text
98 97 from rhodecode.model import meta
99 98
100 99 engine = meta.get_engine()
101 100 with meta.SA_Session(engine) as session:
102 101 result = session.execute(text("SELECT user_id from users where username = :uname"), {'uname': DEFAULT_USER})
103 102 user_id = result.first()[0]
104 103
105 104 return user_id
106 105
107 106
108 107 def get_default_base_path():
109 108 from sqlalchemy import text
110 109 from rhodecode.model import meta
111 110
112 111 engine = meta.get_engine()
113 112 with meta.SA_Session(engine) as session:
114 113 result = session.execute(text("SELECT ui_value from rhodecode_ui where ui_key = '/'"))
115 114 base_path = result.first()[0]
116 115
117 116 return base_path
@@ -1,609 +1,609 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 """
20 20 The base Controller API
21 21 Provides the BaseController class for subclassing. And usage in different
22 22 controllers
23 23 """
24 24
25 25 import logging
26 26 import socket
27 27 import base64
28 28
29 29 import markupsafe
30 30 import ipaddress
31 31
32 32 import paste.httpheaders
33 33 from paste.auth.basic import AuthBasicAuthenticator
34 34 from paste.httpexceptions import HTTPUnauthorized, HTTPForbidden, get_exception
35 35
36 36 import rhodecode
37 37 from rhodecode.authentication.base import VCS_TYPE
38 38 from rhodecode.lib import auth, utils2
39 39 from rhodecode.lib import helpers as h
40 40 from rhodecode.lib.auth import AuthUser, CookieStoreWrapper
41 41 from rhodecode.lib.exceptions import UserCreationError
42 42 from rhodecode.lib.utils import (password_changed, get_enabled_hook_classes)
43 43 from rhodecode.lib.utils2 import AttributeDict
44 44 from rhodecode.lib.str_utils import ascii_bytes, safe_int, safe_str
45 45 from rhodecode.lib.type_utils import aslist, str2bool
46 46 from rhodecode.lib.hash_utils import sha1
47 47 from rhodecode.model.db import Repository, User, ChangesetComment, UserBookmark
48 48 from rhodecode.model.notification import NotificationModel
49 49 from rhodecode.model.settings import VcsSettingsModel, SettingsModel
50 50
51 51 log = logging.getLogger(__name__)
52 52
53 53
54 54 def _filter_proxy(ip):
55 55 """
56 56 Passed in IP addresses in HEADERS can be in a special format of multiple
57 57 ips. Those comma separated IPs are passed from various proxies in the
58 58 chain of request processing. The left-most being the original client.
59 59 We only care about the first IP which came from the org. client.
60 60
61 61 :param ip: ip string from headers
62 62 """
63 63 if ',' in ip:
64 64 _ips = ip.split(',')
65 65 _first_ip = _ips[0].strip()
66 66 log.debug('Got multiple IPs %s, using %s', ','.join(_ips), _first_ip)
67 67 return _first_ip
68 68 return ip
69 69
70 70
71 71 def _filter_port(ip):
72 72 """
73 73 Removes a port from ip, there are 4 main cases to handle here.
74 74 - ipv4 eg. 127.0.0.1
75 75 - ipv6 eg. ::1
76 76 - ipv4+port eg. 127.0.0.1:8080
77 77 - ipv6+port eg. [::1]:8080
78 78
79 79 :param ip:
80 80 """
81 81 def is_ipv6(ip_addr):
82 82 if hasattr(socket, 'inet_pton'):
83 83 try:
84 84 socket.inet_pton(socket.AF_INET6, ip_addr)
85 85 except socket.error:
86 86 return False
87 87 else:
88 88 # fallback to ipaddress
89 89 try:
90 90 ipaddress.IPv6Address(safe_str(ip_addr))
91 91 except Exception:
92 92 return False
93 93 return True
94 94
95 95 if ':' not in ip: # must be ipv4 pure ip
96 96 return ip
97 97
98 98 if '[' in ip and ']' in ip: # ipv6 with port
99 99 return ip.split(']')[0][1:].lower()
100 100
101 101 # must be ipv6 or ipv4 with port
102 102 if is_ipv6(ip):
103 103 return ip
104 104 else:
105 105 ip, _port = ip.split(':')[:2] # means ipv4+port
106 106 return ip
107 107
108 108
109 109 def get_ip_addr(environ):
110 110 proxy_key = 'HTTP_X_REAL_IP'
111 111 proxy_key2 = 'HTTP_X_FORWARDED_FOR'
112 112 def_key = 'REMOTE_ADDR'
113 113
114 114 def ip_filters(ip_):
115 115 return _filter_port(_filter_proxy(ip_))
116 116
117 117 ip = environ.get(proxy_key)
118 118 if ip:
119 119 return ip_filters(ip)
120 120
121 121 ip = environ.get(proxy_key2)
122 122 if ip:
123 123 return ip_filters(ip)
124 124
125 125 ip = environ.get(def_key, '0.0.0.0')
126 126 return ip_filters(ip)
127 127
128 128
129 129 def get_server_ip_addr(environ, log_errors=True):
130 130 hostname = environ.get('SERVER_NAME')
131 131 try:
132 132 return socket.gethostbyname(hostname)
133 133 except Exception as e:
134 134 if log_errors:
135 135 # in some cases this lookup is not possible, and we don't want to
136 136 # make it an exception in logs
137 137 log.exception('Could not retrieve server ip address: %s', e)
138 138 return hostname
139 139
140 140
141 141 def get_server_port(environ):
142 142 return environ.get('SERVER_PORT')
143 143
144 144
145 145
146 146 def get_user_agent(environ):
147 147 return environ.get('HTTP_USER_AGENT')
148 148
149 149
150 150 def vcs_operation_context(
151 151 environ, repo_name, username, action, scm, check_locking=True,
152 152 is_shadow_repo=False, check_branch_perms=False, detect_force_push=False):
153 153 """
154 154 Generate the context for a vcs operation, e.g. push or pull.
155 155
156 156 This context is passed over the layers so that hooks triggered by the
157 157 vcs operation know details like the user, the user's IP address etc.
158 158
159 159 :param check_locking: Allows to switch of the computation of the locking
160 160 data. This serves mainly the need of the simplevcs middleware to be
161 161 able to disable this for certain operations.
162 162
163 163 """
164 164 # Tri-state value: False: unlock, None: nothing, True: lock
165 165 make_lock = None
166 166 locked_by = [None, None, None]
167 167 is_anonymous = username == User.DEFAULT_USER
168 168 user = User.get_by_username(username)
169 169 if not is_anonymous and check_locking:
170 170 log.debug('Checking locking on repository "%s"', repo_name)
171 171 repo = Repository.get_by_repo_name(repo_name)
172 172 make_lock, __, locked_by = repo.get_locking_state(
173 173 action, user.user_id)
174 174 user_id = user.user_id
175 175 settings_model = VcsSettingsModel(repo=repo_name)
176 176 ui_settings = settings_model.get_ui_settings()
177 177
178 178 # NOTE(marcink): This should be also in sync with
179 179 # rhodecode/apps/ssh_support/lib/backends/base.py:update_environment scm_data
180 180 store = [x for x in ui_settings if x.key == '/']
181 181 repo_store = ''
182 182 if store:
183 183 repo_store = store[0].value
184 184
185 185 scm_data = {
186 186 'ip': get_ip_addr(environ),
187 187 'username': username,
188 188 'user_id': user_id,
189 189 'action': action,
190 190 'repository': repo_name,
191 191 'scm': scm,
192 192 'config': rhodecode.CONFIG['__file__'],
193 193 'repo_store': repo_store,
194 194 'make_lock': make_lock,
195 195 'locked_by': locked_by,
196 196 'server_url': utils2.get_server_url(environ),
197 197 'user_agent': get_user_agent(environ),
198 198 'hooks': get_enabled_hook_classes(ui_settings),
199 199 'is_shadow_repo': is_shadow_repo,
200 200 'detect_force_push': detect_force_push,
201 201 'check_branch_perms': check_branch_perms,
202 202 }
203 203 return scm_data
204 204
205 205
206 206 class BasicAuth(AuthBasicAuthenticator):
207 207
208 208 def __init__(self, realm, authfunc, registry, auth_http_code=None,
209 209 initial_call_detection=False, acl_repo_name=None, rc_realm=''):
210 210 super().__init__(realm=realm, authfunc=authfunc)
211 211 self.realm = realm
212 212 self.rc_realm = rc_realm
213 213 self.initial_call = initial_call_detection
214 214 self.authfunc = authfunc
215 215 self.registry = registry
216 216 self.acl_repo_name = acl_repo_name
217 217 self._rc_auth_http_code = auth_http_code
218 218
219 219 def _get_response_from_code(self, http_code, fallback):
220 220 try:
221 221 return get_exception(safe_int(http_code))
222 222 except Exception:
223 223 log.exception('Failed to fetch response class for code %s, using fallback: %s', http_code, fallback)
224 224 return fallback
225 225
226 226 def get_rc_realm(self):
227 227 return safe_str(self.rc_realm)
228 228
229 229 def build_authentication(self):
230 230 header = [('WWW-Authenticate', f'Basic realm="{self.realm}"')]
231 231
232 232 # NOTE: the initial_Call detection seems to be not working/not needed witg latest Mercurial
233 233 # investigate if we still need it.
234 234 if self._rc_auth_http_code and not self.initial_call:
235 235 # return alternative HTTP code if alternative http return code
236 236 # is specified in RhodeCode config, but ONLY if it's not the
237 237 # FIRST call
238 238 custom_response_klass = self._get_response_from_code(self._rc_auth_http_code, fallback=HTTPUnauthorized)
239 239 log.debug('Using custom response class: %s', custom_response_klass)
240 240 return custom_response_klass(headers=header)
241 241 return HTTPUnauthorized(headers=header)
242 242
243 243 def authenticate(self, environ):
244 244 authorization = paste.httpheaders.AUTHORIZATION(environ)
245 245 if not authorization:
246 246 return self.build_authentication()
247 247 (auth_meth, auth_creds_b64) = authorization.split(' ', 1)
248 248 if 'basic' != auth_meth.lower():
249 249 return self.build_authentication()
250 250
251 251 credentials = safe_str(base64.b64decode(auth_creds_b64.strip()))
252 252 _parts = credentials.split(':', 1)
253 253 if len(_parts) == 2:
254 254 username, password = _parts
255 255 auth_data = self.authfunc(
256 256 username, password, environ, VCS_TYPE,
257 257 registry=self.registry, acl_repo_name=self.acl_repo_name)
258 258 if auth_data:
259 259 return {'username': username, 'auth_data': auth_data}
260 260 if username and password:
261 261 # we mark that we actually executed authentication once, at
262 262 # that point we can use the alternative auth code
263 263 self.initial_call = False
264 264
265 265 return self.build_authentication()
266 266
267 267 __call__ = authenticate
268 268
269 269
270 270 def calculate_version_hash(config):
271 271 return sha1(
272 272 config.get(b'beaker.session.secret', b'') + ascii_bytes(rhodecode.__version__)
273 273 )[:8]
274 274
275 275
276 276 def get_current_lang(request):
277 277 return getattr(request, '_LOCALE_', request.locale_name)
278 278
279 279
280 280 def attach_context_attributes(context, request, user_id=None, is_api=None):
281 281 """
282 282 Attach variables into template context called `c`.
283 283 """
284 284 config = request.registry.settings
285 285
286 286 rc_config = SettingsModel().get_all_settings(cache=True, from_request=False)
287 287 context.rc_config = rc_config
288 288 context.rhodecode_version = rhodecode.__version__
289 289 context.rhodecode_edition = config.get('rhodecode.edition')
290 290 context.rhodecode_edition_id = config.get('rhodecode.edition_id')
291 291 # unique secret + version does not leak the version but keep consistency
292 292 context.rhodecode_version_hash = calculate_version_hash(config)
293 293
294 294 # Default language set for the incoming request
295 295 context.language = get_current_lang(request)
296 296
297 297 # Visual options
298 298 context.visual = AttributeDict({})
299 299
300 300 # DB stored Visual Items
301 301 context.visual.show_public_icon = str2bool(
302 302 rc_config.get('rhodecode_show_public_icon'))
303 303 context.visual.show_private_icon = str2bool(
304 304 rc_config.get('rhodecode_show_private_icon'))
305 305 context.visual.stylify_metatags = str2bool(
306 306 rc_config.get('rhodecode_stylify_metatags'))
307 307 context.visual.dashboard_items = safe_int(
308 308 rc_config.get('rhodecode_dashboard_items', 100))
309 309 context.visual.admin_grid_items = safe_int(
310 310 rc_config.get('rhodecode_admin_grid_items', 100))
311 311 context.visual.show_revision_number = str2bool(
312 312 rc_config.get('rhodecode_show_revision_number', True))
313 313 context.visual.show_sha_length = safe_int(
314 314 rc_config.get('rhodecode_show_sha_length', 100))
315 315 context.visual.repository_fields = str2bool(
316 316 rc_config.get('rhodecode_repository_fields'))
317 317 context.visual.show_version = str2bool(
318 318 rc_config.get('rhodecode_show_version'))
319 319 context.visual.use_gravatar = str2bool(
320 320 rc_config.get('rhodecode_use_gravatar'))
321 321 context.visual.gravatar_url = rc_config.get('rhodecode_gravatar_url')
322 322 context.visual.default_renderer = rc_config.get(
323 323 'rhodecode_markup_renderer', 'rst')
324 324 context.visual.comment_types = ChangesetComment.COMMENT_TYPES
325 325 context.visual.rhodecode_support_url = \
326 326 rc_config.get('rhodecode_support_url') or h.route_url('rhodecode_support')
327 327
328 328 context.visual.affected_files_cut_off = 60
329 329
330 330 context.pre_code = rc_config.get('rhodecode_pre_code')
331 331 context.post_code = rc_config.get('rhodecode_post_code')
332 332 context.rhodecode_name = rc_config.get('rhodecode_title')
333 333 context.default_encodings = aslist(config.get('default_encoding'), sep=',')
334 334 # if we have specified default_encoding in the request, it has more
335 335 # priority
336 336 if request.GET.get('default_encoding'):
337 337 context.default_encodings.insert(0, request.GET.get('default_encoding'))
338 338 context.clone_uri_tmpl = rc_config.get('rhodecode_clone_uri_tmpl')
339 339 context.clone_uri_id_tmpl = rc_config.get('rhodecode_clone_uri_id_tmpl')
340 340 context.clone_uri_ssh_tmpl = rc_config.get('rhodecode_clone_uri_ssh_tmpl')
341 341
342 342 # INI stored
343 343 context.labs_active = str2bool(
344 344 config.get('labs_settings_active', 'false'))
345 345 context.ssh_enabled = str2bool(
346 346 config.get('ssh.generate_authorized_keyfile', 'false'))
347 347 context.ssh_key_generator_enabled = str2bool(
348 348 config.get('ssh.enable_ui_key_generator', 'true'))
349 349
350 350 context.visual.allow_repo_location_change = str2bool(
351 351 config.get('allow_repo_location_change', True))
352 352 context.visual.allow_custom_hooks_settings = str2bool(
353 353 config.get('allow_custom_hooks_settings', True))
354 354 context.debug_style = str2bool(config.get('debug_style', False))
355 355
356 356 context.rhodecode_instanceid = config.get('instance_id')
357 357
358 358 context.visual.cut_off_limit_diff = safe_int(
359 359 config.get('cut_off_limit_diff'), default=0)
360 360 context.visual.cut_off_limit_file = safe_int(
361 361 config.get('cut_off_limit_file'), default=0)
362 362
363 363 context.license = AttributeDict({})
364 364 context.license.hide_license_info = str2bool(
365 365 config.get('license.hide_license_info', False))
366 366
367 367 # AppEnlight
368 368 context.appenlight_enabled = config.get('appenlight', False)
369 369 context.appenlight_api_public_key = config.get(
370 370 'appenlight.api_public_key', '')
371 371 context.appenlight_server_url = config.get('appenlight.server_url', '')
372 372
373 373 diffmode = {
374 374 "unified": "unified",
375 375 "sideside": "sideside"
376 376 }.get(request.GET.get('diffmode'))
377 377
378 378 if is_api is not None:
379 379 is_api = hasattr(request, 'rpc_user')
380 380 session_attrs = {
381 381 # defaults
382 382 "clone_url_format": "http",
383 383 "diffmode": "sideside",
384 384 "license_fingerprint": request.session.get('license_fingerprint')
385 385 }
386 386
387 387 if not is_api:
388 388 # don't access pyramid session for API calls
389 389 if diffmode and diffmode != request.session.get('rc_user_session_attr.diffmode'):
390 390 request.session['rc_user_session_attr.diffmode'] = diffmode
391 391
392 392 # session settings per user
393 393
394 394 for k, v in list(request.session.items()):
395 395 pref = 'rc_user_session_attr.'
396 396 if k and k.startswith(pref):
397 397 k = k[len(pref):]
398 398 session_attrs[k] = v
399 399
400 400 context.user_session_attrs = session_attrs
401 401
402 402 # JS template context
403 403 context.template_context = {
404 404 'repo_name': None,
405 405 'repo_type': None,
406 406 'repo_landing_commit': None,
407 407 'rhodecode_user': {
408 408 'username': None,
409 409 'email': None,
410 410 'notification_status': False
411 411 },
412 412 'session_attrs': session_attrs,
413 413 'visual': {
414 414 'default_renderer': None
415 415 },
416 416 'commit_data': {
417 417 'commit_id': None
418 418 },
419 419 'pull_request_data': {'pull_request_id': None},
420 420 'timeago': {
421 421 'refresh_time': 120 * 1000,
422 422 'cutoff_limit': 1000 * 60 * 60 * 24 * 7
423 423 },
424 424 'pyramid_dispatch': {
425 425
426 426 },
427 427 'extra': {'plugins': {}}
428 428 }
429 429 # END CONFIG VARS
430 430 if is_api:
431 431 csrf_token = None
432 432 else:
433 433 csrf_token = auth.get_csrf_token(session=request.session)
434 434
435 435 context.csrf_token = csrf_token
436 436 context.backends = list(rhodecode.BACKENDS.keys())
437 437
438 438 unread_count = 0
439 439 user_bookmark_list = []
440 440 if user_id:
441 441 unread_count = NotificationModel().get_unread_cnt_for_user(user_id)
442 442 user_bookmark_list = UserBookmark.get_bookmarks_for_user(user_id)
443 443 context.unread_notifications = unread_count
444 444 context.bookmark_items = user_bookmark_list
445 445
446 446 # web case
447 447 if hasattr(request, 'user'):
448 448 context.auth_user = request.user
449 449 context.rhodecode_user = request.user
450 450
451 451 # api case
452 452 if hasattr(request, 'rpc_user'):
453 453 context.auth_user = request.rpc_user
454 454 context.rhodecode_user = request.rpc_user
455 455
456 456 # attach the whole call context to the request
457 457 request.set_call_context(context)
458 458
459 459
460 460 def get_auth_user(request):
461 461 environ = request.environ
462 462 session = request.session
463 463
464 464 ip_addr = get_ip_addr(environ)
465 465
466 466 # make sure that we update permissions each time we call controller
467 467 _auth_token = (
468 468 # ?auth_token=XXX
469 469 request.GET.get('auth_token', '')
470 470 # ?api_key=XXX !LEGACY
471 471 or request.GET.get('api_key', '')
472 472 # or headers....
473 473 or request.headers.get('X-Rc-Auth-Token', '')
474 474 )
475 475 if not _auth_token and request.matchdict:
476 476 url_auth_token = request.matchdict.get('_auth_token')
477 477 _auth_token = url_auth_token
478 478 if _auth_token:
479 479 log.debug('Using URL extracted auth token `...%s`', _auth_token[-4:])
480 480
481 481 if _auth_token:
482 482 # when using API_KEY we assume user exists, and
483 483 # doesn't need auth based on cookies.
484 484 auth_user = AuthUser(api_key=_auth_token, ip_addr=ip_addr)
485 485 authenticated = False
486 486 else:
487 487 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
488 488 try:
489 489 auth_user = AuthUser(user_id=cookie_store.get('user_id', None),
490 490 ip_addr=ip_addr)
491 491 except UserCreationError as e:
492 492 h.flash(e, 'error')
493 493 # container auth or other auth functions that create users
494 494 # on the fly can throw this exception signaling that there's
495 495 # issue with user creation, explanation should be provided
496 496 # in Exception itself. We then create a simple blank
497 497 # AuthUser
498 498 auth_user = AuthUser(ip_addr=ip_addr)
499 499
500 500 # in case someone changes a password for user it triggers session
501 501 # flush and forces a re-login
502 502 if password_changed(auth_user, session):
503 503 session.invalidate()
504 504 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
505 505 auth_user = AuthUser(ip_addr=ip_addr)
506 506
507 507 authenticated = cookie_store.get('is_authenticated')
508 508
509 509 if not auth_user.is_authenticated and auth_user.is_user_object:
510 510 # user is not authenticated and not empty
511 511 auth_user.set_authenticated(authenticated)
512 512
513 513 return auth_user, _auth_token
514 514
515 515
516 516 def h_filter(s):
517 517 """
518 518 Custom filter for Mako templates. Mako by standard uses `markupsafe.escape`
519 519 we wrap this with additional functionality that converts None to empty
520 520 strings
521 521 """
522 522 if s is None:
523 523 return markupsafe.Markup()
524 524 return markupsafe.escape(s)
525 525
526 526
527 527 def add_events_routes(config):
528 528 """
529 529 Adds routing that can be used in events. Because some events are triggered
530 530 outside of pyramid context, we need to bootstrap request with some
531 531 routing registered
532 532 """
533 533
534 534 from rhodecode.apps._base import ADMIN_PREFIX
535 535
536 536 config.add_route(name='home', pattern='/')
537 537 config.add_route(name='main_page_repos_data', pattern='/_home_repos')
538 538 config.add_route(name='main_page_repo_groups_data', pattern='/_home_repo_groups')
539 539
540 540 config.add_route(name='login', pattern=ADMIN_PREFIX + '/login')
541 541 config.add_route(name='logout', pattern=ADMIN_PREFIX + '/logout')
542 542 config.add_route(name='repo_summary', pattern='/{repo_name}')
543 543 config.add_route(name='repo_summary_explicit', pattern='/{repo_name}/summary')
544 544 config.add_route(name='repo_group_home', pattern='/{repo_group_name}')
545 545
546 546 config.add_route(name='pullrequest_show',
547 547 pattern='/{repo_name}/pull-request/{pull_request_id}')
548 548 config.add_route(name='pull_requests_global',
549 549 pattern='/pull-request/{pull_request_id}')
550 550
551 551 config.add_route(name='repo_commit',
552 552 pattern='/{repo_name}/changeset/{commit_id}')
553 553 config.add_route(name='repo_files',
554 554 pattern='/{repo_name}/files/{commit_id}/{f_path}')
555 555
556 556 config.add_route(name='hovercard_user',
557 557 pattern='/_hovercard/user/{user_id}')
558 558
559 559 config.add_route(name='hovercard_user_group',
560 560 pattern='/_hovercard/user_group/{user_group_id}')
561 561
562 562 config.add_route(name='hovercard_pull_request',
563 563 pattern='/_hovercard/pull_request/{pull_request_id}')
564 564
565 565 config.add_route(name='hovercard_repo_commit',
566 566 pattern='/_hovercard/commit/{repo_name}/{commit_id}')
567 567
568 568
569 569 def bootstrap_config(request, registry_name='RcTestRegistry'):
570 from rhodecode.config.middleware import sanitize_settings_and_apply_defaults
570 from rhodecode.config.config_maker import sanitize_settings_and_apply_defaults
571 571 import pyramid.testing
572 572 registry = pyramid.testing.Registry(registry_name)
573 573
574 574 global_config = {'__file__': ''}
575 575
576 576 config = pyramid.testing.setUp(registry=registry, request=request)
577 577 sanitize_settings_and_apply_defaults(global_config, config.registry.settings)
578 578
579 579 # allow pyramid lookup in testing
580 580 config.include('pyramid_mako')
581 581 config.include('rhodecode.lib.rc_beaker')
582 582 config.include('rhodecode.lib.rc_cache')
583 583 config.include('rhodecode.lib.rc_cache.archive_cache')
584 584 add_events_routes(config)
585 585
586 586 return config
587 587
588 588
589 589 def bootstrap_request(**kwargs):
590 590 """
591 591 Returns a thin version of Request Object that is used in non-web context like testing/celery
592 592 """
593 593
594 594 import pyramid.testing
595 595 from rhodecode.lib.request import ThinRequest as _ThinRequest
596 596
597 597 class ThinRequest(_ThinRequest):
598 598 application_url = kwargs.pop('application_url', 'http://example.com')
599 599 host = kwargs.pop('host', 'example.com:80')
600 600 domain = kwargs.pop('domain', 'example.com')
601 601
602 602 class ThinSession(pyramid.testing.DummySession):
603 603 def save(*arg, **kw):
604 604 pass
605 605
606 606 request = ThinRequest(**kwargs)
607 607 request.session = ThinSession()
608 608
609 609 return request
@@ -1,244 +1,243 b''
1 1
2 2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 import base64
21 21 import logging
22 22 import urllib.request
23 23 import urllib.parse
24 24 import urllib.error
25 25 import urllib.parse
26 26
27 27 import requests
28 28 from pyramid.httpexceptions import HTTPNotAcceptable
29 29
30 30 from rhodecode.lib import rc_cache
31 31 from rhodecode.lib.middleware import simplevcs
32 32 from rhodecode.lib.middleware.utils import get_path_info
33 33 from rhodecode.lib.utils import is_valid_repo
34 34 from rhodecode.lib.str_utils import safe_str, safe_int, safe_bytes
35 35 from rhodecode.lib.type_utils import str2bool
36 36 from rhodecode.lib.ext_json import json
37 from rhodecode.lib.hooks_daemon import store_txn_id_data
38
37 from rhodecode.lib.hook_daemon.base import store_txn_id_data
39 38
40 39 log = logging.getLogger(__name__)
41 40
42 41
43 42 class SimpleSvnApp(object):
44 43 IGNORED_HEADERS = [
45 44 'connection', 'keep-alive', 'content-encoding',
46 45 'transfer-encoding', 'content-length']
47 46 rc_extras = {}
48 47
49 48 def __init__(self, config):
50 49 self.config = config
51 50 self.session = requests.Session()
52 51
53 52 def __call__(self, environ, start_response):
54 53 request_headers = self._get_request_headers(environ)
55 54 data_io = environ['wsgi.input']
56 55 req_method: str = environ['REQUEST_METHOD']
57 56 has_content_length = 'CONTENT_LENGTH' in environ
58 57
59 58 path_info = self._get_url(
60 59 self.config.get('subversion_http_server_url', ''), get_path_info(environ))
61 60 transfer_encoding = environ.get('HTTP_TRANSFER_ENCODING', '')
62 61 log.debug('Handling: %s method via `%s`', req_method, path_info)
63 62
64 63 # stream control flag, based on request and content type...
65 64 stream = False
66 65
67 66 if req_method in ['MKCOL'] or has_content_length:
68 67 data_processed = False
69 68 # read chunk to check if we have txn-with-props
70 69 initial_data: bytes = data_io.read(1024)
71 70 if initial_data.startswith(b'(create-txn-with-props'):
72 71 data_io = initial_data + data_io.read()
73 72 # store on-the-fly our rc_extra using svn revision properties
74 73 # those can be read later on in hooks executed so we have a way
75 74 # to pass in the data into svn hooks
76 75 rc_data = base64.urlsafe_b64encode(json.dumps(self.rc_extras))
77 76 rc_data_len = str(len(rc_data))
78 77 # header defines data length, and serialized data
79 78 skel = b' rc-scm-extras %b %b' % (safe_bytes(rc_data_len), safe_bytes(rc_data))
80 79 data_io = data_io[:-2] + skel + b'))'
81 80 data_processed = True
82 81
83 82 if not data_processed:
84 83 # NOTE(johbo): Avoid that we end up with sending the request in chunked
85 84 # transfer encoding (mainly on Gunicorn). If we know the content
86 85 # length, then we should transfer the payload in one request.
87 86 data_io = initial_data + data_io.read()
88 87
89 88 if req_method in ['GET', 'PUT'] or transfer_encoding == 'chunked':
90 89 # NOTE(marcink): when getting/uploading files, we want to STREAM content
91 90 # back to the client/proxy instead of buffering it here...
92 91 stream = True
93 92
94 93 stream = stream
95 94 log.debug('Calling SVN PROXY at `%s`, using method:%s. Stream: %s',
96 95 path_info, req_method, stream)
97 96
98 97 call_kwargs = dict(
99 98 data=data_io,
100 99 headers=request_headers,
101 100 stream=stream
102 101 )
103 102 if req_method in ['HEAD', 'DELETE']:
104 103 del call_kwargs['data']
105 104
106 105 try:
107 106 response = self.session.request(
108 107 req_method, path_info, **call_kwargs)
109 108 except requests.ConnectionError:
110 109 log.exception('ConnectionError occurred for endpoint %s', path_info)
111 110 raise
112 111
113 112 if response.status_code not in [200, 401]:
114 113 text = '\n{}'.format(safe_str(response.text)) if response.text else ''
115 114 if response.status_code >= 500:
116 115 log.error('Got SVN response:%s with text:`%s`', response, text)
117 116 else:
118 117 log.debug('Got SVN response:%s with text:`%s`', response, text)
119 118 else:
120 119 log.debug('got response code: %s', response.status_code)
121 120
122 121 response_headers = self._get_response_headers(response.headers)
123 122
124 123 if response.headers.get('SVN-Txn-name'):
125 124 svn_tx_id = response.headers.get('SVN-Txn-name')
126 125 txn_id = rc_cache.utils.compute_key_from_params(
127 126 self.config['repository'], svn_tx_id)
128 127 port = safe_int(self.rc_extras['hooks_uri'].split(':')[-1])
129 128 store_txn_id_data(txn_id, {'port': port})
130 129
131 130 start_response(f'{response.status_code} {response.reason}', response_headers)
132 131 return response.iter_content(chunk_size=1024)
133 132
134 133 def _get_url(self, svn_http_server, path):
135 134 svn_http_server_url = (svn_http_server or '').rstrip('/')
136 135 url_path = urllib.parse.urljoin(svn_http_server_url + '/', (path or '').lstrip('/'))
137 136 url_path = urllib.parse.quote(url_path, safe="/:=~+!$,;'")
138 137 return url_path
139 138
140 139 def _get_request_headers(self, environ):
141 140 headers = {}
142 141 whitelist = {
143 142 'Authorization': {}
144 143 }
145 144 for key in environ:
146 145 if key in whitelist:
147 146 headers[key] = environ[key]
148 147 elif not key.startswith('HTTP_'):
149 148 continue
150 149 else:
151 150 new_key = key.split('_')
152 151 new_key = [k.capitalize() for k in new_key[1:]]
153 152 new_key = '-'.join(new_key)
154 153 headers[new_key] = environ[key]
155 154
156 155 if 'CONTENT_TYPE' in environ:
157 156 headers['Content-Type'] = environ['CONTENT_TYPE']
158 157
159 158 if 'CONTENT_LENGTH' in environ:
160 159 headers['Content-Length'] = environ['CONTENT_LENGTH']
161 160
162 161 return headers
163 162
164 163 def _get_response_headers(self, headers):
165 164 headers = [
166 165 (h, headers[h])
167 166 for h in headers
168 167 if h.lower() not in self.IGNORED_HEADERS
169 168 ]
170 169
171 170 return headers
172 171
173 172
174 173 class DisabledSimpleSvnApp(object):
175 174 def __init__(self, config):
176 175 self.config = config
177 176
178 177 def __call__(self, environ, start_response):
179 178 reason = 'Cannot handle SVN call because: SVN HTTP Proxy is not enabled'
180 179 log.warning(reason)
181 180 return HTTPNotAcceptable(reason)(environ, start_response)
182 181
183 182
184 183 class SimpleSvn(simplevcs.SimpleVCS):
185 184
186 185 SCM = 'svn'
187 186 READ_ONLY_COMMANDS = ('OPTIONS', 'PROPFIND', 'GET', 'REPORT')
188 187 DEFAULT_HTTP_SERVER = 'http://localhost:8090'
189 188
190 189 def _get_repository_name(self, environ):
191 190 """
192 191 Gets repository name out of PATH_INFO header
193 192
194 193 :param environ: environ where PATH_INFO is stored
195 194 """
196 195 path = get_path_info(environ).split('!')
197 196 repo_name = path[0].strip('/')
198 197
199 198 # SVN includes the whole path in it's requests, including
200 199 # subdirectories inside the repo. Therefore we have to search for
201 200 # the repo root directory.
202 201 if not is_valid_repo(
203 202 repo_name, self.base_path, explicit_scm=self.SCM):
204 203 current_path = ''
205 204 for component in repo_name.split('/'):
206 205 current_path += component
207 206 if is_valid_repo(
208 207 current_path, self.base_path, explicit_scm=self.SCM):
209 208 return current_path
210 209 current_path += '/'
211 210
212 211 return repo_name
213 212
214 213 def _get_action(self, environ):
215 214 return (
216 215 'pull'
217 216 if environ['REQUEST_METHOD'] in self.READ_ONLY_COMMANDS
218 217 else 'push')
219 218
220 219 def _should_use_callback_daemon(self, extras, environ, action):
221 220 # only MERGE command triggers hooks, so we don't want to start
222 221 # hooks server too many times. POST however starts the svn transaction
223 222 # so we also need to run the init of callback daemon of POST
224 223 if environ['REQUEST_METHOD'] in ['MERGE', 'POST']:
225 224 return True
226 225 return False
227 226
228 227 def _create_wsgi_app(self, repo_path, repo_name, config):
229 228 if self._is_svn_enabled():
230 229 return SimpleSvnApp(config)
231 230 # we don't have http proxy enabled return dummy request handler
232 231 return DisabledSimpleSvnApp(config)
233 232
234 233 def _is_svn_enabled(self):
235 234 conf = self.repo_vcs_config
236 235 return str2bool(conf.get('vcs_svn_proxy', 'http_requests_enabled'))
237 236
238 237 def _create_config(self, extras, repo_name, scheme='http'):
239 238 conf = self.repo_vcs_config
240 239 server_url = conf.get('vcs_svn_proxy', 'http_server_url')
241 240 server_url = server_url or self.DEFAULT_HTTP_SERVER
242 241
243 242 extras['subversion_http_server_url'] = server_url
244 243 return extras
@@ -1,701 +1,701 b''
1 1
2 2
3 3 # Copyright (C) 2014-2023 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 23 It's implemented with basic auth function
24 24 """
25 25
26 26 import os
27 27 import re
28 28 import io
29 29 import logging
30 30 import importlib
31 31 from functools import wraps
32 32 from lxml import etree
33 33
34 34 import time
35 35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36 36
37 37 from pyramid.httpexceptions import (
38 38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 39 from zope.cachedescriptors.property import Lazy as LazyProperty
40 40
41 41 import rhodecode
42 42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 43 from rhodecode.lib import rc_cache
44 44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 45 from rhodecode.lib.base import (
46 46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
48 from rhodecode.lib.hook_daemon.base import prepare_callback_daemon
49 49 from rhodecode.lib.middleware import appenlight
50 50 from rhodecode.lib.middleware.utils import scm_app_http
51 51 from rhodecode.lib.str_utils import safe_bytes
52 52 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
53 53 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool
54 54 from rhodecode.lib.vcs.conf import settings as vcs_settings
55 55 from rhodecode.lib.vcs.backends import base
56 56
57 57 from rhodecode.model import meta
58 58 from rhodecode.model.db import User, Repository, PullRequest
59 59 from rhodecode.model.scm import ScmModel
60 60 from rhodecode.model.pull_request import PullRequestModel
61 61 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
62 62
63 63 log = logging.getLogger(__name__)
64 64
65 65
66 66 def extract_svn_txn_id(acl_repo_name, data: bytes):
67 67 """
68 68 Helper method for extraction of svn txn_id from submitted XML data during
69 69 POST operations
70 70 """
71 71
72 72 try:
73 73 root = etree.fromstring(data)
74 74 pat = re.compile(r'/txn/(?P<txn_id>.*)')
75 75 for el in root:
76 76 if el.tag == '{DAV:}source':
77 77 for sub_el in el:
78 78 if sub_el.tag == '{DAV:}href':
79 79 match = pat.search(sub_el.text)
80 80 if match:
81 81 svn_tx_id = match.groupdict()['txn_id']
82 82 txn_id = rc_cache.utils.compute_key_from_params(
83 83 acl_repo_name, svn_tx_id)
84 84 return txn_id
85 85 except Exception:
86 86 log.exception('Failed to extract txn_id')
87 87
88 88
89 89 def initialize_generator(factory):
90 90 """
91 91 Initializes the returned generator by draining its first element.
92 92
93 93 This can be used to give a generator an initializer, which is the code
94 94 up to the first yield statement. This decorator enforces that the first
95 95 produced element has the value ``"__init__"`` to make its special
96 96 purpose very explicit in the using code.
97 97 """
98 98
99 99 @wraps(factory)
100 100 def wrapper(*args, **kwargs):
101 101 gen = factory(*args, **kwargs)
102 102 try:
103 103 init = next(gen)
104 104 except StopIteration:
105 105 raise ValueError('Generator must yield at least one element.')
106 106 if init != "__init__":
107 107 raise ValueError('First yielded element must be "__init__".')
108 108 return gen
109 109 return wrapper
110 110
111 111
112 112 class SimpleVCS(object):
113 113 """Common functionality for SCM HTTP handlers."""
114 114
115 115 SCM = 'unknown'
116 116
117 117 acl_repo_name = None
118 118 url_repo_name = None
119 119 vcs_repo_name = None
120 120 rc_extras = {}
121 121
122 122 # We have to handle requests to shadow repositories different than requests
123 123 # to normal repositories. Therefore we have to distinguish them. To do this
124 124 # we use this regex which will match only on URLs pointing to shadow
125 125 # repositories.
126 126 shadow_repo_re = re.compile(
127 127 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
128 128 '(?P<target>{slug_pat})/' # target repo
129 129 'pull-request/(?P<pr_id>\\d+)/' # pull request
130 130 'repository$' # shadow repo
131 131 .format(slug_pat=SLUG_RE.pattern))
132 132
133 133 def __init__(self, config, registry):
134 134 self.registry = registry
135 135 self.config = config
136 136 # re-populated by specialized middleware
137 137 self.repo_vcs_config = base.Config()
138 138
139 139 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
140 140 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
141 141
142 142 # authenticate this VCS request using authfunc
143 143 auth_ret_code_detection = \
144 144 str2bool(self.config.get('auth_ret_code_detection', False))
145 145 self.authenticate = BasicAuth(
146 146 '', authenticate, registry, config.get('auth_ret_code'),
147 147 auth_ret_code_detection, rc_realm=realm)
148 148 self.ip_addr = '0.0.0.0'
149 149
150 150 @LazyProperty
151 151 def global_vcs_config(self):
152 152 try:
153 153 return VcsSettingsModel().get_ui_settings_as_config_obj()
154 154 except Exception:
155 155 return base.Config()
156 156
157 157 @property
158 158 def base_path(self):
159 159 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
160 160
161 161 if not settings_path:
162 162 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
163 163
164 164 if not settings_path:
165 165 # try, maybe we passed in explicitly as config option
166 166 settings_path = self.config.get('base_path')
167 167
168 168 if not settings_path:
169 169 raise ValueError('FATAL: base_path is empty')
170 170 return settings_path
171 171
172 172 def set_repo_names(self, environ):
173 173 """
174 174 This will populate the attributes acl_repo_name, url_repo_name,
175 175 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
176 176 shadow) repositories all names are equal. In case of requests to a
177 177 shadow repository the acl-name points to the target repo of the pull
178 178 request and the vcs-name points to the shadow repo file system path.
179 179 The url-name is always the URL used by the vcs client program.
180 180
181 181 Example in case of a shadow repo:
182 182 acl_repo_name = RepoGroup/MyRepo
183 183 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
184 184 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
185 185 """
186 186 # First we set the repo name from URL for all attributes. This is the
187 187 # default if handling normal (non shadow) repo requests.
188 188 self.url_repo_name = self._get_repository_name(environ)
189 189 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
190 190 self.is_shadow_repo = False
191 191
192 192 # Check if this is a request to a shadow repository.
193 193 match = self.shadow_repo_re.match(self.url_repo_name)
194 194 if match:
195 195 match_dict = match.groupdict()
196 196
197 197 # Build acl repo name from regex match.
198 198 acl_repo_name = safe_str('{groups}{target}'.format(
199 199 groups=match_dict['groups'] or '',
200 200 target=match_dict['target']))
201 201
202 202 # Retrieve pull request instance by ID from regex match.
203 203 pull_request = PullRequest.get(match_dict['pr_id'])
204 204
205 205 # Only proceed if we got a pull request and if acl repo name from
206 206 # URL equals the target repo name of the pull request.
207 207 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
208 208
209 209 # Get file system path to shadow repository.
210 210 workspace_id = PullRequestModel()._workspace_id(pull_request)
211 211 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
212 212
213 213 # Store names for later usage.
214 214 self.vcs_repo_name = vcs_repo_name
215 215 self.acl_repo_name = acl_repo_name
216 216 self.is_shadow_repo = True
217 217
218 218 log.debug('Setting all VCS repository names: %s', {
219 219 'acl_repo_name': self.acl_repo_name,
220 220 'url_repo_name': self.url_repo_name,
221 221 'vcs_repo_name': self.vcs_repo_name,
222 222 })
223 223
224 224 @property
225 225 def scm_app(self):
226 226 custom_implementation = self.config['vcs.scm_app_implementation']
227 227 if custom_implementation == 'http':
228 228 log.debug('Using HTTP implementation of scm app.')
229 229 scm_app_impl = scm_app_http
230 230 else:
231 231 log.debug('Using custom implementation of scm_app: "{}"'.format(
232 232 custom_implementation))
233 233 scm_app_impl = importlib.import_module(custom_implementation)
234 234 return scm_app_impl
235 235
236 236 def _get_by_id(self, repo_name):
237 237 """
238 238 Gets a special pattern _<ID> from clone url and tries to replace it
239 239 with a repository_name for support of _<ID> non changeable urls
240 240 """
241 241
242 242 data = repo_name.split('/')
243 243 if len(data) >= 2:
244 244 from rhodecode.model.repo import RepoModel
245 245 by_id_match = RepoModel().get_repo_by_id(repo_name)
246 246 if by_id_match:
247 247 data[1] = by_id_match.repo_name
248 248
249 249 # Because PEP-3333-WSGI uses bytes-tunneled-in-latin-1 as PATH_INFO
250 250 # and we use this data
251 251 maybe_new_path = '/'.join(data)
252 252 return safe_bytes(maybe_new_path).decode('latin1')
253 253
254 254 def _invalidate_cache(self, repo_name):
255 255 """
256 256 Set's cache for this repository for invalidation on next access
257 257
258 258 :param repo_name: full repo name, also a cache key
259 259 """
260 260 ScmModel().mark_for_invalidation(repo_name)
261 261
262 262 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
263 263 db_repo = Repository.get_by_repo_name(repo_name)
264 264 if not db_repo:
265 265 log.debug('Repository `%s` not found inside the database.',
266 266 repo_name)
267 267 return False
268 268
269 269 if db_repo.repo_type != scm_type:
270 270 log.warning(
271 271 'Repository `%s` have incorrect scm_type, expected %s got %s',
272 272 repo_name, db_repo.repo_type, scm_type)
273 273 return False
274 274
275 275 config = db_repo._config
276 276 config.set('extensions', 'largefiles', '')
277 277 return is_valid_repo(
278 278 repo_name, base_path,
279 279 explicit_scm=scm_type, expect_scm=scm_type, config=config)
280 280
281 281 def valid_and_active_user(self, user):
282 282 """
283 283 Checks if that user is not empty, and if it's actually object it checks
284 284 if he's active.
285 285
286 286 :param user: user object or None
287 287 :return: boolean
288 288 """
289 289 if user is None:
290 290 return False
291 291
292 292 elif user.active:
293 293 return True
294 294
295 295 return False
296 296
297 297 @property
298 298 def is_shadow_repo_dir(self):
299 299 return os.path.isdir(self.vcs_repo_name)
300 300
301 301 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
302 302 plugin_id='', plugin_cache_active=False, cache_ttl=0):
303 303 """
304 304 Checks permissions using action (push/pull) user and repository
305 305 name. If plugin_cache and ttl is set it will use the plugin which
306 306 authenticated the user to store the cached permissions result for N
307 307 amount of seconds as in cache_ttl
308 308
309 309 :param action: push or pull action
310 310 :param user: user instance
311 311 :param repo_name: repository name
312 312 """
313 313
314 314 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
315 315 plugin_id, plugin_cache_active, cache_ttl)
316 316
317 317 user_id = user.user_id
318 318 cache_namespace_uid = f'cache_user_auth.{rc_cache.PERMISSIONS_CACHE_VER}.{user_id}'
319 319 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
320 320
321 321 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
322 322 expiration_time=cache_ttl,
323 323 condition=plugin_cache_active)
324 324 def compute_perm_vcs(
325 325 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
326 326
327 327 log.debug('auth: calculating permission access now...')
328 328 # check IP
329 329 inherit = user.inherit_default_permissions
330 330 ip_allowed = AuthUser.check_ip_allowed(
331 331 user_id, ip_addr, inherit_from_default=inherit)
332 332 if ip_allowed:
333 333 log.info('Access for IP:%s allowed', ip_addr)
334 334 else:
335 335 return False
336 336
337 337 if action == 'push':
338 338 perms = ('repository.write', 'repository.admin')
339 339 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
340 340 return False
341 341
342 342 else:
343 343 # any other action need at least read permission
344 344 perms = (
345 345 'repository.read', 'repository.write', 'repository.admin')
346 346 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
347 347 return False
348 348
349 349 return True
350 350
351 351 start = time.time()
352 352 log.debug('Running plugin `%s` permissions check', plugin_id)
353 353
354 354 # for environ based auth, password can be empty, but then the validation is
355 355 # on the server that fills in the env data needed for authentication
356 356 perm_result = compute_perm_vcs(
357 357 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
358 358
359 359 auth_time = time.time() - start
360 360 log.debug('Permissions for plugin `%s` completed in %.4fs, '
361 361 'expiration time of fetched cache %.1fs.',
362 362 plugin_id, auth_time, cache_ttl)
363 363
364 364 return perm_result
365 365
366 366 def _get_http_scheme(self, environ):
367 367 try:
368 368 return environ['wsgi.url_scheme']
369 369 except Exception:
370 370 log.exception('Failed to read http scheme')
371 371 return 'http'
372 372
373 373 def _check_ssl(self, environ, start_response):
374 374 """
375 375 Checks the SSL check flag and returns False if SSL is not present
376 376 and required True otherwise
377 377 """
378 378 org_proto = environ['wsgi._org_proto']
379 379 # check if we have SSL required ! if not it's a bad request !
380 380 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
381 381 if require_ssl and org_proto == 'http':
382 382 log.debug(
383 383 'Bad request: detected protocol is `%s` and '
384 384 'SSL/HTTPS is required.', org_proto)
385 385 return False
386 386 return True
387 387
388 388 def _get_default_cache_ttl(self):
389 389 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
390 390 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
391 391 plugin_settings = plugin.get_settings()
392 392 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
393 393 plugin_settings) or (False, 0)
394 394 return plugin_cache_active, cache_ttl
395 395
396 396 def __call__(self, environ, start_response):
397 397 try:
398 398 return self._handle_request(environ, start_response)
399 399 except Exception:
400 400 log.exception("Exception while handling request")
401 401 appenlight.track_exception(environ)
402 402 return HTTPInternalServerError()(environ, start_response)
403 403 finally:
404 404 meta.Session.remove()
405 405
406 406 def _handle_request(self, environ, start_response):
407 407 if not self._check_ssl(environ, start_response):
408 408 reason = ('SSL required, while RhodeCode was unable '
409 409 'to detect this as SSL request')
410 410 log.debug('User not allowed to proceed, %s', reason)
411 411 return HTTPNotAcceptable(reason)(environ, start_response)
412 412
413 413 if not self.url_repo_name:
414 414 log.warning('Repository name is empty: %s', self.url_repo_name)
415 415 # failed to get repo name, we fail now
416 416 return HTTPNotFound()(environ, start_response)
417 417 log.debug('Extracted repo name is %s', self.url_repo_name)
418 418
419 419 ip_addr = get_ip_addr(environ)
420 420 user_agent = get_user_agent(environ)
421 421 username = None
422 422
423 423 # skip passing error to error controller
424 424 environ['pylons.status_code_redirect'] = True
425 425
426 426 # ======================================================================
427 427 # GET ACTION PULL or PUSH
428 428 # ======================================================================
429 429 action = self._get_action(environ)
430 430
431 431 # ======================================================================
432 432 # Check if this is a request to a shadow repository of a pull request.
433 433 # In this case only pull action is allowed.
434 434 # ======================================================================
435 435 if self.is_shadow_repo and action != 'pull':
436 436 reason = 'Only pull action is allowed for shadow repositories.'
437 437 log.debug('User not allowed to proceed, %s', reason)
438 438 return HTTPNotAcceptable(reason)(environ, start_response)
439 439
440 440 # Check if the shadow repo actually exists, in case someone refers
441 441 # to it, and it has been deleted because of successful merge.
442 442 if self.is_shadow_repo and not self.is_shadow_repo_dir:
443 443 log.debug(
444 444 'Shadow repo detected, and shadow repo dir `%s` is missing',
445 445 self.is_shadow_repo_dir)
446 446 return HTTPNotFound()(environ, start_response)
447 447
448 448 # ======================================================================
449 449 # CHECK ANONYMOUS PERMISSION
450 450 # ======================================================================
451 451 detect_force_push = False
452 452 check_branch_perms = False
453 453 if action in ['pull', 'push']:
454 454 user_obj = anonymous_user = User.get_default_user()
455 455 auth_user = user_obj.AuthUser()
456 456 username = anonymous_user.username
457 457 if anonymous_user.active:
458 458 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
459 459 # ONLY check permissions if the user is activated
460 460 anonymous_perm = self._check_permission(
461 461 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
462 462 plugin_id='anonymous_access',
463 463 plugin_cache_active=plugin_cache_active,
464 464 cache_ttl=cache_ttl,
465 465 )
466 466 else:
467 467 anonymous_perm = False
468 468
469 469 if not anonymous_user.active or not anonymous_perm:
470 470 if not anonymous_user.active:
471 471 log.debug('Anonymous access is disabled, running '
472 472 'authentication')
473 473
474 474 if not anonymous_perm:
475 475 log.debug('Not enough credentials to access repo: `%s` '
476 476 'repository as anonymous user', self.acl_repo_name)
477 477
478 478
479 479 username = None
480 480 # ==============================================================
481 481 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
482 482 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
483 483 # ==============================================================
484 484
485 485 # try to auth based on environ, container auth methods
486 486 log.debug('Running PRE-AUTH for container|headers based authentication')
487 487
488 488 # headers auth, by just reading special headers and bypass the auth with user/passwd
489 489 pre_auth = authenticate(
490 490 '', '', environ, VCS_TYPE, registry=self.registry,
491 491 acl_repo_name=self.acl_repo_name)
492 492
493 493 if pre_auth and pre_auth.get('username'):
494 494 username = pre_auth['username']
495 495 log.debug('PRE-AUTH got `%s` as username', username)
496 496 if pre_auth:
497 497 log.debug('PRE-AUTH successful from %s',
498 498 pre_auth.get('auth_data', {}).get('_plugin'))
499 499
500 500 # If not authenticated by the container, running basic auth
501 501 # before inject the calling repo_name for special scope checks
502 502 self.authenticate.acl_repo_name = self.acl_repo_name
503 503
504 504 plugin_cache_active, cache_ttl = False, 0
505 505 plugin = None
506 506
507 507 # regular auth chain
508 508 if not username:
509 509 self.authenticate.realm = self.authenticate.get_rc_realm()
510 510
511 511 try:
512 512 auth_result = self.authenticate(environ)
513 513 except (UserCreationError, NotAllowedToCreateUserError) as e:
514 514 log.error(e)
515 515 reason = safe_str(e)
516 516 return HTTPNotAcceptable(reason)(environ, start_response)
517 517
518 518 if isinstance(auth_result, dict):
519 519 AUTH_TYPE.update(environ, 'basic')
520 520 REMOTE_USER.update(environ, auth_result['username'])
521 521 username = auth_result['username']
522 522 plugin = auth_result.get('auth_data', {}).get('_plugin')
523 523 log.info(
524 524 'MAIN-AUTH successful for user `%s` from %s plugin',
525 525 username, plugin)
526 526
527 527 plugin_cache_active, cache_ttl = auth_result.get(
528 528 'auth_data', {}).get('_ttl_cache') or (False, 0)
529 529 else:
530 530 return auth_result.wsgi_application(environ, start_response)
531 531
532 532 # ==============================================================
533 533 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
534 534 # ==============================================================
535 535 user = User.get_by_username(username)
536 536 if not self.valid_and_active_user(user):
537 537 return HTTPForbidden()(environ, start_response)
538 538 username = user.username
539 539 user_id = user.user_id
540 540
541 541 # check user attributes for password change flag
542 542 user_obj = user
543 543 auth_user = user_obj.AuthUser()
544 544 if user_obj and user_obj.username != User.DEFAULT_USER and \
545 545 user_obj.user_data.get('force_password_change'):
546 546 reason = 'password change required'
547 547 log.debug('User not allowed to authenticate, %s', reason)
548 548 return HTTPNotAcceptable(reason)(environ, start_response)
549 549
550 550 # check permissions for this repository
551 551 perm = self._check_permission(
552 552 action, user, auth_user, self.acl_repo_name, ip_addr,
553 553 plugin, plugin_cache_active, cache_ttl)
554 554 if not perm:
555 555 return HTTPForbidden()(environ, start_response)
556 556 environ['rc_auth_user_id'] = str(user_id)
557 557
558 558 if action == 'push':
559 559 perms = auth_user.get_branch_permissions(self.acl_repo_name)
560 560 if perms:
561 561 check_branch_perms = True
562 562 detect_force_push = True
563 563
564 564 # extras are injected into UI object and later available
565 565 # in hooks executed by RhodeCode
566 566 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
567 567
568 568 extras = vcs_operation_context(
569 569 environ, repo_name=self.acl_repo_name, username=username,
570 570 action=action, scm=self.SCM, check_locking=check_locking,
571 571 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
572 572 detect_force_push=detect_force_push
573 573 )
574 574
575 575 # ======================================================================
576 576 # REQUEST HANDLING
577 577 # ======================================================================
578 578 repo_path = os.path.join(
579 579 safe_str(self.base_path), safe_str(self.vcs_repo_name))
580 580 log.debug('Repository path is %s', repo_path)
581 581
582 582 fix_PATH()
583 583
584 584 log.info(
585 585 '%s action on %s repo "%s" by "%s" from %s %s',
586 586 action, self.SCM, safe_str(self.url_repo_name),
587 587 safe_str(username), ip_addr, user_agent)
588 588
589 589 return self._generate_vcs_response(
590 590 environ, start_response, repo_path, extras, action)
591 591
592 592 @initialize_generator
593 593 def _generate_vcs_response(
594 594 self, environ, start_response, repo_path, extras, action):
595 595 """
596 596 Returns a generator for the response content.
597 597
598 598 This method is implemented as a generator, so that it can trigger
599 599 the cache validation after all content sent back to the client. It
600 600 also handles the locking exceptions which will be triggered when
601 601 the first chunk is produced by the underlying WSGI application.
602 602 """
603 603
604 604 txn_id = ''
605 605 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
606 606 # case for SVN, we want to re-use the callback daemon port
607 607 # so we use the txn_id, for this we peek the body, and still save
608 608 # it as wsgi.input
609 609
610 610 stream = environ['wsgi.input']
611 611
612 612 if isinstance(stream, io.BytesIO):
613 613 data: bytes = stream.getvalue()
614 614 elif hasattr(stream, 'buf'): # most likely gunicorn.http.body.Body
615 615 data: bytes = stream.buf.getvalue()
616 616 else:
617 617 # fallback to the crudest way, copy the iterator
618 618 data = safe_bytes(stream.read())
619 619 environ['wsgi.input'] = io.BytesIO(data)
620 620
621 621 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
622 622
623 623 callback_daemon, extras = self._prepare_callback_daemon(
624 624 extras, environ, action, txn_id=txn_id)
625 625 log.debug('HOOKS extras is %s', extras)
626 626
627 627 http_scheme = self._get_http_scheme(environ)
628 628
629 629 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
630 630 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
631 631 with callback_daemon:
632 632 app.rc_extras = extras
633 633
634 634 try:
635 635 response = app(environ, start_response)
636 636 finally:
637 637 # This statement works together with the decorator
638 638 # "initialize_generator" above. The decorator ensures that
639 639 # we hit the first yield statement before the generator is
640 640 # returned back to the WSGI server. This is needed to
641 641 # ensure that the call to "app" above triggers the
642 642 # needed callback to "start_response" before the
643 643 # generator is actually used.
644 644 yield "__init__"
645 645
646 646 # iter content
647 647 for chunk in response:
648 648 yield chunk
649 649
650 650 try:
651 651 # invalidate cache on push
652 652 if action == 'push':
653 653 self._invalidate_cache(self.url_repo_name)
654 654 finally:
655 655 meta.Session.remove()
656 656
657 657 def _get_repository_name(self, environ):
658 658 """Get repository name out of the environmnent
659 659
660 660 :param environ: WSGI environment
661 661 """
662 662 raise NotImplementedError()
663 663
664 664 def _get_action(self, environ):
665 665 """Map request commands into a pull or push command.
666 666
667 667 :param environ: WSGI environment
668 668 """
669 669 raise NotImplementedError()
670 670
671 671 def _create_wsgi_app(self, repo_path, repo_name, config):
672 672 """Return the WSGI app that will finally handle the request."""
673 673 raise NotImplementedError()
674 674
675 675 def _create_config(self, extras, repo_name, scheme='http'):
676 676 """Create a safe config representation."""
677 677 raise NotImplementedError()
678 678
679 679 def _should_use_callback_daemon(self, extras, environ, action):
680 680 if extras.get('is_shadow_repo'):
681 681 # we don't want to execute hooks, and callback daemon for shadow repos
682 682 return False
683 683 return True
684 684
685 685 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
686 686 protocol = vcs_settings.HOOKS_PROTOCOL
687 687 if not self._should_use_callback_daemon(extras, environ, action):
688 688 # disable callback daemon for actions that don't require it
689 689 protocol = 'local'
690 690
691 691 return prepare_callback_daemon(
692 692 extras, protocol=protocol,
693 693 host=vcs_settings.HOOKS_HOST, txn_id=txn_id)
694 694
695 695
696 696 def _should_check_locking(query_string):
697 697 # this is kind of hacky, but due to how mercurial handles client-server
698 698 # server see all operation on commit; bookmarks, phases and
699 699 # obsolescence marker in different transaction, we don't want to check
700 700 # locking on those
701 701 return query_string not in ['cmd=listkeys']
@@ -1,58 +1,48 b''
1 1
2 2
3 3 # Copyright (C) 2016-2023 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import configparser
23
24 from rhodecode.lib.config_utils import get_config
23 25 from pyramid.paster import bootstrap as pyramid_bootstrap, setup_logging # pragma: no cover
24 26
25 from rhodecode.lib.request import Request
26
27
28 def get_config(ini_path, **kwargs):
29 parser = configparser.ConfigParser(**kwargs)
30 parser.read(ini_path)
31 return parser
32
33
34 def get_app_config(ini_path):
35 from paste.deploy.loadwsgi import appconfig
36 return appconfig(f'config:{ini_path}', relative_to=os.getcwd())
37
38 27
39 28 def bootstrap(config_uri, options=None, env=None):
40 29 from rhodecode.lib.utils2 import AttributeDict
30 from rhodecode.lib.request import Request
41 31
42 32 if env:
43 33 os.environ.update(env)
44 34
45 35 config = get_config(config_uri)
46 36 base_url = 'http://rhodecode.local'
47 37 try:
48 38 base_url = config.get('app:main', 'app.base_url')
49 39 except (configparser.NoSectionError, configparser.NoOptionError):
50 40 pass
51 41
52 42 request = Request.blank('/', base_url=base_url)
53 43 # fake inject a running user for bootstrap request !
54 44 request.user = AttributeDict({'username': 'bootstrap-user',
55 45 'user_id': 1,
56 46 'ip_addr': '127.0.0.1'})
57 47 return pyramid_bootstrap(config_uri, request=request, options=options)
58 48
@@ -1,123 +1,124 b''
1 1 # Copyright (C) 2016-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18 import logging
19 19
20 20 import click
21 21 import pyramid.paster
22 22
23 from rhodecode.lib.pyramid_utils import bootstrap, get_app_config
23 from rhodecode.lib.pyramid_utils import bootstrap
24 from rhodecode.lib.config_utils import get_app_config
24 25 from rhodecode.lib.db_manage import DbManage
25 26 from rhodecode.lib.utils2 import get_encryption_key
26 27 from rhodecode.model.db import Session
27 28
28 29
29 30 log = logging.getLogger(__name__)
30 31
31 32
32 33 @click.command()
33 34 @click.argument('ini_path', type=click.Path(exists=True))
34 35 @click.option(
35 36 '--force-yes/--force-no', default=None,
36 37 help="Force yes/no to every question")
37 38 @click.option(
38 39 '--user',
39 40 default=None,
40 41 help='Initial super-admin username')
41 42 @click.option(
42 43 '--email',
43 44 default=None,
44 45 help='Initial super-admin email address.')
45 46 @click.option(
46 47 '--password',
47 48 default=None,
48 49 help='Initial super-admin password. Minimum 6 chars.')
49 50 @click.option(
50 51 '--api-key',
51 52 help='Initial API key for the admin user')
52 53 @click.option(
53 54 '--repos',
54 55 default=None,
55 56 help='Absolute path to storage location. This is storage for all '
56 57 'existing and future repositories, and repository groups.')
57 58 @click.option(
58 59 '--public-access/--no-public-access',
59 60 default=None,
60 61 help='Enable public access on this installation. '
61 62 'Default is public access enabled.')
62 63 @click.option(
63 64 '--skip-existing-db',
64 65 default=False,
65 66 is_flag=True,
66 67 help='Do not destroy and re-initialize the database if it already exist.')
67 68 @click.option(
68 69 '--apply-license-key',
69 70 default=False,
70 71 is_flag=True,
71 72 help='Get the license key from a license file or ENV and apply during DB creation.')
72 73 def main(ini_path, force_yes, user, email, password, api_key, repos,
73 74 public_access, skip_existing_db, apply_license_key):
74 75 return command(ini_path, force_yes, user, email, password, api_key,
75 76 repos, public_access, skip_existing_db, apply_license_key)
76 77
77 78
78 79 def command(ini_path, force_yes, user, email, password, api_key, repos,
79 80 public_access, skip_existing_db, apply_license_key):
80 81 # mapping of old parameters to new CLI from click
81 82 options = dict(
82 83 username=user,
83 84 email=email,
84 85 password=password,
85 86 api_key=api_key,
86 87 repos_location=repos,
87 88 force_ask=force_yes,
88 89 public_access=public_access
89 90 )
90 91 pyramid.paster.setup_logging(ini_path)
91 92
92 93 config = get_app_config(ini_path)
93 94
94 95 db_uri = config['sqlalchemy.db1.url']
95 96 enc_key = get_encryption_key(config)
96 97 dbmanage = DbManage(log_sql=True, dbconf=db_uri, root='.',
97 98 tests=False, cli_args=options, enc_key=enc_key)
98 99 if skip_existing_db and dbmanage.db_exists():
99 100 return
100 101
101 102 dbmanage.create_tables(override=True)
102 103 dbmanage.set_db_version()
103 104 opts = dbmanage.config_prompt(None)
104 105 dbmanage.create_settings(opts)
105 106 dbmanage.create_default_user()
106 107 dbmanage.create_admin_and_prompt()
107 108 dbmanage.create_permissions()
108 109 dbmanage.populate_default_permissions()
109 110 if apply_license_key:
110 111 try:
111 112 from rc_license.models import apply_trial_license_if_missing
112 113 apply_trial_license_if_missing(force=True)
113 114 except ImportError:
114 115 pass
115 116
116 117 Session().commit()
117 118
118 119 with bootstrap(ini_path, env={'RC_CMD_SETUP_RC': '1'}) as env:
119 120 msg = 'Successfully initialized database, schema and default data.'
120 121 print()
121 122 print('*' * len(msg))
122 123 print(msg.upper())
123 124 print('*' * len(msg))
@@ -1,859 +1,824 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 """
20 20 Utilities library for RhodeCode
21 21 """
22 22
23 23 import datetime
24 24
25 25 import decorator
26 26 import logging
27 27 import os
28 28 import re
29 29 import sys
30 30 import shutil
31 31 import socket
32 32 import tempfile
33 33 import traceback
34 34 import tarfile
35 import urllib.parse
36 import warnings
35
37 36 from functools import wraps
38 37 from os.path import join as jn
39 from configparser import NoOptionError
40 38
41 39 import paste
42 40 import pkg_resources
43 41 from webhelpers2.text import collapse, strip_tags, convert_accented_entities, convert_misc_entities
44 42
45 43 from mako import exceptions
46 44
47 45 from rhodecode.lib.hash_utils import sha256_safe, md5, sha1
48 46 from rhodecode.lib.type_utils import AttributeDict
49 47 from rhodecode.lib.str_utils import safe_bytes, safe_str
50 48 from rhodecode.lib.vcs.backends.base import Config
51 49 from rhodecode.lib.vcs.exceptions import VCSError
52 50 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
53 51 from rhodecode.lib.ext_json import sjson as json
54 52 from rhodecode.model import meta
55 53 from rhodecode.model.db import (
56 54 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
57 55 from rhodecode.model.meta import Session
58 from rhodecode.lib.pyramid_utils import get_config
59 from rhodecode.lib.vcs import CurlSession
60 from rhodecode.lib.vcs.exceptions import ImproperlyConfiguredError
61 56
62 57
63 58 log = logging.getLogger(__name__)
64 59
65 60 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
66 61
67 62 # String which contains characters that are not allowed in slug names for
68 63 # repositories or repository groups. It is properly escaped to use it in
69 64 # regular expressions.
70 65 SLUG_BAD_CHARS = re.escape(r'`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
71 66
72 67 # Regex that matches forbidden characters in repo/group slugs.
73 68 SLUG_BAD_CHAR_RE = re.compile(r'[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
74 69
75 70 # Regex that matches allowed characters in repo/group slugs.
76 71 SLUG_GOOD_CHAR_RE = re.compile(r'[^{}]'.format(SLUG_BAD_CHARS))
77 72
78 73 # Regex that matches whole repo/group slugs.
79 74 SLUG_RE = re.compile(r'[^{}]+'.format(SLUG_BAD_CHARS))
80 75
81 76 _license_cache = None
82 77
83 78
84 79 def adopt_for_celery(func):
85 80 """
86 81 Decorator designed to adopt hooks (from rhodecode.lib.hooks_base)
87 82 for further usage as a celery tasks.
88 83 """
89 84 @wraps(func)
90 85 def wrapper(extras):
91 86 extras = AttributeDict(extras)
92 87 # HooksResponse implements to_json method which must be used there.
93 88 return func(extras).to_json()
94 89 return wrapper
95 90
96 91
97 92 def repo_name_slug(value):
98 93 """
99 94 Return slug of name of repository
100 95 This function is called on each creation/modification
101 96 of repository to prevent bad names in repo
102 97 """
103 98
104 99 replacement_char = '-'
105 100
106 101 slug = strip_tags(value)
107 102 slug = convert_accented_entities(slug)
108 103 slug = convert_misc_entities(slug)
109 104
110 105 slug = SLUG_BAD_CHAR_RE.sub('', slug)
111 106 slug = re.sub(r'[\s]+', '-', slug)
112 107 slug = collapse(slug, replacement_char)
113 108
114 109 return slug
115 110
116 111
117 112 #==============================================================================
118 113 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
119 114 #==============================================================================
120 115 def get_repo_slug(request):
121 116 _repo = ''
122 117
123 118 if hasattr(request, 'db_repo_name'):
124 119 # if our requests has set db reference use it for name, this
125 120 # translates the example.com/_<id> into proper repo names
126 121 _repo = request.db_repo_name
127 122 elif getattr(request, 'matchdict', None):
128 123 # pyramid
129 124 _repo = request.matchdict.get('repo_name')
130 125
131 126 if _repo:
132 127 _repo = _repo.rstrip('/')
133 128 return _repo
134 129
135 130
136 131 def get_repo_group_slug(request):
137 132 _group = ''
138 133 if hasattr(request, 'db_repo_group'):
139 134 # if our requests has set db reference use it for name, this
140 135 # translates the example.com/_<id> into proper repo group names
141 136 _group = request.db_repo_group.group_name
142 137 elif getattr(request, 'matchdict', None):
143 138 # pyramid
144 139 _group = request.matchdict.get('repo_group_name')
145 140
146 141 if _group:
147 142 _group = _group.rstrip('/')
148 143 return _group
149 144
150 145
151 146 def get_user_group_slug(request):
152 147 _user_group = ''
153 148
154 149 if hasattr(request, 'db_user_group'):
155 150 _user_group = request.db_user_group.users_group_name
156 151 elif getattr(request, 'matchdict', None):
157 152 # pyramid
158 153 _user_group = request.matchdict.get('user_group_id')
159 154 _user_group_name = request.matchdict.get('user_group_name')
160 155 try:
161 156 if _user_group:
162 157 _user_group = UserGroup.get(_user_group)
163 158 elif _user_group_name:
164 159 _user_group = UserGroup.get_by_group_name(_user_group_name)
165 160
166 161 if _user_group:
167 162 _user_group = _user_group.users_group_name
168 163 except Exception:
169 164 log.exception('Failed to get user group by id and name')
170 165 # catch all failures here
171 166 return None
172 167
173 168 return _user_group
174 169
175 170
176 171 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
177 172 """
178 173 Scans given path for repos and return (name,(type,path)) tuple
179 174
180 175 :param path: path to scan for repositories
181 176 :param recursive: recursive search and return names with subdirs in front
182 177 """
183 178
184 179 # remove ending slash for better results
185 180 path = path.rstrip(os.sep)
186 181 log.debug('now scanning in %s location recursive:%s...', path, recursive)
187 182
188 183 def _get_repos(p):
189 184 dirpaths = get_dirpaths(p)
190 185 if not _is_dir_writable(p):
191 186 log.warning('repo path without write access: %s', p)
192 187
193 188 for dirpath in dirpaths:
194 189 if os.path.isfile(os.path.join(p, dirpath)):
195 190 continue
196 191 cur_path = os.path.join(p, dirpath)
197 192
198 193 # skip removed repos
199 194 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
200 195 continue
201 196
202 197 #skip .<somethin> dirs
203 198 if dirpath.startswith('.'):
204 199 continue
205 200
206 201 try:
207 202 scm_info = get_scm(cur_path)
208 203 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
209 204 except VCSError:
210 205 if not recursive:
211 206 continue
212 207 #check if this dir containts other repos for recursive scan
213 208 rec_path = os.path.join(p, dirpath)
214 209 if os.path.isdir(rec_path):
215 210 yield from _get_repos(rec_path)
216 211
217 212 return _get_repos(path)
218 213
219 214
220 215 def get_dirpaths(p: str) -> list:
221 216 try:
222 217 # OS-independable way of checking if we have at least read-only
223 218 # access or not.
224 219 dirpaths = os.listdir(p)
225 220 except OSError:
226 221 log.warning('ignoring repo path without read access: %s', p)
227 222 return []
228 223
229 224 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
230 225 # decode paths and suddenly returns unicode objects itself. The items it
231 226 # cannot decode are returned as strings and cause issues.
232 227 #
233 228 # Those paths are ignored here until a solid solution for path handling has
234 229 # been built.
235 230 expected_type = type(p)
236 231
237 232 def _has_correct_type(item):
238 233 if type(item) is not expected_type:
239 234 log.error(
240 235 "Ignoring path %s since it cannot be decoded into str.",
241 236 # Using "repr" to make sure that we see the byte value in case
242 237 # of support.
243 238 repr(item))
244 239 return False
245 240 return True
246 241
247 242 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
248 243
249 244 return dirpaths
250 245
251 246
252 247 def _is_dir_writable(path):
253 248 """
254 249 Probe if `path` is writable.
255 250
256 251 Due to trouble on Cygwin / Windows, this is actually probing if it is
257 252 possible to create a file inside of `path`, stat does not produce reliable
258 253 results in this case.
259 254 """
260 255 try:
261 256 with tempfile.TemporaryFile(dir=path):
262 257 pass
263 258 except OSError:
264 259 return False
265 260 return True
266 261
267 262
268 263 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
269 264 """
270 265 Returns True if given path is a valid repository False otherwise.
271 266 If expect_scm param is given also, compare if given scm is the same
272 267 as expected from scm parameter. If explicit_scm is given don't try to
273 268 detect the scm, just use the given one to check if repo is valid
274 269
275 270 :param repo_name:
276 271 :param base_path:
277 272 :param expect_scm:
278 273 :param explicit_scm:
279 274 :param config:
280 275
281 276 :return True: if given path is a valid repository
282 277 """
283 278 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
284 279 log.debug('Checking if `%s` is a valid path for repository. '
285 280 'Explicit type: %s', repo_name, explicit_scm)
286 281
287 282 try:
288 283 if explicit_scm:
289 284 detected_scms = [get_scm_backend(explicit_scm)(
290 285 full_path, config=config).alias]
291 286 else:
292 287 detected_scms = get_scm(full_path)
293 288
294 289 if expect_scm:
295 290 return detected_scms[0] == expect_scm
296 291 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
297 292 return True
298 293 except VCSError:
299 294 log.debug('path: %s is not a valid repo !', full_path)
300 295 return False
301 296
302 297
303 298 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
304 299 """
305 300 Returns True if a given path is a repository group, False otherwise
306 301
307 302 :param repo_group_name:
308 303 :param base_path:
309 304 """
310 305 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
311 306 log.debug('Checking if `%s` is a valid path for repository group',
312 307 repo_group_name)
313 308
314 309 # check if it's not a repo
315 310 if is_valid_repo(repo_group_name, base_path):
316 311 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
317 312 return False
318 313
319 314 try:
320 315 # we need to check bare git repos at higher level
321 316 # since we might match branches/hooks/info/objects or possible
322 317 # other things inside bare git repo
323 318 maybe_repo = os.path.dirname(full_path)
324 319 if maybe_repo == base_path:
325 320 # skip root level repo check; we know root location CANNOT BE a repo group
326 321 return False
327 322
328 323 scm_ = get_scm(maybe_repo)
329 324 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
330 325 return False
331 326 except VCSError:
332 327 pass
333 328
334 329 # check if it's a valid path
335 330 if skip_path_check or os.path.isdir(full_path):
336 331 log.debug('path: %s is a valid repo group !', full_path)
337 332 return True
338 333
339 334 log.debug('path: %s is not a valid repo group !', full_path)
340 335 return False
341 336
342 337
343 338 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
344 339 while True:
345 340 ok = input(prompt)
346 341 if ok.lower() in ('y', 'ye', 'yes'):
347 342 return True
348 343 if ok.lower() in ('n', 'no', 'nop', 'nope'):
349 344 return False
350 345 retries = retries - 1
351 346 if retries < 0:
352 347 raise OSError
353 348 print(complaint)
354 349
355 350 # propagated from mercurial documentation
356 351 ui_sections = [
357 352 'alias', 'auth',
358 353 'decode/encode', 'defaults',
359 354 'diff', 'email',
360 355 'extensions', 'format',
361 356 'merge-patterns', 'merge-tools',
362 357 'hooks', 'http_proxy',
363 358 'smtp', 'patch',
364 359 'paths', 'profiling',
365 360 'server', 'trusted',
366 361 'ui', 'web', ]
367 362
368 363
369 364 def config_data_from_db(clear_session=True, repo=None):
370 365 """
371 366 Read the configuration data from the database and return configuration
372 367 tuples.
373 368 """
374 369 from rhodecode.model.settings import VcsSettingsModel
375 370
376 371 config = []
377 372
378 373 sa = meta.Session()
379 374 settings_model = VcsSettingsModel(repo=repo, sa=sa)
380 375
381 376 ui_settings = settings_model.get_ui_settings()
382 377
383 378 ui_data = []
384 379 for setting in ui_settings:
385 380 if setting.active:
386 381 ui_data.append((setting.section, setting.key, setting.value))
387 382 config.append((
388 383 safe_str(setting.section), safe_str(setting.key),
389 384 safe_str(setting.value)))
390 385 if setting.key == 'push_ssl':
391 386 # force set push_ssl requirement to False, rhodecode
392 387 # handles that
393 388 config.append((
394 389 safe_str(setting.section), safe_str(setting.key), False))
395 390 log.debug(
396 391 'settings ui from db@repo[%s]: %s',
397 392 repo,
398 393 ','.join(['[{}] {}={}'.format(*s) for s in ui_data]))
399 394 if clear_session:
400 395 meta.Session.remove()
401 396
402 397 # TODO: mikhail: probably it makes no sense to re-read hooks information.
403 398 # It's already there and activated/deactivated
404 399 skip_entries = []
405 400 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
406 401 if 'pull' not in enabled_hook_classes:
407 402 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
408 403 if 'push' not in enabled_hook_classes:
409 404 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
410 405 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
411 406 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
412 407
413 408 config = [entry for entry in config if entry[:2] not in skip_entries]
414 409
415 410 return config
416 411
417 412
418 413 def make_db_config(clear_session=True, repo=None):
419 414 """
420 415 Create a :class:`Config` instance based on the values in the database.
421 416 """
422 417 config = Config()
423 418 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
424 419 for section, option, value in config_data:
425 420 config.set(section, option, value)
426 421 return config
427 422
428 423
429 424 def get_enabled_hook_classes(ui_settings):
430 425 """
431 426 Return the enabled hook classes.
432 427
433 428 :param ui_settings: List of ui_settings as returned
434 429 by :meth:`VcsSettingsModel.get_ui_settings`
435 430
436 431 :return: a list with the enabled hook classes. The order is not guaranteed.
437 432 :rtype: list
438 433 """
439 434 enabled_hooks = []
440 435 active_hook_keys = [
441 436 key for section, key, value, active in ui_settings
442 437 if section == 'hooks' and active]
443 438
444 439 hook_names = {
445 440 RhodeCodeUi.HOOK_PUSH: 'push',
446 441 RhodeCodeUi.HOOK_PULL: 'pull',
447 442 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
448 443 }
449 444
450 445 for key in active_hook_keys:
451 446 hook = hook_names.get(key)
452 447 if hook:
453 448 enabled_hooks.append(hook)
454 449
455 450 return enabled_hooks
456 451
457 452
458 453 def set_rhodecode_config(config):
459 454 """
460 455 Updates pyramid config with new settings from database
461 456
462 457 :param config:
463 458 """
464 459 from rhodecode.model.settings import SettingsModel
465 460 app_settings = SettingsModel().get_all_settings()
466 461
467 462 for k, v in list(app_settings.items()):
468 463 config[k] = v
469 464
470 465
471 466 def get_rhodecode_realm():
472 467 """
473 468 Return the rhodecode realm from database.
474 469 """
475 470 from rhodecode.model.settings import SettingsModel
476 471 realm = SettingsModel().get_setting_by_name('realm')
477 472 return safe_str(realm.app_settings_value)
478 473
479 474
480 475 def get_rhodecode_base_path():
481 476 """
482 477 Returns the base path. The base path is the filesystem path which points
483 478 to the repository store.
484 479 """
485 480
486 481 import rhodecode
487 482 return rhodecode.CONFIG['default_base_path']
488 483
489 484
490 485 def map_groups(path):
491 486 """
492 487 Given a full path to a repository, create all nested groups that this
493 488 repo is inside. This function creates parent-child relationships between
494 489 groups and creates default perms for all new groups.
495 490
496 491 :param paths: full path to repository
497 492 """
498 493 from rhodecode.model.repo_group import RepoGroupModel
499 494 sa = meta.Session()
500 495 groups = path.split(Repository.NAME_SEP)
501 496 parent = None
502 497 group = None
503 498
504 499 # last element is repo in nested groups structure
505 500 groups = groups[:-1]
506 501 rgm = RepoGroupModel(sa)
507 502 owner = User.get_first_super_admin()
508 503 for lvl, group_name in enumerate(groups):
509 504 group_name = '/'.join(groups[:lvl] + [group_name])
510 505 group = RepoGroup.get_by_group_name(group_name)
511 506 desc = '%s group' % group_name
512 507
513 508 # skip folders that are now removed repos
514 509 if REMOVED_REPO_PAT.match(group_name):
515 510 break
516 511
517 512 if group is None:
518 513 log.debug('creating group level: %s group_name: %s',
519 514 lvl, group_name)
520 515 group = RepoGroup(group_name, parent)
521 516 group.group_description = desc
522 517 group.user = owner
523 518 sa.add(group)
524 519 perm_obj = rgm._create_default_perms(group)
525 520 sa.add(perm_obj)
526 521 sa.flush()
527 522
528 523 parent = group
529 524 return group
530 525
531 526
532 527 def repo2db_mapper(initial_repo_list, remove_obsolete=False, force_hooks_rebuild=False):
533 528 """
534 529 maps all repos given in initial_repo_list, non existing repositories
535 530 are created, if remove_obsolete is True it also checks for db entries
536 531 that are not in initial_repo_list and removes them.
537 532
538 533 :param initial_repo_list: list of repositories found by scanning methods
539 534 :param remove_obsolete: check for obsolete entries in database
540 535 """
541 536 from rhodecode.model.repo import RepoModel
542 537 from rhodecode.model.repo_group import RepoGroupModel
543 538 from rhodecode.model.settings import SettingsModel
544 539
545 540 sa = meta.Session()
546 541 repo_model = RepoModel()
547 542 user = User.get_first_super_admin()
548 543 added = []
549 544
550 545 # creation defaults
551 546 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
552 547 enable_statistics = defs.get('repo_enable_statistics')
553 548 enable_locking = defs.get('repo_enable_locking')
554 549 enable_downloads = defs.get('repo_enable_downloads')
555 550 private = defs.get('repo_private')
556 551
557 552 for name, repo in list(initial_repo_list.items()):
558 553 group = map_groups(name)
559 554 str_name = safe_str(name)
560 555 db_repo = repo_model.get_by_repo_name(str_name)
561 556
562 557 # found repo that is on filesystem not in RhodeCode database
563 558 if not db_repo:
564 559 log.info('repository `%s` not found in the database, creating now', name)
565 560 added.append(name)
566 561 desc = (repo.description
567 562 if repo.description != 'unknown'
568 563 else '%s repository' % name)
569 564
570 565 db_repo = repo_model._create_repo(
571 566 repo_name=name,
572 567 repo_type=repo.alias,
573 568 description=desc,
574 569 repo_group=getattr(group, 'group_id', None),
575 570 owner=user,
576 571 enable_locking=enable_locking,
577 572 enable_downloads=enable_downloads,
578 573 enable_statistics=enable_statistics,
579 574 private=private,
580 575 state=Repository.STATE_CREATED
581 576 )
582 577 sa.commit()
583 578 # we added that repo just now, and make sure we updated server info
584 579 if db_repo.repo_type == 'git':
585 580 git_repo = db_repo.scm_instance()
586 581 # update repository server-info
587 582 log.debug('Running update server info')
588 583 git_repo._update_server_info(force=True)
589 584
590 585 db_repo.update_commit_cache()
591 586
592 587 config = db_repo._config
593 588 config.set('extensions', 'largefiles', '')
594 589 repo = db_repo.scm_instance(config=config)
595 590 repo.install_hooks(force=force_hooks_rebuild)
596 591
597 592 removed = []
598 593 if remove_obsolete:
599 594 # remove from database those repositories that are not in the filesystem
600 595 for repo in sa.query(Repository).all():
601 596 if repo.repo_name not in list(initial_repo_list.keys()):
602 597 log.debug("Removing non-existing repository found in db `%s`",
603 598 repo.repo_name)
604 599 try:
605 600 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
606 601 sa.commit()
607 602 removed.append(repo.repo_name)
608 603 except Exception:
609 604 # don't hold further removals on error
610 605 log.error(traceback.format_exc())
611 606 sa.rollback()
612 607
613 608 def splitter(full_repo_name):
614 609 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
615 610 gr_name = None
616 611 if len(_parts) == 2:
617 612 gr_name = _parts[0]
618 613 return gr_name
619 614
620 615 initial_repo_group_list = [splitter(x) for x in
621 616 list(initial_repo_list.keys()) if splitter(x)]
622 617
623 618 # remove from database those repository groups that are not in the
624 619 # filesystem due to parent child relationships we need to delete them
625 620 # in a specific order of most nested first
626 621 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
627 622 def nested_sort(gr):
628 623 return len(gr.split('/'))
629 624 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
630 625 if group_name not in initial_repo_group_list:
631 626 repo_group = RepoGroup.get_by_group_name(group_name)
632 627 if (repo_group.children.all() or
633 628 not RepoGroupModel().check_exist_filesystem(
634 629 group_name=group_name, exc_on_failure=False)):
635 630 continue
636 631
637 632 log.info(
638 633 'Removing non-existing repository group found in db `%s`',
639 634 group_name)
640 635 try:
641 636 RepoGroupModel(sa).delete(group_name, fs_remove=False)
642 637 sa.commit()
643 638 removed.append(group_name)
644 639 except Exception:
645 640 # don't hold further removals on error
646 641 log.exception(
647 642 'Unable to remove repository group `%s`',
648 643 group_name)
649 644 sa.rollback()
650 645 raise
651 646
652 647 return added, removed
653 648
654 649
655 650 def load_rcextensions(root_path):
656 651 import rhodecode
657 652 from rhodecode.config import conf
658 653
659 654 path = os.path.join(root_path)
660 655 sys.path.append(path)
661 656
662 657 try:
663 658 rcextensions = __import__('rcextensions')
664 659 except ImportError:
665 660 if os.path.isdir(os.path.join(path, 'rcextensions')):
666 661 log.warning('Unable to load rcextensions from %s', path)
667 662 rcextensions = None
668 663
669 664 if rcextensions:
670 665 log.info('Loaded rcextensions from %s...', rcextensions)
671 666 rhodecode.EXTENSIONS = rcextensions
672 667
673 668 # Additional mappings that are not present in the pygments lexers
674 669 conf.LANGUAGES_EXTENSIONS_MAP.update(
675 670 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
676 671
677 672
678 673 def get_custom_lexer(extension):
679 674 """
680 675 returns a custom lexer if it is defined in rcextensions module, or None
681 676 if there's no custom lexer defined
682 677 """
683 678 import rhodecode
684 679 from pygments import lexers
685 680
686 681 # custom override made by RhodeCode
687 682 if extension in ['mako']:
688 683 return lexers.get_lexer_by_name('html+mako')
689 684
690 685 # check if we didn't define this extension as other lexer
691 686 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
692 687 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
693 688 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
694 689 return lexers.get_lexer_by_name(_lexer_name)
695 690
696 691
697 692 #==============================================================================
698 693 # TEST FUNCTIONS AND CREATORS
699 694 #==============================================================================
700 695 def create_test_index(repo_location, config):
701 696 """
702 697 Makes default test index.
703 698 """
704 699 try:
705 700 import rc_testdata
706 701 except ImportError:
707 702 raise ImportError('Failed to import rc_testdata, '
708 703 'please make sure this package is installed from requirements_test.txt')
709 704 rc_testdata.extract_search_index(
710 705 'vcs_search_index', os.path.dirname(config['search.location']))
711 706
712 707
713 708 def create_test_directory(test_path):
714 709 """
715 710 Create test directory if it doesn't exist.
716 711 """
717 712 if not os.path.isdir(test_path):
718 713 log.debug('Creating testdir %s', test_path)
719 714 os.makedirs(test_path)
720 715
721 716
722 717 def create_test_database(test_path, config):
723 718 """
724 719 Makes a fresh database.
725 720 """
726 721 from rhodecode.lib.db_manage import DbManage
727 722 from rhodecode.lib.utils2 import get_encryption_key
728 723
729 724 # PART ONE create db
730 725 dbconf = config['sqlalchemy.db1.url']
731 726 enc_key = get_encryption_key(config)
732 727
733 728 log.debug('making test db %s', dbconf)
734 729
735 730 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
736 731 tests=True, cli_args={'force_ask': True}, enc_key=enc_key)
737 732 dbmanage.create_tables(override=True)
738 733 dbmanage.set_db_version()
739 734 # for tests dynamically set new root paths based on generated content
740 735 dbmanage.create_settings(dbmanage.config_prompt(test_path))
741 736 dbmanage.create_default_user()
742 737 dbmanage.create_test_admin_and_users()
743 738 dbmanage.create_permissions()
744 739 dbmanage.populate_default_permissions()
745 740 Session().commit()
746 741
747 742
748 743 def create_test_repositories(test_path, config):
749 744 """
750 745 Creates test repositories in the temporary directory. Repositories are
751 746 extracted from archives within the rc_testdata package.
752 747 """
753 748 import rc_testdata
754 749 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
755 750
756 751 log.debug('making test vcs repositories')
757 752
758 753 idx_path = config['search.location']
759 754 data_path = config['cache_dir']
760 755
761 756 # clean index and data
762 757 if idx_path and os.path.exists(idx_path):
763 758 log.debug('remove %s', idx_path)
764 759 shutil.rmtree(idx_path)
765 760
766 761 if data_path and os.path.exists(data_path):
767 762 log.debug('remove %s', data_path)
768 763 shutil.rmtree(data_path)
769 764
770 765 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
771 766 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
772 767
773 768 # Note: Subversion is in the process of being integrated with the system,
774 769 # until we have a properly packed version of the test svn repository, this
775 770 # tries to copy over the repo from a package "rc_testdata"
776 771 svn_repo_path = rc_testdata.get_svn_repo_archive()
777 772 with tarfile.open(svn_repo_path) as tar:
778 773 tar.extractall(jn(test_path, SVN_REPO))
779 774
780 775
781 776 def password_changed(auth_user, session):
782 777 # Never report password change in case of default user or anonymous user.
783 778 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
784 779 return False
785 780
786 781 password_hash = md5(safe_bytes(auth_user.password)) if auth_user.password else None
787 782 rhodecode_user = session.get('rhodecode_user', {})
788 783 session_password_hash = rhodecode_user.get('password', '')
789 784 return password_hash != session_password_hash
790 785
791 786
792 787 def read_opensource_licenses():
793 788 global _license_cache
794 789
795 790 if not _license_cache:
796 791 licenses = pkg_resources.resource_string(
797 792 'rhodecode', 'config/licenses.json')
798 793 _license_cache = json.loads(licenses)
799 794
800 795 return _license_cache
801 796
802 797
803 798 def generate_platform_uuid():
804 799 """
805 800 Generates platform UUID based on it's name
806 801 """
807 802 import platform
808 803
809 804 try:
810 805 uuid_list = [platform.platform()]
811 806 return sha256_safe(':'.join(uuid_list))
812 807 except Exception as e:
813 808 log.error('Failed to generate host uuid: %s', e)
814 809 return 'UNDEFINED'
815 810
816 811
817 812 def send_test_email(recipients, email_body='TEST EMAIL'):
818 813 """
819 814 Simple code for generating test emails.
820 815 Usage::
821 816
822 817 from rhodecode.lib import utils
823 818 utils.send_test_email()
824 819 """
825 820 from rhodecode.lib.celerylib import tasks, run_task
826 821
827 822 email_body = email_body_plaintext = email_body
828 823 subject = f'SUBJECT FROM: {socket.gethostname()}'
829 824 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
830
831
832 def call_service_api(ini_path, payload):
833 config = get_config(ini_path)
834 try:
835 host = config.get('app:main', 'app.service_api.host')
836 except NoOptionError:
837 raise ImproperlyConfiguredError(
838 "app.service_api.host is missing. "
839 "Please ensure that app.service_api.host and app.service_api.token are "
840 "defined inside of .ini configuration file."
841 )
842 try:
843 api_url = config.get('app:main', 'rhodecode.api.url')
844 except NoOptionError:
845 from rhodecode import api
846 log.debug('Cannot find rhodecode.api.url, setting API URL TO Default value')
847 api_url = api.DEFAULT_URL
848
849 payload.update({
850 'id': 'service',
851 'auth_token': config.get('app:main', 'app.service_api.token')
852 })
853
854 response = CurlSession().post(urllib.parse.urljoin(host, api_url), json.dumps(payload))
855
856 if response.status_code != 200:
857 raise Exception("Service API responded with error")
858
859 return json.loads(response.content)['result']
@@ -1,2389 +1,2389 b''
1 1 # Copyright (C) 2012-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19
20 20 """
21 21 pull request model for RhodeCode
22 22 """
23 23
24 24 import logging
25 25 import os
26 26
27 27 import datetime
28 28 import urllib.request
29 29 import urllib.parse
30 30 import urllib.error
31 31 import collections
32 32
33 33 import dataclasses as dataclasses
34 34 from pyramid.threadlocal import get_current_request
35 35
36 36 from rhodecode.lib.vcs.nodes import FileNode
37 37 from rhodecode.translation import lazy_ugettext
38 38 from rhodecode.lib import helpers as h, hooks_utils, diffs
39 39 from rhodecode.lib import audit_logger
40 40 from collections import OrderedDict
41 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
41 from rhodecode.lib.hook_daemon.base import prepare_callback_daemon
42 42 from rhodecode.lib.ext_json import sjson as json
43 43 from rhodecode.lib.markup_renderer import (
44 44 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
45 45 from rhodecode.lib.hash_utils import md5_safe
46 46 from rhodecode.lib.str_utils import safe_str
47 47 from rhodecode.lib.utils2 import AttributeDict, get_current_rhodecode_user
48 48 from rhodecode.lib.vcs.backends.base import (
49 49 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason,
50 50 TargetRefMissing, SourceRefMissing)
51 51 from rhodecode.lib.vcs.conf import settings as vcs_settings
52 52 from rhodecode.lib.vcs.exceptions import (
53 53 CommitDoesNotExistError, EmptyRepositoryError)
54 54 from rhodecode.model import BaseModel
55 55 from rhodecode.model.changeset_status import ChangesetStatusModel
56 56 from rhodecode.model.comment import CommentsModel
57 57 from rhodecode.model.db import (
58 58 aliased, null, lazyload, and_, or_, select, func, String, cast, PullRequest, PullRequestReviewers, ChangesetStatus,
59 59 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule, User)
60 60 from rhodecode.model.meta import Session
61 61 from rhodecode.model.notification import NotificationModel, \
62 62 EmailNotificationModel
63 63 from rhodecode.model.scm import ScmModel
64 64 from rhodecode.model.settings import VcsSettingsModel
65 65
66 66
67 67 log = logging.getLogger(__name__)
68 68
69 69
70 70 # Data structure to hold the response data when updating commits during a pull
71 71 # request update.
72 72 class UpdateResponse(object):
73 73
74 74 def __init__(self, executed, reason, new, old, common_ancestor_id,
75 75 commit_changes, source_changed, target_changed):
76 76
77 77 self.executed = executed
78 78 self.reason = reason
79 79 self.new = new
80 80 self.old = old
81 81 self.common_ancestor_id = common_ancestor_id
82 82 self.changes = commit_changes
83 83 self.source_changed = source_changed
84 84 self.target_changed = target_changed
85 85
86 86
87 87 def get_diff_info(
88 88 source_repo, source_ref, target_repo, target_ref, get_authors=False,
89 89 get_commit_authors=True):
90 90 """
91 91 Calculates detailed diff information for usage in preview of creation of a pull-request.
92 92 This is also used for default reviewers logic
93 93 """
94 94
95 95 source_scm = source_repo.scm_instance()
96 96 target_scm = target_repo.scm_instance()
97 97
98 98 ancestor_id = target_scm.get_common_ancestor(target_ref, source_ref, source_scm)
99 99 if not ancestor_id:
100 100 raise ValueError(
101 101 'cannot calculate diff info without a common ancestor. '
102 102 'Make sure both repositories are related, and have a common forking commit.')
103 103
104 104 # case here is that want a simple diff without incoming commits,
105 105 # previewing what will be merged based only on commits in the source.
106 106 log.debug('Using ancestor %s as source_ref instead of %s',
107 107 ancestor_id, source_ref)
108 108
109 109 # source of changes now is the common ancestor
110 110 source_commit = source_scm.get_commit(commit_id=ancestor_id)
111 111 # target commit becomes the source ref as it is the last commit
112 112 # for diff generation this logic gives proper diff
113 113 target_commit = source_scm.get_commit(commit_id=source_ref)
114 114
115 115 vcs_diff = \
116 116 source_scm.get_diff(commit1=source_commit, commit2=target_commit,
117 117 ignore_whitespace=False, context=3)
118 118
119 119 diff_processor = diffs.DiffProcessor(vcs_diff, diff_format='newdiff',
120 120 diff_limit=0, file_limit=0, show_full_diff=True)
121 121
122 122 _parsed = diff_processor.prepare()
123 123
124 124 all_files = []
125 125 all_files_changes = []
126 126 changed_lines = {}
127 127 stats = [0, 0]
128 128 for f in _parsed:
129 129 all_files.append(f['filename'])
130 130 all_files_changes.append({
131 131 'filename': f['filename'],
132 132 'stats': f['stats']
133 133 })
134 134 stats[0] += f['stats']['added']
135 135 stats[1] += f['stats']['deleted']
136 136
137 137 changed_lines[f['filename']] = []
138 138 if len(f['chunks']) < 2:
139 139 continue
140 140 # first line is "context" information
141 141 for chunks in f['chunks'][1:]:
142 142 for chunk in chunks['lines']:
143 143 if chunk['action'] not in ('del', 'mod'):
144 144 continue
145 145 changed_lines[f['filename']].append(chunk['old_lineno'])
146 146
147 147 commit_authors = []
148 148 user_counts = {}
149 149 email_counts = {}
150 150 author_counts = {}
151 151 _commit_cache = {}
152 152
153 153 commits = []
154 154 if get_commit_authors:
155 155 log.debug('Obtaining commit authors from set of commits')
156 156 _compare_data = target_scm.compare(
157 157 target_ref, source_ref, source_scm, merge=True,
158 158 pre_load=["author", "date", "message"]
159 159 )
160 160
161 161 for commit in _compare_data:
162 162 # NOTE(marcink): we serialize here, so we don't produce more vcsserver calls on data returned
163 163 # at this function which is later called via JSON serialization
164 164 serialized_commit = dict(
165 165 author=commit.author,
166 166 date=commit.date,
167 167 message=commit.message,
168 168 commit_id=commit.raw_id,
169 169 raw_id=commit.raw_id
170 170 )
171 171 commits.append(serialized_commit)
172 172 user = User.get_from_cs_author(serialized_commit['author'])
173 173 if user and user not in commit_authors:
174 174 commit_authors.append(user)
175 175
176 176 # lines
177 177 if get_authors:
178 178 log.debug('Calculating authors of changed files')
179 179 target_commit = source_repo.get_commit(ancestor_id)
180 180
181 181 for fname, lines in changed_lines.items():
182 182
183 183 try:
184 184 node = target_commit.get_node(fname, pre_load=["is_binary"])
185 185 except Exception:
186 186 log.exception("Failed to load node with path %s", fname)
187 187 continue
188 188
189 189 if not isinstance(node, FileNode):
190 190 continue
191 191
192 192 # NOTE(marcink): for binary node we don't do annotation, just use last author
193 193 if node.is_binary:
194 194 author = node.last_commit.author
195 195 email = node.last_commit.author_email
196 196
197 197 user = User.get_from_cs_author(author)
198 198 if user:
199 199 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
200 200 author_counts[author] = author_counts.get(author, 0) + 1
201 201 email_counts[email] = email_counts.get(email, 0) + 1
202 202
203 203 continue
204 204
205 205 for annotation in node.annotate:
206 206 line_no, commit_id, get_commit_func, line_text = annotation
207 207 if line_no in lines:
208 208 if commit_id not in _commit_cache:
209 209 _commit_cache[commit_id] = get_commit_func()
210 210 commit = _commit_cache[commit_id]
211 211 author = commit.author
212 212 email = commit.author_email
213 213 user = User.get_from_cs_author(author)
214 214 if user:
215 215 user_counts[user.user_id] = user_counts.get(user.user_id, 0) + 1
216 216 author_counts[author] = author_counts.get(author, 0) + 1
217 217 email_counts[email] = email_counts.get(email, 0) + 1
218 218
219 219 log.debug('Default reviewers processing finished')
220 220
221 221 return {
222 222 'commits': commits,
223 223 'files': all_files_changes,
224 224 'stats': stats,
225 225 'ancestor': ancestor_id,
226 226 # original authors of modified files
227 227 'original_authors': {
228 228 'users': user_counts,
229 229 'authors': author_counts,
230 230 'emails': email_counts,
231 231 },
232 232 'commit_authors': commit_authors
233 233 }
234 234
235 235
236 236 class PullRequestModel(BaseModel):
237 237
238 238 cls = PullRequest
239 239
240 240 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
241 241
242 242 UPDATE_STATUS_MESSAGES = {
243 243 UpdateFailureReason.NONE: lazy_ugettext(
244 244 'Pull request update successful.'),
245 245 UpdateFailureReason.UNKNOWN: lazy_ugettext(
246 246 'Pull request update failed because of an unknown error.'),
247 247 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
248 248 'No update needed because the source and target have not changed.'),
249 249 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
250 250 'Pull request cannot be updated because the reference type is '
251 251 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
252 252 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
253 253 'This pull request cannot be updated because the target '
254 254 'reference is missing.'),
255 255 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
256 256 'This pull request cannot be updated because the source '
257 257 'reference is missing.'),
258 258 }
259 259 REF_TYPES = ['bookmark', 'book', 'tag', 'branch']
260 260 UPDATABLE_REF_TYPES = ['bookmark', 'book', 'branch']
261 261
262 262 def __get_pull_request(self, pull_request):
263 263 return self._get_instance((
264 264 PullRequest, PullRequestVersion), pull_request)
265 265
266 266 def _check_perms(self, perms, pull_request, user, api=False):
267 267 if not api:
268 268 return h.HasRepoPermissionAny(*perms)(
269 269 user=user, repo_name=pull_request.target_repo.repo_name)
270 270 else:
271 271 return h.HasRepoPermissionAnyApi(*perms)(
272 272 user=user, repo_name=pull_request.target_repo.repo_name)
273 273
274 274 def check_user_read(self, pull_request, user, api=False):
275 275 _perms = ('repository.admin', 'repository.write', 'repository.read',)
276 276 return self._check_perms(_perms, pull_request, user, api)
277 277
278 278 def check_user_merge(self, pull_request, user, api=False):
279 279 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
280 280 return self._check_perms(_perms, pull_request, user, api)
281 281
282 282 def check_user_update(self, pull_request, user, api=False):
283 283 owner = user.user_id == pull_request.user_id
284 284 return self.check_user_merge(pull_request, user, api) or owner
285 285
286 286 def check_user_delete(self, pull_request, user):
287 287 owner = user.user_id == pull_request.user_id
288 288 _perms = ('repository.admin',)
289 289 return self._check_perms(_perms, pull_request, user) or owner
290 290
291 291 def is_user_reviewer(self, pull_request, user):
292 292 return user.user_id in [
293 293 x.user_id for x in
294 294 pull_request.get_pull_request_reviewers(PullRequestReviewers.ROLE_REVIEWER)
295 295 if x.user
296 296 ]
297 297
298 298 def check_user_change_status(self, pull_request, user, api=False):
299 299 return self.check_user_update(pull_request, user, api) \
300 300 or self.is_user_reviewer(pull_request, user)
301 301
302 302 def check_user_comment(self, pull_request, user):
303 303 owner = user.user_id == pull_request.user_id
304 304 return self.check_user_read(pull_request, user) or owner
305 305
306 306 def get(self, pull_request):
307 307 return self.__get_pull_request(pull_request)
308 308
309 309 def _prepare_get_all_query(self, repo_name, search_q=None, source=False,
310 310 statuses=None, opened_by=None, order_by=None,
311 311 order_dir='desc', only_created=False):
312 312 repo = None
313 313 if repo_name:
314 314 repo = self._get_repo(repo_name)
315 315
316 316 q = PullRequest.query()
317 317
318 318 if search_q:
319 319 like_expression = u'%{}%'.format(safe_str(search_q))
320 320 q = q.join(User, User.user_id == PullRequest.user_id)
321 321 q = q.filter(or_(
322 322 cast(PullRequest.pull_request_id, String).ilike(like_expression),
323 323 User.username.ilike(like_expression),
324 324 PullRequest.title.ilike(like_expression),
325 325 PullRequest.description.ilike(like_expression),
326 326 ))
327 327
328 328 # source or target
329 329 if repo and source:
330 330 q = q.filter(PullRequest.source_repo == repo)
331 331 elif repo:
332 332 q = q.filter(PullRequest.target_repo == repo)
333 333
334 334 # closed,opened
335 335 if statuses:
336 336 q = q.filter(PullRequest.status.in_(statuses))
337 337
338 338 # opened by filter
339 339 if opened_by:
340 340 q = q.filter(PullRequest.user_id.in_(opened_by))
341 341
342 342 # only get those that are in "created" state
343 343 if only_created:
344 344 q = q.filter(PullRequest.pull_request_state == PullRequest.STATE_CREATED)
345 345
346 346 order_map = {
347 347 'name_raw': PullRequest.pull_request_id,
348 348 'id': PullRequest.pull_request_id,
349 349 'title': PullRequest.title,
350 350 'updated_on_raw': PullRequest.updated_on,
351 351 'target_repo': PullRequest.target_repo_id
352 352 }
353 353 if order_by and order_by in order_map:
354 354 if order_dir == 'asc':
355 355 q = q.order_by(order_map[order_by].asc())
356 356 else:
357 357 q = q.order_by(order_map[order_by].desc())
358 358
359 359 return q
360 360
361 361 def count_all(self, repo_name, search_q=None, source=False, statuses=None,
362 362 opened_by=None):
363 363 """
364 364 Count the number of pull requests for a specific repository.
365 365
366 366 :param repo_name: target or source repo
367 367 :param search_q: filter by text
368 368 :param source: boolean flag to specify if repo_name refers to source
369 369 :param statuses: list of pull request statuses
370 370 :param opened_by: author user of the pull request
371 371 :returns: int number of pull requests
372 372 """
373 373 q = self._prepare_get_all_query(
374 374 repo_name, search_q=search_q, source=source, statuses=statuses,
375 375 opened_by=opened_by)
376 376
377 377 return q.count()
378 378
379 379 def get_all(self, repo_name, search_q=None, source=False, statuses=None,
380 380 opened_by=None, offset=0, length=None, order_by=None, order_dir='desc'):
381 381 """
382 382 Get all pull requests for a specific repository.
383 383
384 384 :param repo_name: target or source repo
385 385 :param search_q: filter by text
386 386 :param source: boolean flag to specify if repo_name refers to source
387 387 :param statuses: list of pull request statuses
388 388 :param opened_by: author user of the pull request
389 389 :param offset: pagination offset
390 390 :param length: length of returned list
391 391 :param order_by: order of the returned list
392 392 :param order_dir: 'asc' or 'desc' ordering direction
393 393 :returns: list of pull requests
394 394 """
395 395 q = self._prepare_get_all_query(
396 396 repo_name, search_q=search_q, source=source, statuses=statuses,
397 397 opened_by=opened_by, order_by=order_by, order_dir=order_dir)
398 398
399 399 if length:
400 400 pull_requests = q.limit(length).offset(offset).all()
401 401 else:
402 402 pull_requests = q.all()
403 403
404 404 return pull_requests
405 405
406 406 def count_awaiting_review(self, repo_name, search_q=None, statuses=None):
407 407 """
408 408 Count the number of pull requests for a specific repository that are
409 409 awaiting review.
410 410
411 411 :param repo_name: target or source repo
412 412 :param search_q: filter by text
413 413 :param statuses: list of pull request statuses
414 414 :returns: int number of pull requests
415 415 """
416 416 pull_requests = self.get_awaiting_review(
417 417 repo_name, search_q=search_q, statuses=statuses)
418 418
419 419 return len(pull_requests)
420 420
421 421 def get_awaiting_review(self, repo_name, search_q=None, statuses=None,
422 422 offset=0, length=None, order_by=None, order_dir='desc'):
423 423 """
424 424 Get all pull requests for a specific repository that are awaiting
425 425 review.
426 426
427 427 :param repo_name: target or source repo
428 428 :param search_q: filter by text
429 429 :param statuses: list of pull request statuses
430 430 :param offset: pagination offset
431 431 :param length: length of returned list
432 432 :param order_by: order of the returned list
433 433 :param order_dir: 'asc' or 'desc' ordering direction
434 434 :returns: list of pull requests
435 435 """
436 436 pull_requests = self.get_all(
437 437 repo_name, search_q=search_q, statuses=statuses,
438 438 order_by=order_by, order_dir=order_dir)
439 439
440 440 _filtered_pull_requests = []
441 441 for pr in pull_requests:
442 442 status = pr.calculated_review_status()
443 443 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
444 444 ChangesetStatus.STATUS_UNDER_REVIEW]:
445 445 _filtered_pull_requests.append(pr)
446 446 if length:
447 447 return _filtered_pull_requests[offset:offset+length]
448 448 else:
449 449 return _filtered_pull_requests
450 450
451 451 def _prepare_awaiting_my_review_review_query(
452 452 self, repo_name, user_id, search_q=None, statuses=None,
453 453 order_by=None, order_dir='desc'):
454 454
455 455 for_review_statuses = [
456 456 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
457 457 ]
458 458
459 459 pull_request_alias = aliased(PullRequest)
460 460 status_alias = aliased(ChangesetStatus)
461 461 reviewers_alias = aliased(PullRequestReviewers)
462 462 repo_alias = aliased(Repository)
463 463
464 464 last_ver_subq = Session()\
465 465 .query(func.min(ChangesetStatus.version)) \
466 466 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
467 467 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
468 468 .subquery()
469 469
470 470 q = Session().query(pull_request_alias) \
471 471 .options(lazyload(pull_request_alias.author)) \
472 472 .join(reviewers_alias,
473 473 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
474 474 .join(repo_alias,
475 475 repo_alias.repo_id == pull_request_alias.target_repo_id) \
476 476 .outerjoin(status_alias,
477 477 and_(status_alias.user_id == reviewers_alias.user_id,
478 478 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
479 479 .filter(or_(status_alias.version == null(),
480 480 status_alias.version == last_ver_subq)) \
481 481 .filter(reviewers_alias.user_id == user_id) \
482 482 .filter(repo_alias.repo_name == repo_name) \
483 483 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
484 484 .group_by(pull_request_alias)
485 485
486 486 # closed,opened
487 487 if statuses:
488 488 q = q.filter(pull_request_alias.status.in_(statuses))
489 489
490 490 if search_q:
491 491 like_expression = u'%{}%'.format(safe_str(search_q))
492 492 q = q.join(User, User.user_id == pull_request_alias.user_id)
493 493 q = q.filter(or_(
494 494 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
495 495 User.username.ilike(like_expression),
496 496 pull_request_alias.title.ilike(like_expression),
497 497 pull_request_alias.description.ilike(like_expression),
498 498 ))
499 499
500 500 order_map = {
501 501 'name_raw': pull_request_alias.pull_request_id,
502 502 'title': pull_request_alias.title,
503 503 'updated_on_raw': pull_request_alias.updated_on,
504 504 'target_repo': pull_request_alias.target_repo_id
505 505 }
506 506 if order_by and order_by in order_map:
507 507 if order_dir == 'asc':
508 508 q = q.order_by(order_map[order_by].asc())
509 509 else:
510 510 q = q.order_by(order_map[order_by].desc())
511 511
512 512 return q
513 513
514 514 def count_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None):
515 515 """
516 516 Count the number of pull requests for a specific repository that are
517 517 awaiting review from a specific user.
518 518
519 519 :param repo_name: target or source repo
520 520 :param user_id: reviewer user of the pull request
521 521 :param search_q: filter by text
522 522 :param statuses: list of pull request statuses
523 523 :returns: int number of pull requests
524 524 """
525 525 q = self._prepare_awaiting_my_review_review_query(
526 526 repo_name, user_id, search_q=search_q, statuses=statuses)
527 527 return q.count()
528 528
529 529 def get_awaiting_my_review(self, repo_name, user_id, search_q=None, statuses=None,
530 530 offset=0, length=None, order_by=None, order_dir='desc'):
531 531 """
532 532 Get all pull requests for a specific repository that are awaiting
533 533 review from a specific user.
534 534
535 535 :param repo_name: target or source repo
536 536 :param user_id: reviewer user of the pull request
537 537 :param search_q: filter by text
538 538 :param statuses: list of pull request statuses
539 539 :param offset: pagination offset
540 540 :param length: length of returned list
541 541 :param order_by: order of the returned list
542 542 :param order_dir: 'asc' or 'desc' ordering direction
543 543 :returns: list of pull requests
544 544 """
545 545
546 546 q = self._prepare_awaiting_my_review_review_query(
547 547 repo_name, user_id, search_q=search_q, statuses=statuses,
548 548 order_by=order_by, order_dir=order_dir)
549 549
550 550 if length:
551 551 pull_requests = q.limit(length).offset(offset).all()
552 552 else:
553 553 pull_requests = q.all()
554 554
555 555 return pull_requests
556 556
557 557 def _prepare_im_participating_query(self, user_id=None, statuses=None, query='',
558 558 order_by=None, order_dir='desc'):
559 559 """
560 560 return a query of pull-requests user is an creator, or he's added as a reviewer
561 561 """
562 562 q = PullRequest.query()
563 563 if user_id:
564 564
565 565 base_query = select(PullRequestReviewers)\
566 566 .where(PullRequestReviewers.user_id == user_id)\
567 567 .with_only_columns(PullRequestReviewers.pull_request_id)
568 568
569 569 user_filter = or_(
570 570 PullRequest.user_id == user_id,
571 571 PullRequest.pull_request_id.in_(base_query)
572 572 )
573 573 q = PullRequest.query().filter(user_filter)
574 574
575 575 # closed,opened
576 576 if statuses:
577 577 q = q.filter(PullRequest.status.in_(statuses))
578 578
579 579 if query:
580 580 like_expression = u'%{}%'.format(safe_str(query))
581 581 q = q.join(User, User.user_id == PullRequest.user_id)
582 582 q = q.filter(or_(
583 583 cast(PullRequest.pull_request_id, String).ilike(like_expression),
584 584 User.username.ilike(like_expression),
585 585 PullRequest.title.ilike(like_expression),
586 586 PullRequest.description.ilike(like_expression),
587 587 ))
588 588
589 589 order_map = {
590 590 'name_raw': PullRequest.pull_request_id,
591 591 'title': PullRequest.title,
592 592 'updated_on_raw': PullRequest.updated_on,
593 593 'target_repo': PullRequest.target_repo_id
594 594 }
595 595 if order_by and order_by in order_map:
596 596 if order_dir == 'asc':
597 597 q = q.order_by(order_map[order_by].asc())
598 598 else:
599 599 q = q.order_by(order_map[order_by].desc())
600 600
601 601 return q
602 602
603 603 def count_im_participating_in(self, user_id=None, statuses=None, query=''):
604 604 q = self._prepare_im_participating_query(user_id, statuses=statuses, query=query)
605 605 return q.count()
606 606
607 607 def get_im_participating_in(
608 608 self, user_id=None, statuses=None, query='', offset=0,
609 609 length=None, order_by=None, order_dir='desc'):
610 610 """
611 611 Get all Pull requests that i'm participating in as a reviewer, or i have opened
612 612 """
613 613
614 614 q = self._prepare_im_participating_query(
615 615 user_id, statuses=statuses, query=query, order_by=order_by,
616 616 order_dir=order_dir)
617 617
618 618 if length:
619 619 pull_requests = q.limit(length).offset(offset).all()
620 620 else:
621 621 pull_requests = q.all()
622 622
623 623 return pull_requests
624 624
625 625 def _prepare_participating_in_for_review_query(
626 626 self, user_id, statuses=None, query='', order_by=None, order_dir='desc'):
627 627
628 628 for_review_statuses = [
629 629 ChangesetStatus.STATUS_UNDER_REVIEW, ChangesetStatus.STATUS_NOT_REVIEWED
630 630 ]
631 631
632 632 pull_request_alias = aliased(PullRequest)
633 633 status_alias = aliased(ChangesetStatus)
634 634 reviewers_alias = aliased(PullRequestReviewers)
635 635
636 636 last_ver_subq = Session()\
637 637 .query(func.min(ChangesetStatus.version)) \
638 638 .filter(ChangesetStatus.pull_request_id == reviewers_alias.pull_request_id)\
639 639 .filter(ChangesetStatus.user_id == reviewers_alias.user_id) \
640 640 .subquery()
641 641
642 642 q = Session().query(pull_request_alias) \
643 643 .options(lazyload(pull_request_alias.author)) \
644 644 .join(reviewers_alias,
645 645 reviewers_alias.pull_request_id == pull_request_alias.pull_request_id) \
646 646 .outerjoin(status_alias,
647 647 and_(status_alias.user_id == reviewers_alias.user_id,
648 648 status_alias.pull_request_id == reviewers_alias.pull_request_id)) \
649 649 .filter(or_(status_alias.version == null(),
650 650 status_alias.version == last_ver_subq)) \
651 651 .filter(reviewers_alias.user_id == user_id) \
652 652 .filter(or_(status_alias.status == null(), status_alias.status.in_(for_review_statuses))) \
653 653 .group_by(pull_request_alias)
654 654
655 655 # closed,opened
656 656 if statuses:
657 657 q = q.filter(pull_request_alias.status.in_(statuses))
658 658
659 659 if query:
660 660 like_expression = u'%{}%'.format(safe_str(query))
661 661 q = q.join(User, User.user_id == pull_request_alias.user_id)
662 662 q = q.filter(or_(
663 663 cast(pull_request_alias.pull_request_id, String).ilike(like_expression),
664 664 User.username.ilike(like_expression),
665 665 pull_request_alias.title.ilike(like_expression),
666 666 pull_request_alias.description.ilike(like_expression),
667 667 ))
668 668
669 669 order_map = {
670 670 'name_raw': pull_request_alias.pull_request_id,
671 671 'title': pull_request_alias.title,
672 672 'updated_on_raw': pull_request_alias.updated_on,
673 673 'target_repo': pull_request_alias.target_repo_id
674 674 }
675 675 if order_by and order_by in order_map:
676 676 if order_dir == 'asc':
677 677 q = q.order_by(order_map[order_by].asc())
678 678 else:
679 679 q = q.order_by(order_map[order_by].desc())
680 680
681 681 return q
682 682
683 683 def count_im_participating_in_for_review(self, user_id, statuses=None, query=''):
684 684 q = self._prepare_participating_in_for_review_query(user_id, statuses=statuses, query=query)
685 685 return q.count()
686 686
687 687 def get_im_participating_in_for_review(
688 688 self, user_id, statuses=None, query='', offset=0,
689 689 length=None, order_by=None, order_dir='desc'):
690 690 """
691 691 Get all Pull requests that needs user approval or rejection
692 692 """
693 693
694 694 q = self._prepare_participating_in_for_review_query(
695 695 user_id, statuses=statuses, query=query, order_by=order_by,
696 696 order_dir=order_dir)
697 697
698 698 if length:
699 699 pull_requests = q.limit(length).offset(offset).all()
700 700 else:
701 701 pull_requests = q.all()
702 702
703 703 return pull_requests
704 704
705 705 def get_versions(self, pull_request):
706 706 """
707 707 returns version of pull request sorted by ID descending
708 708 """
709 709 return PullRequestVersion.query()\
710 710 .filter(PullRequestVersion.pull_request == pull_request)\
711 711 .order_by(PullRequestVersion.pull_request_version_id.asc())\
712 712 .all()
713 713
714 714 def get_pr_version(self, pull_request_id, version=None):
715 715 at_version = None
716 716
717 717 if version and version == 'latest':
718 718 pull_request_ver = PullRequest.get(pull_request_id)
719 719 pull_request_obj = pull_request_ver
720 720 _org_pull_request_obj = pull_request_obj
721 721 at_version = 'latest'
722 722 elif version:
723 723 pull_request_ver = PullRequestVersion.get_or_404(version)
724 724 pull_request_obj = pull_request_ver
725 725 _org_pull_request_obj = pull_request_ver.pull_request
726 726 at_version = pull_request_ver.pull_request_version_id
727 727 else:
728 728 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
729 729 pull_request_id)
730 730
731 731 pull_request_display_obj = PullRequest.get_pr_display_object(
732 732 pull_request_obj, _org_pull_request_obj)
733 733
734 734 return _org_pull_request_obj, pull_request_obj, \
735 735 pull_request_display_obj, at_version
736 736
737 737 def pr_commits_versions(self, versions):
738 738 """
739 739 Maps the pull-request commits into all known PR versions. This way we can obtain
740 740 each pr version the commit was introduced in.
741 741 """
742 742 commit_versions = collections.defaultdict(list)
743 743 num_versions = [x.pull_request_version_id for x in versions]
744 744 for ver in versions:
745 745 for commit_id in ver.revisions:
746 746 ver_idx = ChangesetComment.get_index_from_version(
747 747 ver.pull_request_version_id, num_versions=num_versions)
748 748 commit_versions[commit_id].append(ver_idx)
749 749 return commit_versions
750 750
751 751 def create(self, created_by, source_repo, source_ref, target_repo,
752 752 target_ref, revisions, reviewers, observers, title, description=None,
753 753 common_ancestor_id=None,
754 754 description_renderer=None,
755 755 reviewer_data=None, translator=None, auth_user=None):
756 756 translator = translator or get_current_request().translate
757 757
758 758 created_by_user = self._get_user(created_by)
759 759 auth_user = auth_user or created_by_user.AuthUser()
760 760 source_repo = self._get_repo(source_repo)
761 761 target_repo = self._get_repo(target_repo)
762 762
763 763 pull_request = PullRequest()
764 764 pull_request.source_repo = source_repo
765 765 pull_request.source_ref = source_ref
766 766 pull_request.target_repo = target_repo
767 767 pull_request.target_ref = target_ref
768 768 pull_request.revisions = revisions
769 769 pull_request.title = title
770 770 pull_request.description = description
771 771 pull_request.description_renderer = description_renderer
772 772 pull_request.author = created_by_user
773 773 pull_request.reviewer_data = reviewer_data
774 774 pull_request.pull_request_state = pull_request.STATE_CREATING
775 775 pull_request.common_ancestor_id = common_ancestor_id
776 776
777 777 Session().add(pull_request)
778 778 Session().flush()
779 779
780 780 reviewer_ids = set()
781 781 # members / reviewers
782 782 for reviewer_object in reviewers:
783 783 user_id, reasons, mandatory, role, rules = reviewer_object
784 784 user = self._get_user(user_id)
785 785
786 786 # skip duplicates
787 787 if user.user_id in reviewer_ids:
788 788 continue
789 789
790 790 reviewer_ids.add(user.user_id)
791 791
792 792 reviewer = PullRequestReviewers()
793 793 reviewer.user = user
794 794 reviewer.pull_request = pull_request
795 795 reviewer.reasons = reasons
796 796 reviewer.mandatory = mandatory
797 797 reviewer.role = role
798 798
799 799 # NOTE(marcink): pick only first rule for now
800 800 rule_id = list(rules)[0] if rules else None
801 801 rule = RepoReviewRule.get(rule_id) if rule_id else None
802 802 if rule:
803 803 review_group = rule.user_group_vote_rule(user_id)
804 804 # we check if this particular reviewer is member of a voting group
805 805 if review_group:
806 806 # NOTE(marcink):
807 807 # can be that user is member of more but we pick the first same,
808 808 # same as default reviewers algo
809 809 review_group = review_group[0]
810 810
811 811 rule_data = {
812 812 'rule_name':
813 813 rule.review_rule_name,
814 814 'rule_user_group_entry_id':
815 815 review_group.repo_review_rule_users_group_id,
816 816 'rule_user_group_name':
817 817 review_group.users_group.users_group_name,
818 818 'rule_user_group_members':
819 819 [x.user.username for x in review_group.users_group.members],
820 820 'rule_user_group_members_id':
821 821 [x.user.user_id for x in review_group.users_group.members],
822 822 }
823 823 # e.g {'vote_rule': -1, 'mandatory': True}
824 824 rule_data.update(review_group.rule_data())
825 825
826 826 reviewer.rule_data = rule_data
827 827
828 828 Session().add(reviewer)
829 829 Session().flush()
830 830
831 831 for observer_object in observers:
832 832 user_id, reasons, mandatory, role, rules = observer_object
833 833 user = self._get_user(user_id)
834 834
835 835 # skip duplicates from reviewers
836 836 if user.user_id in reviewer_ids:
837 837 continue
838 838
839 839 #reviewer_ids.add(user.user_id)
840 840
841 841 observer = PullRequestReviewers()
842 842 observer.user = user
843 843 observer.pull_request = pull_request
844 844 observer.reasons = reasons
845 845 observer.mandatory = mandatory
846 846 observer.role = role
847 847
848 848 # NOTE(marcink): pick only first rule for now
849 849 rule_id = list(rules)[0] if rules else None
850 850 rule = RepoReviewRule.get(rule_id) if rule_id else None
851 851 if rule:
852 852 # TODO(marcink): do we need this for observers ??
853 853 pass
854 854
855 855 Session().add(observer)
856 856 Session().flush()
857 857
858 858 # Set approval status to "Under Review" for all commits which are
859 859 # part of this pull request.
860 860 ChangesetStatusModel().set_status(
861 861 repo=target_repo,
862 862 status=ChangesetStatus.STATUS_UNDER_REVIEW,
863 863 user=created_by_user,
864 864 pull_request=pull_request
865 865 )
866 866 # we commit early at this point. This has to do with a fact
867 867 # that before queries do some row-locking. And because of that
868 868 # we need to commit and finish transaction before below validate call
869 869 # that for large repos could be long resulting in long row locks
870 870 Session().commit()
871 871
872 872 # prepare workspace, and run initial merge simulation. Set state during that
873 873 # operation
874 874 pull_request = PullRequest.get(pull_request.pull_request_id)
875 875
876 876 # set as merging, for merge simulation, and if finished to created so we mark
877 877 # simulation is working fine
878 878 with pull_request.set_state(PullRequest.STATE_MERGING,
879 879 final_state=PullRequest.STATE_CREATED) as state_obj:
880 880 MergeCheck.validate(
881 881 pull_request, auth_user=auth_user, translator=translator)
882 882
883 883 self.notify_reviewers(pull_request, reviewer_ids, created_by_user)
884 884 self.trigger_pull_request_hook(pull_request, created_by_user, 'create')
885 885
886 886 creation_data = pull_request.get_api_data(with_merge_state=False)
887 887 self._log_audit_action(
888 888 'repo.pull_request.create', {'data': creation_data},
889 889 auth_user, pull_request)
890 890
891 891 return pull_request
892 892
893 893 def trigger_pull_request_hook(self, pull_request, user, action, data=None):
894 894 pull_request = self.__get_pull_request(pull_request)
895 895 target_scm = pull_request.target_repo.scm_instance()
896 896 if action == 'create':
897 897 trigger_hook = hooks_utils.trigger_create_pull_request_hook
898 898 elif action == 'merge':
899 899 trigger_hook = hooks_utils.trigger_merge_pull_request_hook
900 900 elif action == 'close':
901 901 trigger_hook = hooks_utils.trigger_close_pull_request_hook
902 902 elif action == 'review_status_change':
903 903 trigger_hook = hooks_utils.trigger_review_pull_request_hook
904 904 elif action == 'update':
905 905 trigger_hook = hooks_utils.trigger_update_pull_request_hook
906 906 elif action == 'comment':
907 907 trigger_hook = hooks_utils.trigger_comment_pull_request_hook
908 908 elif action == 'comment_edit':
909 909 trigger_hook = hooks_utils.trigger_comment_pull_request_edit_hook
910 910 else:
911 911 return
912 912
913 913 log.debug('Handling pull_request %s trigger_pull_request_hook with action %s and hook: %s',
914 914 pull_request, action, trigger_hook)
915 915 trigger_hook(
916 916 username=user.username,
917 917 repo_name=pull_request.target_repo.repo_name,
918 918 repo_type=target_scm.alias,
919 919 pull_request=pull_request,
920 920 data=data)
921 921
922 922 def _get_commit_ids(self, pull_request):
923 923 """
924 924 Return the commit ids of the merged pull request.
925 925
926 926 This method is not dealing correctly yet with the lack of autoupdates
927 927 nor with the implicit target updates.
928 928 For example: if a commit in the source repo is already in the target it
929 929 will be reported anyways.
930 930 """
931 931 merge_rev = pull_request.merge_rev
932 932 if merge_rev is None:
933 933 raise ValueError('This pull request was not merged yet')
934 934
935 935 commit_ids = list(pull_request.revisions)
936 936 if merge_rev not in commit_ids:
937 937 commit_ids.append(merge_rev)
938 938
939 939 return commit_ids
940 940
941 941 def merge_repo(self, pull_request, user, extras):
942 942 repo_type = pull_request.source_repo.repo_type
943 943 log.debug("Merging pull request %s", pull_request)
944 944
945 945 extras['user_agent'] = '{}/internal-merge'.format(repo_type)
946 946 merge_state = self._merge_pull_request(pull_request, user, extras)
947 947 if merge_state.executed:
948 948 log.debug("Merge was successful, updating the pull request comments.")
949 949 self._comment_and_close_pr(pull_request, user, merge_state)
950 950
951 951 self._log_audit_action(
952 952 'repo.pull_request.merge',
953 953 {'merge_state': merge_state.__dict__},
954 954 user, pull_request)
955 955
956 956 else:
957 957 log.warning("Merge failed, not updating the pull request.")
958 958 return merge_state
959 959
960 960 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
961 961 target_vcs = pull_request.target_repo.scm_instance()
962 962 source_vcs = pull_request.source_repo.scm_instance()
963 963
964 964 message = safe_str(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
965 965 pr_id=pull_request.pull_request_id,
966 966 pr_title=pull_request.title,
967 967 pr_desc=pull_request.description,
968 968 source_repo=source_vcs.name,
969 969 source_ref_name=pull_request.source_ref_parts.name,
970 970 target_repo=target_vcs.name,
971 971 target_ref_name=pull_request.target_ref_parts.name,
972 972 )
973 973
974 974 workspace_id = self._workspace_id(pull_request)
975 975 repo_id = pull_request.target_repo.repo_id
976 976 use_rebase = self._use_rebase_for_merging(pull_request)
977 977 close_branch = self._close_branch_before_merging(pull_request)
978 978 user_name = self._user_name_for_merging(pull_request, user)
979 979
980 980 target_ref = self._refresh_reference(
981 981 pull_request.target_ref_parts, target_vcs)
982 982
983 983 callback_daemon, extras = prepare_callback_daemon(
984 984 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
985 985 host=vcs_settings.HOOKS_HOST)
986 986
987 987 with callback_daemon:
988 988 # TODO: johbo: Implement a clean way to run a config_override
989 989 # for a single call.
990 990 target_vcs.config.set(
991 991 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
992 992
993 993 merge_state = target_vcs.merge(
994 994 repo_id, workspace_id, target_ref, source_vcs,
995 995 pull_request.source_ref_parts,
996 996 user_name=user_name, user_email=user.email,
997 997 message=message, use_rebase=use_rebase,
998 998 close_branch=close_branch)
999 999
1000 1000 return merge_state
1001 1001
1002 1002 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
1003 1003 pull_request.merge_rev = merge_state.merge_ref.commit_id
1004 1004 pull_request.updated_on = datetime.datetime.now()
1005 1005 close_msg = close_msg or 'Pull request merged and closed'
1006 1006
1007 1007 CommentsModel().create(
1008 1008 text=safe_str(close_msg),
1009 1009 repo=pull_request.target_repo.repo_id,
1010 1010 user=user.user_id,
1011 1011 pull_request=pull_request.pull_request_id,
1012 1012 f_path=None,
1013 1013 line_no=None,
1014 1014 closing_pr=True
1015 1015 )
1016 1016
1017 1017 Session().add(pull_request)
1018 1018 Session().flush()
1019 1019 # TODO: paris: replace invalidation with less radical solution
1020 1020 ScmModel().mark_for_invalidation(
1021 1021 pull_request.target_repo.repo_name)
1022 1022 self.trigger_pull_request_hook(pull_request, user, 'merge')
1023 1023
1024 1024 def has_valid_update_type(self, pull_request):
1025 1025 source_ref_type = pull_request.source_ref_parts.type
1026 1026 return source_ref_type in self.REF_TYPES
1027 1027
1028 1028 def get_flow_commits(self, pull_request):
1029 1029
1030 1030 # source repo
1031 1031 source_ref_name = pull_request.source_ref_parts.name
1032 1032 source_ref_type = pull_request.source_ref_parts.type
1033 1033 source_ref_id = pull_request.source_ref_parts.commit_id
1034 1034 source_repo = pull_request.source_repo.scm_instance()
1035 1035
1036 1036 try:
1037 1037 if source_ref_type in self.REF_TYPES:
1038 1038 source_commit = source_repo.get_commit(
1039 1039 source_ref_name, reference_obj=pull_request.source_ref_parts)
1040 1040 else:
1041 1041 source_commit = source_repo.get_commit(source_ref_id)
1042 1042 except CommitDoesNotExistError:
1043 1043 raise SourceRefMissing()
1044 1044
1045 1045 # target repo
1046 1046 target_ref_name = pull_request.target_ref_parts.name
1047 1047 target_ref_type = pull_request.target_ref_parts.type
1048 1048 target_ref_id = pull_request.target_ref_parts.commit_id
1049 1049 target_repo = pull_request.target_repo.scm_instance()
1050 1050
1051 1051 try:
1052 1052 if target_ref_type in self.REF_TYPES:
1053 1053 target_commit = target_repo.get_commit(
1054 1054 target_ref_name, reference_obj=pull_request.target_ref_parts)
1055 1055 else:
1056 1056 target_commit = target_repo.get_commit(target_ref_id)
1057 1057 except CommitDoesNotExistError:
1058 1058 raise TargetRefMissing()
1059 1059
1060 1060 return source_commit, target_commit
1061 1061
1062 1062 def update_commits(self, pull_request, updating_user):
1063 1063 """
1064 1064 Get the updated list of commits for the pull request
1065 1065 and return the new pull request version and the list
1066 1066 of commits processed by this update action
1067 1067
1068 1068 updating_user is the user_object who triggered the update
1069 1069 """
1070 1070 pull_request = self.__get_pull_request(pull_request)
1071 1071 source_ref_type = pull_request.source_ref_parts.type
1072 1072 source_ref_name = pull_request.source_ref_parts.name
1073 1073 source_ref_id = pull_request.source_ref_parts.commit_id
1074 1074
1075 1075 target_ref_type = pull_request.target_ref_parts.type
1076 1076 target_ref_name = pull_request.target_ref_parts.name
1077 1077 target_ref_id = pull_request.target_ref_parts.commit_id
1078 1078
1079 1079 if not self.has_valid_update_type(pull_request):
1080 1080 log.debug("Skipping update of pull request %s due to ref type: %s",
1081 1081 pull_request, source_ref_type)
1082 1082 return UpdateResponse(
1083 1083 executed=False,
1084 1084 reason=UpdateFailureReason.WRONG_REF_TYPE,
1085 1085 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1086 1086 source_changed=False, target_changed=False)
1087 1087
1088 1088 try:
1089 1089 source_commit, target_commit = self.get_flow_commits(pull_request)
1090 1090 except SourceRefMissing:
1091 1091 return UpdateResponse(
1092 1092 executed=False,
1093 1093 reason=UpdateFailureReason.MISSING_SOURCE_REF,
1094 1094 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1095 1095 source_changed=False, target_changed=False)
1096 1096 except TargetRefMissing:
1097 1097 return UpdateResponse(
1098 1098 executed=False,
1099 1099 reason=UpdateFailureReason.MISSING_TARGET_REF,
1100 1100 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1101 1101 source_changed=False, target_changed=False)
1102 1102
1103 1103 source_changed = source_ref_id != source_commit.raw_id
1104 1104 target_changed = target_ref_id != target_commit.raw_id
1105 1105
1106 1106 if not (source_changed or target_changed):
1107 1107 log.debug("Nothing changed in pull request %s", pull_request)
1108 1108 return UpdateResponse(
1109 1109 executed=False,
1110 1110 reason=UpdateFailureReason.NO_CHANGE,
1111 1111 old=pull_request, new=None, common_ancestor_id=None, commit_changes=None,
1112 1112 source_changed=target_changed, target_changed=source_changed)
1113 1113
1114 1114 change_in_found = 'target repo' if target_changed else 'source repo'
1115 1115 log.debug('Updating pull request because of change in %s detected',
1116 1116 change_in_found)
1117 1117
1118 1118 # Finally there is a need for an update, in case of source change
1119 1119 # we create a new version, else just an update
1120 1120 if source_changed:
1121 1121 pull_request_version = self._create_version_from_snapshot(pull_request)
1122 1122 self._link_comments_to_version(pull_request_version)
1123 1123 else:
1124 1124 try:
1125 1125 ver = pull_request.versions[-1]
1126 1126 except IndexError:
1127 1127 ver = None
1128 1128
1129 1129 pull_request.pull_request_version_id = \
1130 1130 ver.pull_request_version_id if ver else None
1131 1131 pull_request_version = pull_request
1132 1132
1133 1133 source_repo = pull_request.source_repo.scm_instance()
1134 1134 target_repo = pull_request.target_repo.scm_instance()
1135 1135
1136 1136 # re-compute commit ids
1137 1137 old_commit_ids = pull_request.revisions
1138 1138 pre_load = ["author", "date", "message", "branch"]
1139 1139 commit_ranges = target_repo.compare(
1140 1140 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
1141 1141 pre_load=pre_load)
1142 1142
1143 1143 target_ref = target_commit.raw_id
1144 1144 source_ref = source_commit.raw_id
1145 1145 ancestor_commit_id = target_repo.get_common_ancestor(
1146 1146 target_ref, source_ref, source_repo)
1147 1147
1148 1148 if not ancestor_commit_id:
1149 1149 raise ValueError(
1150 1150 'cannot calculate diff info without a common ancestor. '
1151 1151 'Make sure both repositories are related, and have a common forking commit.')
1152 1152
1153 1153 pull_request.common_ancestor_id = ancestor_commit_id
1154 1154
1155 1155 pull_request.source_ref = f'{source_ref_type}:{source_ref_name}:{source_commit.raw_id}'
1156 1156 pull_request.target_ref = f'{target_ref_type}:{target_ref_name}:{ancestor_commit_id}'
1157 1157
1158 1158 pull_request.revisions = [
1159 1159 commit.raw_id for commit in reversed(commit_ranges)]
1160 1160 pull_request.updated_on = datetime.datetime.now()
1161 1161 Session().add(pull_request)
1162 1162 new_commit_ids = pull_request.revisions
1163 1163
1164 1164 old_diff_data, new_diff_data = self._generate_update_diffs(
1165 1165 pull_request, pull_request_version)
1166 1166
1167 1167 # calculate commit and file changes
1168 1168 commit_changes = self._calculate_commit_id_changes(
1169 1169 old_commit_ids, new_commit_ids)
1170 1170 file_changes = self._calculate_file_changes(
1171 1171 old_diff_data, new_diff_data)
1172 1172
1173 1173 # set comments as outdated if DIFFS changed
1174 1174 CommentsModel().outdate_comments(
1175 1175 pull_request, old_diff_data=old_diff_data,
1176 1176 new_diff_data=new_diff_data)
1177 1177
1178 1178 valid_commit_changes = (commit_changes.added or commit_changes.removed)
1179 1179 file_node_changes = (
1180 1180 file_changes.added or file_changes.modified or file_changes.removed)
1181 1181 pr_has_changes = valid_commit_changes or file_node_changes
1182 1182
1183 1183 # Add an automatic comment to the pull request, in case
1184 1184 # anything has changed
1185 1185 if pr_has_changes:
1186 1186 update_comment = CommentsModel().create(
1187 1187 text=self._render_update_message(ancestor_commit_id, commit_changes, file_changes),
1188 1188 repo=pull_request.target_repo,
1189 1189 user=pull_request.author,
1190 1190 pull_request=pull_request,
1191 1191 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
1192 1192
1193 1193 # Update status to "Under Review" for added commits
1194 1194 for commit_id in commit_changes.added:
1195 1195 ChangesetStatusModel().set_status(
1196 1196 repo=pull_request.source_repo,
1197 1197 status=ChangesetStatus.STATUS_UNDER_REVIEW,
1198 1198 comment=update_comment,
1199 1199 user=pull_request.author,
1200 1200 pull_request=pull_request,
1201 1201 revision=commit_id)
1202 1202
1203 1203 # initial commit
1204 1204 Session().commit()
1205 1205
1206 1206 if pr_has_changes:
1207 1207 # send update email to users
1208 1208 try:
1209 1209 self.notify_users(pull_request=pull_request, updating_user=updating_user,
1210 1210 ancestor_commit_id=ancestor_commit_id,
1211 1211 commit_changes=commit_changes,
1212 1212 file_changes=file_changes)
1213 1213 Session().commit()
1214 1214 except Exception:
1215 1215 log.exception('Failed to send email notification to users')
1216 1216 Session().rollback()
1217 1217
1218 1218 log.debug(
1219 1219 'Updated pull request %s, added_ids: %s, common_ids: %s, '
1220 1220 'removed_ids: %s', pull_request.pull_request_id,
1221 1221 commit_changes.added, commit_changes.common, commit_changes.removed)
1222 1222 log.debug(
1223 1223 'Updated pull request with the following file changes: %s',
1224 1224 file_changes)
1225 1225
1226 1226 log.info(
1227 1227 "Updated pull request %s from commit %s to commit %s, "
1228 1228 "stored new version %s of this pull request.",
1229 1229 pull_request.pull_request_id, source_ref_id,
1230 1230 pull_request.source_ref_parts.commit_id,
1231 1231 pull_request_version.pull_request_version_id)
1232 1232
1233 1233 self.trigger_pull_request_hook(pull_request, pull_request.author, 'update')
1234 1234
1235 1235 return UpdateResponse(
1236 1236 executed=True, reason=UpdateFailureReason.NONE,
1237 1237 old=pull_request, new=pull_request_version,
1238 1238 common_ancestor_id=ancestor_commit_id, commit_changes=commit_changes,
1239 1239 source_changed=source_changed, target_changed=target_changed)
1240 1240
1241 1241 def _create_version_from_snapshot(self, pull_request):
1242 1242 version = PullRequestVersion()
1243 1243 version.title = pull_request.title
1244 1244 version.description = pull_request.description
1245 1245 version.status = pull_request.status
1246 1246 version.pull_request_state = pull_request.pull_request_state
1247 1247 version.created_on = datetime.datetime.now()
1248 1248 version.updated_on = pull_request.updated_on
1249 1249 version.user_id = pull_request.user_id
1250 1250 version.source_repo = pull_request.source_repo
1251 1251 version.source_ref = pull_request.source_ref
1252 1252 version.target_repo = pull_request.target_repo
1253 1253 version.target_ref = pull_request.target_ref
1254 1254
1255 1255 version._last_merge_source_rev = pull_request._last_merge_source_rev
1256 1256 version._last_merge_target_rev = pull_request._last_merge_target_rev
1257 1257 version.last_merge_status = pull_request.last_merge_status
1258 1258 version.last_merge_metadata = pull_request.last_merge_metadata
1259 1259 version.shadow_merge_ref = pull_request.shadow_merge_ref
1260 1260 version.merge_rev = pull_request.merge_rev
1261 1261 version.reviewer_data = pull_request.reviewer_data
1262 1262
1263 1263 version.revisions = pull_request.revisions
1264 1264 version.common_ancestor_id = pull_request.common_ancestor_id
1265 1265 version.pull_request = pull_request
1266 1266 Session().add(version)
1267 1267 Session().flush()
1268 1268
1269 1269 return version
1270 1270
1271 1271 def _generate_update_diffs(self, pull_request, pull_request_version):
1272 1272
1273 1273 diff_context = (
1274 1274 self.DIFF_CONTEXT +
1275 1275 CommentsModel.needed_extra_diff_context())
1276 1276 hide_whitespace_changes = False
1277 1277 source_repo = pull_request_version.source_repo
1278 1278 source_ref_id = pull_request_version.source_ref_parts.commit_id
1279 1279 target_ref_id = pull_request_version.target_ref_parts.commit_id
1280 1280 old_diff = self._get_diff_from_pr_or_version(
1281 1281 source_repo, source_ref_id, target_ref_id,
1282 1282 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1283 1283
1284 1284 source_repo = pull_request.source_repo
1285 1285 source_ref_id = pull_request.source_ref_parts.commit_id
1286 1286 target_ref_id = pull_request.target_ref_parts.commit_id
1287 1287
1288 1288 new_diff = self._get_diff_from_pr_or_version(
1289 1289 source_repo, source_ref_id, target_ref_id,
1290 1290 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1291 1291
1292 1292 # NOTE: this was using diff_format='gitdiff'
1293 1293 old_diff_data = diffs.DiffProcessor(old_diff, diff_format='newdiff')
1294 1294 old_diff_data.prepare()
1295 1295 new_diff_data = diffs.DiffProcessor(new_diff, diff_format='newdiff')
1296 1296 new_diff_data.prepare()
1297 1297
1298 1298 return old_diff_data, new_diff_data
1299 1299
1300 1300 def _link_comments_to_version(self, pull_request_version):
1301 1301 """
1302 1302 Link all unlinked comments of this pull request to the given version.
1303 1303
1304 1304 :param pull_request_version: The `PullRequestVersion` to which
1305 1305 the comments shall be linked.
1306 1306
1307 1307 """
1308 1308 pull_request = pull_request_version.pull_request
1309 1309 comments = ChangesetComment.query()\
1310 1310 .filter(
1311 1311 # TODO: johbo: Should we query for the repo at all here?
1312 1312 # Pending decision on how comments of PRs are to be related
1313 1313 # to either the source repo, the target repo or no repo at all.
1314 1314 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
1315 1315 ChangesetComment.pull_request == pull_request,
1316 1316 ChangesetComment.pull_request_version == null())\
1317 1317 .order_by(ChangesetComment.comment_id.asc())
1318 1318
1319 1319 # TODO: johbo: Find out why this breaks if it is done in a bulk
1320 1320 # operation.
1321 1321 for comment in comments:
1322 1322 comment.pull_request_version_id = (
1323 1323 pull_request_version.pull_request_version_id)
1324 1324 Session().add(comment)
1325 1325
1326 1326 def _calculate_commit_id_changes(self, old_ids, new_ids):
1327 1327 added = [x for x in new_ids if x not in old_ids]
1328 1328 common = [x for x in new_ids if x in old_ids]
1329 1329 removed = [x for x in old_ids if x not in new_ids]
1330 1330 total = new_ids
1331 1331 return ChangeTuple(added, common, removed, total)
1332 1332
1333 1333 def _calculate_file_changes(self, old_diff_data, new_diff_data):
1334 1334
1335 1335 old_files = OrderedDict()
1336 1336 for diff_data in old_diff_data.parsed_diff:
1337 1337 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
1338 1338
1339 1339 added_files = []
1340 1340 modified_files = []
1341 1341 removed_files = []
1342 1342 for diff_data in new_diff_data.parsed_diff:
1343 1343 new_filename = diff_data['filename']
1344 1344 new_hash = md5_safe(diff_data['raw_diff'])
1345 1345
1346 1346 old_hash = old_files.get(new_filename)
1347 1347 if not old_hash:
1348 1348 # file is not present in old diff, we have to figure out from parsed diff
1349 1349 # operation ADD/REMOVE
1350 1350 operations_dict = diff_data['stats']['ops']
1351 1351 if diffs.DEL_FILENODE in operations_dict:
1352 1352 removed_files.append(new_filename)
1353 1353 else:
1354 1354 added_files.append(new_filename)
1355 1355 else:
1356 1356 if new_hash != old_hash:
1357 1357 modified_files.append(new_filename)
1358 1358 # now remove a file from old, since we have seen it already
1359 1359 del old_files[new_filename]
1360 1360
1361 1361 # removed files is when there are present in old, but not in NEW,
1362 1362 # since we remove old files that are present in new diff, left-overs
1363 1363 # if any should be the removed files
1364 1364 removed_files.extend(old_files.keys())
1365 1365
1366 1366 return FileChangeTuple(added_files, modified_files, removed_files)
1367 1367
1368 1368 def _render_update_message(self, ancestor_commit_id, changes, file_changes):
1369 1369 """
1370 1370 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
1371 1371 so it's always looking the same disregarding on which default
1372 1372 renderer system is using.
1373 1373
1374 1374 :param ancestor_commit_id: ancestor raw_id
1375 1375 :param changes: changes named tuple
1376 1376 :param file_changes: file changes named tuple
1377 1377
1378 1378 """
1379 1379 new_status = ChangesetStatus.get_status_lbl(
1380 1380 ChangesetStatus.STATUS_UNDER_REVIEW)
1381 1381
1382 1382 changed_files = (
1383 1383 file_changes.added + file_changes.modified + file_changes.removed)
1384 1384
1385 1385 params = {
1386 1386 'under_review_label': new_status,
1387 1387 'added_commits': changes.added,
1388 1388 'removed_commits': changes.removed,
1389 1389 'changed_files': changed_files,
1390 1390 'added_files': file_changes.added,
1391 1391 'modified_files': file_changes.modified,
1392 1392 'removed_files': file_changes.removed,
1393 1393 'ancestor_commit_id': ancestor_commit_id
1394 1394 }
1395 1395 renderer = RstTemplateRenderer()
1396 1396 return renderer.render('pull_request_update.mako', **params)
1397 1397
1398 1398 def edit(self, pull_request, title, description, description_renderer, user):
1399 1399 pull_request = self.__get_pull_request(pull_request)
1400 1400 old_data = pull_request.get_api_data(with_merge_state=False)
1401 1401 if pull_request.is_closed():
1402 1402 raise ValueError('This pull request is closed')
1403 1403 if title:
1404 1404 pull_request.title = title
1405 1405 pull_request.description = description
1406 1406 pull_request.updated_on = datetime.datetime.now()
1407 1407 pull_request.description_renderer = description_renderer
1408 1408 Session().add(pull_request)
1409 1409 self._log_audit_action(
1410 1410 'repo.pull_request.edit', {'old_data': old_data},
1411 1411 user, pull_request)
1412 1412
1413 1413 def update_reviewers(self, pull_request, reviewer_data, user):
1414 1414 """
1415 1415 Update the reviewers in the pull request
1416 1416
1417 1417 :param pull_request: the pr to update
1418 1418 :param reviewer_data: list of tuples
1419 1419 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1420 1420 :param user: current use who triggers this action
1421 1421 """
1422 1422
1423 1423 pull_request = self.__get_pull_request(pull_request)
1424 1424 if pull_request.is_closed():
1425 1425 raise ValueError('This pull request is closed')
1426 1426
1427 1427 reviewers = {}
1428 1428 for user_id, reasons, mandatory, role, rules in reviewer_data:
1429 1429 if isinstance(user_id, (int, str)):
1430 1430 user_id = self._get_user(user_id).user_id
1431 1431 reviewers[user_id] = {
1432 1432 'reasons': reasons, 'mandatory': mandatory, 'role': role}
1433 1433
1434 1434 reviewers_ids = set(reviewers.keys())
1435 1435 current_reviewers = PullRequestReviewers.get_pull_request_reviewers(
1436 1436 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_REVIEWER)
1437 1437
1438 1438 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1439 1439
1440 1440 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1441 1441 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1442 1442
1443 1443 log.debug("Adding %s reviewers", ids_to_add)
1444 1444 log.debug("Removing %s reviewers", ids_to_remove)
1445 1445 changed = False
1446 1446 added_audit_reviewers = []
1447 1447 removed_audit_reviewers = []
1448 1448
1449 1449 for uid in ids_to_add:
1450 1450 changed = True
1451 1451 _usr = self._get_user(uid)
1452 1452 reviewer = PullRequestReviewers()
1453 1453 reviewer.user = _usr
1454 1454 reviewer.pull_request = pull_request
1455 1455 reviewer.reasons = reviewers[uid]['reasons']
1456 1456 # NOTE(marcink): mandatory shouldn't be changed now
1457 1457 # reviewer.mandatory = reviewers[uid]['reasons']
1458 1458 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1459 1459 reviewer.role = PullRequestReviewers.ROLE_REVIEWER
1460 1460 Session().add(reviewer)
1461 1461 added_audit_reviewers.append(reviewer.get_dict())
1462 1462
1463 1463 for uid in ids_to_remove:
1464 1464 changed = True
1465 1465 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1466 1466 # This is an edge case that handles previous state of having the same reviewer twice.
1467 1467 # this CAN happen due to the lack of DB checks
1468 1468 reviewers = PullRequestReviewers.query()\
1469 1469 .filter(PullRequestReviewers.user_id == uid,
1470 1470 PullRequestReviewers.role == PullRequestReviewers.ROLE_REVIEWER,
1471 1471 PullRequestReviewers.pull_request == pull_request)\
1472 1472 .all()
1473 1473
1474 1474 for obj in reviewers:
1475 1475 added_audit_reviewers.append(obj.get_dict())
1476 1476 Session().delete(obj)
1477 1477
1478 1478 if changed:
1479 1479 Session().expire_all()
1480 1480 pull_request.updated_on = datetime.datetime.now()
1481 1481 Session().add(pull_request)
1482 1482
1483 1483 # finally store audit logs
1484 1484 for user_data in added_audit_reviewers:
1485 1485 self._log_audit_action(
1486 1486 'repo.pull_request.reviewer.add', {'data': user_data},
1487 1487 user, pull_request)
1488 1488 for user_data in removed_audit_reviewers:
1489 1489 self._log_audit_action(
1490 1490 'repo.pull_request.reviewer.delete', {'old_data': user_data},
1491 1491 user, pull_request)
1492 1492
1493 1493 self.notify_reviewers(pull_request, ids_to_add, user)
1494 1494 return ids_to_add, ids_to_remove
1495 1495
1496 1496 def update_observers(self, pull_request, observer_data, user):
1497 1497 """
1498 1498 Update the observers in the pull request
1499 1499
1500 1500 :param pull_request: the pr to update
1501 1501 :param observer_data: list of tuples
1502 1502 [(user, ['reason1', 'reason2'], mandatory_flag, role, [rules])]
1503 1503 :param user: current use who triggers this action
1504 1504 """
1505 1505 pull_request = self.__get_pull_request(pull_request)
1506 1506 if pull_request.is_closed():
1507 1507 raise ValueError('This pull request is closed')
1508 1508
1509 1509 observers = {}
1510 1510 for user_id, reasons, mandatory, role, rules in observer_data:
1511 1511 if isinstance(user_id, (int, str)):
1512 1512 user_id = self._get_user(user_id).user_id
1513 1513 observers[user_id] = {
1514 1514 'reasons': reasons, 'observers': mandatory, 'role': role}
1515 1515
1516 1516 observers_ids = set(observers.keys())
1517 1517 current_observers = PullRequestReviewers.get_pull_request_reviewers(
1518 1518 pull_request.pull_request_id, role=PullRequestReviewers.ROLE_OBSERVER)
1519 1519
1520 1520 current_observers_ids = set([x.user.user_id for x in current_observers])
1521 1521
1522 1522 ids_to_add = observers_ids.difference(current_observers_ids)
1523 1523 ids_to_remove = current_observers_ids.difference(observers_ids)
1524 1524
1525 1525 log.debug("Adding %s observer", ids_to_add)
1526 1526 log.debug("Removing %s observer", ids_to_remove)
1527 1527 changed = False
1528 1528 added_audit_observers = []
1529 1529 removed_audit_observers = []
1530 1530
1531 1531 for uid in ids_to_add:
1532 1532 changed = True
1533 1533 _usr = self._get_user(uid)
1534 1534 observer = PullRequestReviewers()
1535 1535 observer.user = _usr
1536 1536 observer.pull_request = pull_request
1537 1537 observer.reasons = observers[uid]['reasons']
1538 1538 # NOTE(marcink): mandatory shouldn't be changed now
1539 1539 # observer.mandatory = observer[uid]['reasons']
1540 1540
1541 1541 # NOTE(marcink): role should be hardcoded, so we won't edit it.
1542 1542 observer.role = PullRequestReviewers.ROLE_OBSERVER
1543 1543 Session().add(observer)
1544 1544 added_audit_observers.append(observer.get_dict())
1545 1545
1546 1546 for uid in ids_to_remove:
1547 1547 changed = True
1548 1548 # NOTE(marcink): we fetch "ALL" reviewers objects using .all().
1549 1549 # This is an edge case that handles previous state of having the same reviewer twice.
1550 1550 # this CAN happen due to the lack of DB checks
1551 1551 observers = PullRequestReviewers.query()\
1552 1552 .filter(PullRequestReviewers.user_id == uid,
1553 1553 PullRequestReviewers.role == PullRequestReviewers.ROLE_OBSERVER,
1554 1554 PullRequestReviewers.pull_request == pull_request)\
1555 1555 .all()
1556 1556
1557 1557 for obj in observers:
1558 1558 added_audit_observers.append(obj.get_dict())
1559 1559 Session().delete(obj)
1560 1560
1561 1561 if changed:
1562 1562 Session().expire_all()
1563 1563 pull_request.updated_on = datetime.datetime.now()
1564 1564 Session().add(pull_request)
1565 1565
1566 1566 # finally store audit logs
1567 1567 for user_data in added_audit_observers:
1568 1568 self._log_audit_action(
1569 1569 'repo.pull_request.observer.add', {'data': user_data},
1570 1570 user, pull_request)
1571 1571 for user_data in removed_audit_observers:
1572 1572 self._log_audit_action(
1573 1573 'repo.pull_request.observer.delete', {'old_data': user_data},
1574 1574 user, pull_request)
1575 1575
1576 1576 self.notify_observers(pull_request, ids_to_add, user)
1577 1577 return ids_to_add, ids_to_remove
1578 1578
1579 1579 def get_url(self, pull_request, request=None, permalink=False):
1580 1580 if not request:
1581 1581 request = get_current_request()
1582 1582
1583 1583 if permalink:
1584 1584 return request.route_url(
1585 1585 'pull_requests_global',
1586 1586 pull_request_id=pull_request.pull_request_id,)
1587 1587 else:
1588 1588 return request.route_url('pullrequest_show',
1589 1589 repo_name=safe_str(pull_request.target_repo.repo_name),
1590 1590 pull_request_id=pull_request.pull_request_id,)
1591 1591
1592 1592 def get_shadow_clone_url(self, pull_request, request=None):
1593 1593 """
1594 1594 Returns qualified url pointing to the shadow repository. If this pull
1595 1595 request is closed there is no shadow repository and ``None`` will be
1596 1596 returned.
1597 1597 """
1598 1598 if pull_request.is_closed():
1599 1599 return None
1600 1600 else:
1601 1601 pr_url = urllib.parse.unquote(self.get_url(pull_request, request=request))
1602 1602 return safe_str('{pr_url}/repository'.format(pr_url=pr_url))
1603 1603
1604 1604 def _notify_reviewers(self, pull_request, user_ids, role, user):
1605 1605 # notification to reviewers/observers
1606 1606 if not user_ids:
1607 1607 return
1608 1608
1609 1609 log.debug('Notify following %s users about pull-request %s', role, user_ids)
1610 1610
1611 1611 pull_request_obj = pull_request
1612 1612 # get the current participants of this pull request
1613 1613 recipients = user_ids
1614 1614 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1615 1615
1616 1616 pr_source_repo = pull_request_obj.source_repo
1617 1617 pr_target_repo = pull_request_obj.target_repo
1618 1618
1619 1619 pr_url = h.route_url('pullrequest_show',
1620 1620 repo_name=pr_target_repo.repo_name,
1621 1621 pull_request_id=pull_request_obj.pull_request_id,)
1622 1622
1623 1623 # set some variables for email notification
1624 1624 pr_target_repo_url = h.route_url(
1625 1625 'repo_summary', repo_name=pr_target_repo.repo_name)
1626 1626
1627 1627 pr_source_repo_url = h.route_url(
1628 1628 'repo_summary', repo_name=pr_source_repo.repo_name)
1629 1629
1630 1630 # pull request specifics
1631 1631 pull_request_commits = [
1632 1632 (x.raw_id, x.message)
1633 1633 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1634 1634
1635 1635 current_rhodecode_user = user
1636 1636 kwargs = {
1637 1637 'user': current_rhodecode_user,
1638 1638 'pull_request_author': pull_request.author,
1639 1639 'pull_request': pull_request_obj,
1640 1640 'pull_request_commits': pull_request_commits,
1641 1641
1642 1642 'pull_request_target_repo': pr_target_repo,
1643 1643 'pull_request_target_repo_url': pr_target_repo_url,
1644 1644
1645 1645 'pull_request_source_repo': pr_source_repo,
1646 1646 'pull_request_source_repo_url': pr_source_repo_url,
1647 1647
1648 1648 'pull_request_url': pr_url,
1649 1649 'thread_ids': [pr_url],
1650 1650 'user_role': role
1651 1651 }
1652 1652
1653 1653 # create notification objects, and emails
1654 1654 NotificationModel().create(
1655 1655 created_by=current_rhodecode_user,
1656 1656 notification_subject='', # Filled in based on the notification_type
1657 1657 notification_body='', # Filled in based on the notification_type
1658 1658 notification_type=notification_type,
1659 1659 recipients=recipients,
1660 1660 email_kwargs=kwargs,
1661 1661 )
1662 1662
1663 1663 def notify_reviewers(self, pull_request, reviewers_ids, user):
1664 1664 return self._notify_reviewers(pull_request, reviewers_ids,
1665 1665 PullRequestReviewers.ROLE_REVIEWER, user)
1666 1666
1667 1667 def notify_observers(self, pull_request, observers_ids, user):
1668 1668 return self._notify_reviewers(pull_request, observers_ids,
1669 1669 PullRequestReviewers.ROLE_OBSERVER, user)
1670 1670
1671 1671 def notify_users(self, pull_request, updating_user, ancestor_commit_id,
1672 1672 commit_changes, file_changes):
1673 1673
1674 1674 updating_user_id = updating_user.user_id
1675 1675 reviewers = set([x.user.user_id for x in pull_request.get_pull_request_reviewers()])
1676 1676 # NOTE(marcink): send notification to all other users except to
1677 1677 # person who updated the PR
1678 1678 recipients = reviewers.difference(set([updating_user_id]))
1679 1679
1680 1680 log.debug('Notify following recipients about pull-request update %s', recipients)
1681 1681
1682 1682 pull_request_obj = pull_request
1683 1683
1684 1684 # send email about the update
1685 1685 changed_files = (
1686 1686 file_changes.added + file_changes.modified + file_changes.removed)
1687 1687
1688 1688 pr_source_repo = pull_request_obj.source_repo
1689 1689 pr_target_repo = pull_request_obj.target_repo
1690 1690
1691 1691 pr_url = h.route_url('pullrequest_show',
1692 1692 repo_name=pr_target_repo.repo_name,
1693 1693 pull_request_id=pull_request_obj.pull_request_id,)
1694 1694
1695 1695 # set some variables for email notification
1696 1696 pr_target_repo_url = h.route_url(
1697 1697 'repo_summary', repo_name=pr_target_repo.repo_name)
1698 1698
1699 1699 pr_source_repo_url = h.route_url(
1700 1700 'repo_summary', repo_name=pr_source_repo.repo_name)
1701 1701
1702 1702 email_kwargs = {
1703 1703 'date': datetime.datetime.now(),
1704 1704 'updating_user': updating_user,
1705 1705
1706 1706 'pull_request': pull_request_obj,
1707 1707
1708 1708 'pull_request_target_repo': pr_target_repo,
1709 1709 'pull_request_target_repo_url': pr_target_repo_url,
1710 1710
1711 1711 'pull_request_source_repo': pr_source_repo,
1712 1712 'pull_request_source_repo_url': pr_source_repo_url,
1713 1713
1714 1714 'pull_request_url': pr_url,
1715 1715
1716 1716 'ancestor_commit_id': ancestor_commit_id,
1717 1717 'added_commits': commit_changes.added,
1718 1718 'removed_commits': commit_changes.removed,
1719 1719 'changed_files': changed_files,
1720 1720 'added_files': file_changes.added,
1721 1721 'modified_files': file_changes.modified,
1722 1722 'removed_files': file_changes.removed,
1723 1723 'thread_ids': [pr_url],
1724 1724 }
1725 1725
1726 1726 # create notification objects, and emails
1727 1727 NotificationModel().create(
1728 1728 created_by=updating_user,
1729 1729 notification_subject='', # Filled in based on the notification_type
1730 1730 notification_body='', # Filled in based on the notification_type
1731 1731 notification_type=EmailNotificationModel.TYPE_PULL_REQUEST_UPDATE,
1732 1732 recipients=recipients,
1733 1733 email_kwargs=email_kwargs,
1734 1734 )
1735 1735
1736 1736 def delete(self, pull_request, user=None):
1737 1737 if not user:
1738 1738 user = getattr(get_current_rhodecode_user(), 'username', None)
1739 1739
1740 1740 pull_request = self.__get_pull_request(pull_request)
1741 1741 old_data = pull_request.get_api_data(with_merge_state=False)
1742 1742 self._cleanup_merge_workspace(pull_request)
1743 1743 self._log_audit_action(
1744 1744 'repo.pull_request.delete', {'old_data': old_data},
1745 1745 user, pull_request)
1746 1746 Session().delete(pull_request)
1747 1747
1748 1748 def close_pull_request(self, pull_request, user):
1749 1749 pull_request = self.__get_pull_request(pull_request)
1750 1750 self._cleanup_merge_workspace(pull_request)
1751 1751 pull_request.status = PullRequest.STATUS_CLOSED
1752 1752 pull_request.updated_on = datetime.datetime.now()
1753 1753 Session().add(pull_request)
1754 1754 self.trigger_pull_request_hook(pull_request, pull_request.author, 'close')
1755 1755
1756 1756 pr_data = pull_request.get_api_data(with_merge_state=False)
1757 1757 self._log_audit_action(
1758 1758 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1759 1759
1760 1760 def close_pull_request_with_comment(
1761 1761 self, pull_request, user, repo, message=None, auth_user=None):
1762 1762
1763 1763 pull_request_review_status = pull_request.calculated_review_status()
1764 1764
1765 1765 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1766 1766 # approved only if we have voting consent
1767 1767 status = ChangesetStatus.STATUS_APPROVED
1768 1768 else:
1769 1769 status = ChangesetStatus.STATUS_REJECTED
1770 1770 status_lbl = ChangesetStatus.get_status_lbl(status)
1771 1771
1772 1772 default_message = (
1773 1773 'Closing with status change {transition_icon} {status}.'
1774 1774 ).format(transition_icon='>', status=status_lbl)
1775 1775 text = message or default_message
1776 1776
1777 1777 # create a comment, and link it to new status
1778 1778 comment = CommentsModel().create(
1779 1779 text=text,
1780 1780 repo=repo.repo_id,
1781 1781 user=user.user_id,
1782 1782 pull_request=pull_request.pull_request_id,
1783 1783 status_change=status_lbl,
1784 1784 status_change_type=status,
1785 1785 closing_pr=True,
1786 1786 auth_user=auth_user,
1787 1787 )
1788 1788
1789 1789 # calculate old status before we change it
1790 1790 old_calculated_status = pull_request.calculated_review_status()
1791 1791 ChangesetStatusModel().set_status(
1792 1792 repo.repo_id,
1793 1793 status,
1794 1794 user.user_id,
1795 1795 comment=comment,
1796 1796 pull_request=pull_request.pull_request_id
1797 1797 )
1798 1798
1799 1799 Session().flush()
1800 1800
1801 1801 self.trigger_pull_request_hook(pull_request, user, 'comment',
1802 1802 data={'comment': comment})
1803 1803
1804 1804 # we now calculate the status of pull request again, and based on that
1805 1805 # calculation trigger status change. This might happen in cases
1806 1806 # that non-reviewer admin closes a pr, which means his vote doesn't
1807 1807 # change the status, while if he's a reviewer this might change it.
1808 1808 calculated_status = pull_request.calculated_review_status()
1809 1809 if old_calculated_status != calculated_status:
1810 1810 self.trigger_pull_request_hook(pull_request, user, 'review_status_change',
1811 1811 data={'status': calculated_status})
1812 1812
1813 1813 # finally close the PR
1814 1814 PullRequestModel().close_pull_request(pull_request.pull_request_id, user)
1815 1815
1816 1816 return comment, status
1817 1817
1818 1818 def merge_status(self, pull_request, translator=None, force_shadow_repo_refresh=False):
1819 1819 _ = translator or get_current_request().translate
1820 1820
1821 1821 if not self._is_merge_enabled(pull_request):
1822 1822 return None, False, _('Server-side pull request merging is disabled.')
1823 1823
1824 1824 if pull_request.is_closed():
1825 1825 return None, False, _('This pull request is closed.')
1826 1826
1827 1827 merge_possible, msg = self._check_repo_requirements(
1828 1828 target=pull_request.target_repo, source=pull_request.source_repo,
1829 1829 translator=_)
1830 1830 if not merge_possible:
1831 1831 return None, merge_possible, msg
1832 1832
1833 1833 try:
1834 1834 merge_response = self._try_merge(
1835 1835 pull_request, force_shadow_repo_refresh=force_shadow_repo_refresh)
1836 1836 log.debug("Merge response: %s", merge_response)
1837 1837 return merge_response, merge_response.possible, merge_response.merge_status_message
1838 1838 except NotImplementedError:
1839 1839 return None, False, _('Pull request merging is not supported.')
1840 1840
1841 1841 def _check_repo_requirements(self, target, source, translator):
1842 1842 """
1843 1843 Check if `target` and `source` have compatible requirements.
1844 1844
1845 1845 Currently this is just checking for largefiles.
1846 1846 """
1847 1847 _ = translator
1848 1848 target_has_largefiles = self._has_largefiles(target)
1849 1849 source_has_largefiles = self._has_largefiles(source)
1850 1850 merge_possible = True
1851 1851 message = u''
1852 1852
1853 1853 if target_has_largefiles != source_has_largefiles:
1854 1854 merge_possible = False
1855 1855 if source_has_largefiles:
1856 1856 message = _(
1857 1857 'Target repository large files support is disabled.')
1858 1858 else:
1859 1859 message = _(
1860 1860 'Source repository large files support is disabled.')
1861 1861
1862 1862 return merge_possible, message
1863 1863
1864 1864 def _has_largefiles(self, repo):
1865 1865 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1866 1866 'extensions', 'largefiles')
1867 1867 return largefiles_ui and largefiles_ui[0].active
1868 1868
1869 1869 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1870 1870 """
1871 1871 Try to merge the pull request and return the merge status.
1872 1872 """
1873 1873 log.debug(
1874 1874 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1875 1875 pull_request.pull_request_id, force_shadow_repo_refresh)
1876 1876 target_vcs = pull_request.target_repo.scm_instance()
1877 1877 # Refresh the target reference.
1878 1878 try:
1879 1879 target_ref = self._refresh_reference(
1880 1880 pull_request.target_ref_parts, target_vcs)
1881 1881 except CommitDoesNotExistError:
1882 1882 merge_state = MergeResponse(
1883 1883 False, False, None, MergeFailureReason.MISSING_TARGET_REF,
1884 1884 metadata={'target_ref': pull_request.target_ref_parts})
1885 1885 return merge_state
1886 1886
1887 1887 target_locked = pull_request.target_repo.locked
1888 1888 if target_locked and target_locked[0]:
1889 1889 locked_by = 'user:{}'.format(target_locked[0])
1890 1890 log.debug("The target repository is locked by %s.", locked_by)
1891 1891 merge_state = MergeResponse(
1892 1892 False, False, None, MergeFailureReason.TARGET_IS_LOCKED,
1893 1893 metadata={'locked_by': locked_by})
1894 1894 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1895 1895 pull_request, target_ref):
1896 1896 log.debug("Refreshing the merge status of the repository.")
1897 1897 merge_state = self._refresh_merge_state(
1898 1898 pull_request, target_vcs, target_ref)
1899 1899 else:
1900 1900 possible = pull_request.last_merge_status == MergeFailureReason.NONE
1901 1901 metadata = {
1902 1902 'unresolved_files': '',
1903 1903 'target_ref': pull_request.target_ref_parts,
1904 1904 'source_ref': pull_request.source_ref_parts,
1905 1905 }
1906 1906 if pull_request.last_merge_metadata:
1907 1907 metadata.update(pull_request.last_merge_metadata_parsed)
1908 1908
1909 1909 if not possible and target_ref.type == 'branch':
1910 1910 # NOTE(marcink): case for mercurial multiple heads on branch
1911 1911 heads = target_vcs._heads(target_ref.name)
1912 1912 if len(heads) != 1:
1913 1913 heads = '\n,'.join(target_vcs._heads(target_ref.name))
1914 1914 metadata.update({
1915 1915 'heads': heads
1916 1916 })
1917 1917
1918 1918 merge_state = MergeResponse(
1919 1919 possible, False, None, pull_request.last_merge_status, metadata=metadata)
1920 1920
1921 1921 return merge_state
1922 1922
1923 1923 def _refresh_reference(self, reference, vcs_repository):
1924 1924 if reference.type in self.UPDATABLE_REF_TYPES:
1925 1925 name_or_id = reference.name
1926 1926 else:
1927 1927 name_or_id = reference.commit_id
1928 1928
1929 1929 refreshed_commit = vcs_repository.get_commit(name_or_id)
1930 1930 refreshed_reference = Reference(
1931 1931 reference.type, reference.name, refreshed_commit.raw_id)
1932 1932 return refreshed_reference
1933 1933
1934 1934 def _needs_merge_state_refresh(self, pull_request, target_reference):
1935 1935 return not(
1936 1936 pull_request.revisions and
1937 1937 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1938 1938 target_reference.commit_id == pull_request._last_merge_target_rev)
1939 1939
1940 1940 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1941 1941 workspace_id = self._workspace_id(pull_request)
1942 1942 source_vcs = pull_request.source_repo.scm_instance()
1943 1943 repo_id = pull_request.target_repo.repo_id
1944 1944 use_rebase = self._use_rebase_for_merging(pull_request)
1945 1945 close_branch = self._close_branch_before_merging(pull_request)
1946 1946 merge_state = target_vcs.merge(
1947 1947 repo_id, workspace_id,
1948 1948 target_reference, source_vcs, pull_request.source_ref_parts,
1949 1949 dry_run=True, use_rebase=use_rebase,
1950 1950 close_branch=close_branch)
1951 1951
1952 1952 # Do not store the response if there was an unknown error.
1953 1953 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1954 1954 pull_request._last_merge_source_rev = \
1955 1955 pull_request.source_ref_parts.commit_id
1956 1956 pull_request._last_merge_target_rev = target_reference.commit_id
1957 1957 pull_request.last_merge_status = merge_state.failure_reason
1958 1958 pull_request.last_merge_metadata = merge_state.metadata
1959 1959
1960 1960 pull_request.shadow_merge_ref = merge_state.merge_ref
1961 1961 Session().add(pull_request)
1962 1962 Session().commit()
1963 1963
1964 1964 return merge_state
1965 1965
1966 1966 def _workspace_id(self, pull_request):
1967 1967 workspace_id = 'pr-%s' % pull_request.pull_request_id
1968 1968 return workspace_id
1969 1969
1970 1970 def generate_repo_data(self, repo, commit_id=None, branch=None,
1971 1971 bookmark=None, translator=None):
1972 1972 from rhodecode.model.repo import RepoModel
1973 1973
1974 1974 all_refs, selected_ref = \
1975 1975 self._get_repo_pullrequest_sources(
1976 1976 repo.scm_instance(), commit_id=commit_id,
1977 1977 branch=branch, bookmark=bookmark, translator=translator)
1978 1978
1979 1979 refs_select2 = []
1980 1980 for element in all_refs:
1981 1981 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1982 1982 refs_select2.append({'text': element[1], 'children': children})
1983 1983
1984 1984 return {
1985 1985 'user': {
1986 1986 'user_id': repo.user.user_id,
1987 1987 'username': repo.user.username,
1988 1988 'firstname': repo.user.first_name,
1989 1989 'lastname': repo.user.last_name,
1990 1990 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1991 1991 },
1992 1992 'name': repo.repo_name,
1993 1993 'link': RepoModel().get_url(repo),
1994 1994 'description': h.chop_at_smart(repo.description_safe, '\n'),
1995 1995 'refs': {
1996 1996 'all_refs': all_refs,
1997 1997 'selected_ref': selected_ref,
1998 1998 'select2_refs': refs_select2
1999 1999 }
2000 2000 }
2001 2001
2002 2002 def generate_pullrequest_title(self, source, source_ref, target):
2003 2003 return u'{source}#{at_ref} to {target}'.format(
2004 2004 source=source,
2005 2005 at_ref=source_ref,
2006 2006 target=target,
2007 2007 )
2008 2008
2009 2009 def _cleanup_merge_workspace(self, pull_request):
2010 2010 # Merging related cleanup
2011 2011 repo_id = pull_request.target_repo.repo_id
2012 2012 target_scm = pull_request.target_repo.scm_instance()
2013 2013 workspace_id = self._workspace_id(pull_request)
2014 2014
2015 2015 try:
2016 2016 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
2017 2017 except NotImplementedError:
2018 2018 pass
2019 2019
2020 2020 def _get_repo_pullrequest_sources(
2021 2021 self, repo, commit_id=None, branch=None, bookmark=None,
2022 2022 translator=None):
2023 2023 """
2024 2024 Return a structure with repo's interesting commits, suitable for
2025 2025 the selectors in pullrequest controller
2026 2026
2027 2027 :param commit_id: a commit that must be in the list somehow
2028 2028 and selected by default
2029 2029 :param branch: a branch that must be in the list and selected
2030 2030 by default - even if closed
2031 2031 :param bookmark: a bookmark that must be in the list and selected
2032 2032 """
2033 2033 _ = translator or get_current_request().translate
2034 2034
2035 2035 commit_id = safe_str(commit_id) if commit_id else None
2036 2036 branch = safe_str(branch) if branch else None
2037 2037 bookmark = safe_str(bookmark) if bookmark else None
2038 2038
2039 2039 selected = None
2040 2040
2041 2041 # order matters: first source that has commit_id in it will be selected
2042 2042 sources = []
2043 2043 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
2044 2044 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
2045 2045
2046 2046 if commit_id:
2047 2047 ref_commit = (h.short_id(commit_id), commit_id)
2048 2048 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
2049 2049
2050 2050 sources.append(
2051 2051 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
2052 2052 )
2053 2053
2054 2054 groups = []
2055 2055
2056 2056 for group_key, ref_list, group_name, match in sources:
2057 2057 group_refs = []
2058 2058 for ref_name, ref_id in ref_list:
2059 2059 ref_key = u'{}:{}:{}'.format(group_key, ref_name, ref_id)
2060 2060 group_refs.append((ref_key, ref_name))
2061 2061
2062 2062 if not selected:
2063 2063 if set([commit_id, match]) & set([ref_id, ref_name]):
2064 2064 selected = ref_key
2065 2065
2066 2066 if group_refs:
2067 2067 groups.append((group_refs, group_name))
2068 2068
2069 2069 if not selected:
2070 2070 ref = commit_id or branch or bookmark
2071 2071 if ref:
2072 2072 raise CommitDoesNotExistError(
2073 2073 u'No commit refs could be found matching: {}'.format(ref))
2074 2074 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
2075 2075 selected = u'branch:{}:{}'.format(
2076 2076 safe_str(repo.DEFAULT_BRANCH_NAME),
2077 2077 safe_str(repo.branches[repo.DEFAULT_BRANCH_NAME])
2078 2078 )
2079 2079 elif repo.commit_ids:
2080 2080 # make the user select in this case
2081 2081 selected = None
2082 2082 else:
2083 2083 raise EmptyRepositoryError()
2084 2084 return groups, selected
2085 2085
2086 2086 def get_diff(self, source_repo, source_ref_id, target_ref_id,
2087 2087 hide_whitespace_changes, diff_context):
2088 2088
2089 2089 return self._get_diff_from_pr_or_version(
2090 2090 source_repo, source_ref_id, target_ref_id,
2091 2091 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
2092 2092
2093 2093 def _get_diff_from_pr_or_version(
2094 2094 self, source_repo, source_ref_id, target_ref_id,
2095 2095 hide_whitespace_changes, diff_context):
2096 2096
2097 2097 target_commit = source_repo.get_commit(
2098 2098 commit_id=safe_str(target_ref_id))
2099 2099 source_commit = source_repo.get_commit(
2100 2100 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
2101 2101 if isinstance(source_repo, Repository):
2102 2102 vcs_repo = source_repo.scm_instance()
2103 2103 else:
2104 2104 vcs_repo = source_repo
2105 2105
2106 2106 # TODO: johbo: In the context of an update, we cannot reach
2107 2107 # the old commit anymore with our normal mechanisms. It needs
2108 2108 # some sort of special support in the vcs layer to avoid this
2109 2109 # workaround.
2110 2110 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
2111 2111 vcs_repo.alias == 'git'):
2112 2112 source_commit.raw_id = safe_str(source_ref_id)
2113 2113
2114 2114 log.debug('calculating diff between '
2115 2115 'source_ref:%s and target_ref:%s for repo `%s`',
2116 2116 target_ref_id, source_ref_id,
2117 2117 safe_str(vcs_repo.path))
2118 2118
2119 2119 vcs_diff = vcs_repo.get_diff(
2120 2120 commit1=target_commit, commit2=source_commit,
2121 2121 ignore_whitespace=hide_whitespace_changes, context=diff_context)
2122 2122 return vcs_diff
2123 2123
2124 2124 def _is_merge_enabled(self, pull_request):
2125 2125 return self._get_general_setting(
2126 2126 pull_request, 'rhodecode_pr_merge_enabled')
2127 2127
2128 2128 def _use_rebase_for_merging(self, pull_request):
2129 2129 repo_type = pull_request.target_repo.repo_type
2130 2130 if repo_type == 'hg':
2131 2131 return self._get_general_setting(
2132 2132 pull_request, 'rhodecode_hg_use_rebase_for_merging')
2133 2133 elif repo_type == 'git':
2134 2134 return self._get_general_setting(
2135 2135 pull_request, 'rhodecode_git_use_rebase_for_merging')
2136 2136
2137 2137 return False
2138 2138
2139 2139 def _user_name_for_merging(self, pull_request, user):
2140 2140 env_user_name_attr = os.environ.get('RC_MERGE_USER_NAME_ATTR', '')
2141 2141 if env_user_name_attr and hasattr(user, env_user_name_attr):
2142 2142 user_name_attr = env_user_name_attr
2143 2143 else:
2144 2144 user_name_attr = 'short_contact'
2145 2145
2146 2146 user_name = getattr(user, user_name_attr)
2147 2147 return user_name
2148 2148
2149 2149 def _close_branch_before_merging(self, pull_request):
2150 2150 repo_type = pull_request.target_repo.repo_type
2151 2151 if repo_type == 'hg':
2152 2152 return self._get_general_setting(
2153 2153 pull_request, 'rhodecode_hg_close_branch_before_merging')
2154 2154 elif repo_type == 'git':
2155 2155 return self._get_general_setting(
2156 2156 pull_request, 'rhodecode_git_close_branch_before_merging')
2157 2157
2158 2158 return False
2159 2159
2160 2160 def _get_general_setting(self, pull_request, settings_key, default=False):
2161 2161 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
2162 2162 settings = settings_model.get_general_settings()
2163 2163 return settings.get(settings_key, default)
2164 2164
2165 2165 def _log_audit_action(self, action, action_data, user, pull_request):
2166 2166 audit_logger.store(
2167 2167 action=action,
2168 2168 action_data=action_data,
2169 2169 user=user,
2170 2170 repo=pull_request.target_repo)
2171 2171
2172 2172 def get_reviewer_functions(self):
2173 2173 """
2174 2174 Fetches functions for validation and fetching default reviewers.
2175 2175 If available we use the EE package, else we fallback to CE
2176 2176 package functions
2177 2177 """
2178 2178 try:
2179 2179 from rc_reviewers.utils import get_default_reviewers_data
2180 2180 from rc_reviewers.utils import validate_default_reviewers
2181 2181 from rc_reviewers.utils import validate_observers
2182 2182 except ImportError:
2183 2183 from rhodecode.apps.repository.utils import get_default_reviewers_data
2184 2184 from rhodecode.apps.repository.utils import validate_default_reviewers
2185 2185 from rhodecode.apps.repository.utils import validate_observers
2186 2186
2187 2187 return get_default_reviewers_data, validate_default_reviewers, validate_observers
2188 2188
2189 2189
2190 2190 class MergeCheck(object):
2191 2191 """
2192 2192 Perform Merge Checks and returns a check object which stores information
2193 2193 about merge errors, and merge conditions
2194 2194 """
2195 2195 TODO_CHECK = 'todo'
2196 2196 PERM_CHECK = 'perm'
2197 2197 REVIEW_CHECK = 'review'
2198 2198 MERGE_CHECK = 'merge'
2199 2199 WIP_CHECK = 'wip'
2200 2200
2201 2201 def __init__(self):
2202 2202 self.review_status = None
2203 2203 self.merge_possible = None
2204 2204 self.merge_msg = ''
2205 2205 self.merge_response = None
2206 2206 self.failed = None
2207 2207 self.errors = []
2208 2208 self.error_details = OrderedDict()
2209 2209 self.source_commit = AttributeDict()
2210 2210 self.target_commit = AttributeDict()
2211 2211 self.reviewers_count = 0
2212 2212 self.observers_count = 0
2213 2213
2214 2214 def __repr__(self):
2215 2215 return '<MergeCheck(possible:{}, failed:{}, errors:{})>'.format(
2216 2216 self.merge_possible, self.failed, self.errors)
2217 2217
2218 2218 def push_error(self, error_type, message, error_key, details):
2219 2219 self.failed = True
2220 2220 self.errors.append([error_type, message])
2221 2221 self.error_details[error_key] = dict(
2222 2222 details=details,
2223 2223 error_type=error_type,
2224 2224 message=message
2225 2225 )
2226 2226
2227 2227 @classmethod
2228 2228 def validate(cls, pull_request, auth_user, translator, fail_early=False,
2229 2229 force_shadow_repo_refresh=False):
2230 2230 _ = translator
2231 2231 merge_check = cls()
2232 2232
2233 2233 # title has WIP:
2234 2234 if pull_request.work_in_progress:
2235 2235 log.debug("MergeCheck: cannot merge, title has wip: marker.")
2236 2236
2237 2237 msg = _('WIP marker in title prevents from accidental merge.')
2238 2238 merge_check.push_error('error', msg, cls.WIP_CHECK, pull_request.title)
2239 2239 if fail_early:
2240 2240 return merge_check
2241 2241
2242 2242 # permissions to merge
2243 2243 user_allowed_to_merge = PullRequestModel().check_user_merge(pull_request, auth_user)
2244 2244 if not user_allowed_to_merge:
2245 2245 log.debug("MergeCheck: cannot merge, approval is pending.")
2246 2246
2247 2247 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
2248 2248 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2249 2249 if fail_early:
2250 2250 return merge_check
2251 2251
2252 2252 # permission to merge into the target branch
2253 2253 target_commit_id = pull_request.target_ref_parts.commit_id
2254 2254 if pull_request.target_ref_parts.type == 'branch':
2255 2255 branch_name = pull_request.target_ref_parts.name
2256 2256 else:
2257 2257 # for mercurial we can always figure out the branch from the commit
2258 2258 # in case of bookmark
2259 2259 target_commit = pull_request.target_repo.get_commit(target_commit_id)
2260 2260 branch_name = target_commit.branch
2261 2261
2262 2262 rule, branch_perm = auth_user.get_rule_and_branch_permission(
2263 2263 pull_request.target_repo.repo_name, branch_name)
2264 2264 if branch_perm and branch_perm == 'branch.none':
2265 2265 msg = _('Target branch `{}` changes rejected by rule {}.').format(
2266 2266 branch_name, rule)
2267 2267 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
2268 2268 if fail_early:
2269 2269 return merge_check
2270 2270
2271 2271 # review status, must be always present
2272 2272 review_status = pull_request.calculated_review_status()
2273 2273 merge_check.review_status = review_status
2274 2274 merge_check.reviewers_count = pull_request.reviewers_count
2275 2275 merge_check.observers_count = pull_request.observers_count
2276 2276
2277 2277 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
2278 2278 if not status_approved and merge_check.reviewers_count:
2279 2279 log.debug("MergeCheck: cannot merge, approval is pending.")
2280 2280 msg = _('Pull request reviewer approval is pending.')
2281 2281
2282 2282 merge_check.push_error('warning', msg, cls.REVIEW_CHECK, review_status)
2283 2283
2284 2284 if fail_early:
2285 2285 return merge_check
2286 2286
2287 2287 # left over TODOs
2288 2288 todos = CommentsModel().get_pull_request_unresolved_todos(pull_request)
2289 2289 if todos:
2290 2290 log.debug("MergeCheck: cannot merge, {} "
2291 2291 "unresolved TODOs left.".format(len(todos)))
2292 2292
2293 2293 if len(todos) == 1:
2294 2294 msg = _('Cannot merge, {} TODO still not resolved.').format(
2295 2295 len(todos))
2296 2296 else:
2297 2297 msg = _('Cannot merge, {} TODOs still not resolved.').format(
2298 2298 len(todos))
2299 2299
2300 2300 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
2301 2301
2302 2302 if fail_early:
2303 2303 return merge_check
2304 2304
2305 2305 # merge possible, here is the filesystem simulation + shadow repo
2306 2306 merge_response, merge_status, msg = PullRequestModel().merge_status(
2307 2307 pull_request, translator=translator,
2308 2308 force_shadow_repo_refresh=force_shadow_repo_refresh)
2309 2309
2310 2310 merge_check.merge_possible = merge_status
2311 2311 merge_check.merge_msg = msg
2312 2312 merge_check.merge_response = merge_response
2313 2313
2314 2314 source_ref_id = pull_request.source_ref_parts.commit_id
2315 2315 target_ref_id = pull_request.target_ref_parts.commit_id
2316 2316
2317 2317 try:
2318 2318 source_commit, target_commit = PullRequestModel().get_flow_commits(pull_request)
2319 2319 merge_check.source_commit.changed = source_ref_id != source_commit.raw_id
2320 2320 merge_check.source_commit.ref_spec = pull_request.source_ref_parts
2321 2321 merge_check.source_commit.current_raw_id = source_commit.raw_id
2322 2322 merge_check.source_commit.previous_raw_id = source_ref_id
2323 2323
2324 2324 merge_check.target_commit.changed = target_ref_id != target_commit.raw_id
2325 2325 merge_check.target_commit.ref_spec = pull_request.target_ref_parts
2326 2326 merge_check.target_commit.current_raw_id = target_commit.raw_id
2327 2327 merge_check.target_commit.previous_raw_id = target_ref_id
2328 2328 except (SourceRefMissing, TargetRefMissing):
2329 2329 pass
2330 2330
2331 2331 if not merge_status:
2332 2332 log.debug("MergeCheck: cannot merge, pull request merge not possible.")
2333 2333 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
2334 2334
2335 2335 if fail_early:
2336 2336 return merge_check
2337 2337
2338 2338 log.debug('MergeCheck: is failed: %s', merge_check.failed)
2339 2339 return merge_check
2340 2340
2341 2341 @classmethod
2342 2342 def get_merge_conditions(cls, pull_request, translator):
2343 2343 _ = translator
2344 2344 merge_details = {}
2345 2345
2346 2346 model = PullRequestModel()
2347 2347 use_rebase = model._use_rebase_for_merging(pull_request)
2348 2348
2349 2349 if use_rebase:
2350 2350 merge_details['merge_strategy'] = dict(
2351 2351 details={},
2352 2352 message=_('Merge strategy: rebase')
2353 2353 )
2354 2354 else:
2355 2355 merge_details['merge_strategy'] = dict(
2356 2356 details={},
2357 2357 message=_('Merge strategy: explicit merge commit')
2358 2358 )
2359 2359
2360 2360 close_branch = model._close_branch_before_merging(pull_request)
2361 2361 if close_branch:
2362 2362 repo_type = pull_request.target_repo.repo_type
2363 2363 close_msg = ''
2364 2364 if repo_type == 'hg':
2365 2365 close_msg = _('Source branch will be closed before the merge.')
2366 2366 elif repo_type == 'git':
2367 2367 close_msg = _('Source branch will be deleted after the merge.')
2368 2368
2369 2369 merge_details['close_branch'] = dict(
2370 2370 details={},
2371 2371 message=close_msg
2372 2372 )
2373 2373
2374 2374 return merge_details
2375 2375
2376 2376
2377 2377 @dataclasses.dataclass
2378 2378 class ChangeTuple:
2379 2379 added: list
2380 2380 common: list
2381 2381 removed: list
2382 2382 total: list
2383 2383
2384 2384
2385 2385 @dataclasses.dataclass
2386 2386 class FileChangeTuple:
2387 2387 added: list
2388 2388 modified: list
2389 2389 removed: list
@@ -1,223 +1,223 b''
1 1
2 2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 import pytest
21 21
22 from rhodecode.lib.pyramid_utils import get_app_config
22 from rhodecode.lib.config_utils import get_app_config
23 23 from rhodecode.tests.fixture import TestINI
24 24 from rhodecode.tests.server_utils import RcVCSServer
25 25
26 26
27 27 @pytest.fixture(scope='session')
28 28 def vcsserver(request, vcsserver_port, vcsserver_factory):
29 29 """
30 30 Session scope VCSServer.
31 31
32 32 Tests which need the VCSServer have to rely on this fixture in order
33 33 to ensure it will be running.
34 34
35 35 For specific needs, the fixture vcsserver_factory can be used. It allows to
36 36 adjust the configuration file for the test run.
37 37
38 38 Command line args:
39 39
40 40 --without-vcsserver: Allows to switch this fixture off. You have to
41 41 manually start the server.
42 42
43 43 --vcsserver-port: Will expect the VCSServer to listen on this port.
44 44 """
45 45
46 46 if not request.config.getoption('with_vcsserver'):
47 47 return None
48 48
49 49 return vcsserver_factory(
50 50 request, vcsserver_port=vcsserver_port)
51 51
52 52
53 53 @pytest.fixture(scope='session')
54 54 def vcsserver_factory(tmpdir_factory):
55 55 """
56 56 Use this if you need a running vcsserver with a special configuration.
57 57 """
58 58
59 59 def factory(request, overrides=(), vcsserver_port=None,
60 60 log_file=None, workers='2'):
61 61
62 62 if vcsserver_port is None:
63 63 vcsserver_port = get_available_port()
64 64
65 65 overrides = list(overrides)
66 66 overrides.append({'server:main': {'port': vcsserver_port}})
67 67
68 68 option_name = 'vcsserver_config_http'
69 69 override_option_name = 'vcsserver_config_override'
70 70 config_file = get_config(
71 71 request.config, option_name=option_name,
72 72 override_option_name=override_option_name, overrides=overrides,
73 73 basetemp=tmpdir_factory.getbasetemp().strpath,
74 74 prefix='test_vcs_')
75 75
76 76 server = RcVCSServer(config_file, log_file, workers)
77 77 server.start()
78 78
79 79 @request.addfinalizer
80 80 def cleanup():
81 81 server.shutdown()
82 82
83 83 server.wait_until_ready()
84 84 return server
85 85
86 86 return factory
87 87
88 88
89 89 def _use_log_level(config):
90 90 level = config.getoption('test_loglevel') or 'critical'
91 91 return level.upper()
92 92
93 93
94 94 @pytest.fixture(scope='session')
95 95 def ini_config(request, tmpdir_factory, rcserver_port, vcsserver_port):
96 96 option_name = 'pyramid_config'
97 97 log_level = _use_log_level(request.config)
98 98
99 99 overrides = [
100 100 {'server:main': {'port': rcserver_port}},
101 101 {'app:main': {
102 102 'cache_dir': '%(here)s/rc_data',
103 103 'vcs.server': f'localhost:{vcsserver_port}',
104 104 # johbo: We will always start the VCSServer on our own based on the
105 105 # fixtures of the test cases. For the test run it must always be
106 106 # off in the INI file.
107 107 'vcs.start_server': 'false',
108 108
109 109 'vcs.server.protocol': 'http',
110 110 'vcs.scm_app_implementation': 'http',
111 111 'vcs.hooks.protocol': 'http',
112 112 'vcs.hooks.host': '*',
113 113 'app.service_api.token': 'service_secret_token',
114 114 }},
115 115
116 116 {'handler_console': {
117 117 'class': 'StreamHandler',
118 118 'args': '(sys.stderr,)',
119 119 'level': log_level,
120 120 }},
121 121
122 122 ]
123 123
124 124 filename = get_config(
125 125 request.config, option_name=option_name,
126 126 override_option_name='{}_override'.format(option_name),
127 127 overrides=overrides,
128 128 basetemp=tmpdir_factory.getbasetemp().strpath,
129 129 prefix='test_rce_')
130 130 return filename
131 131
132 132
133 133 @pytest.fixture(scope='session')
134 134 def ini_settings(ini_config):
135 135 ini_path = ini_config
136 136 return get_app_config(ini_path)
137 137
138 138
139 139 def get_available_port(min_port=40000, max_port=55555):
140 140 from rhodecode.lib.utils2 import get_available_port as _get_port
141 141 return _get_port(min_port, max_port)
142 142
143 143
144 144 @pytest.fixture(scope='session')
145 145 def rcserver_port(request):
146 146 port = get_available_port()
147 147 print(f'Using rhodecode port {port}')
148 148 return port
149 149
150 150
151 151 @pytest.fixture(scope='session')
152 152 def vcsserver_port(request):
153 153 port = request.config.getoption('--vcsserver-port')
154 154 if port is None:
155 155 port = get_available_port()
156 156 print(f'Using vcsserver port {port}')
157 157 return port
158 158
159 159
160 160 @pytest.fixture(scope='session')
161 161 def available_port_factory():
162 162 """
163 163 Returns a callable which returns free port numbers.
164 164 """
165 165 return get_available_port
166 166
167 167
168 168 @pytest.fixture()
169 169 def available_port(available_port_factory):
170 170 """
171 171 Gives you one free port for the current test.
172 172
173 173 Uses "available_port_factory" to retrieve the port.
174 174 """
175 175 return available_port_factory()
176 176
177 177
178 178 @pytest.fixture(scope='session')
179 179 def testini_factory(tmpdir_factory, ini_config):
180 180 """
181 181 Factory to create an INI file based on TestINI.
182 182
183 183 It will make sure to place the INI file in the correct directory.
184 184 """
185 185 basetemp = tmpdir_factory.getbasetemp().strpath
186 186 return TestIniFactory(basetemp, ini_config)
187 187
188 188
189 189 class TestIniFactory(object):
190 190
191 191 def __init__(self, basetemp, template_ini):
192 192 self._basetemp = basetemp
193 193 self._template_ini = template_ini
194 194
195 195 def __call__(self, ini_params, new_file_prefix='test'):
196 196 ini_file = TestINI(
197 197 self._template_ini, ini_params=ini_params,
198 198 new_file_prefix=new_file_prefix, dir=self._basetemp)
199 199 result = ini_file.create()
200 200 return result
201 201
202 202
203 203 def get_config(
204 204 config, option_name, override_option_name, overrides=None,
205 205 basetemp=None, prefix='test'):
206 206 """
207 207 Find a configuration file and apply overrides for the given `prefix`.
208 208 """
209 209 config_file = (
210 210 config.getoption(option_name) or config.getini(option_name))
211 211 if not config_file:
212 212 pytest.exit(
213 213 "Configuration error, could not extract {}.".format(option_name))
214 214
215 215 overrides = overrides or []
216 216 config_override = config.getoption(override_option_name)
217 217 if config_override:
218 218 overrides.append(config_override)
219 219 temp_ini_file = TestINI(
220 220 config_file, ini_params=overrides, new_file_prefix=prefix,
221 221 dir=basetemp)
222 222
223 223 return temp_ini_file.create()
@@ -1,1750 +1,1750 b''
1 1
2 2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 import collections
21 21 import datetime
22 22 import os
23 23 import re
24 24 import pprint
25 25 import shutil
26 26 import socket
27 27 import subprocess
28 28 import time
29 29 import uuid
30 30 import dateutil.tz
31 31 import logging
32 32 import functools
33 33
34 34 import mock
35 35 import pyramid.testing
36 36 import pytest
37 37 import colander
38 38 import requests
39 39 import pyramid.paster
40 40
41 41 import rhodecode
42 42 import rhodecode.lib
43 43 from rhodecode.model.changeset_status import ChangesetStatusModel
44 44 from rhodecode.model.comment import CommentsModel
45 45 from rhodecode.model.db import (
46 46 PullRequest, PullRequestReviewers, Repository, RhodeCodeSetting, ChangesetStatus,
47 47 RepoGroup, UserGroup, RepoRhodeCodeUi, RepoRhodeCodeSetting, RhodeCodeUi)
48 48 from rhodecode.model.meta import Session
49 49 from rhodecode.model.pull_request import PullRequestModel
50 50 from rhodecode.model.repo import RepoModel
51 51 from rhodecode.model.repo_group import RepoGroupModel
52 52 from rhodecode.model.user import UserModel
53 53 from rhodecode.model.settings import VcsSettingsModel
54 54 from rhodecode.model.user_group import UserGroupModel
55 55 from rhodecode.model.integration import IntegrationModel
56 56 from rhodecode.integrations import integration_type_registry
57 57 from rhodecode.integrations.types.base import IntegrationTypeBase
58 58 from rhodecode.lib.utils import repo2db_mapper
59 59 from rhodecode.lib.str_utils import safe_bytes
60 60 from rhodecode.lib.hash_utils import sha1_safe
61 61 from rhodecode.lib.vcs.backends import get_backend
62 62 from rhodecode.lib.vcs.nodes import FileNode
63 63 from rhodecode.tests import (
64 64 login_user_session, get_new_dir, utils, TESTS_TMP_PATH,
65 65 TEST_USER_ADMIN_LOGIN, TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR2_LOGIN,
66 66 TEST_USER_REGULAR_PASS)
67 67 from rhodecode.tests.utils import CustomTestApp, set_anonymous_access
68 68 from rhodecode.tests.fixture import Fixture
69 69 from rhodecode.config import utils as config_utils
70 70
71 71 log = logging.getLogger(__name__)
72 72
73 73
74 74 def cmp(a, b):
75 75 # backport cmp from python2 so we can still use it in the custom code in this module
76 76 return (a > b) - (a < b)
77 77
78 78
79 79 @pytest.fixture(scope='session', autouse=True)
80 80 def activate_example_rcextensions(request):
81 81 """
82 82 Patch in an example rcextensions module which verifies passed in kwargs.
83 83 """
84 84 from rhodecode.config import rcextensions
85 85
86 86 old_extensions = rhodecode.EXTENSIONS
87 87 rhodecode.EXTENSIONS = rcextensions
88 88 rhodecode.EXTENSIONS.calls = collections.defaultdict(list)
89 89
90 90 @request.addfinalizer
91 91 def cleanup():
92 92 rhodecode.EXTENSIONS = old_extensions
93 93
94 94
95 95 @pytest.fixture()
96 96 def capture_rcextensions():
97 97 """
98 98 Returns the recorded calls to entry points in rcextensions.
99 99 """
100 100 calls = rhodecode.EXTENSIONS.calls
101 101 calls.clear()
102 102 # Note: At this moment, it is still the empty dict, but that will
103 103 # be filled during the test run and since it is a reference this
104 104 # is enough to make it work.
105 105 return calls
106 106
107 107
108 108 @pytest.fixture(scope='session')
109 109 def http_environ_session():
110 110 """
111 111 Allow to use "http_environ" in session scope.
112 112 """
113 113 return plain_http_environ()
114 114
115 115
116 116 def plain_http_host_stub():
117 117 """
118 118 Value of HTTP_HOST in the test run.
119 119 """
120 120 return 'example.com:80'
121 121
122 122
123 123 @pytest.fixture()
124 124 def http_host_stub():
125 125 """
126 126 Value of HTTP_HOST in the test run.
127 127 """
128 128 return plain_http_host_stub()
129 129
130 130
131 131 def plain_http_host_only_stub():
132 132 """
133 133 Value of HTTP_HOST in the test run.
134 134 """
135 135 return plain_http_host_stub().split(':')[0]
136 136
137 137
138 138 @pytest.fixture()
139 139 def http_host_only_stub():
140 140 """
141 141 Value of HTTP_HOST in the test run.
142 142 """
143 143 return plain_http_host_only_stub()
144 144
145 145
146 146 def plain_http_environ():
147 147 """
148 148 HTTP extra environ keys.
149 149
150 150 User by the test application and as well for setting up the pylons
151 151 environment. In the case of the fixture "app" it should be possible
152 152 to override this for a specific test case.
153 153 """
154 154 return {
155 155 'SERVER_NAME': plain_http_host_only_stub(),
156 156 'SERVER_PORT': plain_http_host_stub().split(':')[1],
157 157 'HTTP_HOST': plain_http_host_stub(),
158 158 'HTTP_USER_AGENT': 'rc-test-agent',
159 159 'REQUEST_METHOD': 'GET'
160 160 }
161 161
162 162
163 163 @pytest.fixture()
164 164 def http_environ():
165 165 """
166 166 HTTP extra environ keys.
167 167
168 168 User by the test application and as well for setting up the pylons
169 169 environment. In the case of the fixture "app" it should be possible
170 170 to override this for a specific test case.
171 171 """
172 172 return plain_http_environ()
173 173
174 174
175 175 @pytest.fixture(scope='session')
176 176 def baseapp(ini_config, vcsserver, http_environ_session):
177 from rhodecode.lib.pyramid_utils import get_app_config
177 from rhodecode.lib.config_utils import get_app_config
178 178 from rhodecode.config.middleware import make_pyramid_app
179 179
180 180 log.info("Using the RhodeCode configuration:{}".format(ini_config))
181 181 pyramid.paster.setup_logging(ini_config)
182 182
183 183 settings = get_app_config(ini_config)
184 184 app = make_pyramid_app({'__file__': ini_config}, **settings)
185 185
186 186 return app
187 187
188 188
189 189 @pytest.fixture(scope='function')
190 190 def app(request, config_stub, baseapp, http_environ):
191 191 app = CustomTestApp(
192 192 baseapp,
193 193 extra_environ=http_environ)
194 194 if request.cls:
195 195 request.cls.app = app
196 196 return app
197 197
198 198
199 199 @pytest.fixture(scope='session')
200 200 def app_settings(baseapp, ini_config):
201 201 """
202 202 Settings dictionary used to create the app.
203 203
204 204 Parses the ini file and passes the result through the sanitize and apply
205 205 defaults mechanism in `rhodecode.config.middleware`.
206 206 """
207 207 return baseapp.config.get_settings()
208 208
209 209
210 210 @pytest.fixture(scope='session')
211 211 def db_connection(ini_settings):
212 212 # Initialize the database connection.
213 213 config_utils.initialize_database(ini_settings)
214 214
215 215
216 216 LoginData = collections.namedtuple('LoginData', ('csrf_token', 'user'))
217 217
218 218
219 219 def _autologin_user(app, *args):
220 220 session = login_user_session(app, *args)
221 221 csrf_token = rhodecode.lib.auth.get_csrf_token(session)
222 222 return LoginData(csrf_token, session['rhodecode_user'])
223 223
224 224
225 225 @pytest.fixture()
226 226 def autologin_user(app):
227 227 """
228 228 Utility fixture which makes sure that the admin user is logged in
229 229 """
230 230 return _autologin_user(app)
231 231
232 232
233 233 @pytest.fixture()
234 234 def autologin_regular_user(app):
235 235 """
236 236 Utility fixture which makes sure that the regular user is logged in
237 237 """
238 238 return _autologin_user(
239 239 app, TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS)
240 240
241 241
242 242 @pytest.fixture(scope='function')
243 243 def csrf_token(request, autologin_user):
244 244 return autologin_user.csrf_token
245 245
246 246
247 247 @pytest.fixture(scope='function')
248 248 def xhr_header(request):
249 249 return {'HTTP_X_REQUESTED_WITH': 'XMLHttpRequest'}
250 250
251 251
252 252 @pytest.fixture()
253 253 def real_crypto_backend(monkeypatch):
254 254 """
255 255 Switch the production crypto backend on for this test.
256 256
257 257 During the test run the crypto backend is replaced with a faster
258 258 implementation based on the MD5 algorithm.
259 259 """
260 260 monkeypatch.setattr(rhodecode, 'is_test', False)
261 261
262 262
263 263 @pytest.fixture(scope='class')
264 264 def index_location(request, baseapp):
265 265 index_location = baseapp.config.get_settings()['search.location']
266 266 if request.cls:
267 267 request.cls.index_location = index_location
268 268 return index_location
269 269
270 270
271 271 @pytest.fixture(scope='session', autouse=True)
272 272 def tests_tmp_path(request):
273 273 """
274 274 Create temporary directory to be used during the test session.
275 275 """
276 276 if not os.path.exists(TESTS_TMP_PATH):
277 277 os.makedirs(TESTS_TMP_PATH)
278 278
279 279 if not request.config.getoption('--keep-tmp-path'):
280 280 @request.addfinalizer
281 281 def remove_tmp_path():
282 282 shutil.rmtree(TESTS_TMP_PATH)
283 283
284 284 return TESTS_TMP_PATH
285 285
286 286
287 287 @pytest.fixture()
288 288 def test_repo_group(request):
289 289 """
290 290 Create a temporary repository group, and destroy it after
291 291 usage automatically
292 292 """
293 293 fixture = Fixture()
294 294 repogroupid = 'test_repo_group_%s' % str(time.time()).replace('.', '')
295 295 repo_group = fixture.create_repo_group(repogroupid)
296 296
297 297 def _cleanup():
298 298 fixture.destroy_repo_group(repogroupid)
299 299
300 300 request.addfinalizer(_cleanup)
301 301 return repo_group
302 302
303 303
304 304 @pytest.fixture()
305 305 def test_user_group(request):
306 306 """
307 307 Create a temporary user group, and destroy it after
308 308 usage automatically
309 309 """
310 310 fixture = Fixture()
311 311 usergroupid = 'test_user_group_%s' % str(time.time()).replace('.', '')
312 312 user_group = fixture.create_user_group(usergroupid)
313 313
314 314 def _cleanup():
315 315 fixture.destroy_user_group(user_group)
316 316
317 317 request.addfinalizer(_cleanup)
318 318 return user_group
319 319
320 320
321 321 @pytest.fixture(scope='session')
322 322 def test_repo(request):
323 323 container = TestRepoContainer()
324 324 request.addfinalizer(container._cleanup)
325 325 return container
326 326
327 327
328 328 class TestRepoContainer(object):
329 329 """
330 330 Container for test repositories which are used read only.
331 331
332 332 Repositories will be created on demand and re-used during the lifetime
333 333 of this object.
334 334
335 335 Usage to get the svn test repository "minimal"::
336 336
337 337 test_repo = TestContainer()
338 338 repo = test_repo('minimal', 'svn')
339 339
340 340 """
341 341
342 342 dump_extractors = {
343 343 'git': utils.extract_git_repo_from_dump,
344 344 'hg': utils.extract_hg_repo_from_dump,
345 345 'svn': utils.extract_svn_repo_from_dump,
346 346 }
347 347
348 348 def __init__(self):
349 349 self._cleanup_repos = []
350 350 self._fixture = Fixture()
351 351 self._repos = {}
352 352
353 353 def __call__(self, dump_name, backend_alias, config=None):
354 354 key = (dump_name, backend_alias)
355 355 if key not in self._repos:
356 356 repo = self._create_repo(dump_name, backend_alias, config)
357 357 self._repos[key] = repo.repo_id
358 358 return Repository.get(self._repos[key])
359 359
360 360 def _create_repo(self, dump_name, backend_alias, config):
361 361 repo_name = f'{backend_alias}-{dump_name}'
362 362 backend = get_backend(backend_alias)
363 363 dump_extractor = self.dump_extractors[backend_alias]
364 364 repo_path = dump_extractor(dump_name, repo_name)
365 365
366 366 vcs_repo = backend(repo_path, config=config)
367 367 repo2db_mapper({repo_name: vcs_repo})
368 368
369 369 repo = RepoModel().get_by_repo_name(repo_name)
370 370 self._cleanup_repos.append(repo_name)
371 371 return repo
372 372
373 373 def _cleanup(self):
374 374 for repo_name in reversed(self._cleanup_repos):
375 375 self._fixture.destroy_repo(repo_name)
376 376
377 377
378 378 def backend_base(request, backend_alias, baseapp, test_repo):
379 379 if backend_alias not in request.config.getoption('--backends'):
380 380 pytest.skip("Backend %s not selected." % (backend_alias, ))
381 381
382 382 utils.check_xfail_backends(request.node, backend_alias)
383 383 utils.check_skip_backends(request.node, backend_alias)
384 384
385 385 repo_name = 'vcs_test_%s' % (backend_alias, )
386 386 backend = Backend(
387 387 alias=backend_alias,
388 388 repo_name=repo_name,
389 389 test_name=request.node.name,
390 390 test_repo_container=test_repo)
391 391 request.addfinalizer(backend.cleanup)
392 392 return backend
393 393
394 394
395 395 @pytest.fixture()
396 396 def backend(request, backend_alias, baseapp, test_repo):
397 397 """
398 398 Parametrized fixture which represents a single backend implementation.
399 399
400 400 It respects the option `--backends` to focus the test run on specific
401 401 backend implementations.
402 402
403 403 It also supports `pytest.mark.xfail_backends` to mark tests as failing
404 404 for specific backends. This is intended as a utility for incremental
405 405 development of a new backend implementation.
406 406 """
407 407 return backend_base(request, backend_alias, baseapp, test_repo)
408 408
409 409
410 410 @pytest.fixture()
411 411 def backend_git(request, baseapp, test_repo):
412 412 return backend_base(request, 'git', baseapp, test_repo)
413 413
414 414
415 415 @pytest.fixture()
416 416 def backend_hg(request, baseapp, test_repo):
417 417 return backend_base(request, 'hg', baseapp, test_repo)
418 418
419 419
420 420 @pytest.fixture()
421 421 def backend_svn(request, baseapp, test_repo):
422 422 return backend_base(request, 'svn', baseapp, test_repo)
423 423
424 424
425 425 @pytest.fixture()
426 426 def backend_random(backend_git):
427 427 """
428 428 Use this to express that your tests need "a backend.
429 429
430 430 A few of our tests need a backend, so that we can run the code. This
431 431 fixture is intended to be used for such cases. It will pick one of the
432 432 backends and run the tests.
433 433
434 434 The fixture `backend` would run the test multiple times for each
435 435 available backend which is a pure waste of time if the test is
436 436 independent of the backend type.
437 437 """
438 438 # TODO: johbo: Change this to pick a random backend
439 439 return backend_git
440 440
441 441
442 442 @pytest.fixture()
443 443 def backend_stub(backend_git):
444 444 """
445 445 Use this to express that your tests need a backend stub
446 446
447 447 TODO: mikhail: Implement a real stub logic instead of returning
448 448 a git backend
449 449 """
450 450 return backend_git
451 451
452 452
453 453 @pytest.fixture()
454 454 def repo_stub(backend_stub):
455 455 """
456 456 Use this to express that your tests need a repository stub
457 457 """
458 458 return backend_stub.create_repo()
459 459
460 460
461 461 class Backend(object):
462 462 """
463 463 Represents the test configuration for one supported backend
464 464
465 465 Provides easy access to different test repositories based on
466 466 `__getitem__`. Such repositories will only be created once per test
467 467 session.
468 468 """
469 469
470 470 invalid_repo_name = re.compile(r'[^0-9a-zA-Z]+')
471 471 _master_repo = None
472 472 _master_repo_path = ''
473 473 _commit_ids = {}
474 474
475 475 def __init__(self, alias, repo_name, test_name, test_repo_container):
476 476 self.alias = alias
477 477 self.repo_name = repo_name
478 478 self._cleanup_repos = []
479 479 self._test_name = test_name
480 480 self._test_repo_container = test_repo_container
481 481 # TODO: johbo: Used as a delegate interim. Not yet sure if Backend or
482 482 # Fixture will survive in the end.
483 483 self._fixture = Fixture()
484 484
485 485 def __getitem__(self, key):
486 486 return self._test_repo_container(key, self.alias)
487 487
488 488 def create_test_repo(self, key, config=None):
489 489 return self._test_repo_container(key, self.alias, config)
490 490
491 491 @property
492 492 def repo_id(self):
493 493 # just fake some repo_id
494 494 return self.repo.repo_id
495 495
496 496 @property
497 497 def repo(self):
498 498 """
499 499 Returns the "current" repository. This is the vcs_test repo or the
500 500 last repo which has been created with `create_repo`.
501 501 """
502 502 from rhodecode.model.db import Repository
503 503 return Repository.get_by_repo_name(self.repo_name)
504 504
505 505 @property
506 506 def default_branch_name(self):
507 507 VcsRepository = get_backend(self.alias)
508 508 return VcsRepository.DEFAULT_BRANCH_NAME
509 509
510 510 @property
511 511 def default_head_id(self):
512 512 """
513 513 Returns the default head id of the underlying backend.
514 514
515 515 This will be the default branch name in case the backend does have a
516 516 default branch. In the other cases it will point to a valid head
517 517 which can serve as the base to create a new commit on top of it.
518 518 """
519 519 vcsrepo = self.repo.scm_instance()
520 520 head_id = (
521 521 vcsrepo.DEFAULT_BRANCH_NAME or
522 522 vcsrepo.commit_ids[-1])
523 523 return head_id
524 524
525 525 @property
526 526 def commit_ids(self):
527 527 """
528 528 Returns the list of commits for the last created repository
529 529 """
530 530 return self._commit_ids
531 531
532 532 def create_master_repo(self, commits):
533 533 """
534 534 Create a repository and remember it as a template.
535 535
536 536 This allows to easily create derived repositories to construct
537 537 more complex scenarios for diff, compare and pull requests.
538 538
539 539 Returns a commit map which maps from commit message to raw_id.
540 540 """
541 541 self._master_repo = self.create_repo(commits=commits)
542 542 self._master_repo_path = self._master_repo.repo_full_path
543 543
544 544 return self._commit_ids
545 545
546 546 def create_repo(
547 547 self, commits=None, number_of_commits=0, heads=None,
548 548 name_suffix='', bare=False, **kwargs):
549 549 """
550 550 Create a repository and record it for later cleanup.
551 551
552 552 :param commits: Optional. A sequence of dict instances.
553 553 Will add a commit per entry to the new repository.
554 554 :param number_of_commits: Optional. If set to a number, this number of
555 555 commits will be added to the new repository.
556 556 :param heads: Optional. Can be set to a sequence of of commit
557 557 names which shall be pulled in from the master repository.
558 558 :param name_suffix: adds special suffix to generated repo name
559 559 :param bare: set a repo as bare (no checkout)
560 560 """
561 561 self.repo_name = self._next_repo_name() + name_suffix
562 562 repo = self._fixture.create_repo(
563 563 self.repo_name, repo_type=self.alias, bare=bare, **kwargs)
564 564 self._cleanup_repos.append(repo.repo_name)
565 565
566 566 commits = commits or [
567 567 {'message': f'Commit {x} of {self.repo_name}'}
568 568 for x in range(number_of_commits)]
569 569 vcs_repo = repo.scm_instance()
570 570 vcs_repo.count()
571 571 self._add_commits_to_repo(vcs_repo, commits)
572 572 if heads:
573 573 self.pull_heads(repo, heads)
574 574
575 575 return repo
576 576
577 577 def pull_heads(self, repo, heads, do_fetch=False):
578 578 """
579 579 Make sure that repo contains all commits mentioned in `heads`
580 580 """
581 581 vcsrepo = repo.scm_instance()
582 582 vcsrepo.config.clear_section('hooks')
583 583 commit_ids = [self._commit_ids[h] for h in heads]
584 584 if do_fetch:
585 585 vcsrepo.fetch(self._master_repo_path, commit_ids=commit_ids)
586 586 vcsrepo.pull(self._master_repo_path, commit_ids=commit_ids)
587 587
588 588 def create_fork(self):
589 589 repo_to_fork = self.repo_name
590 590 self.repo_name = self._next_repo_name()
591 591 repo = self._fixture.create_fork(repo_to_fork, self.repo_name)
592 592 self._cleanup_repos.append(self.repo_name)
593 593 return repo
594 594
595 595 def new_repo_name(self, suffix=''):
596 596 self.repo_name = self._next_repo_name() + suffix
597 597 self._cleanup_repos.append(self.repo_name)
598 598 return self.repo_name
599 599
600 600 def _next_repo_name(self):
601 601 return "%s_%s" % (
602 602 self.invalid_repo_name.sub('_', self._test_name), len(self._cleanup_repos))
603 603
604 604 def ensure_file(self, filename, content=b'Test content\n'):
605 605 assert self._cleanup_repos, "Avoid writing into vcs_test repos"
606 606 commits = [
607 607 {'added': [
608 608 FileNode(filename, content=content),
609 609 ]},
610 610 ]
611 611 self._add_commits_to_repo(self.repo.scm_instance(), commits)
612 612
613 613 def enable_downloads(self):
614 614 repo = self.repo
615 615 repo.enable_downloads = True
616 616 Session().add(repo)
617 617 Session().commit()
618 618
619 619 def cleanup(self):
620 620 for repo_name in reversed(self._cleanup_repos):
621 621 self._fixture.destroy_repo(repo_name)
622 622
623 623 def _add_commits_to_repo(self, repo, commits):
624 624 commit_ids = _add_commits_to_repo(repo, commits)
625 625 if not commit_ids:
626 626 return
627 627 self._commit_ids = commit_ids
628 628
629 629 # Creating refs for Git to allow fetching them from remote repository
630 630 if self.alias == 'git':
631 631 refs = {}
632 632 for message in self._commit_ids:
633 633 cleanup_message = message.replace(' ', '')
634 634 ref_name = f'refs/test-refs/{cleanup_message}'
635 635 refs[ref_name] = self._commit_ids[message]
636 636 self._create_refs(repo, refs)
637 637
638 638 def _create_refs(self, repo, refs):
639 639 for ref_name, ref_val in refs.items():
640 640 repo.set_refs(ref_name, ref_val)
641 641
642 642
643 643 class VcsBackend(object):
644 644 """
645 645 Represents the test configuration for one supported vcs backend.
646 646 """
647 647
648 648 invalid_repo_name = re.compile(r'[^0-9a-zA-Z]+')
649 649
650 650 def __init__(self, alias, repo_path, test_name, test_repo_container):
651 651 self.alias = alias
652 652 self._repo_path = repo_path
653 653 self._cleanup_repos = []
654 654 self._test_name = test_name
655 655 self._test_repo_container = test_repo_container
656 656
657 657 def __getitem__(self, key):
658 658 return self._test_repo_container(key, self.alias).scm_instance()
659 659
660 660 def __repr__(self):
661 661 return f'{self.__class__.__name__}(alias={self.alias}, repo={self._repo_path})'
662 662
663 663 @property
664 664 def repo(self):
665 665 """
666 666 Returns the "current" repository. This is the vcs_test repo of the last
667 667 repo which has been created.
668 668 """
669 669 Repository = get_backend(self.alias)
670 670 return Repository(self._repo_path)
671 671
672 672 @property
673 673 def backend(self):
674 674 """
675 675 Returns the backend implementation class.
676 676 """
677 677 return get_backend(self.alias)
678 678
679 679 def create_repo(self, commits=None, number_of_commits=0, _clone_repo=None,
680 680 bare=False):
681 681 repo_name = self._next_repo_name()
682 682 self._repo_path = get_new_dir(repo_name)
683 683 repo_class = get_backend(self.alias)
684 684 src_url = None
685 685 if _clone_repo:
686 686 src_url = _clone_repo.path
687 687 repo = repo_class(self._repo_path, create=True, src_url=src_url, bare=bare)
688 688 self._cleanup_repos.append(repo)
689 689
690 690 commits = commits or [
691 691 {'message': 'Commit %s of %s' % (x, repo_name)}
692 692 for x in range(number_of_commits)]
693 693 _add_commits_to_repo(repo, commits)
694 694 return repo
695 695
696 696 def clone_repo(self, repo):
697 697 return self.create_repo(_clone_repo=repo)
698 698
699 699 def cleanup(self):
700 700 for repo in self._cleanup_repos:
701 701 shutil.rmtree(repo.path)
702 702
703 703 def new_repo_path(self):
704 704 repo_name = self._next_repo_name()
705 705 self._repo_path = get_new_dir(repo_name)
706 706 return self._repo_path
707 707
708 708 def _next_repo_name(self):
709 709
710 710 return "{}_{}".format(
711 711 self.invalid_repo_name.sub('_', self._test_name),
712 712 len(self._cleanup_repos)
713 713 )
714 714
715 715 def add_file(self, repo, filename, content='Test content\n'):
716 716 imc = repo.in_memory_commit
717 717 imc.add(FileNode(safe_bytes(filename), content=safe_bytes(content)))
718 718 imc.commit(
719 719 message='Automatic commit from vcsbackend fixture',
720 720 author='Automatic <automatic@rhodecode.com>')
721 721
722 722 def ensure_file(self, filename, content='Test content\n'):
723 723 assert self._cleanup_repos, "Avoid writing into vcs_test repos"
724 724 self.add_file(self.repo, filename, content)
725 725
726 726
727 727 def vcsbackend_base(request, backend_alias, tests_tmp_path, baseapp, test_repo) -> VcsBackend:
728 728 if backend_alias not in request.config.getoption('--backends'):
729 729 pytest.skip("Backend %s not selected." % (backend_alias, ))
730 730
731 731 utils.check_xfail_backends(request.node, backend_alias)
732 732 utils.check_skip_backends(request.node, backend_alias)
733 733
734 734 repo_name = f'vcs_test_{backend_alias}'
735 735 repo_path = os.path.join(tests_tmp_path, repo_name)
736 736 backend = VcsBackend(
737 737 alias=backend_alias,
738 738 repo_path=repo_path,
739 739 test_name=request.node.name,
740 740 test_repo_container=test_repo)
741 741 request.addfinalizer(backend.cleanup)
742 742 return backend
743 743
744 744
745 745 @pytest.fixture()
746 746 def vcsbackend(request, backend_alias, tests_tmp_path, baseapp, test_repo):
747 747 """
748 748 Parametrized fixture which represents a single vcs backend implementation.
749 749
750 750 See the fixture `backend` for more details. This one implements the same
751 751 concept, but on vcs level. So it does not provide model instances etc.
752 752
753 753 Parameters are generated dynamically, see :func:`pytest_generate_tests`
754 754 for how this works.
755 755 """
756 756 return vcsbackend_base(request, backend_alias, tests_tmp_path, baseapp, test_repo)
757 757
758 758
759 759 @pytest.fixture()
760 760 def vcsbackend_git(request, tests_tmp_path, baseapp, test_repo):
761 761 return vcsbackend_base(request, 'git', tests_tmp_path, baseapp, test_repo)
762 762
763 763
764 764 @pytest.fixture()
765 765 def vcsbackend_hg(request, tests_tmp_path, baseapp, test_repo):
766 766 return vcsbackend_base(request, 'hg', tests_tmp_path, baseapp, test_repo)
767 767
768 768
769 769 @pytest.fixture()
770 770 def vcsbackend_svn(request, tests_tmp_path, baseapp, test_repo):
771 771 return vcsbackend_base(request, 'svn', tests_tmp_path, baseapp, test_repo)
772 772
773 773
774 774 @pytest.fixture()
775 775 def vcsbackend_stub(vcsbackend_git):
776 776 """
777 777 Use this to express that your test just needs a stub of a vcsbackend.
778 778
779 779 Plan is to eventually implement an in-memory stub to speed tests up.
780 780 """
781 781 return vcsbackend_git
782 782
783 783
784 784 def _add_commits_to_repo(vcs_repo, commits):
785 785 commit_ids = {}
786 786 if not commits:
787 787 return commit_ids
788 788
789 789 imc = vcs_repo.in_memory_commit
790 790
791 791 for idx, commit in enumerate(commits):
792 792 message = str(commit.get('message', f'Commit {idx}'))
793 793
794 794 for node in commit.get('added', []):
795 795 imc.add(FileNode(safe_bytes(node.path), content=node.content))
796 796 for node in commit.get('changed', []):
797 797 imc.change(FileNode(safe_bytes(node.path), content=node.content))
798 798 for node in commit.get('removed', []):
799 799 imc.remove(FileNode(safe_bytes(node.path)))
800 800
801 801 parents = [
802 802 vcs_repo.get_commit(commit_id=commit_ids[p])
803 803 for p in commit.get('parents', [])]
804 804
805 805 operations = ('added', 'changed', 'removed')
806 806 if not any((commit.get(o) for o in operations)):
807 807 imc.add(FileNode(b'file_%b' % safe_bytes(str(idx)), content=safe_bytes(message)))
808 808
809 809 commit = imc.commit(
810 810 message=message,
811 811 author=str(commit.get('author', 'Automatic <automatic@rhodecode.com>')),
812 812 date=commit.get('date'),
813 813 branch=commit.get('branch'),
814 814 parents=parents)
815 815
816 816 commit_ids[commit.message] = commit.raw_id
817 817
818 818 return commit_ids
819 819
820 820
821 821 @pytest.fixture()
822 822 def reposerver(request):
823 823 """
824 824 Allows to serve a backend repository
825 825 """
826 826
827 827 repo_server = RepoServer()
828 828 request.addfinalizer(repo_server.cleanup)
829 829 return repo_server
830 830
831 831
832 832 class RepoServer(object):
833 833 """
834 834 Utility to serve a local repository for the duration of a test case.
835 835
836 836 Supports only Subversion so far.
837 837 """
838 838
839 839 url = None
840 840
841 841 def __init__(self):
842 842 self._cleanup_servers = []
843 843
844 844 def serve(self, vcsrepo):
845 845 if vcsrepo.alias != 'svn':
846 846 raise TypeError("Backend %s not supported" % vcsrepo.alias)
847 847
848 848 proc = subprocess.Popen(
849 849 ['svnserve', '-d', '--foreground', '--listen-host', 'localhost',
850 850 '--root', vcsrepo.path])
851 851 self._cleanup_servers.append(proc)
852 852 self.url = 'svn://localhost'
853 853
854 854 def cleanup(self):
855 855 for proc in self._cleanup_servers:
856 856 proc.terminate()
857 857
858 858
859 859 @pytest.fixture()
860 860 def pr_util(backend, request, config_stub):
861 861 """
862 862 Utility for tests of models and for functional tests around pull requests.
863 863
864 864 It gives an instance of :class:`PRTestUtility` which provides various
865 865 utility methods around one pull request.
866 866
867 867 This fixture uses `backend` and inherits its parameterization.
868 868 """
869 869
870 870 util = PRTestUtility(backend)
871 871 request.addfinalizer(util.cleanup)
872 872
873 873 return util
874 874
875 875
876 876 class PRTestUtility(object):
877 877
878 878 pull_request = None
879 879 pull_request_id = None
880 880 mergeable_patcher = None
881 881 mergeable_mock = None
882 882 notification_patcher = None
883 883 commit_ids: dict
884 884
885 885 def __init__(self, backend):
886 886 self.backend = backend
887 887
888 888 def create_pull_request(
889 889 self, commits=None, target_head=None, source_head=None,
890 890 revisions=None, approved=False, author=None, mergeable=False,
891 891 enable_notifications=True, name_suffix='', reviewers=None, observers=None,
892 892 title="Test", description="Description"):
893 893 self.set_mergeable(mergeable)
894 894 if not enable_notifications:
895 895 # mock notification side effect
896 896 self.notification_patcher = mock.patch(
897 897 'rhodecode.model.notification.NotificationModel.create')
898 898 self.notification_patcher.start()
899 899
900 900 if not self.pull_request:
901 901 if not commits:
902 902 commits = [
903 903 {'message': 'c1'},
904 904 {'message': 'c2'},
905 905 {'message': 'c3'},
906 906 ]
907 907 target_head = 'c1'
908 908 source_head = 'c2'
909 909 revisions = ['c2']
910 910
911 911 self.commit_ids = self.backend.create_master_repo(commits)
912 912 self.target_repository = self.backend.create_repo(
913 913 heads=[target_head], name_suffix=name_suffix)
914 914 self.source_repository = self.backend.create_repo(
915 915 heads=[source_head], name_suffix=name_suffix)
916 916 self.author = author or UserModel().get_by_username(
917 917 TEST_USER_ADMIN_LOGIN)
918 918
919 919 model = PullRequestModel()
920 920 self.create_parameters = {
921 921 'created_by': self.author,
922 922 'source_repo': self.source_repository.repo_name,
923 923 'source_ref': self._default_branch_reference(source_head),
924 924 'target_repo': self.target_repository.repo_name,
925 925 'target_ref': self._default_branch_reference(target_head),
926 926 'revisions': [self.commit_ids[r] for r in revisions],
927 927 'reviewers': reviewers or self._get_reviewers(),
928 928 'observers': observers or self._get_observers(),
929 929 'title': title,
930 930 'description': description,
931 931 }
932 932 self.pull_request = model.create(**self.create_parameters)
933 933 assert model.get_versions(self.pull_request) == []
934 934
935 935 self.pull_request_id = self.pull_request.pull_request_id
936 936
937 937 if approved:
938 938 self.approve()
939 939
940 940 Session().add(self.pull_request)
941 941 Session().commit()
942 942
943 943 return self.pull_request
944 944
945 945 def approve(self):
946 946 self.create_status_votes(
947 947 ChangesetStatus.STATUS_APPROVED,
948 948 *self.pull_request.reviewers)
949 949
950 950 def close(self):
951 951 PullRequestModel().close_pull_request(self.pull_request, self.author)
952 952
953 953 def _default_branch_reference(self, commit_message, branch: str = None) -> str:
954 954 default_branch = branch or self.backend.default_branch_name
955 955 message = self.commit_ids[commit_message]
956 956 reference = f'branch:{default_branch}:{message}'
957 957
958 958 return reference
959 959
960 960 def _get_reviewers(self):
961 961 role = PullRequestReviewers.ROLE_REVIEWER
962 962 return [
963 963 (TEST_USER_REGULAR_LOGIN, ['default1'], False, role, []),
964 964 (TEST_USER_REGULAR2_LOGIN, ['default2'], False, role, []),
965 965 ]
966 966
967 967 def _get_observers(self):
968 968 return [
969 969
970 970 ]
971 971
972 972 def update_source_repository(self, head=None, do_fetch=False):
973 973 heads = [head or 'c3']
974 974 self.backend.pull_heads(self.source_repository, heads=heads, do_fetch=do_fetch)
975 975
976 976 def update_target_repository(self, head=None, do_fetch=False):
977 977 heads = [head or 'c3']
978 978 self.backend.pull_heads(self.target_repository, heads=heads, do_fetch=do_fetch)
979 979
980 980 def set_pr_target_ref(self, ref_type: str = "branch", ref_name: str = "branch", ref_commit_id: str = "") -> str:
981 981 full_ref = f"{ref_type}:{ref_name}:{ref_commit_id}"
982 982 self.pull_request.target_ref = full_ref
983 983 return full_ref
984 984
985 985 def set_pr_source_ref(self, ref_type: str = "branch", ref_name: str = "branch", ref_commit_id: str = "") -> str:
986 986 full_ref = f"{ref_type}:{ref_name}:{ref_commit_id}"
987 987 self.pull_request.source_ref = full_ref
988 988 return full_ref
989 989
990 990 def add_one_commit(self, head=None):
991 991 self.update_source_repository(head=head)
992 992 old_commit_ids = set(self.pull_request.revisions)
993 993 PullRequestModel().update_commits(self.pull_request, self.pull_request.author)
994 994 commit_ids = set(self.pull_request.revisions)
995 995 new_commit_ids = commit_ids - old_commit_ids
996 996 assert len(new_commit_ids) == 1
997 997 return new_commit_ids.pop()
998 998
999 999 def remove_one_commit(self):
1000 1000 assert len(self.pull_request.revisions) == 2
1001 1001 source_vcs = self.source_repository.scm_instance()
1002 1002 removed_commit_id = source_vcs.commit_ids[-1]
1003 1003
1004 1004 # TODO: johbo: Git and Mercurial have an inconsistent vcs api here,
1005 1005 # remove the if once that's sorted out.
1006 1006 if self.backend.alias == "git":
1007 1007 kwargs = {'branch_name': self.backend.default_branch_name}
1008 1008 else:
1009 1009 kwargs = {}
1010 1010 source_vcs.strip(removed_commit_id, **kwargs)
1011 1011
1012 1012 PullRequestModel().update_commits(self.pull_request, self.pull_request.author)
1013 1013 assert len(self.pull_request.revisions) == 1
1014 1014 return removed_commit_id
1015 1015
1016 1016 def create_comment(self, linked_to=None):
1017 1017 comment = CommentsModel().create(
1018 1018 text="Test comment",
1019 1019 repo=self.target_repository.repo_name,
1020 1020 user=self.author,
1021 1021 pull_request=self.pull_request)
1022 1022 assert comment.pull_request_version_id is None
1023 1023
1024 1024 if linked_to:
1025 1025 PullRequestModel()._link_comments_to_version(linked_to)
1026 1026
1027 1027 return comment
1028 1028
1029 1029 def create_inline_comment(
1030 1030 self, linked_to=None, line_no='n1', file_path='file_1'):
1031 1031 comment = CommentsModel().create(
1032 1032 text="Test comment",
1033 1033 repo=self.target_repository.repo_name,
1034 1034 user=self.author,
1035 1035 line_no=line_no,
1036 1036 f_path=file_path,
1037 1037 pull_request=self.pull_request)
1038 1038 assert comment.pull_request_version_id is None
1039 1039
1040 1040 if linked_to:
1041 1041 PullRequestModel()._link_comments_to_version(linked_to)
1042 1042
1043 1043 return comment
1044 1044
1045 1045 def create_version_of_pull_request(self):
1046 1046 pull_request = self.create_pull_request()
1047 1047 version = PullRequestModel()._create_version_from_snapshot(
1048 1048 pull_request)
1049 1049 return version
1050 1050
1051 1051 def create_status_votes(self, status, *reviewers):
1052 1052 for reviewer in reviewers:
1053 1053 ChangesetStatusModel().set_status(
1054 1054 repo=self.pull_request.target_repo,
1055 1055 status=status,
1056 1056 user=reviewer.user_id,
1057 1057 pull_request=self.pull_request)
1058 1058
1059 1059 def set_mergeable(self, value):
1060 1060 if not self.mergeable_patcher:
1061 1061 self.mergeable_patcher = mock.patch.object(
1062 1062 VcsSettingsModel, 'get_general_settings')
1063 1063 self.mergeable_mock = self.mergeable_patcher.start()
1064 1064 self.mergeable_mock.return_value = {
1065 1065 'rhodecode_pr_merge_enabled': value}
1066 1066
1067 1067 def cleanup(self):
1068 1068 # In case the source repository is already cleaned up, the pull
1069 1069 # request will already be deleted.
1070 1070 pull_request = PullRequest().get(self.pull_request_id)
1071 1071 if pull_request:
1072 1072 PullRequestModel().delete(pull_request, pull_request.author)
1073 1073 Session().commit()
1074 1074
1075 1075 if self.notification_patcher:
1076 1076 self.notification_patcher.stop()
1077 1077
1078 1078 if self.mergeable_patcher:
1079 1079 self.mergeable_patcher.stop()
1080 1080
1081 1081
1082 1082 @pytest.fixture()
1083 1083 def user_admin(baseapp):
1084 1084 """
1085 1085 Provides the default admin test user as an instance of `db.User`.
1086 1086 """
1087 1087 user = UserModel().get_by_username(TEST_USER_ADMIN_LOGIN)
1088 1088 return user
1089 1089
1090 1090
1091 1091 @pytest.fixture()
1092 1092 def user_regular(baseapp):
1093 1093 """
1094 1094 Provides the default regular test user as an instance of `db.User`.
1095 1095 """
1096 1096 user = UserModel().get_by_username(TEST_USER_REGULAR_LOGIN)
1097 1097 return user
1098 1098
1099 1099
1100 1100 @pytest.fixture()
1101 1101 def user_util(request, db_connection):
1102 1102 """
1103 1103 Provides a wired instance of `UserUtility` with integrated cleanup.
1104 1104 """
1105 1105 utility = UserUtility(test_name=request.node.name)
1106 1106 request.addfinalizer(utility.cleanup)
1107 1107 return utility
1108 1108
1109 1109
1110 1110 # TODO: johbo: Split this up into utilities per domain or something similar
1111 1111 class UserUtility(object):
1112 1112
1113 1113 def __init__(self, test_name="test"):
1114 1114 self._test_name = self._sanitize_name(test_name)
1115 1115 self.fixture = Fixture()
1116 1116 self.repo_group_ids = []
1117 1117 self.repos_ids = []
1118 1118 self.user_ids = []
1119 1119 self.user_group_ids = []
1120 1120 self.user_repo_permission_ids = []
1121 1121 self.user_group_repo_permission_ids = []
1122 1122 self.user_repo_group_permission_ids = []
1123 1123 self.user_group_repo_group_permission_ids = []
1124 1124 self.user_user_group_permission_ids = []
1125 1125 self.user_group_user_group_permission_ids = []
1126 1126 self.user_permissions = []
1127 1127
1128 1128 def _sanitize_name(self, name):
1129 1129 for char in ['[', ']']:
1130 1130 name = name.replace(char, '_')
1131 1131 return name
1132 1132
1133 1133 def create_repo_group(
1134 1134 self, owner=TEST_USER_ADMIN_LOGIN, auto_cleanup=True):
1135 1135 group_name = "{prefix}_repogroup_{count}".format(
1136 1136 prefix=self._test_name,
1137 1137 count=len(self.repo_group_ids))
1138 1138 repo_group = self.fixture.create_repo_group(
1139 1139 group_name, cur_user=owner)
1140 1140 if auto_cleanup:
1141 1141 self.repo_group_ids.append(repo_group.group_id)
1142 1142 return repo_group
1143 1143
1144 1144 def create_repo(self, owner=TEST_USER_ADMIN_LOGIN, parent=None,
1145 1145 auto_cleanup=True, repo_type='hg', bare=False):
1146 1146 repo_name = "{prefix}_repository_{count}".format(
1147 1147 prefix=self._test_name,
1148 1148 count=len(self.repos_ids))
1149 1149
1150 1150 repository = self.fixture.create_repo(
1151 1151 repo_name, cur_user=owner, repo_group=parent, repo_type=repo_type, bare=bare)
1152 1152 if auto_cleanup:
1153 1153 self.repos_ids.append(repository.repo_id)
1154 1154 return repository
1155 1155
1156 1156 def create_user(self, auto_cleanup=True, **kwargs):
1157 1157 user_name = "{prefix}_user_{count}".format(
1158 1158 prefix=self._test_name,
1159 1159 count=len(self.user_ids))
1160 1160 user = self.fixture.create_user(user_name, **kwargs)
1161 1161 if auto_cleanup:
1162 1162 self.user_ids.append(user.user_id)
1163 1163 return user
1164 1164
1165 1165 def create_additional_user_email(self, user, email):
1166 1166 uem = self.fixture.create_additional_user_email(user=user, email=email)
1167 1167 return uem
1168 1168
1169 1169 def create_user_with_group(self):
1170 1170 user = self.create_user()
1171 1171 user_group = self.create_user_group(members=[user])
1172 1172 return user, user_group
1173 1173
1174 1174 def create_user_group(self, owner=TEST_USER_ADMIN_LOGIN, members=None,
1175 1175 auto_cleanup=True, **kwargs):
1176 1176 group_name = "{prefix}_usergroup_{count}".format(
1177 1177 prefix=self._test_name,
1178 1178 count=len(self.user_group_ids))
1179 1179 user_group = self.fixture.create_user_group(
1180 1180 group_name, cur_user=owner, **kwargs)
1181 1181
1182 1182 if auto_cleanup:
1183 1183 self.user_group_ids.append(user_group.users_group_id)
1184 1184 if members:
1185 1185 for user in members:
1186 1186 UserGroupModel().add_user_to_group(user_group, user)
1187 1187 return user_group
1188 1188
1189 1189 def grant_user_permission(self, user_name, permission_name):
1190 1190 self.inherit_default_user_permissions(user_name, False)
1191 1191 self.user_permissions.append((user_name, permission_name))
1192 1192
1193 1193 def grant_user_permission_to_repo_group(
1194 1194 self, repo_group, user, permission_name):
1195 1195 permission = RepoGroupModel().grant_user_permission(
1196 1196 repo_group, user, permission_name)
1197 1197 self.user_repo_group_permission_ids.append(
1198 1198 (repo_group.group_id, user.user_id))
1199 1199 return permission
1200 1200
1201 1201 def grant_user_group_permission_to_repo_group(
1202 1202 self, repo_group, user_group, permission_name):
1203 1203 permission = RepoGroupModel().grant_user_group_permission(
1204 1204 repo_group, user_group, permission_name)
1205 1205 self.user_group_repo_group_permission_ids.append(
1206 1206 (repo_group.group_id, user_group.users_group_id))
1207 1207 return permission
1208 1208
1209 1209 def grant_user_permission_to_repo(
1210 1210 self, repo, user, permission_name):
1211 1211 permission = RepoModel().grant_user_permission(
1212 1212 repo, user, permission_name)
1213 1213 self.user_repo_permission_ids.append(
1214 1214 (repo.repo_id, user.user_id))
1215 1215 return permission
1216 1216
1217 1217 def grant_user_group_permission_to_repo(
1218 1218 self, repo, user_group, permission_name):
1219 1219 permission = RepoModel().grant_user_group_permission(
1220 1220 repo, user_group, permission_name)
1221 1221 self.user_group_repo_permission_ids.append(
1222 1222 (repo.repo_id, user_group.users_group_id))
1223 1223 return permission
1224 1224
1225 1225 def grant_user_permission_to_user_group(
1226 1226 self, target_user_group, user, permission_name):
1227 1227 permission = UserGroupModel().grant_user_permission(
1228 1228 target_user_group, user, permission_name)
1229 1229 self.user_user_group_permission_ids.append(
1230 1230 (target_user_group.users_group_id, user.user_id))
1231 1231 return permission
1232 1232
1233 1233 def grant_user_group_permission_to_user_group(
1234 1234 self, target_user_group, user_group, permission_name):
1235 1235 permission = UserGroupModel().grant_user_group_permission(
1236 1236 target_user_group, user_group, permission_name)
1237 1237 self.user_group_user_group_permission_ids.append(
1238 1238 (target_user_group.users_group_id, user_group.users_group_id))
1239 1239 return permission
1240 1240
1241 1241 def revoke_user_permission(self, user_name, permission_name):
1242 1242 self.inherit_default_user_permissions(user_name, True)
1243 1243 UserModel().revoke_perm(user_name, permission_name)
1244 1244
1245 1245 def inherit_default_user_permissions(self, user_name, value):
1246 1246 user = UserModel().get_by_username(user_name)
1247 1247 user.inherit_default_permissions = value
1248 1248 Session().add(user)
1249 1249 Session().commit()
1250 1250
1251 1251 def cleanup(self):
1252 1252 self._cleanup_permissions()
1253 1253 self._cleanup_repos()
1254 1254 self._cleanup_repo_groups()
1255 1255 self._cleanup_user_groups()
1256 1256 self._cleanup_users()
1257 1257
1258 1258 def _cleanup_permissions(self):
1259 1259 if self.user_permissions:
1260 1260 for user_name, permission_name in self.user_permissions:
1261 1261 self.revoke_user_permission(user_name, permission_name)
1262 1262
1263 1263 for permission in self.user_repo_permission_ids:
1264 1264 RepoModel().revoke_user_permission(*permission)
1265 1265
1266 1266 for permission in self.user_group_repo_permission_ids:
1267 1267 RepoModel().revoke_user_group_permission(*permission)
1268 1268
1269 1269 for permission in self.user_repo_group_permission_ids:
1270 1270 RepoGroupModel().revoke_user_permission(*permission)
1271 1271
1272 1272 for permission in self.user_group_repo_group_permission_ids:
1273 1273 RepoGroupModel().revoke_user_group_permission(*permission)
1274 1274
1275 1275 for permission in self.user_user_group_permission_ids:
1276 1276 UserGroupModel().revoke_user_permission(*permission)
1277 1277
1278 1278 for permission in self.user_group_user_group_permission_ids:
1279 1279 UserGroupModel().revoke_user_group_permission(*permission)
1280 1280
1281 1281 def _cleanup_repo_groups(self):
1282 1282 def _repo_group_compare(first_group_id, second_group_id):
1283 1283 """
1284 1284 Gives higher priority to the groups with the most complex paths
1285 1285 """
1286 1286 first_group = RepoGroup.get(first_group_id)
1287 1287 second_group = RepoGroup.get(second_group_id)
1288 1288 first_group_parts = (
1289 1289 len(first_group.group_name.split('/')) if first_group else 0)
1290 1290 second_group_parts = (
1291 1291 len(second_group.group_name.split('/')) if second_group else 0)
1292 1292 return cmp(second_group_parts, first_group_parts)
1293 1293
1294 1294 sorted_repo_group_ids = sorted(
1295 1295 self.repo_group_ids, key=functools.cmp_to_key(_repo_group_compare))
1296 1296 for repo_group_id in sorted_repo_group_ids:
1297 1297 self.fixture.destroy_repo_group(repo_group_id)
1298 1298
1299 1299 def _cleanup_repos(self):
1300 1300 sorted_repos_ids = sorted(self.repos_ids)
1301 1301 for repo_id in sorted_repos_ids:
1302 1302 self.fixture.destroy_repo(repo_id)
1303 1303
1304 1304 def _cleanup_user_groups(self):
1305 1305 def _user_group_compare(first_group_id, second_group_id):
1306 1306 """
1307 1307 Gives higher priority to the groups with the most complex paths
1308 1308 """
1309 1309 first_group = UserGroup.get(first_group_id)
1310 1310 second_group = UserGroup.get(second_group_id)
1311 1311 first_group_parts = (
1312 1312 len(first_group.users_group_name.split('/'))
1313 1313 if first_group else 0)
1314 1314 second_group_parts = (
1315 1315 len(second_group.users_group_name.split('/'))
1316 1316 if second_group else 0)
1317 1317 return cmp(second_group_parts, first_group_parts)
1318 1318
1319 1319 sorted_user_group_ids = sorted(
1320 1320 self.user_group_ids, key=functools.cmp_to_key(_user_group_compare))
1321 1321 for user_group_id in sorted_user_group_ids:
1322 1322 self.fixture.destroy_user_group(user_group_id)
1323 1323
1324 1324 def _cleanup_users(self):
1325 1325 for user_id in self.user_ids:
1326 1326 self.fixture.destroy_user(user_id)
1327 1327
1328 1328
1329 1329 @pytest.fixture(scope='session')
1330 1330 def testrun():
1331 1331 return {
1332 1332 'uuid': uuid.uuid4(),
1333 1333 'start': datetime.datetime.utcnow().isoformat(),
1334 1334 'timestamp': int(time.time()),
1335 1335 }
1336 1336
1337 1337
1338 1338 class AppenlightClient(object):
1339 1339
1340 1340 url_template = '{url}?protocol_version=0.5'
1341 1341
1342 1342 def __init__(
1343 1343 self, url, api_key, add_server=True, add_timestamp=True,
1344 1344 namespace=None, request=None, testrun=None):
1345 1345 self.url = self.url_template.format(url=url)
1346 1346 self.api_key = api_key
1347 1347 self.add_server = add_server
1348 1348 self.add_timestamp = add_timestamp
1349 1349 self.namespace = namespace
1350 1350 self.request = request
1351 1351 self.server = socket.getfqdn(socket.gethostname())
1352 1352 self.tags_before = {}
1353 1353 self.tags_after = {}
1354 1354 self.stats = []
1355 1355 self.testrun = testrun or {}
1356 1356
1357 1357 def tag_before(self, tag, value):
1358 1358 self.tags_before[tag] = value
1359 1359
1360 1360 def tag_after(self, tag, value):
1361 1361 self.tags_after[tag] = value
1362 1362
1363 1363 def collect(self, data):
1364 1364 if self.add_server:
1365 1365 data.setdefault('server', self.server)
1366 1366 if self.add_timestamp:
1367 1367 data.setdefault('date', datetime.datetime.utcnow().isoformat())
1368 1368 if self.namespace:
1369 1369 data.setdefault('namespace', self.namespace)
1370 1370 if self.request:
1371 1371 data.setdefault('request', self.request)
1372 1372 self.stats.append(data)
1373 1373
1374 1374 def send_stats(self):
1375 1375 tags = [
1376 1376 ('testrun', self.request),
1377 1377 ('testrun.start', self.testrun['start']),
1378 1378 ('testrun.timestamp', self.testrun['timestamp']),
1379 1379 ('test', self.namespace),
1380 1380 ]
1381 1381 for key, value in self.tags_before.items():
1382 1382 tags.append((key + '.before', value))
1383 1383 try:
1384 1384 delta = self.tags_after[key] - value
1385 1385 tags.append((key + '.delta', delta))
1386 1386 except Exception:
1387 1387 pass
1388 1388 for key, value in self.tags_after.items():
1389 1389 tags.append((key + '.after', value))
1390 1390 self.collect({
1391 1391 'message': "Collected tags",
1392 1392 'tags': tags,
1393 1393 })
1394 1394
1395 1395 response = requests.post(
1396 1396 self.url,
1397 1397 headers={
1398 1398 'X-appenlight-api-key': self.api_key},
1399 1399 json=self.stats,
1400 1400 )
1401 1401
1402 1402 if not response.status_code == 200:
1403 1403 pprint.pprint(self.stats)
1404 1404 print(response.headers)
1405 1405 print(response.text)
1406 1406 raise Exception('Sending to appenlight failed')
1407 1407
1408 1408
1409 1409 @pytest.fixture()
1410 1410 def gist_util(request, db_connection):
1411 1411 """
1412 1412 Provides a wired instance of `GistUtility` with integrated cleanup.
1413 1413 """
1414 1414 utility = GistUtility()
1415 1415 request.addfinalizer(utility.cleanup)
1416 1416 return utility
1417 1417
1418 1418
1419 1419 class GistUtility(object):
1420 1420 def __init__(self):
1421 1421 self.fixture = Fixture()
1422 1422 self.gist_ids = []
1423 1423
1424 1424 def create_gist(self, **kwargs):
1425 1425 gist = self.fixture.create_gist(**kwargs)
1426 1426 self.gist_ids.append(gist.gist_id)
1427 1427 return gist
1428 1428
1429 1429 def cleanup(self):
1430 1430 for id_ in self.gist_ids:
1431 1431 self.fixture.destroy_gists(str(id_))
1432 1432
1433 1433
1434 1434 @pytest.fixture()
1435 1435 def enabled_backends(request):
1436 1436 backends = request.config.option.backends
1437 1437 return backends[:]
1438 1438
1439 1439
1440 1440 @pytest.fixture()
1441 1441 def settings_util(request, db_connection):
1442 1442 """
1443 1443 Provides a wired instance of `SettingsUtility` with integrated cleanup.
1444 1444 """
1445 1445 utility = SettingsUtility()
1446 1446 request.addfinalizer(utility.cleanup)
1447 1447 return utility
1448 1448
1449 1449
1450 1450 class SettingsUtility(object):
1451 1451 def __init__(self):
1452 1452 self.rhodecode_ui_ids = []
1453 1453 self.rhodecode_setting_ids = []
1454 1454 self.repo_rhodecode_ui_ids = []
1455 1455 self.repo_rhodecode_setting_ids = []
1456 1456
1457 1457 def create_repo_rhodecode_ui(
1458 1458 self, repo, section, value, key=None, active=True, cleanup=True):
1459 1459 key = key or sha1_safe(f'{section}{value}{repo.repo_id}')
1460 1460
1461 1461 setting = RepoRhodeCodeUi()
1462 1462 setting.repository_id = repo.repo_id
1463 1463 setting.ui_section = section
1464 1464 setting.ui_value = value
1465 1465 setting.ui_key = key
1466 1466 setting.ui_active = active
1467 1467 Session().add(setting)
1468 1468 Session().commit()
1469 1469
1470 1470 if cleanup:
1471 1471 self.repo_rhodecode_ui_ids.append(setting.ui_id)
1472 1472 return setting
1473 1473
1474 1474 def create_rhodecode_ui(
1475 1475 self, section, value, key=None, active=True, cleanup=True):
1476 1476 key = key or sha1_safe(f'{section}{value}')
1477 1477
1478 1478 setting = RhodeCodeUi()
1479 1479 setting.ui_section = section
1480 1480 setting.ui_value = value
1481 1481 setting.ui_key = key
1482 1482 setting.ui_active = active
1483 1483 Session().add(setting)
1484 1484 Session().commit()
1485 1485
1486 1486 if cleanup:
1487 1487 self.rhodecode_ui_ids.append(setting.ui_id)
1488 1488 return setting
1489 1489
1490 1490 def create_repo_rhodecode_setting(
1491 1491 self, repo, name, value, type_, cleanup=True):
1492 1492 setting = RepoRhodeCodeSetting(
1493 1493 repo.repo_id, key=name, val=value, type=type_)
1494 1494 Session().add(setting)
1495 1495 Session().commit()
1496 1496
1497 1497 if cleanup:
1498 1498 self.repo_rhodecode_setting_ids.append(setting.app_settings_id)
1499 1499 return setting
1500 1500
1501 1501 def create_rhodecode_setting(self, name, value, type_, cleanup=True):
1502 1502 setting = RhodeCodeSetting(key=name, val=value, type=type_)
1503 1503 Session().add(setting)
1504 1504 Session().commit()
1505 1505
1506 1506 if cleanup:
1507 1507 self.rhodecode_setting_ids.append(setting.app_settings_id)
1508 1508
1509 1509 return setting
1510 1510
1511 1511 def cleanup(self):
1512 1512 for id_ in self.rhodecode_ui_ids:
1513 1513 setting = RhodeCodeUi.get(id_)
1514 1514 Session().delete(setting)
1515 1515
1516 1516 for id_ in self.rhodecode_setting_ids:
1517 1517 setting = RhodeCodeSetting.get(id_)
1518 1518 Session().delete(setting)
1519 1519
1520 1520 for id_ in self.repo_rhodecode_ui_ids:
1521 1521 setting = RepoRhodeCodeUi.get(id_)
1522 1522 Session().delete(setting)
1523 1523
1524 1524 for id_ in self.repo_rhodecode_setting_ids:
1525 1525 setting = RepoRhodeCodeSetting.get(id_)
1526 1526 Session().delete(setting)
1527 1527
1528 1528 Session().commit()
1529 1529
1530 1530
1531 1531 @pytest.fixture()
1532 1532 def no_notifications(request):
1533 1533 notification_patcher = mock.patch(
1534 1534 'rhodecode.model.notification.NotificationModel.create')
1535 1535 notification_patcher.start()
1536 1536 request.addfinalizer(notification_patcher.stop)
1537 1537
1538 1538
1539 1539 @pytest.fixture(scope='session')
1540 1540 def repeat(request):
1541 1541 """
1542 1542 The number of repetitions is based on this fixture.
1543 1543
1544 1544 Slower calls may divide it by 10 or 100. It is chosen in a way so that the
1545 1545 tests are not too slow in our default test suite.
1546 1546 """
1547 1547 return request.config.getoption('--repeat')
1548 1548
1549 1549
1550 1550 @pytest.fixture()
1551 1551 def rhodecode_fixtures():
1552 1552 return Fixture()
1553 1553
1554 1554
1555 1555 @pytest.fixture()
1556 1556 def context_stub():
1557 1557 """
1558 1558 Stub context object.
1559 1559 """
1560 1560 context = pyramid.testing.DummyResource()
1561 1561 return context
1562 1562
1563 1563
1564 1564 @pytest.fixture()
1565 1565 def request_stub():
1566 1566 """
1567 1567 Stub request object.
1568 1568 """
1569 1569 from rhodecode.lib.base import bootstrap_request
1570 1570 request = bootstrap_request(scheme='https')
1571 1571 return request
1572 1572
1573 1573
1574 1574 @pytest.fixture()
1575 1575 def config_stub(request, request_stub):
1576 1576 """
1577 1577 Set up pyramid.testing and return the Configurator.
1578 1578 """
1579 1579 from rhodecode.lib.base import bootstrap_config
1580 1580 config = bootstrap_config(request=request_stub)
1581 1581
1582 1582 @request.addfinalizer
1583 1583 def cleanup():
1584 1584 pyramid.testing.tearDown()
1585 1585
1586 1586 return config
1587 1587
1588 1588
1589 1589 @pytest.fixture()
1590 1590 def StubIntegrationType():
1591 1591 class _StubIntegrationType(IntegrationTypeBase):
1592 1592 """ Test integration type class """
1593 1593
1594 1594 key = 'test'
1595 1595 display_name = 'Test integration type'
1596 1596 description = 'A test integration type for testing'
1597 1597
1598 1598 @classmethod
1599 1599 def icon(cls):
1600 1600 return 'test_icon_html_image'
1601 1601
1602 1602 def __init__(self, settings):
1603 1603 super(_StubIntegrationType, self).__init__(settings)
1604 1604 self.sent_events = [] # for testing
1605 1605
1606 1606 def send_event(self, event):
1607 1607 self.sent_events.append(event)
1608 1608
1609 1609 def settings_schema(self):
1610 1610 class SettingsSchema(colander.Schema):
1611 1611 test_string_field = colander.SchemaNode(
1612 1612 colander.String(),
1613 1613 missing=colander.required,
1614 1614 title='test string field',
1615 1615 )
1616 1616 test_int_field = colander.SchemaNode(
1617 1617 colander.Int(),
1618 1618 title='some integer setting',
1619 1619 )
1620 1620 return SettingsSchema()
1621 1621
1622 1622
1623 1623 integration_type_registry.register_integration_type(_StubIntegrationType)
1624 1624 return _StubIntegrationType
1625 1625
1626 1626
1627 1627 @pytest.fixture()
1628 1628 def stub_integration_settings():
1629 1629 return {
1630 1630 'test_string_field': 'some data',
1631 1631 'test_int_field': 100,
1632 1632 }
1633 1633
1634 1634
1635 1635 @pytest.fixture()
1636 1636 def repo_integration_stub(request, repo_stub, StubIntegrationType,
1637 1637 stub_integration_settings):
1638 1638 integration = IntegrationModel().create(
1639 1639 StubIntegrationType, settings=stub_integration_settings, enabled=True,
1640 1640 name='test repo integration',
1641 1641 repo=repo_stub, repo_group=None, child_repos_only=None)
1642 1642
1643 1643 @request.addfinalizer
1644 1644 def cleanup():
1645 1645 IntegrationModel().delete(integration)
1646 1646
1647 1647 return integration
1648 1648
1649 1649
1650 1650 @pytest.fixture()
1651 1651 def repogroup_integration_stub(request, test_repo_group, StubIntegrationType,
1652 1652 stub_integration_settings):
1653 1653 integration = IntegrationModel().create(
1654 1654 StubIntegrationType, settings=stub_integration_settings, enabled=True,
1655 1655 name='test repogroup integration',
1656 1656 repo=None, repo_group=test_repo_group, child_repos_only=True)
1657 1657
1658 1658 @request.addfinalizer
1659 1659 def cleanup():
1660 1660 IntegrationModel().delete(integration)
1661 1661
1662 1662 return integration
1663 1663
1664 1664
1665 1665 @pytest.fixture()
1666 1666 def repogroup_recursive_integration_stub(request, test_repo_group,
1667 1667 StubIntegrationType, stub_integration_settings):
1668 1668 integration = IntegrationModel().create(
1669 1669 StubIntegrationType, settings=stub_integration_settings, enabled=True,
1670 1670 name='test recursive repogroup integration',
1671 1671 repo=None, repo_group=test_repo_group, child_repos_only=False)
1672 1672
1673 1673 @request.addfinalizer
1674 1674 def cleanup():
1675 1675 IntegrationModel().delete(integration)
1676 1676
1677 1677 return integration
1678 1678
1679 1679
1680 1680 @pytest.fixture()
1681 1681 def global_integration_stub(request, StubIntegrationType,
1682 1682 stub_integration_settings):
1683 1683 integration = IntegrationModel().create(
1684 1684 StubIntegrationType, settings=stub_integration_settings, enabled=True,
1685 1685 name='test global integration',
1686 1686 repo=None, repo_group=None, child_repos_only=None)
1687 1687
1688 1688 @request.addfinalizer
1689 1689 def cleanup():
1690 1690 IntegrationModel().delete(integration)
1691 1691
1692 1692 return integration
1693 1693
1694 1694
1695 1695 @pytest.fixture()
1696 1696 def root_repos_integration_stub(request, StubIntegrationType,
1697 1697 stub_integration_settings):
1698 1698 integration = IntegrationModel().create(
1699 1699 StubIntegrationType, settings=stub_integration_settings, enabled=True,
1700 1700 name='test global integration',
1701 1701 repo=None, repo_group=None, child_repos_only=True)
1702 1702
1703 1703 @request.addfinalizer
1704 1704 def cleanup():
1705 1705 IntegrationModel().delete(integration)
1706 1706
1707 1707 return integration
1708 1708
1709 1709
1710 1710 @pytest.fixture()
1711 1711 def local_dt_to_utc():
1712 1712 def _factory(dt):
1713 1713 return dt.replace(tzinfo=dateutil.tz.tzlocal()).astimezone(
1714 1714 dateutil.tz.tzutc()).replace(tzinfo=None)
1715 1715 return _factory
1716 1716
1717 1717
1718 1718 @pytest.fixture()
1719 1719 def disable_anonymous_user(request, baseapp):
1720 1720 set_anonymous_access(False)
1721 1721
1722 1722 @request.addfinalizer
1723 1723 def cleanup():
1724 1724 set_anonymous_access(True)
1725 1725
1726 1726
1727 1727 @pytest.fixture(scope='module')
1728 1728 def rc_fixture(request):
1729 1729 return Fixture()
1730 1730
1731 1731
1732 1732 @pytest.fixture()
1733 1733 def repo_groups(request):
1734 1734 fixture = Fixture()
1735 1735
1736 1736 session = Session()
1737 1737 zombie_group = fixture.create_repo_group('zombie')
1738 1738 parent_group = fixture.create_repo_group('parent')
1739 1739 child_group = fixture.create_repo_group('parent/child')
1740 1740 groups_in_db = session.query(RepoGroup).all()
1741 1741 assert len(groups_in_db) == 3
1742 1742 assert child_group.group_parent_id == parent_group.group_id
1743 1743
1744 1744 @request.addfinalizer
1745 1745 def cleanup():
1746 1746 fixture.destroy_repo_group(zombie_group)
1747 1747 fixture.destroy_repo_group(child_group)
1748 1748 fixture.destroy_repo_group(parent_group)
1749 1749
1750 1750 return zombie_group, parent_group, child_group
@@ -1,355 +1,359 b''
1 1
2 2 # Copyright (C) 2010-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 import logging
21 21 import io
22 22
23 23 import mock
24 24 import msgpack
25 25 import pytest
26 26 import tempfile
27 27
28 from rhodecode.lib import hooks_daemon
28 from rhodecode.lib.hook_daemon import http_hooks_deamon
29 from rhodecode.lib.hook_daemon import celery_hooks_deamon
30 from rhodecode.lib.hook_daemon import hook_module
31 from rhodecode.lib.hook_daemon import base as hook_base
29 32 from rhodecode.lib.str_utils import safe_bytes
30 33 from rhodecode.tests.utils import assert_message_in_log
31 34 from rhodecode.lib.ext_json import json
32 35
33 test_proto = hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO
36 test_proto = http_hooks_deamon.HooksHttpHandler.MSGPACK_HOOKS_PROTO
34 37
35 38
36 39 class TestHooks(object):
37 40 def test_hooks_can_be_used_as_a_context_processor(self):
38 hooks = hooks_daemon.Hooks()
41 hooks = hook_module.Hooks()
39 42 with hooks as return_value:
40 43 pass
41 44 assert hooks == return_value
42 45
43 46
44 47 class TestHooksHttpHandler(object):
45 48 def test_read_request_parses_method_name_and_arguments(self):
46 49 data = {
47 50 'method': 'test',
48 51 'extras': {
49 52 'param1': 1,
50 53 'param2': 'a'
51 54 }
52 55 }
53 56 request = self._generate_post_request(data)
54 57 hooks_patcher = mock.patch.object(
55 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
58 hook_module.Hooks, data['method'], create=True, return_value=1)
56 59
57 60 with hooks_patcher as hooks_mock:
58 handler = hooks_daemon.HooksHttpHandler
61 handler = http_hooks_deamon.HooksHttpHandler
59 62 handler.DEFAULT_HOOKS_PROTO = test_proto
60 63 handler.wbufsize = 10240
61 64 MockServer(handler, request)
62 65
63 66 hooks_mock.assert_called_once_with(data['extras'])
64 67
65 68 def test_hooks_serialized_result_is_returned(self):
66 69 request = self._generate_post_request({})
67 70 rpc_method = 'test'
68 71 hook_result = {
69 72 'first': 'one',
70 73 'second': 2
71 74 }
72 75 extras = {}
73 76
74 77 # patching our _read to return test method and proto used
75 78 read_patcher = mock.patch.object(
76 hooks_daemon.HooksHttpHandler, '_read_request',
79 http_hooks_deamon.HooksHttpHandler, '_read_request',
77 80 return_value=(test_proto, rpc_method, extras))
78 81
79 82 # patch Hooks instance to return hook_result data on 'test' call
80 83 hooks_patcher = mock.patch.object(
81 hooks_daemon.Hooks, rpc_method, create=True,
84 hook_module.Hooks, rpc_method, create=True,
82 85 return_value=hook_result)
83 86
84 87 with read_patcher, hooks_patcher:
85 handler = hooks_daemon.HooksHttpHandler
88 handler = http_hooks_deamon.HooksHttpHandler
86 89 handler.DEFAULT_HOOKS_PROTO = test_proto
87 90 handler.wbufsize = 10240
88 91 server = MockServer(handler, request)
89 92
90 expected_result = hooks_daemon.HooksHttpHandler.serialize_data(hook_result)
93 expected_result = http_hooks_deamon.HooksHttpHandler.serialize_data(hook_result)
91 94
92 95 server.request.output_stream.seek(0)
93 96 assert server.request.output_stream.readlines()[-1] == expected_result
94 97
95 98 def test_exception_is_returned_in_response(self):
96 99 request = self._generate_post_request({})
97 100 rpc_method = 'test'
98 101
99 102 read_patcher = mock.patch.object(
100 hooks_daemon.HooksHttpHandler, '_read_request',
103 http_hooks_deamon.HooksHttpHandler, '_read_request',
101 104 return_value=(test_proto, rpc_method, {}))
102 105
103 106 hooks_patcher = mock.patch.object(
104 hooks_daemon.Hooks, rpc_method, create=True,
107 hook_module.Hooks, rpc_method, create=True,
105 108 side_effect=Exception('Test exception'))
106 109
107 110 with read_patcher, hooks_patcher:
108 handler = hooks_daemon.HooksHttpHandler
111 handler = http_hooks_deamon.HooksHttpHandler
109 112 handler.DEFAULT_HOOKS_PROTO = test_proto
110 113 handler.wbufsize = 10240
111 114 server = MockServer(handler, request)
112 115
113 116 server.request.output_stream.seek(0)
114 117 data = server.request.output_stream.readlines()
115 118 msgpack_data = b''.join(data[5:])
116 org_exc = hooks_daemon.HooksHttpHandler.deserialize_data(msgpack_data)
119 org_exc = http_hooks_deamon.HooksHttpHandler.deserialize_data(msgpack_data)
117 120 expected_result = {
118 121 'exception': 'Exception',
119 122 'exception_traceback': org_exc['exception_traceback'],
120 123 'exception_args': ['Test exception']
121 124 }
122 125 assert org_exc == expected_result
123 126
124 127 def test_log_message_writes_to_debug_log(self, caplog):
125 128 ip_port = ('0.0.0.0', 8888)
126 handler = hooks_daemon.HooksHttpHandler(
127 MockRequest('POST /'), ip_port, mock.Mock())
129 handler = http_hooks_deamon.HooksHttpHandler(MockRequest('POST /'), ip_port, mock.Mock())
128 130 fake_date = '1/Nov/2015 00:00:00'
129 131 date_patcher = mock.patch.object(
130 132 handler, 'log_date_time_string', return_value=fake_date)
131 133
132 134 with date_patcher, caplog.at_level(logging.DEBUG):
133 135 handler.log_message('Some message %d, %s', 123, 'string')
134 136
135 137 expected_message = f"HOOKS: client={ip_port} - - [{fake_date}] Some message 123, string"
136 138
137 139 assert_message_in_log(
138 140 caplog.records, expected_message,
139 levelno=logging.DEBUG, module='hooks_daemon')
141 levelno=logging.DEBUG, module='http_hooks_deamon')
140 142
141 143 def _generate_post_request(self, data, proto=test_proto):
142 if proto == hooks_daemon.HooksHttpHandler.MSGPACK_HOOKS_PROTO:
144 if proto == http_hooks_deamon.HooksHttpHandler.MSGPACK_HOOKS_PROTO:
143 145 payload = msgpack.packb(data)
144 146 else:
145 147 payload = json.dumps(data)
146 148
147 149 return b'POST / HTTP/1.0\nContent-Length: %d\n\n%b' % (
148 150 len(payload), payload)
149 151
150 152
151 153 class ThreadedHookCallbackDaemon(object):
152 154 def test_constructor_calls_prepare(self):
153 155 prepare_daemon_patcher = mock.patch.object(
154 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
156 http_hooks_deamon.ThreadedHookCallbackDaemon, '_prepare')
155 157 with prepare_daemon_patcher as prepare_daemon_mock:
156 hooks_daemon.ThreadedHookCallbackDaemon()
158 http_hooks_deamon.ThreadedHookCallbackDaemon()
157 159 prepare_daemon_mock.assert_called_once_with()
158 160
159 161 def test_run_is_called_on_context_start(self):
160 162 patchers = mock.patch.multiple(
161 hooks_daemon.ThreadedHookCallbackDaemon,
163 http_hooks_deamon.ThreadedHookCallbackDaemon,
162 164 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
163 165
164 166 with patchers as mocks:
165 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
167 daemon = http_hooks_deamon.ThreadedHookCallbackDaemon()
166 168 with daemon as daemon_context:
167 169 pass
168 170 mocks['_run'].assert_called_once_with()
169 171 assert daemon_context == daemon
170 172
171 173 def test_stop_is_called_on_context_exit(self):
172 174 patchers = mock.patch.multiple(
173 hooks_daemon.ThreadedHookCallbackDaemon,
175 http_hooks_deamon.ThreadedHookCallbackDaemon,
174 176 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
175 177
176 178 with patchers as mocks:
177 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
179 daemon = http_hooks_deamon.ThreadedHookCallbackDaemon()
178 180 with daemon as daemon_context:
179 181 assert mocks['_stop'].call_count == 0
180 182
181 183 mocks['_stop'].assert_called_once_with()
182 184 assert daemon_context == daemon
183 185
184 186
185 187 class TestHttpHooksCallbackDaemon(object):
186 188 def test_hooks_callback_generates_new_port(self, caplog):
187 189 with caplog.at_level(logging.DEBUG):
188 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
190 daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
189 191 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
190 192
191 193 with caplog.at_level(logging.DEBUG):
192 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
194 daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host=None, port=None)
193 195 assert daemon._daemon.server_address[1] in range(0, 66000)
194 196 assert daemon._daemon.server_address[0] != '127.0.0.1'
195 197
196 198 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
197 199 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
198 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
200 daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
199 201 assert daemon._daemon == tcp_server
200 202
201 203 _, port = tcp_server.server_address
202 204
203 205 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
204 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
206 f"hook object: <class 'rhodecode.lib.hook_daemon.http_hooks_deamon.HooksHttpHandler'>"
205 207 assert_message_in_log(
206 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
208 caplog.records, msg, levelno=logging.DEBUG, module='http_hooks_deamon')
207 209
208 210 def test_prepare_inits_hooks_uri_and_logs_it(
209 211 self, tcp_server, caplog):
210 212 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
211 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
213 daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
212 214
213 215 _, port = tcp_server.server_address
214 216 expected_uri = '{}:{}'.format('127.0.0.1', port)
215 217 assert daemon.hooks_uri == expected_uri
216 218
217 219 msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \
218 f"hook object: <class 'rhodecode.lib.hooks_daemon.HooksHttpHandler'>"
220 f"hook object: <class 'rhodecode.lib.hook_daemon.http_hooks_deamon.HooksHttpHandler'>"
221
219 222 assert_message_in_log(
220 223 caplog.records, msg,
221 levelno=logging.DEBUG, module='hooks_daemon')
224 levelno=logging.DEBUG, module='http_hooks_deamon')
222 225
223 226 def test_run_creates_a_thread(self, tcp_server):
224 227 thread = mock.Mock()
225 228
226 229 with self._tcp_patcher(tcp_server):
227 daemon = hooks_daemon.HttpHooksCallbackDaemon()
230 daemon = http_hooks_deamon.HttpHooksCallbackDaemon()
228 231
229 232 with self._thread_patcher(thread) as thread_mock:
230 233 daemon._run()
231 234
232 235 thread_mock.assert_called_once_with(
233 236 target=tcp_server.serve_forever,
234 237 kwargs={'poll_interval': daemon.POLL_INTERVAL})
235 238 assert thread.daemon is True
236 239 thread.start.assert_called_once_with()
237 240
238 241 def test_run_logs(self, tcp_server, caplog):
239 242
240 243 with self._tcp_patcher(tcp_server):
241 daemon = hooks_daemon.HttpHooksCallbackDaemon()
244 daemon = http_hooks_deamon.HttpHooksCallbackDaemon()
242 245
243 246 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
244 247 daemon._run()
245 248
246 249 assert_message_in_log(
247 250 caplog.records,
248 251 'Running thread-based loop of callback daemon in background',
249 levelno=logging.DEBUG, module='hooks_daemon')
252 levelno=logging.DEBUG, module='http_hooks_deamon')
250 253
251 254 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
252 255 thread = mock.Mock()
253 256
254 257 with self._tcp_patcher(tcp_server):
255 daemon = hooks_daemon.HttpHooksCallbackDaemon()
258 daemon = http_hooks_deamon.HttpHooksCallbackDaemon()
256 259
257 260 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
258 261 with daemon:
259 262 assert daemon._daemon == tcp_server
260 263 assert daemon._callback_thread == thread
261 264
262 265 assert daemon._daemon is None
263 266 assert daemon._callback_thread is None
264 267 tcp_server.shutdown.assert_called_with()
265 268 thread.join.assert_called_once_with()
266 269
267 270 assert_message_in_log(
268 271 caplog.records, 'Waiting for background thread to finish.',
269 levelno=logging.DEBUG, module='hooks_daemon')
272 levelno=logging.DEBUG, module='http_hooks_deamon')
270 273
271 274 def _tcp_patcher(self, tcp_server):
272 275 return mock.patch.object(
273 hooks_daemon, 'TCPServer', return_value=tcp_server)
276 http_hooks_deamon, 'TCPServer', return_value=tcp_server)
274 277
275 278 def _thread_patcher(self, thread):
276 279 return mock.patch.object(
277 hooks_daemon.threading, 'Thread', return_value=thread)
280 http_hooks_deamon.threading, 'Thread', return_value=thread)
278 281
279 282
280 283 class TestPrepareHooksDaemon(object):
284
281 285 @pytest.mark.parametrize('protocol', ('celery',))
282 286 def test_returns_celery_hooks_callback_daemon_when_celery_protocol_specified(
283 287 self, protocol):
284 288 with tempfile.NamedTemporaryFile(mode='w') as temp_file:
285 289 temp_file.write("[app:main]\ncelery.broker_url = redis://redis/0\n"
286 290 "celery.result_backend = redis://redis/0")
287 291 temp_file.flush()
288 292 expected_extras = {'config': temp_file.name}
289 callback, extras = hooks_daemon.prepare_callback_daemon(
293 callback, extras = hook_base.prepare_callback_daemon(
290 294 expected_extras, protocol=protocol, host='')
291 assert isinstance(callback, hooks_daemon.CeleryHooksCallbackDaemon)
295 assert isinstance(callback, celery_hooks_deamon.CeleryHooksCallbackDaemon)
292 296
293 297 @pytest.mark.parametrize('protocol, expected_class', (
294 ('http', hooks_daemon.HttpHooksCallbackDaemon),
298 ('http', http_hooks_deamon.HttpHooksCallbackDaemon),
295 299 ))
296 300 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
297 301 self, protocol, expected_class):
298 302 expected_extras = {
299 303 'extra1': 'value1',
300 304 'txn_id': 'txnid2',
301 305 'hooks_protocol': protocol.lower(),
302 306 'task_backend': '',
303 307 'task_queue': ''
304 308 }
305 callback, extras = hooks_daemon.prepare_callback_daemon(
309 callback, extras = hook_base.prepare_callback_daemon(
306 310 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
307 311 txn_id='txnid2')
308 312 assert isinstance(callback, expected_class)
309 313 extras.pop('hooks_uri')
310 314 expected_extras['time'] = extras['time']
311 315 assert extras == expected_extras
312 316
313 317 @pytest.mark.parametrize('protocol', (
314 318 'invalid',
315 319 'Http',
316 320 'HTTP',
317 321 ))
318 322 def test_raises_on_invalid_protocol(self, protocol):
319 323 expected_extras = {
320 324 'extra1': 'value1',
321 325 'hooks_protocol': protocol.lower()
322 326 }
323 327 with pytest.raises(Exception):
324 callback, extras = hooks_daemon.prepare_callback_daemon(
328 callback, extras = hook_base.prepare_callback_daemon(
325 329 expected_extras.copy(),
326 330 protocol=protocol, host='127.0.0.1')
327 331
328 332
329 333 class MockRequest(object):
330 334
331 335 def __init__(self, request):
332 336 self.request = request
333 337 self.input_stream = io.BytesIO(safe_bytes(self.request))
334 338 self.output_stream = io.BytesIO() # make it un-closable for testing invesitagion
335 339 self.output_stream.close = lambda: None
336 340
337 341 def makefile(self, mode, *args, **kwargs):
338 342 return self.output_stream if mode == 'wb' else self.input_stream
339 343
340 344
341 345 class MockServer(object):
342 346
343 347 def __init__(self, handler_cls, request):
344 348 ip_port = ('0.0.0.0', 8888)
345 349 self.request = MockRequest(request)
346 350 self.server_address = ip_port
347 351 self.handler = handler_cls(self.request, ip_port, self)
348 352
349 353
350 354 @pytest.fixture()
351 355 def tcp_server():
352 356 server = mock.Mock()
353 357 server.server_address = ('127.0.0.1', 8881)
354 358 server.wbufsize = 1024
355 359 return server
1 NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now