##// END OF EJS Templates
py3: 2to3 run
super-admin -
r1044:3e31405b python3
parent child Browse files
Show More
@@ -1,207 +1,207 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import textwrap
23 23 import string
24 24 import functools
25 25 import logging
26 26 import tempfile
27 27 import logging.config
28 28 log = logging.getLogger(__name__)
29 29
30 30 # skip keys, that are set here, so we don't double process those
31 31 set_keys = {
32 32 '__file__': ''
33 33 }
34 34
35 35
36 36 def str2bool(_str):
37 37 """
38 38 returns True/False value from given string, it tries to translate the
39 39 string into boolean
40 40
41 41 :param _str: string value to translate into boolean
42 42 :rtype: boolean
43 43 :returns: boolean from given string
44 44 """
45 45 if _str is None:
46 46 return False
47 47 if _str in (True, False):
48 48 return _str
49 49 _str = str(_str).strip().lower()
50 50 return _str in ('t', 'true', 'y', 'yes', 'on', '1')
51 51
52 52
53 53 def aslist(obj, sep=None, strip=True):
54 54 """
55 55 Returns given string separated by sep as list
56 56
57 57 :param obj:
58 58 :param sep:
59 59 :param strip:
60 60 """
61 if isinstance(obj, (basestring,)):
61 if isinstance(obj, str):
62 62 if obj in ['', ""]:
63 63 return []
64 64
65 65 lst = obj.split(sep)
66 66 if strip:
67 67 lst = [v.strip() for v in lst]
68 68 return lst
69 69 elif isinstance(obj, (list, tuple)):
70 70 return obj
71 71 elif obj is None:
72 72 return []
73 73 else:
74 74 return [obj]
75 75
76 76
77 77 class SettingsMaker(object):
78 78
79 79 def __init__(self, app_settings):
80 80 self.settings = app_settings
81 81
82 82 @classmethod
83 83 def _bool_func(cls, input_val):
84 84 if isinstance(input_val, unicode):
85 85 input_val = input_val.encode('utf8')
86 86 return str2bool(input_val)
87 87
88 88 @classmethod
89 89 def _int_func(cls, input_val):
90 90 return int(input_val)
91 91
92 92 @classmethod
93 93 def _list_func(cls, input_val, sep=','):
94 94 return aslist(input_val, sep=sep)
95 95
96 96 @classmethod
97 97 def _string_func(cls, input_val, lower=True):
98 98 if lower:
99 99 input_val = input_val.lower()
100 100 return input_val
101 101
102 102 @classmethod
103 103 def _float_func(cls, input_val):
104 104 return float(input_val)
105 105
106 106 @classmethod
107 107 def _dir_func(cls, input_val, ensure_dir=False, mode=0o755):
108 108
109 109 # ensure we have our dir created
110 110 if not os.path.isdir(input_val) and ensure_dir:
111 111 os.makedirs(input_val, mode=mode)
112 112
113 113 if not os.path.isdir(input_val):
114 114 raise Exception('Dir at {} does not exist'.format(input_val))
115 115 return input_val
116 116
117 117 @classmethod
118 118 def _file_path_func(cls, input_val, ensure_dir=False, mode=0o755):
119 119 dirname = os.path.dirname(input_val)
120 120 cls._dir_func(dirname, ensure_dir=ensure_dir)
121 121 return input_val
122 122
123 123 @classmethod
124 124 def _key_transformator(cls, key):
125 125 return "{}_{}".format('RC'.upper(), key.upper().replace('.', '_').replace('-', '_'))
126 126
127 127 def maybe_env_key(self, key):
128 128 # now maybe we have this KEY in env, search and use the value with higher priority.
129 129 transformed_key = self._key_transformator(key)
130 130 envvar_value = os.environ.get(transformed_key)
131 131 if envvar_value:
132 132 log.debug('using `%s` key instead of `%s` key for config', transformed_key, key)
133 133
134 134 return envvar_value
135 135
136 136 def env_expand(self):
137 137 replaced = {}
138 138 for k, v in self.settings.items():
139 139 if k not in set_keys:
140 140 envvar_value = self.maybe_env_key(k)
141 141 if envvar_value:
142 142 replaced[k] = envvar_value
143 143 set_keys[k] = envvar_value
144 144
145 145 # replace ALL keys updated
146 146 self.settings.update(replaced)
147 147
148 148 def enable_logging(self, logging_conf=None, level='INFO', formatter='generic'):
149 149 """
150 150 Helper to enable debug on running instance
151 151 :return:
152 152 """
153 153
154 154 if not str2bool(self.settings.get('logging.autoconfigure')):
155 155 log.info('logging configuration based on main .ini file')
156 156 return
157 157
158 158 if logging_conf is None:
159 159 logging_conf = self.settings.get('logging.logging_conf_file') or ''
160 160
161 161 if not os.path.isfile(logging_conf):
162 162 log.error('Unable to setup logging based on %s, '
163 163 'file does not exist.... specify path using logging.logging_conf_file= config setting. ', logging_conf)
164 164 return
165 165
166 166 with open(logging_conf, 'rb') as f:
167 167 ini_template = textwrap.dedent(f.read())
168 168 ini_template = string.Template(ini_template).safe_substitute(
169 169 RC_LOGGING_LEVEL=os.environ.get('RC_LOGGING_LEVEL', '') or level,
170 170 RC_LOGGING_FORMATTER=os.environ.get('RC_LOGGING_FORMATTER', '') or formatter
171 171 )
172 172
173 173 with tempfile.NamedTemporaryFile(prefix='rc_logging_', suffix='.ini', delete=False) as f:
174 174 log.info('Saved Temporary LOGGING config at %s', f.name)
175 175 f.write(ini_template)
176 176
177 177 logging.config.fileConfig(f.name)
178 178 os.remove(f.name)
179 179
180 180 def make_setting(self, key, default, lower=False, default_when_empty=False, parser=None):
181 181 input_val = self.settings.get(key, default)
182 182
183 183 if default_when_empty and not input_val:
184 184 # use default value when value is set in the config but it is empty
185 185 input_val = default
186 186
187 187 parser_func = {
188 188 'bool': self._bool_func,
189 189 'int': self._int_func,
190 190 'list': self._list_func,
191 191 'list:newline': functools.partial(self._list_func, sep='/n'),
192 192 'list:spacesep': functools.partial(self._list_func, sep=' '),
193 193 'string': functools.partial(self._string_func, lower=lower),
194 194 'dir': self._dir_func,
195 195 'dir:ensured': functools.partial(self._dir_func, ensure_dir=True),
196 196 'file': self._file_path_func,
197 197 'file:ensured': functools.partial(self._file_path_func, ensure_dir=True),
198 198 None: lambda i: i
199 199 }[parser]
200 200
201 201 envvar_value = self.maybe_env_key(key)
202 202 if envvar_value:
203 203 input_val = envvar_value
204 204 set_keys[key] = input_val
205 205
206 206 self.settings[key] = parser_func(input_val)
207 207 return self.settings[key]
@@ -1,272 +1,272 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 20 from webtest.app import TestApp as WebObTestApp
21 21 import simplejson as json
22 22
23 23 from vcsserver.git_lfs.app import create_app
24 24
25 25
26 26 @pytest.fixture(scope='function')
27 27 def git_lfs_app(tmpdir):
28 28 custom_app = WebObTestApp(create_app(
29 29 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
30 30 git_lfs_http_scheme='http'))
31 31 custom_app._store = str(tmpdir)
32 32 return custom_app
33 33
34 34
35 35 @pytest.fixture(scope='function')
36 36 def git_lfs_https_app(tmpdir):
37 37 custom_app = WebObTestApp(create_app(
38 38 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
39 39 git_lfs_http_scheme='https'))
40 40 custom_app._store = str(tmpdir)
41 41 return custom_app
42 42
43 43
44 44 @pytest.fixture()
45 45 def http_auth():
46 46 return {'HTTP_AUTHORIZATION': "Basic XXXXX"}
47 47
48 48
49 49 class TestLFSApplication(object):
50 50
51 51 def test_app_wrong_path(self, git_lfs_app):
52 52 git_lfs_app.get('/repo/info/lfs/xxx', status=404)
53 53
54 54 def test_app_deprecated_endpoint(self, git_lfs_app):
55 55 response = git_lfs_app.post('/repo/info/lfs/objects', status=501)
56 56 assert response.status_code == 501
57 assert json.loads(response.text) == {u'message': u'LFS: v1 api not supported'}
57 assert json.loads(response.text) == {'message': 'LFS: v1 api not supported'}
58 58
59 59 def test_app_lock_verify_api_not_available(self, git_lfs_app):
60 60 response = git_lfs_app.post('/repo/info/lfs/locks/verify', status=501)
61 61 assert response.status_code == 501
62 62 assert json.loads(response.text) == {
63 u'message': u'GIT LFS locking api not supported'}
63 'message': 'GIT LFS locking api not supported'}
64 64
65 65 def test_app_lock_api_not_available(self, git_lfs_app):
66 66 response = git_lfs_app.post('/repo/info/lfs/locks', status=501)
67 67 assert response.status_code == 501
68 68 assert json.loads(response.text) == {
69 u'message': u'GIT LFS locking api not supported'}
69 'message': 'GIT LFS locking api not supported'}
70 70
71 71 def test_app_batch_api_missing_auth(self, git_lfs_app):
72 72 git_lfs_app.post_json(
73 73 '/repo/info/lfs/objects/batch', params={}, status=403)
74 74
75 75 def test_app_batch_api_unsupported_operation(self, git_lfs_app, http_auth):
76 76 response = git_lfs_app.post_json(
77 77 '/repo/info/lfs/objects/batch', params={}, status=400,
78 78 extra_environ=http_auth)
79 79 assert json.loads(response.text) == {
80 u'message': u'unsupported operation mode: `None`'}
80 'message': 'unsupported operation mode: `None`'}
81 81
82 82 def test_app_batch_api_missing_objects(self, git_lfs_app, http_auth):
83 83 response = git_lfs_app.post_json(
84 84 '/repo/info/lfs/objects/batch', params={'operation': 'download'},
85 85 status=400, extra_environ=http_auth)
86 86 assert json.loads(response.text) == {
87 u'message': u'missing objects data'}
87 'message': 'missing objects data'}
88 88
89 89 def test_app_batch_api_unsupported_data_in_objects(
90 90 self, git_lfs_app, http_auth):
91 91 params = {'operation': 'download',
92 92 'objects': [{}]}
93 93 response = git_lfs_app.post_json(
94 94 '/repo/info/lfs/objects/batch', params=params, status=400,
95 95 extra_environ=http_auth)
96 96 assert json.loads(response.text) == {
97 u'message': u'unsupported data in objects'}
97 'message': 'unsupported data in objects'}
98 98
99 99 def test_app_batch_api_download_missing_object(
100 100 self, git_lfs_app, http_auth):
101 101 params = {'operation': 'download',
102 102 'objects': [{'oid': '123', 'size': '1024'}]}
103 103 response = git_lfs_app.post_json(
104 104 '/repo/info/lfs/objects/batch', params=params,
105 105 extra_environ=http_auth)
106 106
107 107 expected_objects = [
108 {u'authenticated': True,
109 u'errors': {u'error': {
110 u'code': 404,
111 u'message': u'object: 123 does not exist in store'}},
112 u'oid': u'123',
113 u'size': u'1024'}
108 {'authenticated': True,
109 'errors': {'error': {
110 'code': 404,
111 'message': 'object: 123 does not exist in store'}},
112 'oid': '123',
113 'size': '1024'}
114 114 ]
115 115 assert json.loads(response.text) == {
116 116 'objects': expected_objects, 'transfer': 'basic'}
117 117
118 118 def test_app_batch_api_download(self, git_lfs_app, http_auth):
119 119 oid = '456'
120 120 oid_path = os.path.join(git_lfs_app._store, oid)
121 121 if not os.path.isdir(os.path.dirname(oid_path)):
122 122 os.makedirs(os.path.dirname(oid_path))
123 123 with open(oid_path, 'wb') as f:
124 124 f.write('OID_CONTENT')
125 125
126 126 params = {'operation': 'download',
127 127 'objects': [{'oid': oid, 'size': '1024'}]}
128 128 response = git_lfs_app.post_json(
129 129 '/repo/info/lfs/objects/batch', params=params,
130 130 extra_environ=http_auth)
131 131
132 132 expected_objects = [
133 {u'authenticated': True,
134 u'actions': {
135 u'download': {
136 u'header': {u'Authorization': u'Basic XXXXX'},
137 u'href': u'http://localhost/repo/info/lfs/objects/456'},
133 {'authenticated': True,
134 'actions': {
135 'download': {
136 'header': {'Authorization': 'Basic XXXXX'},
137 'href': 'http://localhost/repo/info/lfs/objects/456'},
138 138 },
139 u'oid': u'456',
140 u'size': u'1024'}
139 'oid': '456',
140 'size': '1024'}
141 141 ]
142 142 assert json.loads(response.text) == {
143 143 'objects': expected_objects, 'transfer': 'basic'}
144 144
145 145 def test_app_batch_api_upload(self, git_lfs_app, http_auth):
146 146 params = {'operation': 'upload',
147 147 'objects': [{'oid': '123', 'size': '1024'}]}
148 148 response = git_lfs_app.post_json(
149 149 '/repo/info/lfs/objects/batch', params=params,
150 150 extra_environ=http_auth)
151 151 expected_objects = [
152 {u'authenticated': True,
153 u'actions': {
154 u'upload': {
155 u'header': {u'Authorization': u'Basic XXXXX',
156 u'Transfer-Encoding': u'chunked'},
157 u'href': u'http://localhost/repo/info/lfs/objects/123'},
158 u'verify': {
159 u'header': {u'Authorization': u'Basic XXXXX'},
160 u'href': u'http://localhost/repo/info/lfs/verify'}
152 {'authenticated': True,
153 'actions': {
154 'upload': {
155 'header': {'Authorization': 'Basic XXXXX',
156 'Transfer-Encoding': 'chunked'},
157 'href': 'http://localhost/repo/info/lfs/objects/123'},
158 'verify': {
159 'header': {'Authorization': 'Basic XXXXX'},
160 'href': 'http://localhost/repo/info/lfs/verify'}
161 161 },
162 u'oid': u'123',
163 u'size': u'1024'}
162 'oid': '123',
163 'size': '1024'}
164 164 ]
165 165 assert json.loads(response.text) == {
166 166 'objects': expected_objects, 'transfer': 'basic'}
167 167
168 168 def test_app_batch_api_upload_for_https(self, git_lfs_https_app, http_auth):
169 169 params = {'operation': 'upload',
170 170 'objects': [{'oid': '123', 'size': '1024'}]}
171 171 response = git_lfs_https_app.post_json(
172 172 '/repo/info/lfs/objects/batch', params=params,
173 173 extra_environ=http_auth)
174 174 expected_objects = [
175 {u'authenticated': True,
176 u'actions': {
177 u'upload': {
178 u'header': {u'Authorization': u'Basic XXXXX',
179 u'Transfer-Encoding': u'chunked'},
180 u'href': u'https://localhost/repo/info/lfs/objects/123'},
181 u'verify': {
182 u'header': {u'Authorization': u'Basic XXXXX'},
183 u'href': u'https://localhost/repo/info/lfs/verify'}
175 {'authenticated': True,
176 'actions': {
177 'upload': {
178 'header': {'Authorization': 'Basic XXXXX',
179 'Transfer-Encoding': 'chunked'},
180 'href': 'https://localhost/repo/info/lfs/objects/123'},
181 'verify': {
182 'header': {'Authorization': 'Basic XXXXX'},
183 'href': 'https://localhost/repo/info/lfs/verify'}
184 184 },
185 u'oid': u'123',
186 u'size': u'1024'}
185 'oid': '123',
186 'size': '1024'}
187 187 ]
188 188 assert json.loads(response.text) == {
189 189 'objects': expected_objects, 'transfer': 'basic'}
190 190
191 191 def test_app_verify_api_missing_data(self, git_lfs_app):
192 192 params = {'oid': 'missing'}
193 193 response = git_lfs_app.post_json(
194 194 '/repo/info/lfs/verify', params=params,
195 195 status=400)
196 196
197 197 assert json.loads(response.text) == {
198 u'message': u'missing oid and size in request data'}
198 'message': 'missing oid and size in request data'}
199 199
200 200 def test_app_verify_api_missing_obj(self, git_lfs_app):
201 201 params = {'oid': 'missing', 'size': '1024'}
202 202 response = git_lfs_app.post_json(
203 203 '/repo/info/lfs/verify', params=params,
204 204 status=404)
205 205
206 206 assert json.loads(response.text) == {
207 u'message': u'oid `missing` does not exists in store'}
207 'message': 'oid `missing` does not exists in store'}
208 208
209 209 def test_app_verify_api_size_mismatch(self, git_lfs_app):
210 210 oid = 'existing'
211 211 oid_path = os.path.join(git_lfs_app._store, oid)
212 212 if not os.path.isdir(os.path.dirname(oid_path)):
213 213 os.makedirs(os.path.dirname(oid_path))
214 214 with open(oid_path, 'wb') as f:
215 215 f.write('OID_CONTENT')
216 216
217 217 params = {'oid': oid, 'size': '1024'}
218 218 response = git_lfs_app.post_json(
219 219 '/repo/info/lfs/verify', params=params, status=422)
220 220
221 221 assert json.loads(response.text) == {
222 u'message': u'requested file size mismatch '
223 u'store size:11 requested:1024'}
222 'message': 'requested file size mismatch '
223 'store size:11 requested:1024'}
224 224
225 225 def test_app_verify_api(self, git_lfs_app):
226 226 oid = 'existing'
227 227 oid_path = os.path.join(git_lfs_app._store, oid)
228 228 if not os.path.isdir(os.path.dirname(oid_path)):
229 229 os.makedirs(os.path.dirname(oid_path))
230 230 with open(oid_path, 'wb') as f:
231 231 f.write('OID_CONTENT')
232 232
233 233 params = {'oid': oid, 'size': 11}
234 234 response = git_lfs_app.post_json(
235 235 '/repo/info/lfs/verify', params=params)
236 236
237 237 assert json.loads(response.text) == {
238 u'message': {u'size': u'ok', u'in_store': u'ok'}}
238 'message': {'size': 'ok', 'in_store': 'ok'}}
239 239
240 240 def test_app_download_api_oid_not_existing(self, git_lfs_app):
241 241 oid = 'missing'
242 242
243 243 response = git_lfs_app.get(
244 244 '/repo/info/lfs/objects/{oid}'.format(oid=oid), status=404)
245 245
246 246 assert json.loads(response.text) == {
247 u'message': u'requested file with oid `missing` not found in store'}
247 'message': 'requested file with oid `missing` not found in store'}
248 248
249 249 def test_app_download_api(self, git_lfs_app):
250 250 oid = 'existing'
251 251 oid_path = os.path.join(git_lfs_app._store, oid)
252 252 if not os.path.isdir(os.path.dirname(oid_path)):
253 253 os.makedirs(os.path.dirname(oid_path))
254 254 with open(oid_path, 'wb') as f:
255 255 f.write('OID_CONTENT')
256 256
257 257 response = git_lfs_app.get(
258 258 '/repo/info/lfs/objects/{oid}'.format(oid=oid))
259 259 assert response
260 260
261 261 def test_app_upload(self, git_lfs_app):
262 262 oid = 'uploaded'
263 263
264 264 response = git_lfs_app.put(
265 265 '/repo/info/lfs/objects/{oid}'.format(oid=oid), params='CONTENT')
266 266
267 assert json.loads(response.text) == {u'upload': u'ok'}
267 assert json.loads(response.text) == {'upload': 'ok'}
268 268
269 269 # verify that we actually wrote that OID
270 270 oid_path = os.path.join(git_lfs_app._store, oid)
271 271 assert os.path.isfile(oid_path)
272 272 assert 'CONTENT' == open(oid_path).read()
@@ -1,84 +1,84 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """
19 19 Mercurial libs compatibility
20 20 """
21 21
22 22 import mercurial
23 23 from mercurial import demandimport
24 24
25 25 # patch demandimport, due to bug in mercurial when it always triggers
26 26 # demandimport.enable()
27 27 demandimport.enable = lambda *args, **kwargs: 1
28 28
29 29 from mercurial import ui
30 30 from mercurial import patch
31 31 from mercurial import config
32 32 from mercurial import extensions
33 33 from mercurial import scmutil
34 34 from mercurial import archival
35 35 from mercurial import discovery
36 36 from mercurial import unionrepo
37 37 from mercurial import localrepo
38 38 from mercurial import merge as hg_merge
39 39 from mercurial import subrepo
40 40 from mercurial import subrepoutil
41 41 from mercurial import tags as hg_tag
42 42 from mercurial import util as hgutil
43 43 from mercurial.commands import clone, pull
44 44 from mercurial.node import nullid
45 45 from mercurial.context import memctx, memfilectx
46 46 from mercurial.error import (
47 47 LookupError, RepoError, RepoLookupError, Abort, InterventionRequired,
48 48 RequirementError, ProgrammingError)
49 49 from mercurial.hgweb import hgweb_mod
50 50 from mercurial.localrepo import instance
51 51 from mercurial.match import match, alwaysmatcher, patternmatcher
52 52 from mercurial.mdiff import diffopts
53 53 from mercurial.node import bin, hex
54 54 from mercurial.encoding import tolocal
55 55 from mercurial.discovery import findcommonoutgoing
56 56 from mercurial.hg import peer
57 57 from mercurial.httppeer import makepeer
58 58 from mercurial.utils.urlutil import url as hg_url
59 59 from mercurial.scmutil import revrange, revsymbol
60 60 from mercurial.node import nullrev
61 61 from mercurial import exchange
62 62 from hgext import largefiles
63 63
64 64 # those authnadlers are patched for python 2.6.5 bug an
65 65 # infinit looping when given invalid resources
66 66 from mercurial.url import httpbasicauthhandler, httpdigestauthhandler
67 67
68 68 # hg strip is in core now
69 69 from mercurial import strip as hgext_strip
70 70
71 71
72 72 def get_ctx(repo, ref):
73 73 try:
74 74 ctx = repo[ref]
75 75 except (ProgrammingError, TypeError):
76 76 # we're unable to find the rev using a regular lookup, we fallback
77 77 # to slower, but backward compat revsymbol usage
78 78 ctx = revsymbol(repo, ref)
79 79 except (LookupError, RepoLookupError):
80 80 # Similar case as above but only for refs that are not numeric
81 if isinstance(ref, (int, long)):
81 if isinstance(ref, int):
82 82 raise
83 83 ctx = revsymbol(repo, ref)
84 84 return ctx
@@ -1,729 +1,729 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2020 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20 import io
21 21 import os
22 22 import sys
23 23 import logging
24 24 import collections
25 25 import importlib
26 26 import base64
27 27
28 28 from http.client import HTTPConnection
29 29
30 30
31 31 import mercurial.scmutil
32 32 import mercurial.node
33 33 import simplejson as json
34 34
35 35 from vcsserver import exceptions, subprocessio, settings
36 36
37 37 log = logging.getLogger(__name__)
38 38
39 39
40 40 class HooksHttpClient(object):
41 41 connection = None
42 42
43 43 def __init__(self, hooks_uri):
44 44 self.hooks_uri = hooks_uri
45 45
46 46 def __call__(self, method, extras):
47 47 connection = HTTPConnection(self.hooks_uri)
48 48 body = self._serialize(method, extras)
49 49 try:
50 50 connection.request('POST', '/', body)
51 51 except Exception:
52 52 log.error('Hooks calling Connection failed on %s', connection.__dict__)
53 53 raise
54 54 response = connection.getresponse()
55 55
56 56 response_data = response.read()
57 57
58 58 try:
59 59 return json.loads(response_data)
60 60 except Exception:
61 61 log.exception('Failed to decode hook response json data. '
62 62 'response_code:%s, raw_data:%s',
63 63 response.status, response_data)
64 64 raise
65 65
66 66 def _serialize(self, hook_name, extras):
67 67 data = {
68 68 'method': hook_name,
69 69 'extras': extras
70 70 }
71 71 return json.dumps(data)
72 72
73 73
74 74 class HooksDummyClient(object):
75 75 def __init__(self, hooks_module):
76 76 self._hooks_module = importlib.import_module(hooks_module)
77 77
78 78 def __call__(self, hook_name, extras):
79 79 with self._hooks_module.Hooks() as hooks:
80 80 return getattr(hooks, hook_name)(extras)
81 81
82 82
83 83 class HooksShadowRepoClient(object):
84 84
85 85 def __call__(self, hook_name, extras):
86 86 return {'output': '', 'status': 0}
87 87
88 88
89 89 class RemoteMessageWriter(object):
90 90 """Writer base class."""
91 91 def write(self, message):
92 92 raise NotImplementedError()
93 93
94 94
95 95 class HgMessageWriter(RemoteMessageWriter):
96 96 """Writer that knows how to send messages to mercurial clients."""
97 97
98 98 def __init__(self, ui):
99 99 self.ui = ui
100 100
101 101 def write(self, message):
102 102 # TODO: Check why the quiet flag is set by default.
103 103 old = self.ui.quiet
104 104 self.ui.quiet = False
105 105 self.ui.status(message.encode('utf-8'))
106 106 self.ui.quiet = old
107 107
108 108
109 109 class GitMessageWriter(RemoteMessageWriter):
110 110 """Writer that knows how to send messages to git clients."""
111 111
112 112 def __init__(self, stdout=None):
113 113 self.stdout = stdout or sys.stdout
114 114
115 115 def write(self, message):
116 116 self.stdout.write(message.encode('utf-8'))
117 117
118 118
119 119 class SvnMessageWriter(RemoteMessageWriter):
120 120 """Writer that knows how to send messages to svn clients."""
121 121
122 122 def __init__(self, stderr=None):
123 123 # SVN needs data sent to stderr for back-to-client messaging
124 124 self.stderr = stderr or sys.stderr
125 125
126 126 def write(self, message):
127 127 self.stderr.write(message.encode('utf-8'))
128 128
129 129
130 130 def _handle_exception(result):
131 131 exception_class = result.get('exception')
132 132 exception_traceback = result.get('exception_traceback')
133 133
134 134 if exception_traceback:
135 135 log.error('Got traceback from remote call:%s', exception_traceback)
136 136
137 137 if exception_class == 'HTTPLockedRC':
138 138 raise exceptions.RepositoryLockedException()(*result['exception_args'])
139 139 elif exception_class == 'HTTPBranchProtected':
140 140 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
141 141 elif exception_class == 'RepositoryError':
142 142 raise exceptions.VcsException()(*result['exception_args'])
143 143 elif exception_class:
144 144 raise Exception('Got remote exception "%s" with args "%s"' %
145 145 (exception_class, result['exception_args']))
146 146
147 147
148 148 def _get_hooks_client(extras):
149 149 hooks_uri = extras.get('hooks_uri')
150 150 is_shadow_repo = extras.get('is_shadow_repo')
151 151 if hooks_uri:
152 152 return HooksHttpClient(extras['hooks_uri'])
153 153 elif is_shadow_repo:
154 154 return HooksShadowRepoClient()
155 155 else:
156 156 return HooksDummyClient(extras['hooks_module'])
157 157
158 158
159 159 def _call_hook(hook_name, extras, writer):
160 160 hooks_client = _get_hooks_client(extras)
161 161 log.debug('Hooks, using client:%s', hooks_client)
162 162 result = hooks_client(hook_name, extras)
163 163 log.debug('Hooks got result: %s', result)
164 164
165 165 _handle_exception(result)
166 166 writer.write(result['output'])
167 167
168 168 return result['status']
169 169
170 170
171 171 def _extras_from_ui(ui):
172 172 hook_data = ui.config('rhodecode', 'RC_SCM_DATA')
173 173 if not hook_data:
174 174 # maybe it's inside environ ?
175 175 env_hook_data = os.environ.get('RC_SCM_DATA')
176 176 if env_hook_data:
177 177 hook_data = env_hook_data
178 178
179 179 extras = {}
180 180 if hook_data:
181 181 extras = json.loads(hook_data)
182 182 return extras
183 183
184 184
185 185 def _rev_range_hash(repo, node, check_heads=False):
186 186 from vcsserver.hgcompat import get_ctx
187 187
188 188 commits = []
189 189 revs = []
190 190 start = get_ctx(repo, node).rev()
191 191 end = len(repo)
192 192 for rev in range(start, end):
193 193 revs.append(rev)
194 194 ctx = get_ctx(repo, rev)
195 195 commit_id = mercurial.node.hex(ctx.node())
196 196 branch = ctx.branch()
197 197 commits.append((commit_id, branch))
198 198
199 199 parent_heads = []
200 200 if check_heads:
201 201 parent_heads = _check_heads(repo, start, end, revs)
202 202 return commits, parent_heads
203 203
204 204
205 205 def _check_heads(repo, start, end, commits):
206 206 from vcsserver.hgcompat import get_ctx
207 207 changelog = repo.changelog
208 208 parents = set()
209 209
210 210 for new_rev in commits:
211 211 for p in changelog.parentrevs(new_rev):
212 212 if p == mercurial.node.nullrev:
213 213 continue
214 214 if p < start:
215 215 parents.add(p)
216 216
217 217 for p in parents:
218 218 branch = get_ctx(repo, p).branch()
219 219 # The heads descending from that parent, on the same branch
220 220 parent_heads = set([p])
221 221 reachable = set([p])
222 222 for x in range(p + 1, end):
223 223 if get_ctx(repo, x).branch() != branch:
224 224 continue
225 225 for pp in changelog.parentrevs(x):
226 226 if pp in reachable:
227 227 reachable.add(x)
228 228 parent_heads.discard(pp)
229 229 parent_heads.add(x)
230 230 # More than one head? Suggest merging
231 231 if len(parent_heads) > 1:
232 232 return list(parent_heads)
233 233
234 234 return []
235 235
236 236
237 237 def _get_git_env():
238 238 env = {}
239 239 for k, v in os.environ.items():
240 240 if k.startswith('GIT'):
241 241 env[k] = v
242 242
243 243 # serialized version
244 244 return [(k, v) for k, v in env.items()]
245 245
246 246
247 247 def _get_hg_env(old_rev, new_rev, txnid, repo_path):
248 248 env = {}
249 249 for k, v in os.environ.items():
250 250 if k.startswith('HG'):
251 251 env[k] = v
252 252
253 253 env['HG_NODE'] = old_rev
254 254 env['HG_NODE_LAST'] = new_rev
255 255 env['HG_TXNID'] = txnid
256 256 env['HG_PENDING'] = repo_path
257 257
258 258 return [(k, v) for k, v in env.items()]
259 259
260 260
261 261 def repo_size(ui, repo, **kwargs):
262 262 extras = _extras_from_ui(ui)
263 263 return _call_hook('repo_size', extras, HgMessageWriter(ui))
264 264
265 265
266 266 def pre_pull(ui, repo, **kwargs):
267 267 extras = _extras_from_ui(ui)
268 268 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
269 269
270 270
271 271 def pre_pull_ssh(ui, repo, **kwargs):
272 272 extras = _extras_from_ui(ui)
273 273 if extras and extras.get('SSH'):
274 274 return pre_pull(ui, repo, **kwargs)
275 275 return 0
276 276
277 277
278 278 def post_pull(ui, repo, **kwargs):
279 279 extras = _extras_from_ui(ui)
280 280 return _call_hook('post_pull', extras, HgMessageWriter(ui))
281 281
282 282
283 283 def post_pull_ssh(ui, repo, **kwargs):
284 284 extras = _extras_from_ui(ui)
285 285 if extras and extras.get('SSH'):
286 286 return post_pull(ui, repo, **kwargs)
287 287 return 0
288 288
289 289
290 290 def pre_push(ui, repo, node=None, **kwargs):
291 291 """
292 292 Mercurial pre_push hook
293 293 """
294 294 extras = _extras_from_ui(ui)
295 295 detect_force_push = extras.get('detect_force_push')
296 296
297 297 rev_data = []
298 298 if node and kwargs.get('hooktype') == 'pretxnchangegroup':
299 299 branches = collections.defaultdict(list)
300 300 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
301 301 for commit_id, branch in commits:
302 302 branches[branch].append(commit_id)
303 303
304 304 for branch, commits in branches.items():
305 305 old_rev = kwargs.get('node_last') or commits[0]
306 306 rev_data.append({
307 307 'total_commits': len(commits),
308 308 'old_rev': old_rev,
309 309 'new_rev': commits[-1],
310 310 'ref': '',
311 311 'type': 'branch',
312 312 'name': branch,
313 313 })
314 314
315 315 for push_ref in rev_data:
316 316 push_ref['multiple_heads'] = _heads
317 317
318 318 repo_path = os.path.join(
319 319 extras.get('repo_store', ''), extras.get('repository', ''))
320 320 push_ref['hg_env'] = _get_hg_env(
321 321 old_rev=push_ref['old_rev'],
322 322 new_rev=push_ref['new_rev'], txnid=kwargs.get('txnid'),
323 323 repo_path=repo_path)
324 324
325 325 extras['hook_type'] = kwargs.get('hooktype', 'pre_push')
326 326 extras['commit_ids'] = rev_data
327 327
328 328 return _call_hook('pre_push', extras, HgMessageWriter(ui))
329 329
330 330
331 331 def pre_push_ssh(ui, repo, node=None, **kwargs):
332 332 extras = _extras_from_ui(ui)
333 333 if extras.get('SSH'):
334 334 return pre_push(ui, repo, node, **kwargs)
335 335
336 336 return 0
337 337
338 338
339 339 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
340 340 """
341 341 Mercurial pre_push hook for SSH
342 342 """
343 343 extras = _extras_from_ui(ui)
344 344 if extras.get('SSH'):
345 345 permission = extras['SSH_PERMISSIONS']
346 346
347 347 if 'repository.write' == permission or 'repository.admin' == permission:
348 348 return 0
349 349
350 350 # non-zero ret code
351 351 return 1
352 352
353 353 return 0
354 354
355 355
356 356 def post_push(ui, repo, node, **kwargs):
357 357 """
358 358 Mercurial post_push hook
359 359 """
360 360 extras = _extras_from_ui(ui)
361 361
362 362 commit_ids = []
363 363 branches = []
364 364 bookmarks = []
365 365 tags = []
366 366
367 367 commits, _heads = _rev_range_hash(repo, node)
368 368 for commit_id, branch in commits:
369 369 commit_ids.append(commit_id)
370 370 if branch not in branches:
371 371 branches.append(branch)
372 372
373 373 if hasattr(ui, '_rc_pushkey_branches'):
374 374 bookmarks = ui._rc_pushkey_branches
375 375
376 376 extras['hook_type'] = kwargs.get('hooktype', 'post_push')
377 377 extras['commit_ids'] = commit_ids
378 378 extras['new_refs'] = {
379 379 'branches': branches,
380 380 'bookmarks': bookmarks,
381 381 'tags': tags
382 382 }
383 383
384 384 return _call_hook('post_push', extras, HgMessageWriter(ui))
385 385
386 386
387 387 def post_push_ssh(ui, repo, node, **kwargs):
388 388 """
389 389 Mercurial post_push hook for SSH
390 390 """
391 391 if _extras_from_ui(ui).get('SSH'):
392 392 return post_push(ui, repo, node, **kwargs)
393 393 return 0
394 394
395 395
396 396 def key_push(ui, repo, **kwargs):
397 397 from vcsserver.hgcompat import get_ctx
398 398 if kwargs['new'] != '0' and kwargs['namespace'] == 'bookmarks':
399 399 # store new bookmarks in our UI object propagated later to post_push
400 400 ui._rc_pushkey_branches = get_ctx(repo, kwargs['key']).bookmarks()
401 401 return
402 402
403 403
404 404 # backward compat
405 405 log_pull_action = post_pull
406 406
407 407 # backward compat
408 408 log_push_action = post_push
409 409
410 410
411 411 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
412 412 """
413 413 Old hook name: keep here for backward compatibility.
414 414
415 415 This is only required when the installed git hooks are not upgraded.
416 416 """
417 417 pass
418 418
419 419
420 420 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
421 421 """
422 422 Old hook name: keep here for backward compatibility.
423 423
424 424 This is only required when the installed git hooks are not upgraded.
425 425 """
426 426 pass
427 427
428 428
429 429 HookResponse = collections.namedtuple('HookResponse', ('status', 'output'))
430 430
431 431
432 432 def git_pre_pull(extras):
433 433 """
434 434 Pre pull hook.
435 435
436 436 :param extras: dictionary containing the keys defined in simplevcs
437 437 :type extras: dict
438 438
439 439 :return: status code of the hook. 0 for success.
440 440 :rtype: int
441 441 """
442 442 if 'pull' not in extras['hooks']:
443 443 return HookResponse(0, '')
444 444
445 445 stdout = io.BytesIO()
446 446 try:
447 447 status = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
448 448 except Exception as error:
449 449 status = 128
450 450 stdout.write('ERROR: %s\n' % str(error))
451 451
452 452 return HookResponse(status, stdout.getvalue())
453 453
454 454
455 455 def git_post_pull(extras):
456 456 """
457 457 Post pull hook.
458 458
459 459 :param extras: dictionary containing the keys defined in simplevcs
460 460 :type extras: dict
461 461
462 462 :return: status code of the hook. 0 for success.
463 463 :rtype: int
464 464 """
465 465 if 'pull' not in extras['hooks']:
466 466 return HookResponse(0, '')
467 467
468 468 stdout = io.BytesIO()
469 469 try:
470 470 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
471 471 except Exception as error:
472 472 status = 128
473 473 stdout.write('ERROR: %s\n' % error)
474 474
475 475 return HookResponse(status, stdout.getvalue())
476 476
477 477
478 478 def _parse_git_ref_lines(revision_lines):
479 479 rev_data = []
480 480 for revision_line in revision_lines or []:
481 481 old_rev, new_rev, ref = revision_line.strip().split(' ')
482 482 ref_data = ref.split('/', 2)
483 483 if ref_data[1] in ('tags', 'heads'):
484 484 rev_data.append({
485 485 # NOTE(marcink):
486 486 # we're unable to tell total_commits for git at this point
487 487 # but we set the variable for consistency with GIT
488 488 'total_commits': -1,
489 489 'old_rev': old_rev,
490 490 'new_rev': new_rev,
491 491 'ref': ref,
492 492 'type': ref_data[1],
493 493 'name': ref_data[2],
494 494 })
495 495 return rev_data
496 496
497 497
498 498 def git_pre_receive(unused_repo_path, revision_lines, env):
499 499 """
500 500 Pre push hook.
501 501
502 502 :param extras: dictionary containing the keys defined in simplevcs
503 503 :type extras: dict
504 504
505 505 :return: status code of the hook. 0 for success.
506 506 :rtype: int
507 507 """
508 508 extras = json.loads(env['RC_SCM_DATA'])
509 509 rev_data = _parse_git_ref_lines(revision_lines)
510 510 if 'push' not in extras['hooks']:
511 511 return 0
512 512 empty_commit_id = '0' * 40
513 513
514 514 detect_force_push = extras.get('detect_force_push')
515 515
516 516 for push_ref in rev_data:
517 517 # store our git-env which holds the temp store
518 518 push_ref['git_env'] = _get_git_env()
519 519 push_ref['pruned_sha'] = ''
520 520 if not detect_force_push:
521 521 # don't check for forced-push when we don't need to
522 522 continue
523 523
524 524 type_ = push_ref['type']
525 525 new_branch = push_ref['old_rev'] == empty_commit_id
526 526 delete_branch = push_ref['new_rev'] == empty_commit_id
527 527 if type_ == 'heads' and not (new_branch or delete_branch):
528 528 old_rev = push_ref['old_rev']
529 529 new_rev = push_ref['new_rev']
530 530 cmd = [settings.GIT_EXECUTABLE, 'rev-list', old_rev, '^{}'.format(new_rev)]
531 531 stdout, stderr = subprocessio.run_command(
532 532 cmd, env=os.environ.copy())
533 533 # means we're having some non-reachable objects, this forced push was used
534 534 if stdout:
535 535 push_ref['pruned_sha'] = stdout.splitlines()
536 536
537 537 extras['hook_type'] = 'pre_receive'
538 538 extras['commit_ids'] = rev_data
539 539 return _call_hook('pre_push', extras, GitMessageWriter())
540 540
541 541
542 542 def git_post_receive(unused_repo_path, revision_lines, env):
543 543 """
544 544 Post push hook.
545 545
546 546 :param extras: dictionary containing the keys defined in simplevcs
547 547 :type extras: dict
548 548
549 549 :return: status code of the hook. 0 for success.
550 550 :rtype: int
551 551 """
552 552 extras = json.loads(env['RC_SCM_DATA'])
553 553 if 'push' not in extras['hooks']:
554 554 return 0
555 555
556 556 rev_data = _parse_git_ref_lines(revision_lines)
557 557
558 558 git_revs = []
559 559
560 560 # N.B.(skreft): it is ok to just call git, as git before calling a
561 561 # subcommand sets the PATH environment variable so that it point to the
562 562 # correct version of the git executable.
563 563 empty_commit_id = '0' * 40
564 564 branches = []
565 565 tags = []
566 566 for push_ref in rev_data:
567 567 type_ = push_ref['type']
568 568
569 569 if type_ == 'heads':
570 570 if push_ref['old_rev'] == empty_commit_id:
571 571 # starting new branch case
572 572 if push_ref['name'] not in branches:
573 573 branches.append(push_ref['name'])
574 574
575 575 # Fix up head revision if needed
576 576 cmd = [settings.GIT_EXECUTABLE, 'show', 'HEAD']
577 577 try:
578 578 subprocessio.run_command(cmd, env=os.environ.copy())
579 579 except Exception:
580 580 cmd = [settings.GIT_EXECUTABLE, 'symbolic-ref', '"HEAD"',
581 581 '"refs/heads/%s"' % push_ref['name']]
582 print("Setting default branch to %s" % push_ref['name'])
582 print(("Setting default branch to %s" % push_ref['name']))
583 583 subprocessio.run_command(cmd, env=os.environ.copy())
584 584
585 585 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref',
586 586 '--format=%(refname)', 'refs/heads/*']
587 587 stdout, stderr = subprocessio.run_command(
588 588 cmd, env=os.environ.copy())
589 589 heads = stdout
590 590 heads = heads.replace(push_ref['ref'], '')
591 591 heads = ' '.join(head for head
592 592 in heads.splitlines() if head) or '.'
593 593 cmd = [settings.GIT_EXECUTABLE, 'log', '--reverse',
594 594 '--pretty=format:%H', '--', push_ref['new_rev'],
595 595 '--not', heads]
596 596 stdout, stderr = subprocessio.run_command(
597 597 cmd, env=os.environ.copy())
598 598 git_revs.extend(stdout.splitlines())
599 599 elif push_ref['new_rev'] == empty_commit_id:
600 600 # delete branch case
601 601 git_revs.append('delete_branch=>%s' % push_ref['name'])
602 602 else:
603 603 if push_ref['name'] not in branches:
604 604 branches.append(push_ref['name'])
605 605
606 606 cmd = [settings.GIT_EXECUTABLE, 'log',
607 607 '{old_rev}..{new_rev}'.format(**push_ref),
608 608 '--reverse', '--pretty=format:%H']
609 609 stdout, stderr = subprocessio.run_command(
610 610 cmd, env=os.environ.copy())
611 611 git_revs.extend(stdout.splitlines())
612 612 elif type_ == 'tags':
613 613 if push_ref['name'] not in tags:
614 614 tags.append(push_ref['name'])
615 615 git_revs.append('tag=>%s' % push_ref['name'])
616 616
617 617 extras['hook_type'] = 'post_receive'
618 618 extras['commit_ids'] = git_revs
619 619 extras['new_refs'] = {
620 620 'branches': branches,
621 621 'bookmarks': [],
622 622 'tags': tags,
623 623 }
624 624
625 625 if 'repo_size' in extras['hooks']:
626 626 try:
627 627 _call_hook('repo_size', extras, GitMessageWriter())
628 628 except:
629 629 pass
630 630
631 631 return _call_hook('post_push', extras, GitMessageWriter())
632 632
633 633
634 634 def _get_extras_from_txn_id(path, txn_id):
635 635 extras = {}
636 636 try:
637 637 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
638 638 '-t', txn_id,
639 639 '--revprop', path, 'rc-scm-extras']
640 640 stdout, stderr = subprocessio.run_command(
641 641 cmd, env=os.environ.copy())
642 642 extras = json.loads(base64.urlsafe_b64decode(stdout))
643 643 except Exception:
644 644 log.exception('Failed to extract extras info from txn_id')
645 645
646 646 return extras
647 647
648 648
649 649 def _get_extras_from_commit_id(commit_id, path):
650 650 extras = {}
651 651 try:
652 652 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
653 653 '-r', commit_id,
654 654 '--revprop', path, 'rc-scm-extras']
655 655 stdout, stderr = subprocessio.run_command(
656 656 cmd, env=os.environ.copy())
657 657 extras = json.loads(base64.urlsafe_b64decode(stdout))
658 658 except Exception:
659 659 log.exception('Failed to extract extras info from commit_id')
660 660
661 661 return extras
662 662
663 663
664 664 def svn_pre_commit(repo_path, commit_data, env):
665 665 path, txn_id = commit_data
666 666 branches = []
667 667 tags = []
668 668
669 669 if env.get('RC_SCM_DATA'):
670 670 extras = json.loads(env['RC_SCM_DATA'])
671 671 else:
672 672 # fallback method to read from TXN-ID stored data
673 673 extras = _get_extras_from_txn_id(path, txn_id)
674 674 if not extras:
675 675 return 0
676 676
677 677 extras['hook_type'] = 'pre_commit'
678 678 extras['commit_ids'] = [txn_id]
679 679 extras['txn_id'] = txn_id
680 680 extras['new_refs'] = {
681 681 'total_commits': 1,
682 682 'branches': branches,
683 683 'bookmarks': [],
684 684 'tags': tags,
685 685 }
686 686
687 687 return _call_hook('pre_push', extras, SvnMessageWriter())
688 688
689 689
690 690 def svn_post_commit(repo_path, commit_data, env):
691 691 """
692 692 commit_data is path, rev, txn_id
693 693 """
694 694 if len(commit_data) == 3:
695 695 path, commit_id, txn_id = commit_data
696 696 elif len(commit_data) == 2:
697 697 log.error('Failed to extract txn_id from commit_data using legacy method. '
698 698 'Some functionality might be limited')
699 699 path, commit_id = commit_data
700 700 txn_id = None
701 701
702 702 branches = []
703 703 tags = []
704 704
705 705 if env.get('RC_SCM_DATA'):
706 706 extras = json.loads(env['RC_SCM_DATA'])
707 707 else:
708 708 # fallback method to read from TXN-ID stored data
709 709 extras = _get_extras_from_commit_id(commit_id, path)
710 710 if not extras:
711 711 return 0
712 712
713 713 extras['hook_type'] = 'post_commit'
714 714 extras['commit_ids'] = [commit_id]
715 715 extras['txn_id'] = txn_id
716 716 extras['new_refs'] = {
717 717 'branches': branches,
718 718 'bookmarks': [],
719 719 'tags': tags,
720 720 'total_commits': 1,
721 721 }
722 722
723 723 if 'repo_size' in extras['hooks']:
724 724 try:
725 725 _call_hook('repo_size', extras, SvnMessageWriter())
726 726 except Exception:
727 727 pass
728 728
729 729 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,52 +1,52 b''
1 from __future__ import absolute_import, division, unicode_literals
1
2 2
3 3 import logging
4 4
5 5 from .stream import TCPStatsClient, UnixSocketStatsClient # noqa
6 6 from .udp import StatsClient # noqa
7 7
8 8 HOST = 'localhost'
9 9 PORT = 8125
10 10 IPV6 = False
11 11 PREFIX = None
12 12 MAXUDPSIZE = 512
13 13
14 14 log = logging.getLogger('rhodecode.statsd')
15 15
16 16
17 17 def statsd_config(config, prefix='statsd.'):
18 18 _config = {}
19 19 for key in config.keys():
20 20 if key.startswith(prefix):
21 21 _config[key[len(prefix):]] = config[key]
22 22 return _config
23 23
24 24
25 25 def client_from_config(configuration, prefix='statsd.', **kwargs):
26 26 from pyramid.settings import asbool
27 27
28 28 _config = statsd_config(configuration, prefix)
29 29 statsd_enabled = asbool(_config.pop('enabled', False))
30 30 if not statsd_enabled:
31 31 log.debug('statsd client not enabled by statsd.enabled = flag, skipping...')
32 32 return
33 33
34 34 host = _config.pop('statsd_host', HOST)
35 35 port = _config.pop('statsd_port', PORT)
36 36 prefix = _config.pop('statsd_prefix', PREFIX)
37 37 maxudpsize = _config.pop('statsd_maxudpsize', MAXUDPSIZE)
38 38 ipv6 = asbool(_config.pop('statsd_ipv6', IPV6))
39 39 log.debug('configured statsd client %s:%s', host, port)
40 40
41 41 try:
42 42 client = StatsClient(
43 43 host=host, port=port, prefix=prefix, maxudpsize=maxudpsize, ipv6=ipv6)
44 44 except Exception:
45 45 log.exception('StatsD is enabled, but failed to connect to statsd server, fallback: disable statsd')
46 46 client = None
47 47
48 48 return client
49 49
50 50
51 51 def get_statsd_client(request):
52 52 return client_from_config(request.registry.settings)
@@ -1,156 +1,156 b''
1 from __future__ import absolute_import, division, unicode_literals
1
2 2
3 3 import re
4 4 import random
5 5 from collections import deque
6 6 from datetime import timedelta
7 7 from repoze.lru import lru_cache
8 8
9 9 from .timer import Timer
10 10
11 11 TAG_INVALID_CHARS_RE = re.compile(
12 12 r"[^\w\d_\-:/\.]",
13 13 #re.UNICODE
14 14 )
15 15 TAG_INVALID_CHARS_SUBS = "_"
16 16
17 17 # we save and expose methods called by statsd for discovery
18 18 buckets_dict = {
19 19
20 20 }
21 21
22 22
23 23 @lru_cache(maxsize=500)
24 24 def _normalize_tags_with_cache(tag_list):
25 25 return [TAG_INVALID_CHARS_RE.sub(TAG_INVALID_CHARS_SUBS, tag) for tag in tag_list]
26 26
27 27
28 28 def normalize_tags(tag_list):
29 29 # We have to turn our input tag list into a non-mutable tuple for it to
30 30 # be hashable (and thus usable) by the @lru_cache decorator.
31 31 return _normalize_tags_with_cache(tuple(tag_list))
32 32
33 33
34 34 class StatsClientBase(object):
35 35 """A Base class for various statsd clients."""
36 36
37 37 def close(self):
38 38 """Used to close and clean up any underlying resources."""
39 39 raise NotImplementedError()
40 40
41 41 def _send(self):
42 42 raise NotImplementedError()
43 43
44 44 def pipeline(self):
45 45 raise NotImplementedError()
46 46
47 47 def timer(self, stat, rate=1, tags=None, auto_send=True):
48 48 """
49 49 statsd = StatsdClient.statsd
50 50 with statsd.timer('bucket_name', auto_send=True) as tmr:
51 51 # This block will be timed.
52 52 for i in xrange(0, 100000):
53 53 i ** 2
54 54 # you can access time here...
55 55 elapsed_ms = tmr.ms
56 56 """
57 57 return Timer(self, stat, rate, tags, auto_send=auto_send)
58 58
59 59 def timing(self, stat, delta, rate=1, tags=None, use_decimals=True):
60 60 """
61 61 Send new timing information.
62 62
63 63 `delta` can be either a number of milliseconds or a timedelta.
64 64 """
65 65 if isinstance(delta, timedelta):
66 66 # Convert timedelta to number of milliseconds.
67 67 delta = delta.total_seconds() * 1000.
68 68 if use_decimals:
69 69 fmt = '%0.6f|ms'
70 70 else:
71 71 fmt = '%s|ms'
72 72 self._send_stat(stat, fmt % delta, rate, tags)
73 73
74 74 def incr(self, stat, count=1, rate=1, tags=None):
75 75 """Increment a stat by `count`."""
76 76 self._send_stat(stat, '%s|c' % count, rate, tags)
77 77
78 78 def decr(self, stat, count=1, rate=1, tags=None):
79 79 """Decrement a stat by `count`."""
80 80 self.incr(stat, -count, rate, tags)
81 81
82 82 def gauge(self, stat, value, rate=1, delta=False, tags=None):
83 83 """Set a gauge value."""
84 84 if value < 0 and not delta:
85 85 if rate < 1:
86 86 if random.random() > rate:
87 87 return
88 88 with self.pipeline() as pipe:
89 89 pipe._send_stat(stat, '0|g', 1)
90 90 pipe._send_stat(stat, '%s|g' % value, 1)
91 91 else:
92 92 prefix = '+' if delta and value >= 0 else ''
93 93 self._send_stat(stat, '%s%s|g' % (prefix, value), rate, tags)
94 94
95 95 def set(self, stat, value, rate=1):
96 96 """Set a set value."""
97 97 self._send_stat(stat, '%s|s' % value, rate)
98 98
99 99 def histogram(self, stat, value, rate=1, tags=None):
100 100 """Set a histogram"""
101 101 self._send_stat(stat, '%s|h' % value, rate, tags)
102 102
103 103 def _send_stat(self, stat, value, rate, tags=None):
104 104 self._after(self._prepare(stat, value, rate, tags))
105 105
106 106 def _prepare(self, stat, value, rate, tags=None):
107 107 global buckets_dict
108 108 buckets_dict[stat] = 1
109 109
110 110 if rate < 1:
111 111 if random.random() > rate:
112 112 return
113 113 value = '%s|@%s' % (value, rate)
114 114
115 115 if self._prefix:
116 116 stat = '%s.%s' % (self._prefix, stat)
117 117
118 118 res = '%s:%s%s' % (
119 119 stat,
120 120 value,
121 121 ("|#" + ",".join(normalize_tags(tags))) if tags else "",
122 122 )
123 123 return res
124 124
125 125 def _after(self, data):
126 126 if data:
127 127 self._send(data)
128 128
129 129
130 130 class PipelineBase(StatsClientBase):
131 131
132 132 def __init__(self, client):
133 133 self._client = client
134 134 self._prefix = client._prefix
135 135 self._stats = deque()
136 136
137 137 def _send(self):
138 138 raise NotImplementedError()
139 139
140 140 def _after(self, data):
141 141 if data is not None:
142 142 self._stats.append(data)
143 143
144 144 def __enter__(self):
145 145 return self
146 146
147 147 def __exit__(self, typ, value, tb):
148 148 self.send()
149 149
150 150 def send(self):
151 151 if not self._stats:
152 152 return
153 153 self._send()
154 154
155 155 def pipeline(self):
156 156 return self.__class__(self)
@@ -1,75 +1,75 b''
1 from __future__ import absolute_import, division, unicode_literals
1
2 2
3 3 import socket
4 4
5 5 from .base import StatsClientBase, PipelineBase
6 6
7 7
8 8 class StreamPipeline(PipelineBase):
9 9 def _send(self):
10 10 self._client._after('\n'.join(self._stats))
11 11 self._stats.clear()
12 12
13 13
14 14 class StreamClientBase(StatsClientBase):
15 15 def connect(self):
16 16 raise NotImplementedError()
17 17
18 18 def close(self):
19 19 if self._sock and hasattr(self._sock, 'close'):
20 20 self._sock.close()
21 21 self._sock = None
22 22
23 23 def reconnect(self):
24 24 self.close()
25 25 self.connect()
26 26
27 27 def pipeline(self):
28 28 return StreamPipeline(self)
29 29
30 30 def _send(self, data):
31 31 """Send data to statsd."""
32 32 if not self._sock:
33 33 self.connect()
34 34 self._do_send(data)
35 35
36 36 def _do_send(self, data):
37 37 self._sock.sendall(data.encode('ascii') + b'\n')
38 38
39 39
40 40 class TCPStatsClient(StreamClientBase):
41 41 """TCP version of StatsClient."""
42 42
43 43 def __init__(self, host='localhost', port=8125, prefix=None,
44 44 timeout=None, ipv6=False):
45 45 """Create a new client."""
46 46 self._host = host
47 47 self._port = port
48 48 self._ipv6 = ipv6
49 49 self._timeout = timeout
50 50 self._prefix = prefix
51 51 self._sock = None
52 52
53 53 def connect(self):
54 54 fam = socket.AF_INET6 if self._ipv6 else socket.AF_INET
55 55 family, _, _, _, addr = socket.getaddrinfo(
56 56 self._host, self._port, fam, socket.SOCK_STREAM)[0]
57 57 self._sock = socket.socket(family, socket.SOCK_STREAM)
58 58 self._sock.settimeout(self._timeout)
59 59 self._sock.connect(addr)
60 60
61 61
62 62 class UnixSocketStatsClient(StreamClientBase):
63 63 """Unix domain socket version of StatsClient."""
64 64
65 65 def __init__(self, socket_path, prefix=None, timeout=None):
66 66 """Create a new client."""
67 67 self._socket_path = socket_path
68 68 self._timeout = timeout
69 69 self._prefix = prefix
70 70 self._sock = None
71 71
72 72 def connect(self):
73 73 self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
74 74 self._sock.settimeout(self._timeout)
75 75 self._sock.connect(self._socket_path)
@@ -1,75 +1,75 b''
1 from __future__ import absolute_import, division, unicode_literals
1
2 2
3 3 import functools
4 4
5 5 # Use timer that's not susceptible to time of day adjustments.
6 6 try:
7 7 # perf_counter is only present on Py3.3+
8 8 from time import perf_counter as time_now
9 9 except ImportError:
10 10 # fall back to using time
11 11 from time import time as time_now
12 12
13 13
14 14 def safe_wraps(wrapper, *args, **kwargs):
15 15 """Safely wraps partial functions."""
16 16 while isinstance(wrapper, functools.partial):
17 17 wrapper = wrapper.func
18 18 return functools.wraps(wrapper, *args, **kwargs)
19 19
20 20
21 21 class Timer(object):
22 22 """A context manager/decorator for statsd.timing()."""
23 23
24 24 def __init__(self, client, stat, rate=1, tags=None, use_decimals=True, auto_send=True):
25 25 self.client = client
26 26 self.stat = stat
27 27 self.rate = rate
28 28 self.tags = tags
29 29 self.ms = None
30 30 self._sent = False
31 31 self._start_time = None
32 32 self.use_decimals = use_decimals
33 33 self.auto_send = auto_send
34 34
35 35 def __call__(self, f):
36 36 """Thread-safe timing function decorator."""
37 37 @safe_wraps(f)
38 38 def _wrapped(*args, **kwargs):
39 39 start_time = time_now()
40 40 try:
41 41 return f(*args, **kwargs)
42 42 finally:
43 43 elapsed_time_ms = 1000.0 * (time_now() - start_time)
44 44 self.client.timing(self.stat, elapsed_time_ms, self.rate, self.tags, self.use_decimals)
45 45 self._sent = True
46 46 return _wrapped
47 47
48 48 def __enter__(self):
49 49 return self.start()
50 50
51 51 def __exit__(self, typ, value, tb):
52 52 self.stop(send=self.auto_send)
53 53
54 54 def start(self):
55 55 self.ms = None
56 56 self._sent = False
57 57 self._start_time = time_now()
58 58 return self
59 59
60 60 def stop(self, send=True):
61 61 if self._start_time is None:
62 62 raise RuntimeError('Timer has not started.')
63 63 dt = time_now() - self._start_time
64 64 self.ms = 1000.0 * dt # Convert to milliseconds.
65 65 if send:
66 66 self.send()
67 67 return self
68 68
69 69 def send(self):
70 70 if self.ms is None:
71 71 raise RuntimeError('No data recorded.')
72 72 if self._sent:
73 73 raise RuntimeError('Already sent data.')
74 74 self._sent = True
75 75 self.client.timing(self.stat, self.ms, self.rate, self.tags, self.use_decimals)
@@ -1,55 +1,55 b''
1 from __future__ import absolute_import, division, unicode_literals
1
2 2
3 3 import socket
4 4
5 5 from .base import StatsClientBase, PipelineBase
6 6
7 7
8 8 class Pipeline(PipelineBase):
9 9
10 10 def __init__(self, client):
11 11 super(Pipeline, self).__init__(client)
12 12 self._maxudpsize = client._maxudpsize
13 13
14 14 def _send(self):
15 15 data = self._stats.popleft()
16 16 while self._stats:
17 17 # Use popleft to preserve the order of the stats.
18 18 stat = self._stats.popleft()
19 19 if len(stat) + len(data) + 1 >= self._maxudpsize:
20 20 self._client._after(data)
21 21 data = stat
22 22 else:
23 23 data += '\n' + stat
24 24 self._client._after(data)
25 25
26 26
27 27 class StatsClient(StatsClientBase):
28 28 """A client for statsd."""
29 29
30 30 def __init__(self, host='localhost', port=8125, prefix=None,
31 31 maxudpsize=512, ipv6=False):
32 32 """Create a new client."""
33 33 fam = socket.AF_INET6 if ipv6 else socket.AF_INET
34 34 family, _, _, _, addr = socket.getaddrinfo(
35 35 host, port, fam, socket.SOCK_DGRAM)[0]
36 36 self._addr = addr
37 37 self._sock = socket.socket(family, socket.SOCK_DGRAM)
38 38 self._prefix = prefix
39 39 self._maxudpsize = maxudpsize
40 40
41 41 def _send(self, data):
42 42 """Send data to statsd."""
43 43 try:
44 44 self._sock.sendto(data.encode('ascii'), self._addr)
45 45 except (socket.error, RuntimeError):
46 46 # No time for love, Dr. Jones!
47 47 pass
48 48
49 49 def close(self):
50 50 if self._sock and hasattr(self._sock, 'close'):
51 51 self._sock.close()
52 52 self._sock = None
53 53
54 54 def pipeline(self):
55 55 return Pipeline(self)
@@ -1,329 +1,329 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import time
19 19 import errno
20 20 import logging
21 21
22 22 import msgpack
23 23 import redis
24 24
25 25 from dogpile.cache.api import CachedValue
26 26 from dogpile.cache.backends import memory as memory_backend
27 27 from dogpile.cache.backends import file as file_backend
28 28 from dogpile.cache.backends import redis as redis_backend
29 29 from dogpile.cache.backends.file import NO_VALUE, FileLock
30 30 from dogpile.cache.util import memoized_property
31 31
32 32 from pyramid.settings import asbool
33 33
34 34 from vcsserver.lib.memory_lru_dict import LRUDict, LRUDictDebug
35 35 from vcsserver.utils import safe_str, safe_unicode
36 36
37 37
38 38 _default_max_size = 1024
39 39
40 40 log = logging.getLogger(__name__)
41 41
42 42
43 43 class LRUMemoryBackend(memory_backend.MemoryBackend):
44 44 key_prefix = 'lru_mem_backend'
45 45 pickle_values = False
46 46
47 47 def __init__(self, arguments):
48 48 max_size = arguments.pop('max_size', _default_max_size)
49 49
50 50 LRUDictClass = LRUDict
51 51 if arguments.pop('log_key_count', None):
52 52 LRUDictClass = LRUDictDebug
53 53
54 54 arguments['cache_dict'] = LRUDictClass(max_size)
55 55 super(LRUMemoryBackend, self).__init__(arguments)
56 56
57 57 def delete(self, key):
58 58 try:
59 59 del self._cache[key]
60 60 except KeyError:
61 61 # we don't care if key isn't there at deletion
62 62 pass
63 63
64 64 def delete_multi(self, keys):
65 65 for key in keys:
66 66 self.delete(key)
67 67
68 68
69 69 class PickleSerializer(object):
70 70
71 71 def _dumps(self, value, safe=False):
72 72 try:
73 73 return pickle.dumps(value)
74 74 except Exception:
75 75 if safe:
76 76 return NO_VALUE
77 77 else:
78 78 raise
79 79
80 80 def _loads(self, value, safe=True):
81 81 try:
82 82 return pickle.loads(value)
83 83 except Exception:
84 84 if safe:
85 85 return NO_VALUE
86 86 else:
87 87 raise
88 88
89 89
90 90 class MsgPackSerializer(object):
91 91
92 92 def _dumps(self, value, safe=False):
93 93 try:
94 94 return msgpack.packb(value)
95 95 except Exception:
96 96 if safe:
97 97 return NO_VALUE
98 98 else:
99 99 raise
100 100
101 101 def _loads(self, value, safe=True):
102 102 """
103 103 pickle maintained the `CachedValue` wrapper of the tuple
104 104 msgpack does not, so it must be added back in.
105 105 """
106 106 try:
107 107 value = msgpack.unpackb(value, use_list=False)
108 108 return CachedValue(*value)
109 109 except Exception:
110 110 if safe:
111 111 return NO_VALUE
112 112 else:
113 113 raise
114 114
115 115
116 116 import fcntl
117 117 flock_org = fcntl.flock
118 118
119 119
120 120 class CustomLockFactory(FileLock):
121 121
122 122 pass
123 123
124 124
125 125 class FileNamespaceBackend(PickleSerializer, file_backend.DBMBackend):
126 126 key_prefix = 'file_backend'
127 127
128 128 def __init__(self, arguments):
129 129 arguments['lock_factory'] = CustomLockFactory
130 130 db_file = arguments.get('filename')
131 131
132 132 log.debug('initialing %s DB in %s', self.__class__.__name__, db_file)
133 133 try:
134 134 super(FileNamespaceBackend, self).__init__(arguments)
135 135 except Exception:
136 136 log.exception('Failed to initialize db at: %s', db_file)
137 137 raise
138 138
139 139 def __repr__(self):
140 140 return '{} `{}`'.format(self.__class__, self.filename)
141 141
142 142 def list_keys(self, prefix=''):
143 143 prefix = '{}:{}'.format(self.key_prefix, prefix)
144 144
145 145 def cond(v):
146 146 if not prefix:
147 147 return True
148 148
149 149 if v.startswith(prefix):
150 150 return True
151 151 return False
152 152
153 153 with self._dbm_file(True) as dbm:
154 154 try:
155 155 return filter(cond, dbm.keys())
156 156 except Exception:
157 157 log.error('Failed to fetch DBM keys from DB: %s', self.get_store())
158 158 raise
159 159
160 160 def get_store(self):
161 161 return self.filename
162 162
163 163 def _dbm_get(self, key):
164 164 with self._dbm_file(False) as dbm:
165 165 if hasattr(dbm, 'get'):
166 166 value = dbm.get(key, NO_VALUE)
167 167 else:
168 168 # gdbm objects lack a .get method
169 169 try:
170 170 value = dbm[key]
171 171 except KeyError:
172 172 value = NO_VALUE
173 173 if value is not NO_VALUE:
174 174 value = self._loads(value)
175 175 return value
176 176
177 177 def get(self, key):
178 178 try:
179 179 return self._dbm_get(key)
180 180 except Exception:
181 181 log.error('Failed to fetch DBM key %s from DB: %s', key, self.get_store())
182 182 raise
183 183
184 184 def set(self, key, value):
185 185 with self._dbm_file(True) as dbm:
186 186 dbm[key] = self._dumps(value)
187 187
188 188 def set_multi(self, mapping):
189 189 with self._dbm_file(True) as dbm:
190 190 for key, value in mapping.items():
191 191 dbm[key] = self._dumps(value)
192 192
193 193
194 194 class BaseRedisBackend(redis_backend.RedisBackend):
195 195 key_prefix = ''
196 196
197 197 def __init__(self, arguments):
198 198 super(BaseRedisBackend, self).__init__(arguments)
199 199 self._lock_timeout = self.lock_timeout
200 200 self._lock_auto_renewal = asbool(arguments.pop("lock_auto_renewal", True))
201 201
202 202 if self._lock_auto_renewal and not self._lock_timeout:
203 203 # set default timeout for auto_renewal
204 204 self._lock_timeout = 30
205 205
206 206 def _create_client(self):
207 207 args = {}
208 208
209 209 if self.url is not None:
210 210 args.update(url=self.url)
211 211
212 212 else:
213 213 args.update(
214 214 host=self.host, password=self.password,
215 215 port=self.port, db=self.db
216 216 )
217 217
218 218 connection_pool = redis.ConnectionPool(**args)
219 219
220 220 return redis.StrictRedis(connection_pool=connection_pool)
221 221
222 222 def list_keys(self, prefix=''):
223 223 prefix = '{}:{}*'.format(self.key_prefix, prefix)
224 224 return self.client.keys(prefix)
225 225
226 226 def get_store(self):
227 227 return self.client.connection_pool
228 228
229 229 def get(self, key):
230 230 value = self.client.get(key)
231 231 if value is None:
232 232 return NO_VALUE
233 233 return self._loads(value)
234 234
235 235 def get_multi(self, keys):
236 236 if not keys:
237 237 return []
238 238 values = self.client.mget(keys)
239 239 loads = self._loads
240 240 return [
241 241 loads(v) if v is not None else NO_VALUE
242 242 for v in values]
243 243
244 244 def set(self, key, value):
245 245 if self.redis_expiration_time:
246 246 self.client.setex(key, self.redis_expiration_time,
247 247 self._dumps(value))
248 248 else:
249 249 self.client.set(key, self._dumps(value))
250 250
251 251 def set_multi(self, mapping):
252 252 dumps = self._dumps
253 253 mapping = dict(
254 254 (k, dumps(v))
255 255 for k, v in mapping.items()
256 256 )
257 257
258 258 if not self.redis_expiration_time:
259 259 self.client.mset(mapping)
260 260 else:
261 261 pipe = self.client.pipeline()
262 262 for key, value in mapping.items():
263 263 pipe.setex(key, self.redis_expiration_time, value)
264 264 pipe.execute()
265 265
266 266 def get_mutex(self, key):
267 267 if self.distributed_lock:
268 lock_key = u'_lock_{0}'.format(safe_unicode(key))
268 lock_key = '_lock_{0}'.format(safe_unicode(key))
269 269 return get_mutex_lock(self.client, lock_key, self._lock_timeout,
270 270 auto_renewal=self._lock_auto_renewal)
271 271 else:
272 272 return None
273 273
274 274
275 275 class RedisPickleBackend(PickleSerializer, BaseRedisBackend):
276 276 key_prefix = 'redis_pickle_backend'
277 277 pass
278 278
279 279
280 280 class RedisMsgPackBackend(MsgPackSerializer, BaseRedisBackend):
281 281 key_prefix = 'redis_msgpack_backend'
282 282 pass
283 283
284 284
285 285 def get_mutex_lock(client, lock_key, lock_timeout, auto_renewal=False):
286 286 import redis_lock
287 287
288 288 class _RedisLockWrapper(object):
289 289 """LockWrapper for redis_lock"""
290 290
291 291 @classmethod
292 292 def get_lock(cls):
293 293 return redis_lock.Lock(
294 294 redis_client=client,
295 295 name=lock_key,
296 296 expire=lock_timeout,
297 297 auto_renewal=auto_renewal,
298 298 strict=True,
299 299 )
300 300
301 301 def __repr__(self):
302 302 return "{}:{}".format(self.__class__.__name__, lock_key)
303 303
304 304 def __str__(self):
305 305 return "{}:{}".format(self.__class__.__name__, lock_key)
306 306
307 307 def __init__(self):
308 308 self.lock = self.get_lock()
309 309 self.lock_key = lock_key
310 310
311 311 def acquire(self, wait=True):
312 312 log.debug('Trying to acquire Redis lock for key %s', self.lock_key)
313 313 try:
314 314 acquired = self.lock.acquire(wait)
315 315 log.debug('Got lock for key %s, %s', self.lock_key, acquired)
316 316 return acquired
317 317 except redis_lock.AlreadyAcquired:
318 318 return False
319 319 except redis_lock.AlreadyStarted:
320 320 # refresh thread exists, but it also means we acquired the lock
321 321 return True
322 322
323 323 def release(self):
324 324 try:
325 325 self.lock.release()
326 326 except redis_lock.NotAcquired:
327 327 pass
328 328
329 329 return _RedisLockWrapper()
@@ -1,261 +1,261 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import time
20 20 import logging
21 21 import functools
22 22
23 23 from dogpile.cache import CacheRegion
24 24
25 25 from vcsserver.utils import safe_str, sha1
26 26 from vcsserver.lib.rc_cache import region_meta
27 27
28 28 log = logging.getLogger(__name__)
29 29
30 30
31 31 class RhodeCodeCacheRegion(CacheRegion):
32 32
33 33 def conditional_cache_on_arguments(
34 34 self, namespace=None,
35 35 expiration_time=None,
36 36 should_cache_fn=None,
37 37 to_str=str,
38 38 function_key_generator=None,
39 39 condition=True):
40 40 """
41 41 Custom conditional decorator, that will not touch any dogpile internals if
42 42 condition isn't meet. This works a bit different than should_cache_fn
43 43 And it's faster in cases we don't ever want to compute cached values
44 44 """
45 45 expiration_time_is_callable = callable(expiration_time)
46 46
47 47 if function_key_generator is None:
48 48 function_key_generator = self.function_key_generator
49 49
50 50 # workaround for py2 and cython problems, this block should be removed
51 51 # once we've migrated to py3
52 52 if 'cython' == 'cython':
53 53 def decorator(fn):
54 54 if to_str is str:
55 55 # backwards compatible
56 56 key_generator = function_key_generator(namespace, fn)
57 57 else:
58 58 key_generator = function_key_generator(namespace, fn, to_str=to_str)
59 59
60 60 @functools.wraps(fn)
61 61 def decorate(*arg, **kw):
62 62 key = key_generator(*arg, **kw)
63 63
64 64 @functools.wraps(fn)
65 65 def creator():
66 66 return fn(*arg, **kw)
67 67
68 68 if not condition:
69 69 return creator()
70 70
71 71 timeout = expiration_time() if expiration_time_is_callable \
72 72 else expiration_time
73 73
74 74 return self.get_or_create(key, creator, timeout, should_cache_fn)
75 75
76 76 def invalidate(*arg, **kw):
77 77 key = key_generator(*arg, **kw)
78 78 self.delete(key)
79 79
80 80 def set_(value, *arg, **kw):
81 81 key = key_generator(*arg, **kw)
82 82 self.set(key, value)
83 83
84 84 def get(*arg, **kw):
85 85 key = key_generator(*arg, **kw)
86 86 return self.get(key)
87 87
88 88 def refresh(*arg, **kw):
89 89 key = key_generator(*arg, **kw)
90 90 value = fn(*arg, **kw)
91 91 self.set(key, value)
92 92 return value
93 93
94 94 decorate.set = set_
95 95 decorate.invalidate = invalidate
96 96 decorate.refresh = refresh
97 97 decorate.get = get
98 98 decorate.original = fn
99 99 decorate.key_generator = key_generator
100 100 decorate.__wrapped__ = fn
101 101
102 102 return decorate
103 103 return decorator
104 104
105 105 def get_or_create_for_user_func(key_generator, user_func, *arg, **kw):
106 106
107 107 if not condition:
108 log.debug('Calling un-cached method:%s', user_func.func_name)
108 log.debug('Calling un-cached method:%s', user_func.__name__)
109 109 start = time.time()
110 110 result = user_func(*arg, **kw)
111 111 total = time.time() - start
112 log.debug('un-cached method:%s took %.4fs', user_func.func_name, total)
112 log.debug('un-cached method:%s took %.4fs', user_func.__name__, total)
113 113 return result
114 114
115 115 key = key_generator(*arg, **kw)
116 116
117 117 timeout = expiration_time() if expiration_time_is_callable \
118 118 else expiration_time
119 119
120 log.debug('Calling cached method:`%s`', user_func.func_name)
120 log.debug('Calling cached method:`%s`', user_func.__name__)
121 121 return self.get_or_create(key, user_func, timeout, should_cache_fn, (arg, kw))
122 122
123 123 def cache_decorator(user_func):
124 124 if to_str is str:
125 125 # backwards compatible
126 126 key_generator = function_key_generator(namespace, user_func)
127 127 else:
128 128 key_generator = function_key_generator(namespace, user_func, to_str=to_str)
129 129
130 130 def refresh(*arg, **kw):
131 131 """
132 132 Like invalidate, but regenerates the value instead
133 133 """
134 134 key = key_generator(*arg, **kw)
135 135 value = user_func(*arg, **kw)
136 136 self.set(key, value)
137 137 return value
138 138
139 139 def invalidate(*arg, **kw):
140 140 key = key_generator(*arg, **kw)
141 141 self.delete(key)
142 142
143 143 def set_(value, *arg, **kw):
144 144 key = key_generator(*arg, **kw)
145 145 self.set(key, value)
146 146
147 147 def get(*arg, **kw):
148 148 key = key_generator(*arg, **kw)
149 149 return self.get(key)
150 150
151 151 user_func.set = set_
152 152 user_func.invalidate = invalidate
153 153 user_func.get = get
154 154 user_func.refresh = refresh
155 155 user_func.key_generator = key_generator
156 156 user_func.original = user_func
157 157
158 158 # Use `decorate` to preserve the signature of :param:`user_func`.
159 159 return decorator.decorate(user_func, functools.partial(
160 160 get_or_create_for_user_func, key_generator))
161 161
162 162 return cache_decorator
163 163
164 164
165 165 def make_region(*arg, **kw):
166 166 return RhodeCodeCacheRegion(*arg, **kw)
167 167
168 168
169 169 def get_default_cache_settings(settings, prefixes=None):
170 170 prefixes = prefixes or []
171 171 cache_settings = {}
172 172 for key in settings.keys():
173 173 for prefix in prefixes:
174 174 if key.startswith(prefix):
175 175 name = key.split(prefix)[1].strip()
176 176 val = settings[key]
177 177 if isinstance(val, str):
178 178 val = val.strip()
179 179 cache_settings[name] = val
180 180 return cache_settings
181 181
182 182
183 183 def compute_key_from_params(*args):
184 184 """
185 185 Helper to compute key from given params to be used in cache manager
186 186 """
187 187 return sha1("_".join(map(safe_str, args)))
188 188
189 189
190 190 def backend_key_generator(backend):
191 191 """
192 192 Special wrapper that also sends over the backend to the key generator
193 193 """
194 194 def wrapper(namespace, fn):
195 195 return key_generator(backend, namespace, fn)
196 196 return wrapper
197 197
198 198
199 199 def key_generator(backend, namespace, fn):
200 200 fname = fn.__name__
201 201
202 202 def generate_key(*args):
203 203 backend_prefix = getattr(backend, 'key_prefix', None) or 'backend_prefix'
204 204 namespace_pref = namespace or 'default_namespace'
205 205 arg_key = compute_key_from_params(*args)
206 206 final_key = "{}:{}:{}_{}".format(backend_prefix, namespace_pref, fname, arg_key)
207 207
208 208 return final_key
209 209
210 210 return generate_key
211 211
212 212
213 213 def get_or_create_region(region_name, region_namespace=None):
214 214 from vcsserver.lib.rc_cache.backends import FileNamespaceBackend
215 215 region_obj = region_meta.dogpile_cache_regions.get(region_name)
216 216 if not region_obj:
217 217 raise EnvironmentError(
218 218 'Region `{}` not in configured: {}.'.format(
219 219 region_name, region_meta.dogpile_cache_regions.keys()))
220 220
221 221 region_uid_name = '{}:{}'.format(region_name, region_namespace)
222 222 if isinstance(region_obj.actual_backend, FileNamespaceBackend):
223 223 region_exist = region_meta.dogpile_cache_regions.get(region_namespace)
224 224 if region_exist:
225 225 log.debug('Using already configured region: %s', region_namespace)
226 226 return region_exist
227 227 cache_dir = region_meta.dogpile_config_defaults['cache_dir']
228 228 expiration_time = region_obj.expiration_time
229 229
230 230 if not os.path.isdir(cache_dir):
231 231 os.makedirs(cache_dir)
232 232 new_region = make_region(
233 233 name=region_uid_name,
234 234 function_key_generator=backend_key_generator(region_obj.actual_backend)
235 235 )
236 236 namespace_filename = os.path.join(
237 237 cache_dir, "{}.cache.dbm".format(region_namespace))
238 238 # special type that allows 1db per namespace
239 239 new_region.configure(
240 240 backend='dogpile.cache.rc.file_namespace',
241 241 expiration_time=expiration_time,
242 242 arguments={"filename": namespace_filename}
243 243 )
244 244
245 245 # create and save in region caches
246 246 log.debug('configuring new region: %s', region_uid_name)
247 247 region_obj = region_meta.dogpile_cache_regions[region_namespace] = new_region
248 248
249 249 return region_obj
250 250
251 251
252 252 def clear_cache_namespace(cache_region, cache_namespace_uid, invalidate=False):
253 253 region = get_or_create_region(cache_region, cache_namespace_uid)
254 254 cache_keys = region.backend.list_keys(prefix=cache_namespace_uid)
255 255 num_delete_keys = len(cache_keys)
256 256 if invalidate:
257 257 region.invalidate(hard=False)
258 258 else:
259 259 if num_delete_keys:
260 260 region.delete_multi(cache_keys)
261 261 return num_delete_keys
@@ -1,235 +1,235 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import logging
20 20 import itertools
21 21
22 22 import mercurial
23 23 import mercurial.error
24 24 import mercurial.wireprotoserver
25 25 import mercurial.hgweb.common
26 26 import mercurial.hgweb.hgweb_mod
27 27 import webob.exc
28 28
29 29 from vcsserver import pygrack, exceptions, settings, git_lfs
30 30
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 # propagated from mercurial documentation
36 36 HG_UI_SECTIONS = [
37 37 'alias', 'auth', 'decode/encode', 'defaults', 'diff', 'email', 'extensions',
38 38 'format', 'merge-patterns', 'merge-tools', 'hooks', 'http_proxy', 'smtp',
39 39 'patch', 'paths', 'profiling', 'server', 'trusted', 'ui', 'web',
40 40 ]
41 41
42 42
43 43 class HgWeb(mercurial.hgweb.hgweb_mod.hgweb):
44 44 """Extension of hgweb that simplifies some functions."""
45 45
46 46 def _get_view(self, repo):
47 47 """Views are not supported."""
48 48 return repo
49 49
50 50 def loadsubweb(self):
51 51 """The result is only used in the templater method which is not used."""
52 52 return None
53 53
54 54 def run(self):
55 55 """Unused function so raise an exception if accidentally called."""
56 56 raise NotImplementedError
57 57
58 58 def templater(self, req):
59 59 """Function used in an unreachable code path.
60 60
61 61 This code is unreachable because we guarantee that the HTTP request,
62 62 corresponds to a Mercurial command. See the is_hg method. So, we are
63 63 never going to get a user-visible url.
64 64 """
65 65 raise NotImplementedError
66 66
67 67 def archivelist(self, nodeid):
68 68 """Unused function so raise an exception if accidentally called."""
69 69 raise NotImplementedError
70 70
71 71 def __call__(self, environ, start_response):
72 72 """Run the WSGI application.
73 73
74 74 This may be called by multiple threads.
75 75 """
76 76 from mercurial.hgweb import request as requestmod
77 77 req = requestmod.parserequestfromenv(environ)
78 78 res = requestmod.wsgiresponse(req, start_response)
79 79 gen = self.run_wsgi(req, res)
80 80
81 81 first_chunk = None
82 82
83 83 try:
84 data = gen.next()
84 data = next(gen)
85 85
86 86 def first_chunk():
87 87 yield data
88 88 except StopIteration:
89 89 pass
90 90
91 91 if first_chunk:
92 92 return itertools.chain(first_chunk(), gen)
93 93 return gen
94 94
95 95 def _runwsgi(self, req, res, repo):
96 96
97 97 cmd = req.qsparams.get('cmd', '')
98 98 if not mercurial.wireprotoserver.iscmd(cmd):
99 99 # NOTE(marcink): for unsupported commands, we return bad request
100 100 # internally from HG
101 101 from mercurial.hgweb.common import statusmessage
102 102 res.status = statusmessage(mercurial.hgweb.common.HTTP_BAD_REQUEST)
103 103 res.setbodybytes('')
104 104 return res.sendresponse()
105 105
106 106 return super(HgWeb, self)._runwsgi(req, res, repo)
107 107
108 108
109 109 def make_hg_ui_from_config(repo_config):
110 110 baseui = mercurial.ui.ui()
111 111
112 112 # clean the baseui object
113 113 baseui._ocfg = mercurial.config.config()
114 114 baseui._ucfg = mercurial.config.config()
115 115 baseui._tcfg = mercurial.config.config()
116 116
117 117 for section, option, value in repo_config:
118 118 baseui.setconfig(section, option, value)
119 119
120 120 # make our hgweb quiet so it doesn't print output
121 121 baseui.setconfig('ui', 'quiet', 'true')
122 122
123 123 return baseui
124 124
125 125
126 126 def update_hg_ui_from_hgrc(baseui, repo_path):
127 127 path = os.path.join(repo_path, '.hg', 'hgrc')
128 128
129 129 if not os.path.isfile(path):
130 130 log.debug('hgrc file is not present at %s, skipping...', path)
131 131 return
132 132 log.debug('reading hgrc from %s', path)
133 133 cfg = mercurial.config.config()
134 134 cfg.read(path)
135 135 for section in HG_UI_SECTIONS:
136 136 for k, v in cfg.items(section):
137 137 log.debug('settings ui from file: [%s] %s=%s', section, k, v)
138 138 baseui.setconfig(section, k, v)
139 139
140 140
141 141 def create_hg_wsgi_app(repo_path, repo_name, config):
142 142 """
143 143 Prepares a WSGI application to handle Mercurial requests.
144 144
145 145 :param config: is a list of 3-item tuples representing a ConfigObject
146 146 (it is the serialized version of the config object).
147 147 """
148 148 log.debug("Creating Mercurial WSGI application")
149 149
150 150 baseui = make_hg_ui_from_config(config)
151 151 update_hg_ui_from_hgrc(baseui, repo_path)
152 152
153 153 try:
154 154 return HgWeb(repo_path, name=repo_name, baseui=baseui)
155 155 except mercurial.error.RequirementError as e:
156 156 raise exceptions.RequirementException(e)(e)
157 157
158 158
159 159 class GitHandler(object):
160 160 """
161 161 Handler for Git operations like push/pull etc
162 162 """
163 163 def __init__(self, repo_location, repo_name, git_path, update_server_info,
164 164 extras):
165 165 if not os.path.isdir(repo_location):
166 166 raise OSError(repo_location)
167 167 self.content_path = repo_location
168 168 self.repo_name = repo_name
169 169 self.repo_location = repo_location
170 170 self.extras = extras
171 171 self.git_path = git_path
172 172 self.update_server_info = update_server_info
173 173
174 174 def __call__(self, environ, start_response):
175 175 app = webob.exc.HTTPNotFound()
176 176 candidate_paths = (
177 177 self.content_path, os.path.join(self.content_path, '.git'))
178 178
179 179 for content_path in candidate_paths:
180 180 try:
181 181 app = pygrack.GitRepository(
182 182 self.repo_name, content_path, self.git_path,
183 183 self.update_server_info, self.extras)
184 184 break
185 185 except OSError:
186 186 continue
187 187
188 188 return app(environ, start_response)
189 189
190 190
191 191 def create_git_wsgi_app(repo_path, repo_name, config):
192 192 """
193 193 Creates a WSGI application to handle Git requests.
194 194
195 195 :param config: is a dictionary holding the extras.
196 196 """
197 197 git_path = settings.GIT_EXECUTABLE
198 198 update_server_info = config.pop('git_update_server_info')
199 199 app = GitHandler(
200 200 repo_path, repo_name, git_path, update_server_info, config)
201 201
202 202 return app
203 203
204 204
205 205 class GitLFSHandler(object):
206 206 """
207 207 Handler for Git LFS operations
208 208 """
209 209
210 210 def __init__(self, repo_location, repo_name, git_path, update_server_info,
211 211 extras):
212 212 if not os.path.isdir(repo_location):
213 213 raise OSError(repo_location)
214 214 self.content_path = repo_location
215 215 self.repo_name = repo_name
216 216 self.repo_location = repo_location
217 217 self.extras = extras
218 218 self.git_path = git_path
219 219 self.update_server_info = update_server_info
220 220
221 221 def get_app(self, git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
222 222 app = git_lfs.create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
223 223 return app
224 224
225 225
226 226 def create_git_lfs_wsgi_app(repo_path, repo_name, config):
227 227 git_path = settings.GIT_EXECUTABLE
228 228 update_server_info = config.pop('git_update_server_info')
229 229 git_lfs_enabled = config.pop('git_lfs_enabled')
230 230 git_lfs_store_path = config.pop('git_lfs_store_path')
231 231 git_lfs_http_scheme = config.pop('git_lfs_http_scheme', 'http')
232 232 app = GitLFSHandler(
233 233 repo_path, repo_name, git_path, update_server_info, config)
234 234
235 235 return app.get_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
@@ -1,519 +1,519 b''
1 1 """
2 2 Module provides a class allowing to wrap communication over subprocess.Popen
3 3 input, output, error streams into a meaningfull, non-blocking, concurrent
4 4 stream processor exposing the output data as an iterator fitting to be a
5 5 return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
6 6
7 7 Copyright (c) 2011 Daniel Dotsenko <dotsa[at]hotmail.com>
8 8
9 9 This file is part of git_http_backend.py Project.
10 10
11 11 git_http_backend.py Project is free software: you can redistribute it and/or
12 12 modify it under the terms of the GNU Lesser General Public License as
13 13 published by the Free Software Foundation, either version 2.1 of the License,
14 14 or (at your option) any later version.
15 15
16 16 git_http_backend.py Project is distributed in the hope that it will be useful,
17 17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 19 GNU Lesser General Public License for more details.
20 20
21 21 You should have received a copy of the GNU Lesser General Public License
22 22 along with git_http_backend.py Project.
23 23 If not, see <http://www.gnu.org/licenses/>.
24 24 """
25 25 import os
26 26 import logging
27 27 import subprocess
28 28 from collections import deque
29 29 from threading import Event, Thread
30 30
31 31 log = logging.getLogger(__name__)
32 32
33 33
34 34 class StreamFeeder(Thread):
35 35 """
36 36 Normal writing into pipe-like is blocking once the buffer is filled.
37 37 This thread allows a thread to seep data from a file-like into a pipe
38 38 without blocking the main thread.
39 39 We close inpipe once the end of the source stream is reached.
40 40 """
41 41
42 42 def __init__(self, source):
43 43 super(StreamFeeder, self).__init__()
44 44 self.daemon = True
45 45 filelike = False
46 46 self.bytes = bytes()
47 47 if type(source) in (type(''), bytes, bytearray): # string-like
48 48 self.bytes = bytes(source)
49 49 else: # can be either file pointer or file-like
50 if type(source) in (int, long): # file pointer it is
50 if type(source) in (int, int): # file pointer it is
51 51 # converting file descriptor (int) stdin into file-like
52 52 try:
53 53 source = os.fdopen(source, 'rb', 16384)
54 54 except Exception:
55 55 pass
56 56 # let's see if source is file-like by now
57 57 try:
58 58 filelike = source.read
59 59 except Exception:
60 60 pass
61 61 if not filelike and not self.bytes:
62 62 raise TypeError("StreamFeeder's source object must be a readable "
63 63 "file-like, a file descriptor, or a string-like.")
64 64 self.source = source
65 65 self.readiface, self.writeiface = os.pipe()
66 66
67 67 def run(self):
68 68 t = self.writeiface
69 69 try:
70 70 if self.bytes:
71 71 os.write(t, self.bytes)
72 72 else:
73 73 s = self.source
74 74 b = s.read(4096)
75 75 while b:
76 76 os.write(t, b)
77 77 b = s.read(4096)
78 78 finally:
79 79 os.close(t)
80 80
81 81 @property
82 82 def output(self):
83 83 return self.readiface
84 84
85 85
86 86 class InputStreamChunker(Thread):
87 87 def __init__(self, source, target, buffer_size, chunk_size):
88 88
89 89 super(InputStreamChunker, self).__init__()
90 90
91 91 self.daemon = True # die die die.
92 92
93 93 self.source = source
94 94 self.target = target
95 95 self.chunk_count_max = int(buffer_size / chunk_size) + 1
96 96 self.chunk_size = chunk_size
97 97
98 98 self.data_added = Event()
99 99 self.data_added.clear()
100 100
101 101 self.keep_reading = Event()
102 102 self.keep_reading.set()
103 103
104 104 self.EOF = Event()
105 105 self.EOF.clear()
106 106
107 107 self.go = Event()
108 108 self.go.set()
109 109
110 110 def stop(self):
111 111 self.go.clear()
112 112 self.EOF.set()
113 113 try:
114 114 # this is not proper, but is done to force the reader thread let
115 115 # go of the input because, if successful, .close() will send EOF
116 116 # down the pipe.
117 117 self.source.close()
118 118 except:
119 119 pass
120 120
121 121 def run(self):
122 122 s = self.source
123 123 t = self.target
124 124 cs = self.chunk_size
125 125 chunk_count_max = self.chunk_count_max
126 126 keep_reading = self.keep_reading
127 127 da = self.data_added
128 128 go = self.go
129 129
130 130 try:
131 131 b = s.read(cs)
132 132 except ValueError:
133 133 b = ''
134 134
135 135 timeout_input = 20
136 136 while b and go.is_set():
137 137 if len(t) > chunk_count_max:
138 138 keep_reading.clear()
139 139 keep_reading.wait(timeout_input)
140 140 if len(t) > chunk_count_max + timeout_input:
141 141 log.error("Timed out while waiting for input from subprocess.")
142 142 os._exit(-1) # this will cause the worker to recycle itself
143 143
144 144 t.append(b)
145 145 da.set()
146 146
147 147 try:
148 148 b = s.read(cs)
149 149 except ValueError:
150 150 b = ''
151 151
152 152 self.EOF.set()
153 153 da.set() # for cases when done but there was no input.
154 154
155 155
156 156 class BufferedGenerator(object):
157 157 """
158 158 Class behaves as a non-blocking, buffered pipe reader.
159 159 Reads chunks of data (through a thread)
160 160 from a blocking pipe, and attaches these to an array (Deque) of chunks.
161 161 Reading is halted in the thread when max chunks is internally buffered.
162 162 The .next() may operate in blocking or non-blocking fashion by yielding
163 163 '' if no data is ready
164 164 to be sent or by not returning until there is some data to send
165 165 When we get EOF from underlying source pipe we raise the marker to raise
166 166 StopIteration after the last chunk of data is yielded.
167 167 """
168 168
169 169 def __init__(self, source, buffer_size=65536, chunk_size=4096,
170 170 starting_values=None, bottomless=False):
171 171 starting_values = starting_values or []
172 172
173 173 if bottomless:
174 174 maxlen = int(buffer_size / chunk_size)
175 175 else:
176 176 maxlen = None
177 177
178 178 self.data = deque(starting_values, maxlen)
179 179 self.worker = InputStreamChunker(source, self.data, buffer_size,
180 180 chunk_size)
181 181 if starting_values:
182 182 self.worker.data_added.set()
183 183 self.worker.start()
184 184
185 185 ####################
186 186 # Generator's methods
187 187 ####################
188 188
189 189 def __iter__(self):
190 190 return self
191 191
192 def next(self):
192 def __next__(self):
193 193 while not len(self.data) and not self.worker.EOF.is_set():
194 194 self.worker.data_added.clear()
195 195 self.worker.data_added.wait(0.2)
196 196 if len(self.data):
197 197 self.worker.keep_reading.set()
198 198 return bytes(self.data.popleft())
199 199 elif self.worker.EOF.is_set():
200 200 raise StopIteration
201 201
202 202 def throw(self, exc_type, value=None, traceback=None):
203 203 if not self.worker.EOF.is_set():
204 204 raise exc_type(value)
205 205
206 206 def start(self):
207 207 self.worker.start()
208 208
209 209 def stop(self):
210 210 self.worker.stop()
211 211
212 212 def close(self):
213 213 try:
214 214 self.worker.stop()
215 215 self.throw(GeneratorExit)
216 216 except (GeneratorExit, StopIteration):
217 217 pass
218 218
219 219 ####################
220 220 # Threaded reader's infrastructure.
221 221 ####################
222 222 @property
223 223 def input(self):
224 224 return self.worker.w
225 225
226 226 @property
227 227 def data_added_event(self):
228 228 return self.worker.data_added
229 229
230 230 @property
231 231 def data_added(self):
232 232 return self.worker.data_added.is_set()
233 233
234 234 @property
235 235 def reading_paused(self):
236 236 return not self.worker.keep_reading.is_set()
237 237
238 238 @property
239 239 def done_reading_event(self):
240 240 """
241 241 Done_reding does not mean that the iterator's buffer is empty.
242 242 Iterator might have done reading from underlying source, but the read
243 243 chunks might still be available for serving through .next() method.
244 244
245 245 :returns: An Event class instance.
246 246 """
247 247 return self.worker.EOF
248 248
249 249 @property
250 250 def done_reading(self):
251 251 """
252 252 Done_reding does not mean that the iterator's buffer is empty.
253 253 Iterator might have done reading from underlying source, but the read
254 254 chunks might still be available for serving through .next() method.
255 255
256 256 :returns: An Bool value.
257 257 """
258 258 return self.worker.EOF.is_set()
259 259
260 260 @property
261 261 def length(self):
262 262 """
263 263 returns int.
264 264
265 265 This is the lenght of the que of chunks, not the length of
266 266 the combined contents in those chunks.
267 267
268 268 __len__() cannot be meaningfully implemented because this
269 269 reader is just flying throuh a bottomless pit content and
270 270 can only know the lenght of what it already saw.
271 271
272 272 If __len__() on WSGI server per PEP 3333 returns a value,
273 273 the responce's length will be set to that. In order not to
274 274 confuse WSGI PEP3333 servers, we will not implement __len__
275 275 at all.
276 276 """
277 277 return len(self.data)
278 278
279 279 def prepend(self, x):
280 280 self.data.appendleft(x)
281 281
282 282 def append(self, x):
283 283 self.data.append(x)
284 284
285 285 def extend(self, o):
286 286 self.data.extend(o)
287 287
288 288 def __getitem__(self, i):
289 289 return self.data[i]
290 290
291 291
292 292 class SubprocessIOChunker(object):
293 293 """
294 294 Processor class wrapping handling of subprocess IO.
295 295
296 296 .. important::
297 297
298 298 Watch out for the method `__del__` on this class. If this object
299 299 is deleted, it will kill the subprocess, so avoid to
300 300 return the `output` attribute or usage of it like in the following
301 301 example::
302 302
303 303 # `args` expected to run a program that produces a lot of output
304 304 output = ''.join(SubprocessIOChunker(
305 305 args, shell=False, inputstream=inputstream, env=environ).output)
306 306
307 307 # `output` will not contain all the data, because the __del__ method
308 308 # has already killed the subprocess in this case before all output
309 309 # has been consumed.
310 310
311 311
312 312
313 313 In a way, this is a "communicate()" replacement with a twist.
314 314
315 315 - We are multithreaded. Writing in and reading out, err are all sep threads.
316 316 - We support concurrent (in and out) stream processing.
317 317 - The output is not a stream. It's a queue of read string (bytes, not unicode)
318 318 chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
319 319 - We are non-blocking in more respects than communicate()
320 320 (reading from subprocess out pauses when internal buffer is full, but
321 321 does not block the parent calling code. On the flip side, reading from
322 322 slow-yielding subprocess may block the iteration until data shows up. This
323 323 does not block the parallel inpipe reading occurring parallel thread.)
324 324
325 325 The purpose of the object is to allow us to wrap subprocess interactions into
326 326 and interable that can be passed to a WSGI server as the application's return
327 327 value. Because of stream-processing-ability, WSGI does not have to read ALL
328 328 of the subprocess's output and buffer it, before handing it to WSGI server for
329 329 HTTP response. Instead, the class initializer reads just a bit of the stream
330 330 to figure out if error ocurred or likely to occur and if not, just hands the
331 331 further iteration over subprocess output to the server for completion of HTTP
332 332 response.
333 333
334 334 The real or perceived subprocess error is trapped and raised as one of
335 335 EnvironmentError family of exceptions
336 336
337 337 Example usage:
338 338 # try:
339 339 # answer = SubprocessIOChunker(
340 340 # cmd,
341 341 # input,
342 342 # buffer_size = 65536,
343 343 # chunk_size = 4096
344 344 # )
345 345 # except (EnvironmentError) as e:
346 346 # print str(e)
347 347 # raise e
348 348 #
349 349 # return answer
350 350
351 351
352 352 """
353 353
354 354 # TODO: johbo: This is used to make sure that the open end of the PIPE
355 355 # is closed in the end. It would be way better to wrap this into an
356 356 # object, so that it is closed automatically once it is consumed or
357 357 # something similar.
358 358 _close_input_fd = None
359 359
360 360 _closed = False
361 361
362 362 def __init__(self, cmd, inputstream=None, buffer_size=65536,
363 363 chunk_size=4096, starting_values=None, fail_on_stderr=True,
364 364 fail_on_return_code=True, **kwargs):
365 365 """
366 366 Initializes SubprocessIOChunker
367 367
368 368 :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
369 369 :param inputstream: (Default: None) A file-like, string, or file pointer.
370 370 :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
371 371 :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
372 372 :param starting_values: (Default: []) An array of strings to put in front of output que.
373 373 :param fail_on_stderr: (Default: True) Whether to raise an exception in
374 374 case something is written to stderr.
375 375 :param fail_on_return_code: (Default: True) Whether to raise an
376 376 exception if the return code is not 0.
377 377 """
378 378
379 379 starting_values = starting_values or []
380 380 if inputstream:
381 381 input_streamer = StreamFeeder(inputstream)
382 382 input_streamer.start()
383 383 inputstream = input_streamer.output
384 384 self._close_input_fd = inputstream
385 385
386 386 self._fail_on_stderr = fail_on_stderr
387 387 self._fail_on_return_code = fail_on_return_code
388 388
389 389 _shell = kwargs.get('shell', True)
390 390 kwargs['shell'] = _shell
391 391
392 392 _p = subprocess.Popen(cmd, bufsize=-1,
393 393 stdin=inputstream,
394 394 stdout=subprocess.PIPE,
395 395 stderr=subprocess.PIPE,
396 396 **kwargs)
397 397
398 398 bg_out = BufferedGenerator(_p.stdout, buffer_size, chunk_size,
399 399 starting_values)
400 400 bg_err = BufferedGenerator(_p.stderr, 16000, 1, bottomless=True)
401 401
402 402 while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
403 403 # doing this until we reach either end of file, or end of buffer.
404 404 bg_out.data_added_event.wait(1)
405 405 bg_out.data_added_event.clear()
406 406
407 407 # at this point it's still ambiguous if we are done reading or just full buffer.
408 408 # Either way, if error (returned by ended process, or implied based on
409 409 # presence of stuff in stderr output) we error out.
410 410 # Else, we are happy.
411 411 _returncode = _p.poll()
412 412
413 413 if ((_returncode and fail_on_return_code) or
414 414 (fail_on_stderr and _returncode is None and bg_err.length)):
415 415 try:
416 416 _p.terminate()
417 417 except Exception:
418 418 pass
419 419 bg_out.stop()
420 420 bg_err.stop()
421 421 if fail_on_stderr:
422 422 err = ''.join(bg_err)
423 423 raise EnvironmentError(
424 424 "Subprocess exited due to an error:\n" + err)
425 425 if _returncode and fail_on_return_code:
426 426 err = ''.join(bg_err)
427 427 if not err:
428 428 # maybe get empty stderr, try stdout instead
429 429 # in many cases git reports the errors on stdout too
430 430 err = ''.join(bg_out)
431 431 raise EnvironmentError(
432 432 "Subprocess exited with non 0 ret code:%s: stderr:%s" % (
433 433 _returncode, err))
434 434
435 435 self.process = _p
436 436 self.output = bg_out
437 437 self.error = bg_err
438 438 self.inputstream = inputstream
439 439
440 440 def __iter__(self):
441 441 return self
442 442
443 def next(self):
443 def __next__(self):
444 444 # Note: mikhail: We need to be sure that we are checking the return
445 445 # code after the stdout stream is closed. Some processes, e.g. git
446 446 # are doing some magic in between closing stdout and terminating the
447 447 # process and, as a result, we are not getting return code on "slow"
448 448 # systems.
449 449 result = None
450 450 stop_iteration = None
451 451 try:
452 result = self.output.next()
452 result = next(self.output)
453 453 except StopIteration as e:
454 454 stop_iteration = e
455 455
456 456 if self.process.poll() and self._fail_on_return_code:
457 457 err = '%s' % ''.join(self.error)
458 458 raise EnvironmentError(
459 459 "Subprocess exited due to an error:\n" + err)
460 460
461 461 if stop_iteration:
462 462 raise stop_iteration
463 463 return result
464 464
465 465 def throw(self, type, value=None, traceback=None):
466 466 if self.output.length or not self.output.done_reading:
467 467 raise type(value)
468 468
469 469 def close(self):
470 470 if self._closed:
471 471 return
472 472 self._closed = True
473 473 try:
474 474 self.process.terminate()
475 475 except Exception:
476 476 pass
477 477 if self._close_input_fd:
478 478 os.close(self._close_input_fd)
479 479 try:
480 480 self.output.close()
481 481 except Exception:
482 482 pass
483 483 try:
484 484 self.error.close()
485 485 except Exception:
486 486 pass
487 487 try:
488 488 os.close(self.inputstream)
489 489 except Exception:
490 490 pass
491 491
492 492
493 493 def run_command(arguments, env=None):
494 494 """
495 495 Run the specified command and return the stdout.
496 496
497 497 :param arguments: sequence of program arguments (including the program name)
498 498 :type arguments: list[str]
499 499 """
500 500
501 501 cmd = arguments
502 502 log.debug('Running subprocessio command %s', cmd)
503 503 proc = None
504 504 try:
505 505 _opts = {'shell': False, 'fail_on_stderr': False}
506 506 if env:
507 507 _opts.update({'env': env})
508 508 proc = SubprocessIOChunker(cmd, **_opts)
509 509 return ''.join(proc), ''.join(proc.error)
510 510 except (EnvironmentError, OSError) as err:
511 511 cmd = ' '.join(cmd) # human friendly CMD
512 512 tb_err = ("Couldn't run subprocessio command (%s).\n"
513 513 "Original error was:%s\n" % (cmd, err))
514 514 log.exception(tb_err)
515 515 raise Exception(tb_err)
516 516 finally:
517 517 if proc:
518 518 proc.close()
519 519
@@ -1,56 +1,56 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import socket
19 19 import pytest
20 20
21 21
22 22 def pytest_addoption(parser):
23 23 parser.addoption(
24 24 '--perf-repeat-vcs', type=int, default=100,
25 25 help="Number of repetitions in performance tests.")
26 26
27 27
28 28 @pytest.fixture(scope='session')
29 29 def repeat(request):
30 30 """
31 31 The number of repetitions is based on this fixture.
32 32
33 33 Slower calls may divide it by 10 or 100. It is chosen in a way so that the
34 34 tests are not too slow in our default test suite.
35 35 """
36 36 return request.config.getoption('--perf-repeat-vcs')
37 37
38 38
39 39 @pytest.fixture(scope='session')
40 40 def vcsserver_port(request):
41 41 port = get_available_port()
42 print('Using vcsserver port %s' % (port, ))
42 print(('Using vcsserver port %s' % (port, )))
43 43 return port
44 44
45 45
46 46 def get_available_port():
47 47 family = socket.AF_INET
48 48 socktype = socket.SOCK_STREAM
49 49 host = '127.0.0.1'
50 50
51 51 mysocket = socket.socket(family, socktype)
52 52 mysocket.bind((host, 0))
53 53 port = mysocket.getsockname()[1]
54 54 mysocket.close()
55 55 del mysocket
56 56 return port
@@ -1,159 +1,159 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import inspect
19 19
20 20 import pytest
21 21 import dulwich.errors
22 22 from mock import Mock, patch
23 23
24 24 from vcsserver.remote import git
25 25
26 26 SAMPLE_REFS = {
27 27 'HEAD': 'fd627b9e0dd80b47be81af07c4a98518244ed2f7',
28 28 'refs/tags/v0.1.9': '341d28f0eec5ddf0b6b77871e13c2bbd6bec685c',
29 29 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
30 30 'refs/tags/v0.1.1': 'e6ea6d16e2f26250124a1f4b4fe37a912f9d86a0',
31 31 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
32 32 }
33 33
34 34
35 35 @pytest.fixture
36 36 def git_remote():
37 37 """
38 38 A GitRemote instance with a mock factory.
39 39 """
40 40 factory = Mock()
41 41 remote = git.GitRemote(factory)
42 42 return remote
43 43
44 44
45 45 def test_discover_git_version(git_remote):
46 46 version = git_remote.discover_git_version()
47 47 assert version
48 48
49 49
50 50 class TestGitFetch(object):
51 51 def setup(self):
52 52 self.mock_repo = Mock()
53 53 factory = Mock()
54 54 factory.repo = Mock(return_value=self.mock_repo)
55 55 self.remote_git = git.GitRemote(factory)
56 56
57 57 def test_fetches_all_when_no_commit_ids_specified(self):
58 58 def side_effect(determine_wants, *args, **kwargs):
59 59 determine_wants(SAMPLE_REFS)
60 60
61 61 with patch('dulwich.client.LocalGitClient.fetch') as mock_fetch:
62 62 mock_fetch.side_effect = side_effect
63 63 self.remote_git.pull(wire={}, url='/tmp/', apply_refs=False)
64 64 determine_wants = self.mock_repo.object_store.determine_wants_all
65 65 determine_wants.assert_called_once_with(SAMPLE_REFS)
66 66
67 67 def test_fetches_specified_commits(self):
68 68 selected_refs = {
69 69 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
70 70 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
71 71 }
72 72
73 73 def side_effect(determine_wants, *args, **kwargs):
74 74 result = determine_wants(SAMPLE_REFS)
75 75 assert sorted(result) == sorted(selected_refs.values())
76 76 return result
77 77
78 78 with patch('dulwich.client.LocalGitClient.fetch') as mock_fetch:
79 79 mock_fetch.side_effect = side_effect
80 80 self.remote_git.pull(
81 81 wire={}, url='/tmp/', apply_refs=False,
82 82 refs=selected_refs.keys())
83 83 determine_wants = self.mock_repo.object_store.determine_wants_all
84 84 assert determine_wants.call_count == 0
85 85
86 86 def test_get_remote_refs(self):
87 87 factory = Mock()
88 88 remote_git = git.GitRemote(factory)
89 89 url = 'http://example.com/test/test.git'
90 90 sample_refs = {
91 91 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
92 92 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
93 93 }
94 94
95 95 with patch('vcsserver.git.Repo', create=False) as mock_repo:
96 96 mock_repo().get_refs.return_value = sample_refs
97 97 remote_refs = remote_git.get_remote_refs(wire={}, url=url)
98 98 mock_repo().get_refs.assert_called_once_with()
99 99 assert remote_refs == sample_refs
100 100
101 101
102 102 class TestReraiseSafeExceptions(object):
103 103
104 104 def test_method_decorated_with_reraise_safe_exceptions(self):
105 105 factory = Mock()
106 106 git_remote = git.GitRemote(factory)
107 107
108 108 def fake_function():
109 109 return None
110 110
111 111 decorator = git.reraise_safe_exceptions(fake_function)
112 112
113 113 methods = inspect.getmembers(git_remote, predicate=inspect.ismethod)
114 114 for method_name, method in methods:
115 115 if not method_name.startswith('_') and method_name not in ['vcsserver_invalidate_cache']:
116 assert method.im_func.__code__ == decorator.__code__
116 assert method.__func__.__code__ == decorator.__code__
117 117
118 118 @pytest.mark.parametrize('side_effect, expected_type', [
119 119 (dulwich.errors.ChecksumMismatch('0000000', 'deadbeef'), 'lookup'),
120 120 (dulwich.errors.NotCommitError('deadbeef'), 'lookup'),
121 121 (dulwich.errors.MissingCommitError('deadbeef'), 'lookup'),
122 122 (dulwich.errors.ObjectMissing('deadbeef'), 'lookup'),
123 123 (dulwich.errors.HangupException(), 'error'),
124 124 (dulwich.errors.UnexpectedCommandError('test-cmd'), 'error'),
125 125 ])
126 126 def test_safe_exceptions_reraised(self, side_effect, expected_type):
127 127 @git.reraise_safe_exceptions
128 128 def fake_method():
129 129 raise side_effect
130 130
131 131 with pytest.raises(Exception) as exc_info:
132 132 fake_method()
133 133 assert type(exc_info.value) == Exception
134 134 assert exc_info.value._vcs_kind == expected_type
135 135
136 136
137 137 class TestDulwichRepoWrapper(object):
138 138 def test_calls_close_on_delete(self):
139 139 isdir_patcher = patch('dulwich.repo.os.path.isdir', return_value=True)
140 140 with isdir_patcher:
141 141 repo = git.Repo('/tmp/abcde')
142 142 with patch.object(git.DulwichRepo, 'close') as close_mock:
143 143 del repo
144 144 close_mock.assert_called_once_with()
145 145
146 146
147 147 class TestGitFactory(object):
148 148 def test_create_repo_returns_dulwich_wrapper(self):
149 149
150 150 with patch('vcsserver.lib.rc_cache.region_meta.dogpile_cache_regions') as mock:
151 151 mock.side_effect = {'repo_objects': ''}
152 152 factory = git.GitFactory()
153 153 wire = {
154 154 'path': '/tmp/abcde'
155 155 }
156 156 isdir_patcher = patch('dulwich.repo.os.path.isdir', return_value=True)
157 157 with isdir_patcher:
158 158 result = factory._create_repo(wire, True)
159 159 assert isinstance(result, git.Repo)
@@ -1,109 +1,109 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import inspect
19 19 import sys
20 20 import traceback
21 21
22 22 import pytest
23 23 from mercurial.error import LookupError
24 24 from mock import Mock, patch
25 25
26 26 from vcsserver import exceptions, hgcompat
27 27 from vcsserver.remote import hg
28 28
29 29
30 30 class TestDiff(object):
31 31 def test_raising_safe_exception_when_lookup_failed(self):
32 32
33 33 factory = Mock()
34 34 hg_remote = hg.HgRemote(factory)
35 35 with patch('mercurial.patch.diff') as diff_mock:
36 36 diff_mock.side_effect = LookupError(
37 37 'deadbeef', 'index', 'message')
38 38 with pytest.raises(Exception) as exc_info:
39 39 hg_remote.diff(
40 40 wire={}, commit_id_1='deadbeef', commit_id_2='deadbee1',
41 41 file_filter=None, opt_git=True, opt_ignorews=True,
42 42 context=3)
43 43 assert type(exc_info.value) == Exception
44 44 assert exc_info.value._vcs_kind == 'lookup'
45 45
46 46
47 47 class TestReraiseSafeExceptions(object):
48 48 def test_method_decorated_with_reraise_safe_exceptions(self):
49 49 factory = Mock()
50 50 hg_remote = hg.HgRemote(factory)
51 51 methods = inspect.getmembers(hg_remote, predicate=inspect.ismethod)
52 52 decorator = hg.reraise_safe_exceptions(None)
53 53 for method_name, method in methods:
54 54 if not method_name.startswith('_') and method_name not in ['vcsserver_invalidate_cache']:
55 assert method.im_func.__code__ == decorator.__code__
55 assert method.__func__.__code__ == decorator.__code__
56 56
57 57 @pytest.mark.parametrize('side_effect, expected_type', [
58 58 (hgcompat.Abort(), 'abort'),
59 59 (hgcompat.InterventionRequired(), 'abort'),
60 60 (hgcompat.RepoLookupError(), 'lookup'),
61 61 (hgcompat.LookupError('deadbeef', 'index', 'message'), 'lookup'),
62 62 (hgcompat.RepoError(), 'error'),
63 63 (hgcompat.RequirementError(), 'requirement'),
64 64 ])
65 65 def test_safe_exceptions_reraised(self, side_effect, expected_type):
66 66 @hg.reraise_safe_exceptions
67 67 def fake_method():
68 68 raise side_effect
69 69
70 70 with pytest.raises(Exception) as exc_info:
71 71 fake_method()
72 72 assert type(exc_info.value) == Exception
73 73 assert exc_info.value._vcs_kind == expected_type
74 74
75 75 def test_keeps_original_traceback(self):
76 76 @hg.reraise_safe_exceptions
77 77 def fake_method():
78 78 try:
79 79 raise hgcompat.Abort()
80 80 except:
81 81 self.original_traceback = traceback.format_tb(
82 82 sys.exc_info()[2])
83 83 raise
84 84
85 85 try:
86 86 fake_method()
87 87 except Exception:
88 88 new_traceback = traceback.format_tb(sys.exc_info()[2])
89 89
90 90 new_traceback_tail = new_traceback[-len(self.original_traceback):]
91 91 assert new_traceback_tail == self.original_traceback
92 92
93 93 def test_maps_unknow_exceptions_to_unhandled(self):
94 94 @hg.reraise_safe_exceptions
95 95 def stub_method():
96 96 raise ValueError('stub')
97 97
98 98 with pytest.raises(Exception) as exc_info:
99 99 stub_method()
100 100 assert exc_info.value._vcs_kind == 'unhandled'
101 101
102 102 def test_does_not_map_known_exceptions(self):
103 103 @hg.reraise_safe_exceptions
104 104 def stub_method():
105 105 raise exceptions.LookupException()('stub')
106 106
107 107 with pytest.raises(Exception) as exc_info:
108 108 stub_method()
109 109 assert exc_info.value._vcs_kind == 'lookup'
@@ -1,124 +1,124 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import mock
19 19 import pytest
20 20
21 21 from vcsserver import hgcompat, hgpatches
22 22
23 23
24 24 LARGEFILES_CAPABILITY = 'largefiles=serve'
25 25
26 26
27 27 def test_patch_largefiles_capabilities_applies_patch(
28 28 patched_capabilities):
29 29 lfproto = hgcompat.largefiles.proto
30 30 hgpatches.patch_largefiles_capabilities()
31 assert lfproto._capabilities.func_name == '_dynamic_capabilities'
31 assert lfproto._capabilities.__name__ == '_dynamic_capabilities'
32 32
33 33
34 34 def test_dynamic_capabilities_uses_original_function_if_not_enabled(
35 35 stub_repo, stub_proto, stub_ui, stub_extensions, patched_capabilities,
36 36 orig_capabilities):
37 37 dynamic_capabilities = hgpatches._dynamic_capabilities_wrapper(
38 38 hgcompat.largefiles.proto, stub_extensions)
39 39
40 40 caps = dynamic_capabilities(orig_capabilities, stub_repo, stub_proto)
41 41
42 42 stub_extensions.assert_called_once_with(stub_ui)
43 43 assert LARGEFILES_CAPABILITY not in caps
44 44
45 45
46 46 def test_dynamic_capabilities_ignores_updated_capabilities(
47 47 stub_repo, stub_proto, stub_ui, stub_extensions, patched_capabilities,
48 48 orig_capabilities):
49 49 stub_extensions.return_value = [('largefiles', mock.Mock())]
50 50 dynamic_capabilities = hgpatches._dynamic_capabilities_wrapper(
51 51 hgcompat.largefiles.proto, stub_extensions)
52 52
53 53 # This happens when the extension is loaded for the first time, important
54 54 # to ensure that an updated function is correctly picked up.
55 55 hgcompat.largefiles.proto._capabilities = mock.Mock(
56 56 side_effect=Exception('Must not be called'))
57 57
58 58 dynamic_capabilities(orig_capabilities, stub_repo, stub_proto)
59 59
60 60
61 61 def test_dynamic_capabilities_uses_largefiles_if_enabled(
62 62 stub_repo, stub_proto, stub_ui, stub_extensions, patched_capabilities,
63 63 orig_capabilities):
64 64 stub_extensions.return_value = [('largefiles', mock.Mock())]
65 65
66 66 dynamic_capabilities = hgpatches._dynamic_capabilities_wrapper(
67 67 hgcompat.largefiles.proto, stub_extensions)
68 68
69 69 caps = dynamic_capabilities(orig_capabilities, stub_repo, stub_proto)
70 70
71 71 stub_extensions.assert_called_once_with(stub_ui)
72 72 assert LARGEFILES_CAPABILITY in caps
73 73
74 74
75 75 def test_hgsubversion_import():
76 76 from hgsubversion import svnrepo
77 77 assert svnrepo
78 78
79 79
80 80 @pytest.fixture
81 81 def patched_capabilities(request):
82 82 """
83 83 Patch in `capabilitiesorig` and restore both capability functions.
84 84 """
85 85 lfproto = hgcompat.largefiles.proto
86 86 orig_capabilities = lfproto._capabilities
87 87
88 88 @request.addfinalizer
89 89 def restore():
90 90 lfproto._capabilities = orig_capabilities
91 91
92 92
93 93 @pytest.fixture
94 94 def stub_repo(stub_ui):
95 95 repo = mock.Mock()
96 96 repo.ui = stub_ui
97 97 return repo
98 98
99 99
100 100 @pytest.fixture
101 101 def stub_proto(stub_ui):
102 102 proto = mock.Mock()
103 103 proto.ui = stub_ui
104 104 return proto
105 105
106 106
107 107 @pytest.fixture
108 108 def orig_capabilities():
109 109 from mercurial.wireprotov1server import wireprotocaps
110 110
111 111 def _capabilities(repo, proto):
112 112 return wireprotocaps
113 113 return _capabilities
114 114
115 115
116 116 @pytest.fixture
117 117 def stub_ui():
118 118 return hgcompat.ui.ui()
119 119
120 120
121 121 @pytest.fixture
122 122 def stub_extensions():
123 123 extensions = mock.Mock(return_value=tuple())
124 124 return extensions
@@ -1,155 +1,155 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21
22 22 import pytest
23 23
24 24 from vcsserver import subprocessio
25 25
26 26
27 27 class KindaFilelike(object): # pragma: no cover
28 28
29 29 def __init__(self, data, size):
30 30 chunks = size / len(data)
31 31
32 32 self.stream = self._get_stream(data, chunks)
33 33
34 34 def _get_stream(self, data, chunks):
35 35 for x in range(chunks):
36 36 yield data
37 37
38 38 def read(self, n):
39 39
40 40 buffer_stream = ''
41 41 for chunk in self.stream:
42 42 buffer_stream += chunk
43 43 if len(buffer_stream) >= n:
44 44 break
45 45
46 46 # self.stream = self.bytes[n:]
47 47 return buffer_stream
48 48
49 49
50 50 @pytest.fixture(scope='module')
51 51 def environ():
52 52 """Delete coverage variables, as they make the tests fail."""
53 53 env = dict(os.environ)
54 54 for key in env.keys():
55 55 if key.startswith('COV_CORE_'):
56 56 del env[key]
57 57
58 58 return env
59 59
60 60
61 61 def _get_python_args(script):
62 62 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
63 63
64 64
65 65 def test_raise_exception_on_non_zero_return_code(environ):
66 66 args = _get_python_args('sys.exit(1)')
67 67 with pytest.raises(EnvironmentError):
68 68 list(subprocessio.SubprocessIOChunker(args, shell=False, env=environ))
69 69
70 70
71 71 def test_does_not_fail_on_non_zero_return_code(environ):
72 72 args = _get_python_args('sys.exit(1)')
73 73 output = ''.join(
74 74 subprocessio.SubprocessIOChunker(
75 75 args, shell=False, fail_on_return_code=False, env=environ
76 76 )
77 77 )
78 78
79 79 assert output == ''
80 80
81 81
82 82 def test_raise_exception_on_stderr(environ):
83 83 args = _get_python_args('sys.stderr.write("X"); time.sleep(1);')
84 84 with pytest.raises(EnvironmentError) as excinfo:
85 85 list(subprocessio.SubprocessIOChunker(args, shell=False, env=environ))
86 86
87 87 assert 'exited due to an error:\nX' in str(excinfo.value)
88 88
89 89
90 90 def test_does_not_fail_on_stderr(environ):
91 91 args = _get_python_args('sys.stderr.write("X"); time.sleep(1);')
92 92 output = ''.join(
93 93 subprocessio.SubprocessIOChunker(
94 94 args, shell=False, fail_on_stderr=False, env=environ
95 95 )
96 96 )
97 97
98 98 assert output == ''
99 99
100 100
101 101 @pytest.mark.parametrize('size', [1, 10 ** 5])
102 102 def test_output_with_no_input(size, environ):
103 print(type(environ))
103 print((type(environ)))
104 104 data = 'X'
105 105 args = _get_python_args('sys.stdout.write("%s" * %d)' % (data, size))
106 106 output = ''.join(subprocessio.SubprocessIOChunker(args, shell=False, env=environ))
107 107
108 108 assert output == data * size
109 109
110 110
111 111 @pytest.mark.parametrize('size', [1, 10 ** 5])
112 112 def test_output_with_no_input_does_not_fail(size, environ):
113 113 data = 'X'
114 114 args = _get_python_args('sys.stdout.write("%s" * %d); sys.exit(1)' % (data, size))
115 115 output = ''.join(
116 116 subprocessio.SubprocessIOChunker(
117 117 args, shell=False, fail_on_return_code=False, env=environ
118 118 )
119 119 )
120 120
121 print("{} {}".format(len(data * size), len(output)))
121 print(("{} {}".format(len(data * size), len(output))))
122 122 assert output == data * size
123 123
124 124
125 125 @pytest.mark.parametrize('size', [1, 10 ** 5])
126 126 def test_output_with_input(size, environ):
127 127 data_len = size
128 128 inputstream = KindaFilelike('X', size)
129 129
130 130 # This acts like the cat command.
131 131 args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 132 output = ''.join(
133 133 subprocessio.SubprocessIOChunker(
134 134 args, shell=False, inputstream=inputstream, env=environ
135 135 )
136 136 )
137 137
138 138 assert len(output) == data_len
139 139
140 140
141 141 @pytest.mark.parametrize('size', [1, 10 ** 5])
142 142 def test_output_with_input_skipping_iterator(size, environ):
143 143 data_len = size
144 144 inputstream = KindaFilelike('X', size)
145 145
146 146 # This acts like the cat command.
147 147 args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
148 148
149 149 # Note: assigning the chunker makes sure that it is not deleted too early
150 150 chunker = subprocessio.SubprocessIOChunker(
151 151 args, shell=False, inputstream=inputstream, env=environ
152 152 )
153 153 output = ''.join(chunker.output)
154 154
155 155 assert len(output) == data_len
General Comments 0
You need to be logged in to leave comments. Login now