##// END OF EJS Templates
packages: move the str utils to it's own module
super-admin -
r1060:680d7e36 python3
parent child Browse files
Show More
@@ -0,0 +1,127 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
18 import logging
19
20
21 log = logging.getLogger(__name__)
22
23
24 def safe_int(val, default=None) -> int:
25 """
26 Returns int() of val if val is not convertable to int use default
27 instead
28
29 :param val:
30 :param default:
31 """
32
33 try:
34 val = int(val)
35 except (ValueError, TypeError):
36 val = default
37
38 return val
39
40
41 def safe_str(str_, to_encoding=None) -> str:
42 """
43 safe str function. Does few trick to turn unicode_ into string
44
45 :param str_: str to encode
46 :param to_encoding: encode to this type UTF8 default
47 :rtype: str
48 :returns: str object
49 """
50 if isinstance(str_, str):
51 return str_
52
53 # if it's bytes cast to str
54 if not isinstance(str_, bytes):
55 return str(str_)
56
57 to_encoding = to_encoding or ['utf8']
58 if not isinstance(to_encoding, (list, tuple)):
59 to_encoding = [to_encoding]
60
61 for enc in to_encoding:
62 try:
63 return str(str_, enc)
64 except UnicodeDecodeError:
65 pass
66
67 return str(str_, to_encoding[0], 'replace')
68
69
70 def safe_bytes(str_, from_encoding=None) -> bytes:
71 """
72 safe bytes function. Does few trick to turn str_ into bytes string:
73
74 :param str_: string to decode
75 :param from_encoding: encode from this type UTF8 default
76 :rtype: unicode
77 :returns: unicode object
78 """
79 if isinstance(str_, bytes):
80 return str_
81
82 if not isinstance(str_, str):
83 raise ValueError('safe_bytes cannot convert other types than str: got: {}'.format(type(str_)))
84
85 from_encoding = from_encoding or ['utf8']
86 if not isinstance(from_encoding, (list, tuple)):
87 from_encoding = [from_encoding]
88
89 for enc in from_encoding:
90 try:
91 return str_.encode(enc)
92 except UnicodeDecodeError:
93 pass
94
95 return str_.encode(from_encoding[0], 'replace')
96
97
98 def ascii_bytes(str_, allow_bytes=False) -> bytes:
99 """
100 Simple conversion from str to bytes, with assumption that str_ is pure ASCII.
101 Fails with UnicodeError on invalid input.
102 This should be used where encoding and "safe" ambiguity should be avoided.
103 Where strings already have been encoded in other ways but still are unicode
104 string - for example to hex, base64, json, urlencoding, or are known to be
105 identifiers.
106 """
107 if allow_bytes and isinstance(str_, bytes):
108 return str_
109
110 if not isinstance(str_, str):
111 raise ValueError('ascii_bytes cannot convert other types than str: got: {}'.format(type(str_)))
112 return str_.encode('ascii')
113
114
115 def ascii_str(str_):
116 """
117 Simple conversion from bytes to str, with assumption that str_ is pure ASCII.
118 Fails with UnicodeError on invalid input.
119 This should be used where encoding and "safe" ambiguity should be avoided.
120 Where strings are encoded but also in other ways are known to be ASCII, and
121 where a unicode string is wanted without caring about encoding. For example
122 to hex, base64, urlencoding, or are known to be identifiers.
123 """
124
125 if not isinstance(str_, bytes):
126 raise ValueError('ascii_str cannot convert other types than bytes: got: {}'.format(type(str_)))
127 return str_.decode('ascii')
@@ -1,292 +1,292 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import re
19 19 import logging
20 20 from wsgiref.util import FileWrapper
21 21
22 22 from pyramid.config import Configurator
23 23 from pyramid.response import Response, FileIter
24 24 from pyramid.httpexceptions import (
25 25 HTTPBadRequest, HTTPNotImplemented, HTTPNotFound, HTTPForbidden,
26 26 HTTPUnprocessableEntity)
27 27
28 28 from vcsserver.lib.rc_json import json
29 29 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
30 30 from vcsserver.git_lfs.utils import safe_result, get_cython_compat_decorator
31 from vcsserver.utils import safe_int
31 from vcsserver.str_utils import safe_int
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 36 GIT_LFS_CONTENT_TYPE = 'application/vnd.git-lfs' #+json ?
37 37 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
38 38
39 39
40 40 def write_response_error(http_exception, text=None):
41 41 content_type = GIT_LFS_CONTENT_TYPE + '+json'
42 42 _exception = http_exception(content_type=content_type)
43 43 _exception.content_type = content_type
44 44 if text:
45 45 _exception.text = json.dumps({'message': text})
46 46 log.debug('LFS: writing response of type %s to client with text:%s',
47 47 http_exception, text)
48 48 return _exception
49 49
50 50
51 51 class AuthHeaderRequired(object):
52 52 """
53 53 Decorator to check if request has proper auth-header
54 54 """
55 55
56 56 def __call__(self, func):
57 57 return get_cython_compat_decorator(self.__wrapper, func)
58 58
59 59 def __wrapper(self, func, *fargs, **fkwargs):
60 60 request = fargs[1]
61 61 auth = request.authorization
62 62 if not auth:
63 63 return write_response_error(HTTPForbidden)
64 64 return func(*fargs[1:], **fkwargs)
65 65
66 66
67 67 # views
68 68
69 69 def lfs_objects(request):
70 70 # indicate not supported, V1 API
71 71 log.warning('LFS: v1 api not supported, reporting it back to client')
72 72 return write_response_error(HTTPNotImplemented, 'LFS: v1 api not supported')
73 73
74 74
75 75 @AuthHeaderRequired()
76 76 def lfs_objects_batch(request):
77 77 """
78 78 The client sends the following information to the Batch endpoint to transfer some objects:
79 79
80 80 operation - Should be download or upload.
81 81 transfers - An optional Array of String identifiers for transfer
82 82 adapters that the client has configured. If omitted, the basic
83 83 transfer adapter MUST be assumed by the server.
84 84 objects - An Array of objects to download.
85 85 oid - String OID of the LFS object.
86 86 size - Integer byte size of the LFS object. Must be at least zero.
87 87 """
88 88 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
89 89 auth = request.authorization
90 90 repo = request.matchdict.get('repo')
91 91 data = request.json
92 92 operation = data.get('operation')
93 93 http_scheme = request.registry.git_lfs_http_scheme
94 94
95 95 if operation not in ('download', 'upload'):
96 96 log.debug('LFS: unsupported operation:%s', operation)
97 97 return write_response_error(
98 98 HTTPBadRequest, 'unsupported operation mode: `%s`' % operation)
99 99
100 100 if 'objects' not in data:
101 101 log.debug('LFS: missing objects data')
102 102 return write_response_error(
103 103 HTTPBadRequest, 'missing objects data')
104 104
105 105 log.debug('LFS: handling operation of type: %s', operation)
106 106
107 107 objects = []
108 108 for o in data['objects']:
109 109 try:
110 110 oid = o['oid']
111 111 obj_size = o['size']
112 112 except KeyError:
113 113 log.exception('LFS, failed to extract data')
114 114 return write_response_error(
115 115 HTTPBadRequest, 'unsupported data in objects')
116 116
117 117 obj_data = {'oid': oid}
118 118
119 119 obj_href = request.route_url('lfs_objects_oid', repo=repo, oid=oid,
120 120 _scheme=http_scheme)
121 121 obj_verify_href = request.route_url('lfs_objects_verify', repo=repo,
122 122 _scheme=http_scheme)
123 123 store = LFSOidStore(
124 124 oid, repo, store_location=request.registry.git_lfs_store_path)
125 125 handler = OidHandler(
126 126 store, repo, auth, oid, obj_size, obj_data,
127 127 obj_href, obj_verify_href)
128 128
129 129 # this verifies also OIDs
130 130 actions, errors = handler.exec_operation(operation)
131 131 if errors:
132 132 log.warning('LFS: got following errors: %s', errors)
133 133 obj_data['errors'] = errors
134 134
135 135 if actions:
136 136 obj_data['actions'] = actions
137 137
138 138 obj_data['size'] = obj_size
139 139 obj_data['authenticated'] = True
140 140 objects.append(obj_data)
141 141
142 142 result = {'objects': objects, 'transfer': 'basic'}
143 143 log.debug('LFS Response %s', safe_result(result))
144 144
145 145 return result
146 146
147 147
148 148 def lfs_objects_oid_upload(request):
149 149 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
150 150 repo = request.matchdict.get('repo')
151 151 oid = request.matchdict.get('oid')
152 152 store = LFSOidStore(
153 153 oid, repo, store_location=request.registry.git_lfs_store_path)
154 154 engine = store.get_engine(mode='wb')
155 155 log.debug('LFS: starting chunked write of LFS oid: %s to storage', oid)
156 156
157 157 body = request.environ['wsgi.input']
158 158
159 159 with engine as f:
160 160 blksize = 64 * 1024 # 64kb
161 161 while True:
162 162 # read in chunks as stream comes in from Gunicorn
163 163 # this is a specific Gunicorn support function.
164 164 # might work differently on waitress
165 165 chunk = body.read(blksize)
166 166 if not chunk:
167 167 break
168 168 f.write(chunk)
169 169
170 170 return {'upload': 'ok'}
171 171
172 172
173 173 def lfs_objects_oid_download(request):
174 174 repo = request.matchdict.get('repo')
175 175 oid = request.matchdict.get('oid')
176 176
177 177 store = LFSOidStore(
178 178 oid, repo, store_location=request.registry.git_lfs_store_path)
179 179 if not store.has_oid():
180 180 log.debug('LFS: oid %s does not exists in store', oid)
181 181 return write_response_error(
182 182 HTTPNotFound, 'requested file with oid `%s` not found in store' % oid)
183 183
184 184 # TODO(marcink): support range header ?
185 185 # Range: bytes=0-, `bytes=(\d+)\-.*`
186 186
187 187 f = open(store.oid_path, 'rb')
188 188 response = Response(
189 189 content_type='application/octet-stream', app_iter=FileIter(f))
190 190 response.headers.add('X-RC-LFS-Response-Oid', str(oid))
191 191 return response
192 192
193 193
194 194 def lfs_objects_verify(request):
195 195 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
196 196 repo = request.matchdict.get('repo')
197 197
198 198 data = request.json
199 199 oid = data.get('oid')
200 200 size = safe_int(data.get('size'))
201 201
202 202 if not (oid and size):
203 203 return write_response_error(
204 204 HTTPBadRequest, 'missing oid and size in request data')
205 205
206 206 store = LFSOidStore(
207 207 oid, repo, store_location=request.registry.git_lfs_store_path)
208 208 if not store.has_oid():
209 209 log.debug('LFS: oid %s does not exists in store', oid)
210 210 return write_response_error(
211 211 HTTPNotFound, 'oid `%s` does not exists in store' % oid)
212 212
213 213 store_size = store.size_oid()
214 214 if store_size != size:
215 215 msg = 'requested file size mismatch store size:%s requested:%s' % (
216 216 store_size, size)
217 217 return write_response_error(
218 218 HTTPUnprocessableEntity, msg)
219 219
220 220 return {'message': {'size': 'ok', 'in_store': 'ok'}}
221 221
222 222
223 223 def lfs_objects_lock(request):
224 224 return write_response_error(
225 225 HTTPNotImplemented, 'GIT LFS locking api not supported')
226 226
227 227
228 228 def not_found(request):
229 229 return write_response_error(
230 230 HTTPNotFound, 'request path not found')
231 231
232 232
233 233 def lfs_disabled(request):
234 234 return write_response_error(
235 235 HTTPNotImplemented, 'GIT LFS disabled for this repo')
236 236
237 237
238 238 def git_lfs_app(config):
239 239
240 240 # v1 API deprecation endpoint
241 241 config.add_route('lfs_objects',
242 242 '/{repo:.*?[^/]}/info/lfs/objects')
243 243 config.add_view(lfs_objects, route_name='lfs_objects',
244 244 request_method='POST', renderer='json')
245 245
246 246 # locking API
247 247 config.add_route('lfs_objects_lock',
248 248 '/{repo:.*?[^/]}/info/lfs/locks')
249 249 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock',
250 250 request_method=('POST', 'GET'), renderer='json')
251 251
252 252 config.add_route('lfs_objects_lock_verify',
253 253 '/{repo:.*?[^/]}/info/lfs/locks/verify')
254 254 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock_verify',
255 255 request_method=('POST', 'GET'), renderer='json')
256 256
257 257 # batch API
258 258 config.add_route('lfs_objects_batch',
259 259 '/{repo:.*?[^/]}/info/lfs/objects/batch')
260 260 config.add_view(lfs_objects_batch, route_name='lfs_objects_batch',
261 261 request_method='POST', renderer='json')
262 262
263 263 # oid upload/download API
264 264 config.add_route('lfs_objects_oid',
265 265 '/{repo:.*?[^/]}/info/lfs/objects/{oid}')
266 266 config.add_view(lfs_objects_oid_upload, route_name='lfs_objects_oid',
267 267 request_method='PUT', renderer='json')
268 268 config.add_view(lfs_objects_oid_download, route_name='lfs_objects_oid',
269 269 request_method='GET', renderer='json')
270 270
271 271 # verification API
272 272 config.add_route('lfs_objects_verify',
273 273 '/{repo:.*?[^/]}/info/lfs/verify')
274 274 config.add_view(lfs_objects_verify, route_name='lfs_objects_verify',
275 275 request_method='POST', renderer='json')
276 276
277 277 # not found handler for API
278 278 config.add_notfound_view(not_found, renderer='json')
279 279
280 280
281 281 def create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
282 282 config = Configurator()
283 283 if git_lfs_enabled:
284 284 config.include(git_lfs_app)
285 285 config.registry.git_lfs_store_path = git_lfs_store_path
286 286 config.registry.git_lfs_http_scheme = git_lfs_http_scheme
287 287 else:
288 288 # not found handler for API, reporting disabled LFS support
289 289 config.add_notfound_view(lfs_disabled, renderer='json')
290 290
291 291 app = config.make_wsgi_app()
292 292 return app
@@ -1,273 +1,273 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 20 from webtest.app import TestApp as WebObTestApp
21 21
22 22 from vcsserver.lib.rc_json import json
23 from vcsserver.utils import safe_bytes
23 from vcsserver.str_utils import safe_bytes
24 24 from vcsserver.git_lfs.app import create_app
25 25
26 26
27 27 @pytest.fixture(scope='function')
28 28 def git_lfs_app(tmpdir):
29 29 custom_app = WebObTestApp(create_app(
30 30 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
31 31 git_lfs_http_scheme='http'))
32 32 custom_app._store = str(tmpdir)
33 33 return custom_app
34 34
35 35
36 36 @pytest.fixture(scope='function')
37 37 def git_lfs_https_app(tmpdir):
38 38 custom_app = WebObTestApp(create_app(
39 39 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
40 40 git_lfs_http_scheme='https'))
41 41 custom_app._store = str(tmpdir)
42 42 return custom_app
43 43
44 44
45 45 @pytest.fixture()
46 46 def http_auth():
47 47 return {'HTTP_AUTHORIZATION': "Basic XXXXX"}
48 48
49 49
50 50 class TestLFSApplication(object):
51 51
52 52 def test_app_wrong_path(self, git_lfs_app):
53 53 git_lfs_app.get('/repo/info/lfs/xxx', status=404)
54 54
55 55 def test_app_deprecated_endpoint(self, git_lfs_app):
56 56 response = git_lfs_app.post('/repo/info/lfs/objects', status=501)
57 57 assert response.status_code == 501
58 58 assert json.loads(response.text) == {'message': 'LFS: v1 api not supported'}
59 59
60 60 def test_app_lock_verify_api_not_available(self, git_lfs_app):
61 61 response = git_lfs_app.post('/repo/info/lfs/locks/verify', status=501)
62 62 assert response.status_code == 501
63 63 assert json.loads(response.text) == {
64 64 'message': 'GIT LFS locking api not supported'}
65 65
66 66 def test_app_lock_api_not_available(self, git_lfs_app):
67 67 response = git_lfs_app.post('/repo/info/lfs/locks', status=501)
68 68 assert response.status_code == 501
69 69 assert json.loads(response.text) == {
70 70 'message': 'GIT LFS locking api not supported'}
71 71
72 72 def test_app_batch_api_missing_auth(self, git_lfs_app):
73 73 git_lfs_app.post_json(
74 74 '/repo/info/lfs/objects/batch', params={}, status=403)
75 75
76 76 def test_app_batch_api_unsupported_operation(self, git_lfs_app, http_auth):
77 77 response = git_lfs_app.post_json(
78 78 '/repo/info/lfs/objects/batch', params={}, status=400,
79 79 extra_environ=http_auth)
80 80 assert json.loads(response.text) == {
81 81 'message': 'unsupported operation mode: `None`'}
82 82
83 83 def test_app_batch_api_missing_objects(self, git_lfs_app, http_auth):
84 84 response = git_lfs_app.post_json(
85 85 '/repo/info/lfs/objects/batch', params={'operation': 'download'},
86 86 status=400, extra_environ=http_auth)
87 87 assert json.loads(response.text) == {
88 88 'message': 'missing objects data'}
89 89
90 90 def test_app_batch_api_unsupported_data_in_objects(
91 91 self, git_lfs_app, http_auth):
92 92 params = {'operation': 'download',
93 93 'objects': [{}]}
94 94 response = git_lfs_app.post_json(
95 95 '/repo/info/lfs/objects/batch', params=params, status=400,
96 96 extra_environ=http_auth)
97 97 assert json.loads(response.text) == {
98 98 'message': 'unsupported data in objects'}
99 99
100 100 def test_app_batch_api_download_missing_object(
101 101 self, git_lfs_app, http_auth):
102 102 params = {'operation': 'download',
103 103 'objects': [{'oid': '123', 'size': '1024'}]}
104 104 response = git_lfs_app.post_json(
105 105 '/repo/info/lfs/objects/batch', params=params,
106 106 extra_environ=http_auth)
107 107
108 108 expected_objects = [
109 109 {'authenticated': True,
110 110 'errors': {'error': {
111 111 'code': 404,
112 112 'message': 'object: 123 does not exist in store'}},
113 113 'oid': '123',
114 114 'size': '1024'}
115 115 ]
116 116 assert json.loads(response.text) == {
117 117 'objects': expected_objects, 'transfer': 'basic'}
118 118
119 119 def test_app_batch_api_download(self, git_lfs_app, http_auth):
120 120 oid = '456'
121 121 oid_path = os.path.join(git_lfs_app._store, oid)
122 122 if not os.path.isdir(os.path.dirname(oid_path)):
123 123 os.makedirs(os.path.dirname(oid_path))
124 124 with open(oid_path, 'wb') as f:
125 125 f.write(safe_bytes('OID_CONTENT'))
126 126
127 127 params = {'operation': 'download',
128 128 'objects': [{'oid': oid, 'size': '1024'}]}
129 129 response = git_lfs_app.post_json(
130 130 '/repo/info/lfs/objects/batch', params=params,
131 131 extra_environ=http_auth)
132 132
133 133 expected_objects = [
134 134 {'authenticated': True,
135 135 'actions': {
136 136 'download': {
137 137 'header': {'Authorization': 'Basic XXXXX'},
138 138 'href': 'http://localhost/repo/info/lfs/objects/456'},
139 139 },
140 140 'oid': '456',
141 141 'size': '1024'}
142 142 ]
143 143 assert json.loads(response.text) == {
144 144 'objects': expected_objects, 'transfer': 'basic'}
145 145
146 146 def test_app_batch_api_upload(self, git_lfs_app, http_auth):
147 147 params = {'operation': 'upload',
148 148 'objects': [{'oid': '123', 'size': '1024'}]}
149 149 response = git_lfs_app.post_json(
150 150 '/repo/info/lfs/objects/batch', params=params,
151 151 extra_environ=http_auth)
152 152 expected_objects = [
153 153 {'authenticated': True,
154 154 'actions': {
155 155 'upload': {
156 156 'header': {'Authorization': 'Basic XXXXX',
157 157 'Transfer-Encoding': 'chunked'},
158 158 'href': 'http://localhost/repo/info/lfs/objects/123'},
159 159 'verify': {
160 160 'header': {'Authorization': 'Basic XXXXX'},
161 161 'href': 'http://localhost/repo/info/lfs/verify'}
162 162 },
163 163 'oid': '123',
164 164 'size': '1024'}
165 165 ]
166 166 assert json.loads(response.text) == {
167 167 'objects': expected_objects, 'transfer': 'basic'}
168 168
169 169 def test_app_batch_api_upload_for_https(self, git_lfs_https_app, http_auth):
170 170 params = {'operation': 'upload',
171 171 'objects': [{'oid': '123', 'size': '1024'}]}
172 172 response = git_lfs_https_app.post_json(
173 173 '/repo/info/lfs/objects/batch', params=params,
174 174 extra_environ=http_auth)
175 175 expected_objects = [
176 176 {'authenticated': True,
177 177 'actions': {
178 178 'upload': {
179 179 'header': {'Authorization': 'Basic XXXXX',
180 180 'Transfer-Encoding': 'chunked'},
181 181 'href': 'https://localhost/repo/info/lfs/objects/123'},
182 182 'verify': {
183 183 'header': {'Authorization': 'Basic XXXXX'},
184 184 'href': 'https://localhost/repo/info/lfs/verify'}
185 185 },
186 186 'oid': '123',
187 187 'size': '1024'}
188 188 ]
189 189 assert json.loads(response.text) == {
190 190 'objects': expected_objects, 'transfer': 'basic'}
191 191
192 192 def test_app_verify_api_missing_data(self, git_lfs_app):
193 193 params = {'oid': 'missing'}
194 194 response = git_lfs_app.post_json(
195 195 '/repo/info/lfs/verify', params=params,
196 196 status=400)
197 197
198 198 assert json.loads(response.text) == {
199 199 'message': 'missing oid and size in request data'}
200 200
201 201 def test_app_verify_api_missing_obj(self, git_lfs_app):
202 202 params = {'oid': 'missing', 'size': '1024'}
203 203 response = git_lfs_app.post_json(
204 204 '/repo/info/lfs/verify', params=params,
205 205 status=404)
206 206
207 207 assert json.loads(response.text) == {
208 208 'message': 'oid `missing` does not exists in store'}
209 209
210 210 def test_app_verify_api_size_mismatch(self, git_lfs_app):
211 211 oid = 'existing'
212 212 oid_path = os.path.join(git_lfs_app._store, oid)
213 213 if not os.path.isdir(os.path.dirname(oid_path)):
214 214 os.makedirs(os.path.dirname(oid_path))
215 215 with open(oid_path, 'wb') as f:
216 216 f.write(safe_bytes('OID_CONTENT'))
217 217
218 218 params = {'oid': oid, 'size': '1024'}
219 219 response = git_lfs_app.post_json(
220 220 '/repo/info/lfs/verify', params=params, status=422)
221 221
222 222 assert json.loads(response.text) == {
223 223 'message': 'requested file size mismatch '
224 224 'store size:11 requested:1024'}
225 225
226 226 def test_app_verify_api(self, git_lfs_app):
227 227 oid = 'existing'
228 228 oid_path = os.path.join(git_lfs_app._store, oid)
229 229 if not os.path.isdir(os.path.dirname(oid_path)):
230 230 os.makedirs(os.path.dirname(oid_path))
231 231 with open(oid_path, 'wb') as f:
232 232 f.write(safe_bytes('OID_CONTENT'))
233 233
234 234 params = {'oid': oid, 'size': 11}
235 235 response = git_lfs_app.post_json(
236 236 '/repo/info/lfs/verify', params=params)
237 237
238 238 assert json.loads(response.text) == {
239 239 'message': {'size': 'ok', 'in_store': 'ok'}}
240 240
241 241 def test_app_download_api_oid_not_existing(self, git_lfs_app):
242 242 oid = 'missing'
243 243
244 244 response = git_lfs_app.get(
245 245 '/repo/info/lfs/objects/{oid}'.format(oid=oid), status=404)
246 246
247 247 assert json.loads(response.text) == {
248 248 'message': 'requested file with oid `missing` not found in store'}
249 249
250 250 def test_app_download_api(self, git_lfs_app):
251 251 oid = 'existing'
252 252 oid_path = os.path.join(git_lfs_app._store, oid)
253 253 if not os.path.isdir(os.path.dirname(oid_path)):
254 254 os.makedirs(os.path.dirname(oid_path))
255 255 with open(oid_path, 'wb') as f:
256 256 f.write(safe_bytes('OID_CONTENT'))
257 257
258 258 response = git_lfs_app.get(
259 259 '/repo/info/lfs/objects/{oid}'.format(oid=oid))
260 260 assert response
261 261
262 262 def test_app_upload(self, git_lfs_app):
263 263 oid = 'uploaded'
264 264
265 265 response = git_lfs_app.put(
266 266 '/repo/info/lfs/objects/{oid}'.format(oid=oid), params='CONTENT')
267 267
268 268 assert json.loads(response.text) == {'upload': 'ok'}
269 269
270 270 # verify that we actually wrote that OID
271 271 oid_path = os.path.join(git_lfs_app._store, oid)
272 272 assert os.path.isfile(oid_path)
273 273 assert 'CONTENT' == open(oid_path).read()
@@ -1,142 +1,142 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 from vcsserver.utils import safe_bytes
20 from vcsserver.str_utils import safe_bytes
21 21 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
22 22
23 23
24 24 @pytest.fixture()
25 25 def lfs_store(tmpdir):
26 26 repo = 'test'
27 27 oid = '123456789'
28 28 store = LFSOidStore(oid=oid, repo=repo, store_location=str(tmpdir))
29 29 return store
30 30
31 31
32 32 @pytest.fixture()
33 33 def oid_handler(lfs_store):
34 34 store = lfs_store
35 35 repo = store.repo
36 36 oid = store.oid
37 37
38 38 oid_handler = OidHandler(
39 39 store=store, repo_name=repo, auth=('basic', 'xxxx'),
40 40 oid=oid,
41 41 obj_size='1024', obj_data={}, obj_href='http://localhost/handle_oid',
42 42 obj_verify_href='http://localhost/verify')
43 43 return oid_handler
44 44
45 45
46 46 class TestOidHandler(object):
47 47
48 48 @pytest.mark.parametrize('exec_action', [
49 49 'download',
50 50 'upload',
51 51 ])
52 52 def test_exec_action(self, exec_action, oid_handler):
53 53 handler = oid_handler.exec_operation(exec_action)
54 54 assert handler
55 55
56 56 def test_exec_action_undefined(self, oid_handler):
57 57 with pytest.raises(AttributeError):
58 58 oid_handler.exec_operation('wrong')
59 59
60 60 def test_download_oid_not_existing(self, oid_handler):
61 61 response, has_errors = oid_handler.exec_operation('download')
62 62
63 63 assert response is None
64 64 assert has_errors['error'] == {
65 65 'code': 404,
66 66 'message': 'object: 123456789 does not exist in store'}
67 67
68 68 def test_download_oid(self, oid_handler):
69 69 store = oid_handler.get_store()
70 70 if not os.path.isdir(os.path.dirname(store.oid_path)):
71 71 os.makedirs(os.path.dirname(store.oid_path))
72 72
73 73 with open(store.oid_path, 'wb') as f:
74 74 f.write(safe_bytes('CONTENT'))
75 75
76 76 response, has_errors = oid_handler.exec_operation('download')
77 77
78 78 assert has_errors is None
79 79 assert response['download'] == {
80 80 'header': {'Authorization': 'basic xxxx'},
81 81 'href': 'http://localhost/handle_oid'
82 82 }
83 83
84 84 def test_upload_oid_that_exists(self, oid_handler):
85 85 store = oid_handler.get_store()
86 86 if not os.path.isdir(os.path.dirname(store.oid_path)):
87 87 os.makedirs(os.path.dirname(store.oid_path))
88 88
89 89 with open(store.oid_path, 'wb') as f:
90 90 f.write(safe_bytes('CONTENT'))
91 91 oid_handler.obj_size = 7
92 92 response, has_errors = oid_handler.exec_operation('upload')
93 93 assert has_errors is None
94 94 assert response is None
95 95
96 96 def test_upload_oid_that_exists_but_has_wrong_size(self, oid_handler):
97 97 store = oid_handler.get_store()
98 98 if not os.path.isdir(os.path.dirname(store.oid_path)):
99 99 os.makedirs(os.path.dirname(store.oid_path))
100 100
101 101 with open(store.oid_path, 'wb') as f:
102 102 f.write(safe_bytes('CONTENT'))
103 103
104 104 oid_handler.obj_size = 10240
105 105 response, has_errors = oid_handler.exec_operation('upload')
106 106 assert has_errors is None
107 107 assert response['upload'] == {
108 108 'header': {'Authorization': 'basic xxxx',
109 109 'Transfer-Encoding': 'chunked'},
110 110 'href': 'http://localhost/handle_oid',
111 111 }
112 112
113 113 def test_upload_oid(self, oid_handler):
114 114 response, has_errors = oid_handler.exec_operation('upload')
115 115 assert has_errors is None
116 116 assert response['upload'] == {
117 117 'header': {'Authorization': 'basic xxxx',
118 118 'Transfer-Encoding': 'chunked'},
119 119 'href': 'http://localhost/handle_oid'
120 120 }
121 121
122 122
123 123 class TestLFSStore(object):
124 124 def test_write_oid(self, lfs_store):
125 125 oid_location = lfs_store.oid_path
126 126
127 127 assert not os.path.isfile(oid_location)
128 128
129 129 engine = lfs_store.get_engine(mode='wb')
130 130 with engine as f:
131 131 f.write(safe_bytes('CONTENT'))
132 132
133 133 assert os.path.isfile(oid_location)
134 134
135 135 def test_detect_has_oid(self, lfs_store):
136 136
137 137 assert lfs_store.has_oid() is False
138 138 engine = lfs_store.get_engine(mode='wb')
139 139 with engine as f:
140 140 f.write(safe_bytes('CONTENT'))
141 141
142 142 assert lfs_store.has_oid() is True
@@ -1,87 +1,87 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """
19 19 Mercurial libs compatibility
20 20 """
21 21
22 22 import mercurial
23 23 from mercurial import demandimport
24 24
25 25 # patch demandimport, due to bug in mercurial when it always triggers
26 26 # demandimport.enable()
27 from vcsserver.utils import safe_bytes
27 from vcsserver.str_utils import safe_bytes
28 28
29 29 demandimport.enable = lambda *args, **kwargs: 1
30 30
31 31 from mercurial import ui
32 32 from mercurial import patch
33 33 from mercurial import config
34 34 from mercurial import extensions
35 35 from mercurial import scmutil
36 36 from mercurial import archival
37 37 from mercurial import discovery
38 38 from mercurial import unionrepo
39 39 from mercurial import localrepo
40 40 from mercurial import merge as hg_merge
41 41 from mercurial import subrepo
42 42 from mercurial import subrepoutil
43 43 from mercurial import tags as hg_tag
44 44 from mercurial import util as hgutil
45 45 from mercurial.commands import clone, pull
46 46 from mercurial.node import nullid
47 47 from mercurial.context import memctx, memfilectx
48 48 from mercurial.error import (
49 49 LookupError, RepoError, RepoLookupError, Abort, InterventionRequired,
50 50 RequirementError, ProgrammingError)
51 51 from mercurial.hgweb import hgweb_mod
52 52 from mercurial.localrepo import instance
53 53 from mercurial.match import match, alwaysmatcher, patternmatcher
54 54 from mercurial.mdiff import diffopts
55 55 from mercurial.node import bin, hex
56 56 from mercurial.encoding import tolocal
57 57 from mercurial.discovery import findcommonoutgoing
58 58 from mercurial.hg import peer
59 59 from mercurial.httppeer import makepeer
60 60 from mercurial.utils.urlutil import url as hg_url
61 61 from mercurial.scmutil import revrange, revsymbol
62 62 from mercurial.node import nullrev
63 63 from mercurial import exchange
64 64 from hgext import largefiles
65 65
66 66 # those authnadlers are patched for python 2.6.5 bug an
67 67 # infinit looping when given invalid resources
68 68 from mercurial.url import httpbasicauthhandler, httpdigestauthhandler
69 69
70 70 # hg strip is in core now
71 71 from mercurial import strip as hgext_strip
72 72
73 73
74 74 def get_ctx(repo, ref):
75 75 ref = safe_bytes(ref)
76 76 try:
77 77 ctx = repo[ref]
78 78 except (ProgrammingError, TypeError):
79 79 # we're unable to find the rev using a regular lookup, we fallback
80 80 # to slower, but backward compat revsymbol usage
81 81 ctx = revsymbol(repo, ref)
82 82 except (LookupError, RepoLookupError):
83 83 # Similar case as above but only for refs that are not numeric
84 84 if isinstance(ref, int):
85 85 raise
86 86 ctx = revsymbol(repo, ref)
87 87 return ctx
@@ -1,204 +1,204 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2020 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20 import re
21 21 import os
22 22 import sys
23 23 import datetime
24 24 import logging
25 25 import pkg_resources
26 26
27 27 import vcsserver
28 from vcsserver.utils import safe_bytes
28 from vcsserver.str_utils import safe_bytes
29 29
30 30 log = logging.getLogger(__name__)
31 31
32 32
33 33 def get_git_hooks_path(repo_path, bare):
34 34 hooks_path = os.path.join(repo_path, 'hooks')
35 35 if not bare:
36 36 hooks_path = os.path.join(repo_path, '.git', 'hooks')
37 37
38 38 return hooks_path
39 39
40 40
41 41 def install_git_hooks(repo_path, bare, executable=None, force_create=False):
42 42 """
43 43 Creates a RhodeCode hook inside a git repository
44 44
45 45 :param repo_path: path to repository
46 46 :param executable: binary executable to put in the hooks
47 47 :param force_create: Create even if same name hook exists
48 48 """
49 49 executable = executable or sys.executable
50 50 hooks_path = get_git_hooks_path(repo_path, bare)
51 51
52 52 if not os.path.isdir(hooks_path):
53 53 os.makedirs(hooks_path, mode=0o777)
54 54
55 55 tmpl_post = pkg_resources.resource_string(
56 56 'vcsserver', '/'.join(
57 57 ('hook_utils', 'hook_templates', 'git_post_receive.py.tmpl')))
58 58 tmpl_pre = pkg_resources.resource_string(
59 59 'vcsserver', '/'.join(
60 60 ('hook_utils', 'hook_templates', 'git_pre_receive.py.tmpl')))
61 61
62 62 path = '' # not used for now
63 63 timestamp = datetime.datetime.utcnow().isoformat()
64 64
65 65 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
66 66 log.debug('Installing git hook in repo %s', repo_path)
67 67 _hook_file = os.path.join(hooks_path, '%s-receive' % h_type)
68 68 _rhodecode_hook = check_rhodecode_hook(_hook_file)
69 69
70 70 if _rhodecode_hook or force_create:
71 71 log.debug('writing git %s hook file at %s !', h_type, _hook_file)
72 72 try:
73 73 with open(_hook_file, 'wb') as f:
74 74 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
75 75 template = template.replace(b'_DATE_', safe_bytes(timestamp))
76 76 template = template.replace(b'_ENV_', safe_bytes(executable))
77 77 template = template.replace(b'_PATH_', safe_bytes(path))
78 78 f.write(template)
79 79 os.chmod(_hook_file, 0o755)
80 80 except IOError:
81 81 log.exception('error writing hook file %s', _hook_file)
82 82 else:
83 83 log.debug('skipping writing hook file')
84 84
85 85 return True
86 86
87 87
88 88 def get_svn_hooks_path(repo_path):
89 89 hooks_path = os.path.join(repo_path, 'hooks')
90 90
91 91 return hooks_path
92 92
93 93
94 94 def install_svn_hooks(repo_path, executable=None, force_create=False):
95 95 """
96 96 Creates RhodeCode hooks inside a svn repository
97 97
98 98 :param repo_path: path to repository
99 99 :param executable: binary executable to put in the hooks
100 100 :param force_create: Create even if same name hook exists
101 101 """
102 102 executable = executable or sys.executable
103 103 hooks_path = get_svn_hooks_path(repo_path)
104 104 if not os.path.isdir(hooks_path):
105 105 os.makedirs(hooks_path, mode=0o777)
106 106
107 107 tmpl_post = pkg_resources.resource_string(
108 108 'vcsserver', '/'.join(
109 109 ('hook_utils', 'hook_templates', 'svn_post_commit_hook.py.tmpl')))
110 110 tmpl_pre = pkg_resources.resource_string(
111 111 'vcsserver', '/'.join(
112 112 ('hook_utils', 'hook_templates', 'svn_pre_commit_hook.py.tmpl')))
113 113
114 114 path = '' # not used for now
115 115 timestamp = datetime.datetime.utcnow().isoformat()
116 116
117 117 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
118 118 log.debug('Installing svn hook in repo %s', repo_path)
119 119 _hook_file = os.path.join(hooks_path, '%s-commit' % h_type)
120 120 _rhodecode_hook = check_rhodecode_hook(_hook_file)
121 121
122 122 if _rhodecode_hook or force_create:
123 123 log.debug('writing svn %s hook file at %s !', h_type, _hook_file)
124 124
125 125 try:
126 126 with open(_hook_file, 'wb') as f:
127 127 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
128 128 template = template.replace(b'_DATE_', safe_bytes(timestamp))
129 129 template = template.replace(b'_ENV_', safe_bytes(executable))
130 130 template = template.replace(b'_PATH_', safe_bytes(path))
131 131
132 132 f.write(template)
133 133 os.chmod(_hook_file, 0o755)
134 134 except IOError:
135 135 log.exception('error writing hook file %s', _hook_file)
136 136 else:
137 137 log.debug('skipping writing hook file')
138 138
139 139 return True
140 140
141 141
142 142 def get_version_from_hook(hook_path):
143 143 version = b''
144 144 hook_content = read_hook_content(hook_path)
145 145 matches = re.search(rb'(?:RC_HOOK_VER)\s*=\s*(.*)', hook_content)
146 146 if matches:
147 147 try:
148 148 version = matches.groups()[0]
149 149 log.debug('got version %s from hooks.', version)
150 150 except Exception:
151 151 log.exception("Exception while reading the hook version.")
152 152 return version.replace(b"'", b"")
153 153
154 154
155 155 def check_rhodecode_hook(hook_path):
156 156 """
157 157 Check if the hook was created by RhodeCode
158 158 """
159 159 if not os.path.exists(hook_path):
160 160 return True
161 161
162 162 log.debug('hook exists, checking if it is from RhodeCode')
163 163
164 164 version = get_version_from_hook(hook_path)
165 165 if version:
166 166 return True
167 167
168 168 return False
169 169
170 170
171 171 def read_hook_content(hook_path):
172 172 content = ''
173 173 if os.path.isfile(hook_path):
174 174 with open(hook_path, 'rb') as f:
175 175 content = f.read()
176 176 return content
177 177
178 178
179 179 def get_git_pre_hook_version(repo_path, bare):
180 180 hooks_path = get_git_hooks_path(repo_path, bare)
181 181 _hook_file = os.path.join(hooks_path, 'pre-receive')
182 182 version = get_version_from_hook(_hook_file)
183 183 return version
184 184
185 185
186 186 def get_git_post_hook_version(repo_path, bare):
187 187 hooks_path = get_git_hooks_path(repo_path, bare)
188 188 _hook_file = os.path.join(hooks_path, 'post-receive')
189 189 version = get_version_from_hook(_hook_file)
190 190 return version
191 191
192 192
193 193 def get_svn_pre_hook_version(repo_path):
194 194 hooks_path = get_svn_hooks_path(repo_path)
195 195 _hook_file = os.path.join(hooks_path, 'pre-commit')
196 196 version = get_version_from_hook(_hook_file)
197 197 return version
198 198
199 199
200 200 def get_svn_post_hook_version(repo_path):
201 201 hooks_path = get_svn_hooks_path(repo_path)
202 202 _hook_file = os.path.join(hooks_path, 'post-commit')
203 203 version = get_version_from_hook(_hook_file)
204 204 return version
@@ -1,738 +1,738 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2020 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20 import io
21 21 import os
22 22 import sys
23 23 import logging
24 24 import collections
25 25 import importlib
26 26 import base64
27 27 import msgpack
28 28
29 29 from http.client import HTTPConnection
30 30
31 31
32 32 import mercurial.scmutil
33 33 import mercurial.node
34 34
35 35 from vcsserver.lib.rc_json import json
36 36 from vcsserver import exceptions, subprocessio, settings
37 from vcsserver.utils import safe_bytes
37 from vcsserver.str_utils import safe_bytes
38 38
39 39 log = logging.getLogger(__name__)
40 40
41 41
42 42 class HooksHttpClient(object):
43 43 proto = 'msgpack.v1'
44 44 connection = None
45 45
46 46 def __init__(self, hooks_uri):
47 47 self.hooks_uri = hooks_uri
48 48
49 49 def __call__(self, method, extras):
50 50 connection = HTTPConnection(self.hooks_uri)
51 51 # binary msgpack body
52 52 headers, body = self._serialize(method, extras)
53 53 try:
54 54 connection.request('POST', '/', body, headers)
55 55 except Exception as error:
56 56 log.error('Hooks calling Connection failed on %s, org error: %s', connection.__dict__, error)
57 57 raise
58 58 response = connection.getresponse()
59 59 try:
60 60 return msgpack.load(response, raw=False)
61 61 except Exception:
62 62 response_data = response.read()
63 63 log.exception('Failed to decode hook response json data. '
64 64 'response_code:%s, raw_data:%s',
65 65 response.status, response_data)
66 66 raise
67 67
68 68 @classmethod
69 69 def _serialize(cls, hook_name, extras):
70 70 data = {
71 71 'method': hook_name,
72 72 'extras': extras
73 73 }
74 74 headers = {
75 75 'rc-hooks-protocol': cls.proto
76 76 }
77 77 return headers, msgpack.packb(data)
78 78
79 79
80 80 class HooksDummyClient(object):
81 81 def __init__(self, hooks_module):
82 82 self._hooks_module = importlib.import_module(hooks_module)
83 83
84 84 def __call__(self, hook_name, extras):
85 85 with self._hooks_module.Hooks() as hooks:
86 86 return getattr(hooks, hook_name)(extras)
87 87
88 88
89 89 class HooksShadowRepoClient(object):
90 90
91 91 def __call__(self, hook_name, extras):
92 92 return {'output': '', 'status': 0}
93 93
94 94
95 95 class RemoteMessageWriter(object):
96 96 """Writer base class."""
97 97 def write(self, message):
98 98 raise NotImplementedError()
99 99
100 100
101 101 class HgMessageWriter(RemoteMessageWriter):
102 102 """Writer that knows how to send messages to mercurial clients."""
103 103
104 104 def __init__(self, ui):
105 105 self.ui = ui
106 106
107 107 def write(self, message):
108 108 # TODO: Check why the quiet flag is set by default.
109 109 old = self.ui.quiet
110 110 self.ui.quiet = False
111 111 self.ui.status(message.encode('utf-8'))
112 112 self.ui.quiet = old
113 113
114 114
115 115 class GitMessageWriter(RemoteMessageWriter):
116 116 """Writer that knows how to send messages to git clients."""
117 117
118 118 def __init__(self, stdout=None):
119 119 self.stdout = stdout or sys.stdout
120 120
121 121 def write(self, message):
122 122 self.stdout.write(safe_bytes(message))
123 123
124 124
125 125 class SvnMessageWriter(RemoteMessageWriter):
126 126 """Writer that knows how to send messages to svn clients."""
127 127
128 128 def __init__(self, stderr=None):
129 129 # SVN needs data sent to stderr for back-to-client messaging
130 130 self.stderr = stderr or sys.stderr
131 131
132 132 def write(self, message):
133 133 self.stderr.write(message.encode('utf-8'))
134 134
135 135
136 136 def _handle_exception(result):
137 137 exception_class = result.get('exception')
138 138 exception_traceback = result.get('exception_traceback')
139 139
140 140 if exception_traceback:
141 141 log.error('Got traceback from remote call:%s', exception_traceback)
142 142
143 143 if exception_class == 'HTTPLockedRC':
144 144 raise exceptions.RepositoryLockedException()(*result['exception_args'])
145 145 elif exception_class == 'HTTPBranchProtected':
146 146 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
147 147 elif exception_class == 'RepositoryError':
148 148 raise exceptions.VcsException()(*result['exception_args'])
149 149 elif exception_class:
150 150 raise Exception('Got remote exception "%s" with args "%s"' %
151 151 (exception_class, result['exception_args']))
152 152
153 153
154 154 def _get_hooks_client(extras):
155 155 hooks_uri = extras.get('hooks_uri')
156 156 is_shadow_repo = extras.get('is_shadow_repo')
157 157 if hooks_uri:
158 158 return HooksHttpClient(extras['hooks_uri'])
159 159 elif is_shadow_repo:
160 160 return HooksShadowRepoClient()
161 161 else:
162 162 return HooksDummyClient(extras['hooks_module'])
163 163
164 164
165 165 def _call_hook(hook_name, extras, writer):
166 166 hooks_client = _get_hooks_client(extras)
167 167 log.debug('Hooks, using client:%s', hooks_client)
168 168 result = hooks_client(hook_name, extras)
169 169 log.debug('Hooks got result: %s', result)
170 170
171 171 _handle_exception(result)
172 172 writer.write(result['output'])
173 173
174 174 return result['status']
175 175
176 176
177 177 def _extras_from_ui(ui):
178 178 hook_data = ui.config(b'rhodecode', b'RC_SCM_DATA')
179 179 if not hook_data:
180 180 # maybe it's inside environ ?
181 181 env_hook_data = os.environ.get('RC_SCM_DATA')
182 182 if env_hook_data:
183 183 hook_data = env_hook_data
184 184
185 185 extras = {}
186 186 if hook_data:
187 187 extras = json.loads(hook_data)
188 188 return extras
189 189
190 190
191 191 def _rev_range_hash(repo, node, check_heads=False):
192 192 from vcsserver.hgcompat import get_ctx
193 193
194 194 commits = []
195 195 revs = []
196 196 start = get_ctx(repo, node).rev()
197 197 end = len(repo)
198 198 for rev in range(start, end):
199 199 revs.append(rev)
200 200 ctx = get_ctx(repo, rev)
201 201 commit_id = mercurial.node.hex(ctx.node())
202 202 branch = ctx.branch()
203 203 commits.append((commit_id, branch))
204 204
205 205 parent_heads = []
206 206 if check_heads:
207 207 parent_heads = _check_heads(repo, start, end, revs)
208 208 return commits, parent_heads
209 209
210 210
211 211 def _check_heads(repo, start, end, commits):
212 212 from vcsserver.hgcompat import get_ctx
213 213 changelog = repo.changelog
214 214 parents = set()
215 215
216 216 for new_rev in commits:
217 217 for p in changelog.parentrevs(new_rev):
218 218 if p == mercurial.node.nullrev:
219 219 continue
220 220 if p < start:
221 221 parents.add(p)
222 222
223 223 for p in parents:
224 224 branch = get_ctx(repo, p).branch()
225 225 # The heads descending from that parent, on the same branch
226 226 parent_heads = set([p])
227 227 reachable = set([p])
228 228 for x in range(p + 1, end):
229 229 if get_ctx(repo, x).branch() != branch:
230 230 continue
231 231 for pp in changelog.parentrevs(x):
232 232 if pp in reachable:
233 233 reachable.add(x)
234 234 parent_heads.discard(pp)
235 235 parent_heads.add(x)
236 236 # More than one head? Suggest merging
237 237 if len(parent_heads) > 1:
238 238 return list(parent_heads)
239 239
240 240 return []
241 241
242 242
243 243 def _get_git_env():
244 244 env = {}
245 245 for k, v in os.environ.items():
246 246 if k.startswith('GIT'):
247 247 env[k] = v
248 248
249 249 # serialized version
250 250 return [(k, v) for k, v in env.items()]
251 251
252 252
253 253 def _get_hg_env(old_rev, new_rev, txnid, repo_path):
254 254 env = {}
255 255 for k, v in os.environ.items():
256 256 if k.startswith('HG'):
257 257 env[k] = v
258 258
259 259 env['HG_NODE'] = old_rev
260 260 env['HG_NODE_LAST'] = new_rev
261 261 env['HG_TXNID'] = txnid
262 262 env['HG_PENDING'] = repo_path
263 263
264 264 return [(k, v) for k, v in env.items()]
265 265
266 266
267 267 def repo_size(ui, repo, **kwargs):
268 268 extras = _extras_from_ui(ui)
269 269 return _call_hook('repo_size', extras, HgMessageWriter(ui))
270 270
271 271
272 272 def pre_pull(ui, repo, **kwargs):
273 273 extras = _extras_from_ui(ui)
274 274 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
275 275
276 276
277 277 def pre_pull_ssh(ui, repo, **kwargs):
278 278 extras = _extras_from_ui(ui)
279 279 if extras and extras.get('SSH'):
280 280 return pre_pull(ui, repo, **kwargs)
281 281 return 0
282 282
283 283
284 284 def post_pull(ui, repo, **kwargs):
285 285 extras = _extras_from_ui(ui)
286 286 return _call_hook('post_pull', extras, HgMessageWriter(ui))
287 287
288 288
289 289 def post_pull_ssh(ui, repo, **kwargs):
290 290 extras = _extras_from_ui(ui)
291 291 if extras and extras.get('SSH'):
292 292 return post_pull(ui, repo, **kwargs)
293 293 return 0
294 294
295 295
296 296 def pre_push(ui, repo, node=None, **kwargs):
297 297 """
298 298 Mercurial pre_push hook
299 299 """
300 300 extras = _extras_from_ui(ui)
301 301 detect_force_push = extras.get('detect_force_push')
302 302
303 303 rev_data = []
304 304 if node and kwargs.get('hooktype') == 'pretxnchangegroup':
305 305 branches = collections.defaultdict(list)
306 306 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
307 307 for commit_id, branch in commits:
308 308 branches[branch].append(commit_id)
309 309
310 310 for branch, commits in branches.items():
311 311 old_rev = kwargs.get('node_last') or commits[0]
312 312 rev_data.append({
313 313 'total_commits': len(commits),
314 314 'old_rev': old_rev,
315 315 'new_rev': commits[-1],
316 316 'ref': '',
317 317 'type': 'branch',
318 318 'name': branch,
319 319 })
320 320
321 321 for push_ref in rev_data:
322 322 push_ref['multiple_heads'] = _heads
323 323
324 324 repo_path = os.path.join(
325 325 extras.get('repo_store', ''), extras.get('repository', ''))
326 326 push_ref['hg_env'] = _get_hg_env(
327 327 old_rev=push_ref['old_rev'],
328 328 new_rev=push_ref['new_rev'], txnid=kwargs.get('txnid'),
329 329 repo_path=repo_path)
330 330
331 331 extras['hook_type'] = kwargs.get('hooktype', 'pre_push')
332 332 extras['commit_ids'] = rev_data
333 333
334 334 return _call_hook('pre_push', extras, HgMessageWriter(ui))
335 335
336 336
337 337 def pre_push_ssh(ui, repo, node=None, **kwargs):
338 338 extras = _extras_from_ui(ui)
339 339 if extras.get('SSH'):
340 340 return pre_push(ui, repo, node, **kwargs)
341 341
342 342 return 0
343 343
344 344
345 345 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
346 346 """
347 347 Mercurial pre_push hook for SSH
348 348 """
349 349 extras = _extras_from_ui(ui)
350 350 if extras.get('SSH'):
351 351 permission = extras['SSH_PERMISSIONS']
352 352
353 353 if 'repository.write' == permission or 'repository.admin' == permission:
354 354 return 0
355 355
356 356 # non-zero ret code
357 357 return 1
358 358
359 359 return 0
360 360
361 361
362 362 def post_push(ui, repo, node, **kwargs):
363 363 """
364 364 Mercurial post_push hook
365 365 """
366 366 extras = _extras_from_ui(ui)
367 367
368 368 commit_ids = []
369 369 branches = []
370 370 bookmarks = []
371 371 tags = []
372 372
373 373 commits, _heads = _rev_range_hash(repo, node)
374 374 for commit_id, branch in commits:
375 375 commit_ids.append(commit_id)
376 376 if branch not in branches:
377 377 branches.append(branch)
378 378
379 379 if hasattr(ui, '_rc_pushkey_branches'):
380 380 bookmarks = ui._rc_pushkey_branches
381 381
382 382 extras['hook_type'] = kwargs.get('hooktype', 'post_push')
383 383 extras['commit_ids'] = commit_ids
384 384 extras['new_refs'] = {
385 385 'branches': branches,
386 386 'bookmarks': bookmarks,
387 387 'tags': tags
388 388 }
389 389
390 390 return _call_hook('post_push', extras, HgMessageWriter(ui))
391 391
392 392
393 393 def post_push_ssh(ui, repo, node, **kwargs):
394 394 """
395 395 Mercurial post_push hook for SSH
396 396 """
397 397 if _extras_from_ui(ui).get('SSH'):
398 398 return post_push(ui, repo, node, **kwargs)
399 399 return 0
400 400
401 401
402 402 def key_push(ui, repo, **kwargs):
403 403 from vcsserver.hgcompat import get_ctx
404 404 if kwargs['new'] != '0' and kwargs['namespace'] == 'bookmarks':
405 405 # store new bookmarks in our UI object propagated later to post_push
406 406 ui._rc_pushkey_branches = get_ctx(repo, kwargs['key']).bookmarks()
407 407 return
408 408
409 409
410 410 # backward compat
411 411 log_pull_action = post_pull
412 412
413 413 # backward compat
414 414 log_push_action = post_push
415 415
416 416
417 417 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
418 418 """
419 419 Old hook name: keep here for backward compatibility.
420 420
421 421 This is only required when the installed git hooks are not upgraded.
422 422 """
423 423 pass
424 424
425 425
426 426 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
427 427 """
428 428 Old hook name: keep here for backward compatibility.
429 429
430 430 This is only required when the installed git hooks are not upgraded.
431 431 """
432 432 pass
433 433
434 434
435 435 HookResponse = collections.namedtuple('HookResponse', ('status', 'output'))
436 436
437 437
438 438 def git_pre_pull(extras):
439 439 """
440 440 Pre pull hook.
441 441
442 442 :param extras: dictionary containing the keys defined in simplevcs
443 443 :type extras: dict
444 444
445 445 :return: status code of the hook. 0 for success.
446 446 :rtype: int
447 447 """
448 448
449 449 if 'pull' not in extras['hooks']:
450 450 return HookResponse(0, '')
451 451
452 452 stdout = io.BytesIO()
453 453 try:
454 454 status = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
455 455
456 456 except Exception as error:
457 457 log.exception('Failed to call pre_pull hook')
458 458 status = 128
459 459 stdout.write(safe_bytes(f'ERROR: {error}\n'))
460 460
461 461 return HookResponse(status, stdout.getvalue())
462 462
463 463
464 464 def git_post_pull(extras):
465 465 """
466 466 Post pull hook.
467 467
468 468 :param extras: dictionary containing the keys defined in simplevcs
469 469 :type extras: dict
470 470
471 471 :return: status code of the hook. 0 for success.
472 472 :rtype: int
473 473 """
474 474 if 'pull' not in extras['hooks']:
475 475 return HookResponse(0, '')
476 476
477 477 stdout = io.BytesIO()
478 478 try:
479 479 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
480 480 except Exception as error:
481 481 status = 128
482 482 stdout.write(safe_bytes(f'ERROR: {error}\n'))
483 483
484 484 return HookResponse(status, stdout.getvalue())
485 485
486 486
487 487 def _parse_git_ref_lines(revision_lines):
488 488 rev_data = []
489 489 for revision_line in revision_lines or []:
490 490 old_rev, new_rev, ref = revision_line.strip().split(' ')
491 491 ref_data = ref.split('/', 2)
492 492 if ref_data[1] in ('tags', 'heads'):
493 493 rev_data.append({
494 494 # NOTE(marcink):
495 495 # we're unable to tell total_commits for git at this point
496 496 # but we set the variable for consistency with GIT
497 497 'total_commits': -1,
498 498 'old_rev': old_rev,
499 499 'new_rev': new_rev,
500 500 'ref': ref,
501 501 'type': ref_data[1],
502 502 'name': ref_data[2],
503 503 })
504 504 return rev_data
505 505
506 506
507 507 def git_pre_receive(unused_repo_path, revision_lines, env):
508 508 """
509 509 Pre push hook.
510 510
511 511 :param extras: dictionary containing the keys defined in simplevcs
512 512 :type extras: dict
513 513
514 514 :return: status code of the hook. 0 for success.
515 515 :rtype: int
516 516 """
517 517 extras = json.loads(env['RC_SCM_DATA'])
518 518 rev_data = _parse_git_ref_lines(revision_lines)
519 519 if 'push' not in extras['hooks']:
520 520 return 0
521 521 empty_commit_id = '0' * 40
522 522
523 523 detect_force_push = extras.get('detect_force_push')
524 524
525 525 for push_ref in rev_data:
526 526 # store our git-env which holds the temp store
527 527 push_ref['git_env'] = _get_git_env()
528 528 push_ref['pruned_sha'] = ''
529 529 if not detect_force_push:
530 530 # don't check for forced-push when we don't need to
531 531 continue
532 532
533 533 type_ = push_ref['type']
534 534 new_branch = push_ref['old_rev'] == empty_commit_id
535 535 delete_branch = push_ref['new_rev'] == empty_commit_id
536 536 if type_ == 'heads' and not (new_branch or delete_branch):
537 537 old_rev = push_ref['old_rev']
538 538 new_rev = push_ref['new_rev']
539 539 cmd = [settings.GIT_EXECUTABLE, 'rev-list', old_rev, '^{}'.format(new_rev)]
540 540 stdout, stderr = subprocessio.run_command(
541 541 cmd, env=os.environ.copy())
542 542 # means we're having some non-reachable objects, this forced push was used
543 543 if stdout:
544 544 push_ref['pruned_sha'] = stdout.splitlines()
545 545
546 546 extras['hook_type'] = 'pre_receive'
547 547 extras['commit_ids'] = rev_data
548 548 return _call_hook('pre_push', extras, GitMessageWriter())
549 549
550 550
551 551 def git_post_receive(unused_repo_path, revision_lines, env):
552 552 """
553 553 Post push hook.
554 554
555 555 :param extras: dictionary containing the keys defined in simplevcs
556 556 :type extras: dict
557 557
558 558 :return: status code of the hook. 0 for success.
559 559 :rtype: int
560 560 """
561 561 extras = json.loads(env['RC_SCM_DATA'])
562 562 if 'push' not in extras['hooks']:
563 563 return 0
564 564
565 565 rev_data = _parse_git_ref_lines(revision_lines)
566 566
567 567 git_revs = []
568 568
569 569 # N.B.(skreft): it is ok to just call git, as git before calling a
570 570 # subcommand sets the PATH environment variable so that it point to the
571 571 # correct version of the git executable.
572 572 empty_commit_id = '0' * 40
573 573 branches = []
574 574 tags = []
575 575 for push_ref in rev_data:
576 576 type_ = push_ref['type']
577 577
578 578 if type_ == 'heads':
579 579 if push_ref['old_rev'] == empty_commit_id:
580 580 # starting new branch case
581 581 if push_ref['name'] not in branches:
582 582 branches.append(push_ref['name'])
583 583
584 584 # Fix up head revision if needed
585 585 cmd = [settings.GIT_EXECUTABLE, 'show', 'HEAD']
586 586 try:
587 587 subprocessio.run_command(cmd, env=os.environ.copy())
588 588 except Exception:
589 589 cmd = [settings.GIT_EXECUTABLE, 'symbolic-ref', '"HEAD"',
590 590 '"refs/heads/%s"' % push_ref['name']]
591 591 print(("Setting default branch to %s" % push_ref['name']))
592 592 subprocessio.run_command(cmd, env=os.environ.copy())
593 593
594 594 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref',
595 595 '--format=%(refname)', 'refs/heads/*']
596 596 stdout, stderr = subprocessio.run_command(
597 597 cmd, env=os.environ.copy())
598 598 heads = stdout
599 599 heads = heads.replace(push_ref['ref'], '')
600 600 heads = ' '.join(head for head
601 601 in heads.splitlines() if head) or '.'
602 602 cmd = [settings.GIT_EXECUTABLE, 'log', '--reverse',
603 603 '--pretty=format:%H', '--', push_ref['new_rev'],
604 604 '--not', heads]
605 605 stdout, stderr = subprocessio.run_command(
606 606 cmd, env=os.environ.copy())
607 607 git_revs.extend(stdout.splitlines())
608 608 elif push_ref['new_rev'] == empty_commit_id:
609 609 # delete branch case
610 610 git_revs.append('delete_branch=>%s' % push_ref['name'])
611 611 else:
612 612 if push_ref['name'] not in branches:
613 613 branches.append(push_ref['name'])
614 614
615 615 cmd = [settings.GIT_EXECUTABLE, 'log',
616 616 '{old_rev}..{new_rev}'.format(**push_ref),
617 617 '--reverse', '--pretty=format:%H']
618 618 stdout, stderr = subprocessio.run_command(
619 619 cmd, env=os.environ.copy())
620 620 git_revs.extend(stdout.splitlines())
621 621 elif type_ == 'tags':
622 622 if push_ref['name'] not in tags:
623 623 tags.append(push_ref['name'])
624 624 git_revs.append('tag=>%s' % push_ref['name'])
625 625
626 626 extras['hook_type'] = 'post_receive'
627 627 extras['commit_ids'] = git_revs
628 628 extras['new_refs'] = {
629 629 'branches': branches,
630 630 'bookmarks': [],
631 631 'tags': tags,
632 632 }
633 633
634 634 if 'repo_size' in extras['hooks']:
635 635 try:
636 636 _call_hook('repo_size', extras, GitMessageWriter())
637 637 except:
638 638 pass
639 639
640 640 return _call_hook('post_push', extras, GitMessageWriter())
641 641
642 642
643 643 def _get_extras_from_txn_id(path, txn_id):
644 644 extras = {}
645 645 try:
646 646 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
647 647 '-t', txn_id,
648 648 '--revprop', path, 'rc-scm-extras']
649 649 stdout, stderr = subprocessio.run_command(
650 650 cmd, env=os.environ.copy())
651 651 extras = json.loads(base64.urlsafe_b64decode(stdout))
652 652 except Exception:
653 653 log.exception('Failed to extract extras info from txn_id')
654 654
655 655 return extras
656 656
657 657
658 658 def _get_extras_from_commit_id(commit_id, path):
659 659 extras = {}
660 660 try:
661 661 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
662 662 '-r', commit_id,
663 663 '--revprop', path, 'rc-scm-extras']
664 664 stdout, stderr = subprocessio.run_command(
665 665 cmd, env=os.environ.copy())
666 666 extras = json.loads(base64.urlsafe_b64decode(stdout))
667 667 except Exception:
668 668 log.exception('Failed to extract extras info from commit_id')
669 669
670 670 return extras
671 671
672 672
673 673 def svn_pre_commit(repo_path, commit_data, env):
674 674 path, txn_id = commit_data
675 675 branches = []
676 676 tags = []
677 677
678 678 if env.get('RC_SCM_DATA'):
679 679 extras = json.loads(env['RC_SCM_DATA'])
680 680 else:
681 681 # fallback method to read from TXN-ID stored data
682 682 extras = _get_extras_from_txn_id(path, txn_id)
683 683 if not extras:
684 684 return 0
685 685
686 686 extras['hook_type'] = 'pre_commit'
687 687 extras['commit_ids'] = [txn_id]
688 688 extras['txn_id'] = txn_id
689 689 extras['new_refs'] = {
690 690 'total_commits': 1,
691 691 'branches': branches,
692 692 'bookmarks': [],
693 693 'tags': tags,
694 694 }
695 695
696 696 return _call_hook('pre_push', extras, SvnMessageWriter())
697 697
698 698
699 699 def svn_post_commit(repo_path, commit_data, env):
700 700 """
701 701 commit_data is path, rev, txn_id
702 702 """
703 703 if len(commit_data) == 3:
704 704 path, commit_id, txn_id = commit_data
705 705 elif len(commit_data) == 2:
706 706 log.error('Failed to extract txn_id from commit_data using legacy method. '
707 707 'Some functionality might be limited')
708 708 path, commit_id = commit_data
709 709 txn_id = None
710 710
711 711 branches = []
712 712 tags = []
713 713
714 714 if env.get('RC_SCM_DATA'):
715 715 extras = json.loads(env['RC_SCM_DATA'])
716 716 else:
717 717 # fallback method to read from TXN-ID stored data
718 718 extras = _get_extras_from_commit_id(commit_id, path)
719 719 if not extras:
720 720 return 0
721 721
722 722 extras['hook_type'] = 'post_commit'
723 723 extras['commit_ids'] = [commit_id]
724 724 extras['txn_id'] = txn_id
725 725 extras['new_refs'] = {
726 726 'branches': branches,
727 727 'bookmarks': [],
728 728 'tags': tags,
729 729 'total_commits': 1,
730 730 }
731 731
732 732 if 'repo_size' in extras['hooks']:
733 733 try:
734 734 _call_hook('repo_size', extras, SvnMessageWriter())
735 735 except Exception:
736 736 pass
737 737
738 738 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,740 +1,740 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21 import base64
22 22 import locale
23 23 import logging
24 24 import uuid
25 25 import time
26 26 import wsgiref.util
27 27 import traceback
28 28 import tempfile
29 29 import psutil
30 30
31 31 from itertools import chain
32 32
33 33 import msgpack
34 34 import configparser
35 35
36 36 from pyramid.config import Configurator
37 37 from pyramid.wsgi import wsgiapp
38 38 from pyramid.response import Response
39 39
40 40 from vcsserver.lib.rc_json import json
41 41 from vcsserver.config.settings_maker import SettingsMaker
42 from vcsserver.utils import safe_int
42 from vcsserver.str_utils import safe_int
43 43 from vcsserver.lib.statsd_client import StatsdClient
44 44
45 45 log = logging.getLogger(__name__)
46 46
47 47 # due to Mercurial/glibc2.27 problems we need to detect if locale settings are
48 48 # causing problems and "fix" it in case they do and fallback to LC_ALL = C
49 49
50 50 try:
51 51 locale.setlocale(locale.LC_ALL, '')
52 52 except locale.Error as e:
53 53 log.error(
54 54 'LOCALE ERROR: failed to set LC_ALL, fallback to LC_ALL=C, org error: %s', e)
55 55 os.environ['LC_ALL'] = 'C'
56 56
57 57
58 58 import vcsserver
59 59 from vcsserver import remote_wsgi, scm_app, settings, hgpatches
60 60 from vcsserver.git_lfs.app import GIT_LFS_CONTENT_TYPE, GIT_LFS_PROTO_PAT
61 61 from vcsserver.echo_stub import remote_wsgi as remote_wsgi_stub
62 62 from vcsserver.echo_stub.echo_app import EchoApp
63 63 from vcsserver.exceptions import HTTPRepoLocked, HTTPRepoBranchProtected
64 64 from vcsserver.lib.exc_tracking import store_exception
65 65 from vcsserver.server import VcsServer
66 66
67 67 strict_vcs = True
68 68
69 69 git_import_err = None
70 70 try:
71 71 from vcsserver.remote.git import GitFactory, GitRemote
72 72 except ImportError as e:
73 73 GitFactory = None
74 74 GitRemote = None
75 75 git_import_err = e
76 76 if strict_vcs:
77 77 raise
78 78
79 79
80 80 hg_import_err = None
81 81 try:
82 82 from vcsserver.remote.hg import MercurialFactory, HgRemote
83 83 except ImportError as e:
84 84 MercurialFactory = None
85 85 HgRemote = None
86 86 hg_import_err = e
87 87 if strict_vcs:
88 88 raise
89 89
90 90
91 91 svn_import_err = None
92 92 try:
93 93 from vcsserver.remote.svn import SubversionFactory, SvnRemote
94 94 except ImportError as e:
95 95 SubversionFactory = None
96 96 SvnRemote = None
97 97 svn_import_err = e
98 98 if strict_vcs:
99 99 raise
100 100
101 101
102 102 def _is_request_chunked(environ):
103 103 stream = environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked'
104 104 return stream
105 105
106 106
107 107 def log_max_fd():
108 108 try:
109 109 maxfd = psutil.Process().rlimit(psutil.RLIMIT_NOFILE)[1]
110 110 log.info('Max file descriptors value: %s', maxfd)
111 111 except Exception:
112 112 pass
113 113
114 114
115 115 class VCS(object):
116 116 def __init__(self, locale_conf=None, cache_config=None):
117 117 self.locale = locale_conf
118 118 self.cache_config = cache_config
119 119 self._configure_locale()
120 120
121 121 log_max_fd()
122 122
123 123 if GitFactory and GitRemote:
124 124 git_factory = GitFactory()
125 125 self._git_remote = GitRemote(git_factory)
126 126 else:
127 127 log.error("Git client import failed: %s", git_import_err)
128 128
129 129 if MercurialFactory and HgRemote:
130 130 hg_factory = MercurialFactory()
131 131 self._hg_remote = HgRemote(hg_factory)
132 132 else:
133 133 log.error("Mercurial client import failed: %s", hg_import_err)
134 134
135 135 if SubversionFactory and SvnRemote:
136 136 svn_factory = SubversionFactory()
137 137
138 138 # hg factory is used for svn url validation
139 139 hg_factory = MercurialFactory()
140 140 self._svn_remote = SvnRemote(svn_factory, hg_factory=hg_factory)
141 141 else:
142 142 log.error("Subversion client import failed: %s", svn_import_err)
143 143
144 144 self._vcsserver = VcsServer()
145 145
146 146 def _configure_locale(self):
147 147 if self.locale:
148 148 log.info('Settings locale: `LC_ALL` to %s', self.locale)
149 149 else:
150 150 log.info('Configuring locale subsystem based on environment variables')
151 151 try:
152 152 # If self.locale is the empty string, then the locale
153 153 # module will use the environment variables. See the
154 154 # documentation of the package `locale`.
155 155 locale.setlocale(locale.LC_ALL, self.locale)
156 156
157 157 language_code, encoding = locale.getlocale()
158 158 log.info(
159 159 'Locale set to language code "%s" with encoding "%s".',
160 160 language_code, encoding)
161 161 except locale.Error:
162 162 log.exception('Cannot set locale, not configuring the locale system')
163 163
164 164
165 165 class WsgiProxy(object):
166 166 def __init__(self, wsgi):
167 167 self.wsgi = wsgi
168 168
169 169 def __call__(self, environ, start_response):
170 170 input_data = environ['wsgi.input'].read()
171 171 input_data = msgpack.unpackb(input_data)
172 172
173 173 error = None
174 174 try:
175 175 data, status, headers = self.wsgi.handle(
176 176 input_data['environment'], input_data['input_data'],
177 177 *input_data['args'], **input_data['kwargs'])
178 178 except Exception as e:
179 179 data, status, headers = [], None, None
180 180 error = {
181 181 'message': str(e),
182 182 '_vcs_kind': getattr(e, '_vcs_kind', None)
183 183 }
184 184
185 185 start_response(200, {})
186 186 return self._iterator(error, status, headers, data)
187 187
188 188 def _iterator(self, error, status, headers, data):
189 189 initial_data = [
190 190 error,
191 191 status,
192 192 headers,
193 193 ]
194 194
195 195 for d in chain(initial_data, data):
196 196 yield msgpack.packb(d)
197 197
198 198
199 199 def not_found(request):
200 200 return {'status': '404 NOT FOUND'}
201 201
202 202
203 203 class VCSViewPredicate(object):
204 204 def __init__(self, val, config):
205 205 self.remotes = val
206 206
207 207 def text(self):
208 208 return 'vcs view method = %s' % (list(self.remotes.keys()),)
209 209
210 210 phash = text
211 211
212 212 def __call__(self, context, request):
213 213 """
214 214 View predicate that returns true if given backend is supported by
215 215 defined remotes.
216 216 """
217 217 backend = request.matchdict.get('backend')
218 218 return backend in self.remotes
219 219
220 220
221 221 class HTTPApplication(object):
222 222 ALLOWED_EXCEPTIONS = ('KeyError', 'URLError')
223 223
224 224 remote_wsgi = remote_wsgi
225 225 _use_echo_app = False
226 226
227 227 def __init__(self, settings=None, global_config=None):
228 228
229 229 self.config = Configurator(settings=settings)
230 230 # Init our statsd at very start
231 231 self.config.registry.statsd = StatsdClient.statsd
232 232
233 233 self.global_config = global_config
234 234 self.config.include('vcsserver.lib.rc_cache')
235 235
236 236 settings_locale = settings.get('locale', '') or 'en_US.UTF-8'
237 237 vcs = VCS(locale_conf=settings_locale, cache_config=settings)
238 238 self._remotes = {
239 239 'hg': vcs._hg_remote,
240 240 'git': vcs._git_remote,
241 241 'svn': vcs._svn_remote,
242 242 'server': vcs._vcsserver,
243 243 }
244 244 if settings.get('dev.use_echo_app', 'false').lower() == 'true':
245 245 self._use_echo_app = True
246 246 log.warning("Using EchoApp for VCS operations.")
247 247 self.remote_wsgi = remote_wsgi_stub
248 248
249 249 self._configure_settings(global_config, settings)
250 250
251 251 self._configure()
252 252
253 253 def _configure_settings(self, global_config, app_settings):
254 254 """
255 255 Configure the settings module.
256 256 """
257 257 settings_merged = global_config.copy()
258 258 settings_merged.update(app_settings)
259 259
260 260 git_path = app_settings.get('git_path', None)
261 261 if git_path:
262 262 settings.GIT_EXECUTABLE = git_path
263 263 binary_dir = app_settings.get('core.binary_dir', None)
264 264 if binary_dir:
265 265 settings.BINARY_DIR = binary_dir
266 266
267 267 # Store the settings to make them available to other modules.
268 268 vcsserver.PYRAMID_SETTINGS = settings_merged
269 269 vcsserver.CONFIG = settings_merged
270 270
271 271 def _configure(self):
272 272 self.config.add_renderer(name='msgpack', factory=self._msgpack_renderer_factory)
273 273
274 274 self.config.add_route('service', '/_service')
275 275 self.config.add_route('status', '/status')
276 276 self.config.add_route('hg_proxy', '/proxy/hg')
277 277 self.config.add_route('git_proxy', '/proxy/git')
278 278
279 279 # rpc methods
280 280 self.config.add_route('vcs', '/{backend}')
281 281
282 282 # streaming rpc remote methods
283 283 self.config.add_route('vcs_stream', '/{backend}/stream')
284 284
285 285 # vcs operations clone/push as streaming
286 286 self.config.add_route('stream_git', '/stream/git/*repo_name')
287 287 self.config.add_route('stream_hg', '/stream/hg/*repo_name')
288 288
289 289 self.config.add_view(self.status_view, route_name='status', renderer='json')
290 290 self.config.add_view(self.service_view, route_name='service', renderer='msgpack')
291 291
292 292 self.config.add_view(self.hg_proxy(), route_name='hg_proxy')
293 293 self.config.add_view(self.git_proxy(), route_name='git_proxy')
294 294 self.config.add_view(self.vcs_view, route_name='vcs', renderer='msgpack',
295 295 vcs_view=self._remotes)
296 296 self.config.add_view(self.vcs_stream_view, route_name='vcs_stream',
297 297 vcs_view=self._remotes)
298 298
299 299 self.config.add_view(self.hg_stream(), route_name='stream_hg')
300 300 self.config.add_view(self.git_stream(), route_name='stream_git')
301 301
302 302 self.config.add_view_predicate('vcs_view', VCSViewPredicate)
303 303
304 304 self.config.add_notfound_view(not_found, renderer='json')
305 305
306 306 self.config.add_view(self.handle_vcs_exception, context=Exception)
307 307
308 308 self.config.add_tween(
309 309 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
310 310 )
311 311 self.config.add_request_method(
312 312 'vcsserver.lib.request_counter.get_request_counter',
313 313 'request_count')
314 314
315 315 def wsgi_app(self):
316 316 return self.config.make_wsgi_app()
317 317
318 318 def _vcs_view_params(self, request):
319 319 remote = self._remotes[request.matchdict['backend']]
320 320 payload = msgpack.unpackb(request.body, use_list=True, raw=False)
321 321
322 322 method = payload.get('method')
323 323 params = payload['params']
324 324 wire = params.get('wire')
325 325 args = params.get('args')
326 326 kwargs = params.get('kwargs')
327 327 context_uid = None
328 328
329 329 if wire:
330 330 try:
331 331 wire['context'] = context_uid = uuid.UUID(wire['context'])
332 332 except KeyError:
333 333 pass
334 334 args.insert(0, wire)
335 335 repo_state_uid = wire.get('repo_state_uid') if wire else None
336 336
337 337 # NOTE(marcink): trading complexity for slight performance
338 338 if log.isEnabledFor(logging.DEBUG):
339 339 no_args_methods = [
340 340
341 341 ]
342 342 if method in no_args_methods:
343 343 call_args = ''
344 344 else:
345 345 call_args = args[1:]
346 346
347 347 log.debug('Method requested:`%s` with args:%s kwargs:%s context_uid: %s, repo_state_uid:%s',
348 348 method, call_args, kwargs, context_uid, repo_state_uid)
349 349
350 350 statsd = request.registry.statsd
351 351 if statsd:
352 352 statsd.incr(
353 353 'vcsserver_method_total', tags=[
354 354 "method:{}".format(method),
355 355 ])
356 356 return payload, remote, method, args, kwargs
357 357
358 358 def vcs_view(self, request):
359 359
360 360 payload, remote, method, args, kwargs = self._vcs_view_params(request)
361 361 payload_id = payload.get('id')
362 362
363 363 try:
364 364 resp = getattr(remote, method)(*args, **kwargs)
365 365 except Exception as e:
366 366 exc_info = list(sys.exc_info())
367 367 exc_type, exc_value, exc_traceback = exc_info
368 368
369 369 org_exc = getattr(e, '_org_exc', None)
370 370 org_exc_name = None
371 371 org_exc_tb = ''
372 372 if org_exc:
373 373 org_exc_name = org_exc.__class__.__name__
374 374 org_exc_tb = getattr(e, '_org_exc_tb', '')
375 375 # replace our "faked" exception with our org
376 376 exc_info[0] = org_exc.__class__
377 377 exc_info[1] = org_exc
378 378
379 379 should_store_exc = True
380 380 if org_exc:
381 381 def get_exc_fqn(_exc_obj):
382 382 module_name = getattr(org_exc.__class__, '__module__', 'UNKNOWN')
383 383 return module_name + '.' + org_exc_name
384 384
385 385 exc_fqn = get_exc_fqn(org_exc)
386 386
387 387 if exc_fqn in ['mercurial.error.RepoLookupError',
388 388 'vcsserver.exceptions.RefNotFoundException']:
389 389 should_store_exc = False
390 390
391 391 if should_store_exc:
392 392 store_exception(id(exc_info), exc_info, request_path=request.path)
393 393
394 394 tb_info = ''.join(
395 395 traceback.format_exception(exc_type, exc_value, exc_traceback))
396 396
397 397 type_ = e.__class__.__name__
398 398 if type_ not in self.ALLOWED_EXCEPTIONS:
399 399 type_ = None
400 400
401 401 resp = {
402 402 'id': payload_id,
403 403 'error': {
404 404 'message': str(e),
405 405 'traceback': tb_info,
406 406 'org_exc': org_exc_name,
407 407 'org_exc_tb': org_exc_tb,
408 408 'type': type_
409 409 }
410 410 }
411 411
412 412 try:
413 413 resp['error']['_vcs_kind'] = getattr(e, '_vcs_kind', None)
414 414 except AttributeError:
415 415 pass
416 416 else:
417 417 resp = {
418 418 'id': payload_id,
419 419 'result': resp
420 420 }
421 421
422 422 return resp
423 423
424 424 def vcs_stream_view(self, request):
425 425 payload, remote, method, args, kwargs = self._vcs_view_params(request)
426 426 # this method has a stream: marker we remove it here
427 427 method = method.split('stream:')[-1]
428 428 chunk_size = safe_int(payload.get('chunk_size')) or 4096
429 429
430 430 try:
431 431 resp = getattr(remote, method)(*args, **kwargs)
432 432 except Exception as e:
433 433 raise
434 434
435 435 def get_chunked_data(method_resp):
436 436 stream = io.BytesIO(method_resp)
437 437 while 1:
438 438 chunk = stream.read(chunk_size)
439 439 if not chunk:
440 440 break
441 441 yield chunk
442 442
443 443 response = Response(app_iter=get_chunked_data(resp))
444 444 response.content_type = 'application/octet-stream'
445 445
446 446 return response
447 447
448 448 def status_view(self, request):
449 449 import vcsserver
450 450 return {'status': 'OK', 'vcsserver_version': vcsserver.__version__,
451 451 'pid': os.getpid()}
452 452
453 453 def service_view(self, request):
454 454 import vcsserver
455 455
456 456 payload = msgpack.unpackb(request.body, use_list=True)
457 457 server_config, app_config = {}, {}
458 458
459 459 try:
460 460 path = self.global_config['__file__']
461 461 config = configparser.RawConfigParser()
462 462
463 463 config.read(path)
464 464
465 465 if config.has_section('server:main'):
466 466 server_config = dict(config.items('server:main'))
467 467 if config.has_section('app:main'):
468 468 app_config = dict(config.items('app:main'))
469 469
470 470 except Exception:
471 471 log.exception('Failed to read .ini file for display')
472 472
473 473 environ = list(os.environ.items())
474 474
475 475 resp = {
476 476 'id': payload.get('id'),
477 477 'result': dict(
478 478 version=vcsserver.__version__,
479 479 config=server_config,
480 480 app_config=app_config,
481 481 environ=environ,
482 482 payload=payload,
483 483 )
484 484 }
485 485 return resp
486 486
487 487 def _msgpack_renderer_factory(self, info):
488 488 def _render(value, system):
489 489 request = system.get('request')
490 490 if request is not None:
491 491 response = request.response
492 492 ct = response.content_type
493 493 if ct == response.default_content_type:
494 494 response.content_type = 'application/x-msgpack'
495 495 return msgpack.packb(value)
496 496 return _render
497 497
498 498 def set_env_from_config(self, environ, config):
499 499 dict_conf = {}
500 500 try:
501 501 for elem in config:
502 502 if elem[0] == 'rhodecode':
503 503 dict_conf = json.loads(elem[2])
504 504 break
505 505 except Exception:
506 506 log.exception('Failed to fetch SCM CONFIG')
507 507 return
508 508
509 509 username = dict_conf.get('username')
510 510 if username:
511 511 environ['REMOTE_USER'] = username
512 512 # mercurial specific, some extension api rely on this
513 513 environ['HGUSER'] = username
514 514
515 515 ip = dict_conf.get('ip')
516 516 if ip:
517 517 environ['REMOTE_HOST'] = ip
518 518
519 519 if _is_request_chunked(environ):
520 520 # set the compatibility flag for webob
521 521 environ['wsgi.input_terminated'] = True
522 522
523 523 def hg_proxy(self):
524 524 @wsgiapp
525 525 def _hg_proxy(environ, start_response):
526 526 app = WsgiProxy(self.remote_wsgi.HgRemoteWsgi())
527 527 return app(environ, start_response)
528 528 return _hg_proxy
529 529
530 530 def git_proxy(self):
531 531 @wsgiapp
532 532 def _git_proxy(environ, start_response):
533 533 app = WsgiProxy(self.remote_wsgi.GitRemoteWsgi())
534 534 return app(environ, start_response)
535 535 return _git_proxy
536 536
537 537 def hg_stream(self):
538 538 if self._use_echo_app:
539 539 @wsgiapp
540 540 def _hg_stream(environ, start_response):
541 541 app = EchoApp('fake_path', 'fake_name', None)
542 542 return app(environ, start_response)
543 543 return _hg_stream
544 544 else:
545 545 @wsgiapp
546 546 def _hg_stream(environ, start_response):
547 547 log.debug('http-app: handling hg stream')
548 548 repo_path = environ['HTTP_X_RC_REPO_PATH']
549 549 repo_name = environ['HTTP_X_RC_REPO_NAME']
550 550 packed_config = base64.b64decode(
551 551 environ['HTTP_X_RC_REPO_CONFIG'])
552 552 config = msgpack.unpackb(packed_config)
553 553 app = scm_app.create_hg_wsgi_app(
554 554 repo_path, repo_name, config)
555 555
556 556 # Consistent path information for hgweb
557 557 environ['PATH_INFO'] = environ['HTTP_X_RC_PATH_INFO']
558 558 environ['REPO_NAME'] = repo_name
559 559 self.set_env_from_config(environ, config)
560 560
561 561 log.debug('http-app: starting app handler '
562 562 'with %s and process request', app)
563 563 return app(environ, ResponseFilter(start_response))
564 564 return _hg_stream
565 565
566 566 def git_stream(self):
567 567 if self._use_echo_app:
568 568 @wsgiapp
569 569 def _git_stream(environ, start_response):
570 570 app = EchoApp('fake_path', 'fake_name', None)
571 571 return app(environ, start_response)
572 572 return _git_stream
573 573 else:
574 574 @wsgiapp
575 575 def _git_stream(environ, start_response):
576 576 log.debug('http-app: handling git stream')
577 577 repo_path = environ['HTTP_X_RC_REPO_PATH']
578 578 repo_name = environ['HTTP_X_RC_REPO_NAME']
579 579 packed_config = base64.b64decode(
580 580 environ['HTTP_X_RC_REPO_CONFIG'])
581 581 config = msgpack.unpackb(packed_config, raw=False)
582 582
583 583 environ['PATH_INFO'] = environ['HTTP_X_RC_PATH_INFO']
584 584 self.set_env_from_config(environ, config)
585 585
586 586 content_type = environ.get('CONTENT_TYPE', '')
587 587
588 588 path = environ['PATH_INFO']
589 589 is_lfs_request = GIT_LFS_CONTENT_TYPE in content_type
590 590 log.debug(
591 591 'LFS: Detecting if request `%s` is LFS server path based '
592 592 'on content type:`%s`, is_lfs:%s',
593 593 path, content_type, is_lfs_request)
594 594
595 595 if not is_lfs_request:
596 596 # fallback detection by path
597 597 if GIT_LFS_PROTO_PAT.match(path):
598 598 is_lfs_request = True
599 599 log.debug(
600 600 'LFS: fallback detection by path of: `%s`, is_lfs:%s',
601 601 path, is_lfs_request)
602 602
603 603 if is_lfs_request:
604 604 app = scm_app.create_git_lfs_wsgi_app(
605 605 repo_path, repo_name, config)
606 606 else:
607 607 app = scm_app.create_git_wsgi_app(
608 608 repo_path, repo_name, config)
609 609
610 610 log.debug('http-app: starting app handler '
611 611 'with %s and process request', app)
612 612
613 613 return app(environ, start_response)
614 614
615 615 return _git_stream
616 616
617 617 def handle_vcs_exception(self, exception, request):
618 618 _vcs_kind = getattr(exception, '_vcs_kind', '')
619 619 if _vcs_kind == 'repo_locked':
620 620 # Get custom repo-locked status code if present.
621 621 status_code = request.headers.get('X-RC-Locked-Status-Code')
622 622 return HTTPRepoLocked(
623 623 title=exception.message, status_code=status_code)
624 624
625 625 elif _vcs_kind == 'repo_branch_protected':
626 626 # Get custom repo-branch-protected status code if present.
627 627 return HTTPRepoBranchProtected(title=exception.message)
628 628
629 629 exc_info = request.exc_info
630 630 store_exception(id(exc_info), exc_info)
631 631
632 632 traceback_info = 'unavailable'
633 633 if request.exc_info:
634 634 exc_type, exc_value, exc_tb = request.exc_info
635 635 traceback_info = ''.join(traceback.format_exception(exc_type, exc_value, exc_tb))
636 636
637 637 log.error(
638 638 'error occurred handling this request for path: %s, \n tb: %s',
639 639 request.path, traceback_info)
640 640
641 641 statsd = request.registry.statsd
642 642 if statsd:
643 643 exc_type = "{}.{}".format(exception.__class__.__module__, exception.__class__.__name__)
644 644 statsd.incr('vcsserver_exception_total',
645 645 tags=["type:{}".format(exc_type)])
646 646 raise exception
647 647
648 648
649 649 class ResponseFilter(object):
650 650
651 651 def __init__(self, start_response):
652 652 self._start_response = start_response
653 653
654 654 def __call__(self, status, response_headers, exc_info=None):
655 655 headers = tuple(
656 656 (h, v) for h, v in response_headers
657 657 if not wsgiref.util.is_hop_by_hop(h))
658 658 return self._start_response(status, headers, exc_info)
659 659
660 660
661 661 def sanitize_settings_and_apply_defaults(global_config, settings):
662 662 global_settings_maker = SettingsMaker(global_config)
663 663 settings_maker = SettingsMaker(settings)
664 664
665 665 settings_maker.make_setting('logging.autoconfigure', False, parser='bool')
666 666
667 667 logging_conf = os.path.join(os.path.dirname(global_config.get('__file__')), 'logging.ini')
668 668 settings_maker.enable_logging(logging_conf)
669 669
670 670 # Default includes, possible to change as a user
671 671 pyramid_includes = settings_maker.make_setting('pyramid.includes', [], parser='list:newline')
672 672 log.debug("Using the following pyramid.includes: %s", pyramid_includes)
673 673
674 674 settings_maker.make_setting('__file__', global_config.get('__file__'))
675 675
676 676 settings_maker.make_setting('pyramid.default_locale_name', 'en')
677 677 settings_maker.make_setting('locale', 'en_US.UTF-8')
678 678
679 679 settings_maker.make_setting('core.binary_dir', '')
680 680
681 681 temp_store = tempfile.gettempdir()
682 682 default_cache_dir = os.path.join(temp_store, 'rc_cache')
683 683 # save default, cache dir, and use it for all backends later.
684 684 default_cache_dir = settings_maker.make_setting(
685 685 'cache_dir',
686 686 default=default_cache_dir, default_when_empty=True,
687 687 parser='dir:ensured')
688 688
689 689 # exception store cache
690 690 settings_maker.make_setting(
691 691 'exception_tracker.store_path',
692 692 default=os.path.join(default_cache_dir, 'exc_store'), default_when_empty=True,
693 693 parser='dir:ensured'
694 694 )
695 695
696 696 # repo_object cache defaults
697 697 settings_maker.make_setting(
698 698 'rc_cache.repo_object.backend',
699 699 default='dogpile.cache.rc.file_namespace',
700 700 parser='string')
701 701 settings_maker.make_setting(
702 702 'rc_cache.repo_object.expiration_time',
703 703 default=30 * 24 * 60 * 60, # 30days
704 704 parser='int')
705 705 settings_maker.make_setting(
706 706 'rc_cache.repo_object.arguments.filename',
707 707 default=os.path.join(default_cache_dir, 'vcsserver_cache_repo_object.db'),
708 708 parser='string')
709 709
710 710 # statsd
711 711 settings_maker.make_setting('statsd.enabled', False, parser='bool')
712 712 settings_maker.make_setting('statsd.statsd_host', 'statsd-exporter', parser='string')
713 713 settings_maker.make_setting('statsd.statsd_port', 9125, parser='int')
714 714 settings_maker.make_setting('statsd.statsd_prefix', '')
715 715 settings_maker.make_setting('statsd.statsd_ipv6', False, parser='bool')
716 716
717 717 settings_maker.env_expand()
718 718
719 719
720 720 def main(global_config, **settings):
721 721 start_time = time.time()
722 722 log.info('Pyramid app config starting')
723 723
724 724 if MercurialFactory:
725 725 hgpatches.patch_largefiles_capabilities()
726 726 hgpatches.patch_subrepo_type_mapping()
727 727
728 728 # Fill in and sanitize the defaults & do ENV expansion
729 729 sanitize_settings_and_apply_defaults(global_config, settings)
730 730
731 731 # init and bootstrap StatsdClient
732 732 StatsdClient.setup(settings)
733 733
734 734 pyramid_app = HTTPApplication(settings=settings, global_config=global_config).wsgi_app()
735 735 total_time = time.time() - start_time
736 736 log.info('Pyramid app `%s` created and configured in %.2fs',
737 737 getattr(pyramid_app, 'func_name', 'pyramid_app'), total_time)
738 738 return pyramid_app
739 739
740 740
@@ -1,65 +1,65 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2020 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20
21 21 import logging
22 22
23 23 from repoze.lru import LRUCache
24 24
25 from vcsserver.utils import safe_str
25 from vcsserver.str_utils import safe_str
26 26
27 27 log = logging.getLogger(__name__)
28 28
29 29
30 30 class LRUDict(LRUCache):
31 31 """
32 32 Wrapper to provide partial dict access
33 33 """
34 34
35 35 def __setitem__(self, key, value):
36 36 return self.put(key, value)
37 37
38 38 def __getitem__(self, key):
39 39 return self.get(key)
40 40
41 41 def __contains__(self, key):
42 42 return bool(self.get(key))
43 43
44 44 def __delitem__(self, key):
45 45 del self.data[key]
46 46
47 47 def keys(self):
48 48 return list(self.data.keys())
49 49
50 50
51 51 class LRUDictDebug(LRUDict):
52 52 """
53 53 Wrapper to provide some debug options
54 54 """
55 55 def _report_keys(self):
56 56 elems_cnt = '%s/%s' % (len(list(self.keys())), self.size)
57 57 # trick for pformat print it more nicely
58 58 fmt = '\n'
59 59 for cnt, elem in enumerate(self.keys()):
60 60 fmt += '%s - %s\n' % (cnt+1, safe_str(elem))
61 61 log.debug('current LRU keys (%s):%s', elems_cnt, fmt)
62 62
63 63 def __getitem__(self, key):
64 64 self._report_keys()
65 65 return self.get(key)
@@ -1,330 +1,330 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import time
19 19 import errno
20 20 import logging
21 21
22 22 import msgpack
23 23 import redis
24 24 import pickle
25 25
26 26 from dogpile.cache.api import CachedValue
27 27 from dogpile.cache.backends import memory as memory_backend
28 28 from dogpile.cache.backends import file as file_backend
29 29 from dogpile.cache.backends import redis as redis_backend
30 30 from dogpile.cache.backends.file import NO_VALUE, FileLock
31 31 from dogpile.cache.util import memoized_property
32 32
33 33 from pyramid.settings import asbool
34 34
35 35 from vcsserver.lib.memory_lru_dict import LRUDict, LRUDictDebug
36 from vcsserver.utils import safe_str
36 from vcsserver.str_utils import safe_str
37 37
38 38
39 39 _default_max_size = 1024
40 40
41 41 log = logging.getLogger(__name__)
42 42
43 43
44 44 class LRUMemoryBackend(memory_backend.MemoryBackend):
45 45 key_prefix = 'lru_mem_backend'
46 46 pickle_values = False
47 47
48 48 def __init__(self, arguments):
49 49 max_size = arguments.pop('max_size', _default_max_size)
50 50
51 51 LRUDictClass = LRUDict
52 52 if arguments.pop('log_key_count', None):
53 53 LRUDictClass = LRUDictDebug
54 54
55 55 arguments['cache_dict'] = LRUDictClass(max_size)
56 56 super(LRUMemoryBackend, self).__init__(arguments)
57 57
58 58 def delete(self, key):
59 59 try:
60 60 del self._cache[key]
61 61 except KeyError:
62 62 # we don't care if key isn't there at deletion
63 63 pass
64 64
65 65 def delete_multi(self, keys):
66 66 for key in keys:
67 67 self.delete(key)
68 68
69 69
70 70 class PickleSerializer(object):
71 71
72 72 def _dumps(self, value, safe=False):
73 73 try:
74 74 return pickle.dumps(value)
75 75 except Exception:
76 76 if safe:
77 77 return NO_VALUE
78 78 else:
79 79 raise
80 80
81 81 def _loads(self, value, safe=True):
82 82 try:
83 83 return pickle.loads(value)
84 84 except Exception:
85 85 if safe:
86 86 return NO_VALUE
87 87 else:
88 88 raise
89 89
90 90
91 91 class MsgPackSerializer(object):
92 92
93 93 def _dumps(self, value, safe=False):
94 94 try:
95 95 return msgpack.packb(value)
96 96 except Exception:
97 97 if safe:
98 98 return NO_VALUE
99 99 else:
100 100 raise
101 101
102 102 def _loads(self, value, safe=True):
103 103 """
104 104 pickle maintained the `CachedValue` wrapper of the tuple
105 105 msgpack does not, so it must be added back in.
106 106 """
107 107 try:
108 108 value = msgpack.unpackb(value, use_list=False)
109 109 return CachedValue(*value)
110 110 except Exception:
111 111 if safe:
112 112 return NO_VALUE
113 113 else:
114 114 raise
115 115
116 116
117 117 import fcntl
118 118 flock_org = fcntl.flock
119 119
120 120
121 121 class CustomLockFactory(FileLock):
122 122
123 123 pass
124 124
125 125
126 126 class FileNamespaceBackend(PickleSerializer, file_backend.DBMBackend):
127 127 key_prefix = 'file_backend'
128 128
129 129 def __init__(self, arguments):
130 130 arguments['lock_factory'] = CustomLockFactory
131 131 db_file = arguments.get('filename')
132 132
133 133 log.debug('initialing %s DB in %s', self.__class__.__name__, db_file)
134 134 try:
135 135 super(FileNamespaceBackend, self).__init__(arguments)
136 136 except Exception:
137 137 log.exception('Failed to initialize db at: %s', db_file)
138 138 raise
139 139
140 140 def __repr__(self):
141 141 return '{} `{}`'.format(self.__class__, self.filename)
142 142
143 143 def list_keys(self, prefix=''):
144 144 prefix = '{}:{}'.format(self.key_prefix, prefix)
145 145
146 146 def cond(v):
147 147 if not prefix:
148 148 return True
149 149
150 150 if v.startswith(prefix):
151 151 return True
152 152 return False
153 153
154 154 with self._dbm_file(True) as dbm:
155 155 try:
156 156 return list(filter(cond, list(dbm.keys())))
157 157 except Exception:
158 158 log.error('Failed to fetch DBM keys from DB: %s', self.get_store())
159 159 raise
160 160
161 161 def get_store(self):
162 162 return self.filename
163 163
164 164 def _dbm_get(self, key):
165 165 with self._dbm_file(False) as dbm:
166 166 if hasattr(dbm, 'get'):
167 167 value = dbm.get(key, NO_VALUE)
168 168 else:
169 169 # gdbm objects lack a .get method
170 170 try:
171 171 value = dbm[key]
172 172 except KeyError:
173 173 value = NO_VALUE
174 174 if value is not NO_VALUE:
175 175 value = self._loads(value)
176 176 return value
177 177
178 178 def get(self, key):
179 179 try:
180 180 return self._dbm_get(key)
181 181 except Exception:
182 182 log.error('Failed to fetch DBM key %s from DB: %s', key, self.get_store())
183 183 raise
184 184
185 185 def set(self, key, value):
186 186 with self._dbm_file(True) as dbm:
187 187 dbm[key] = self._dumps(value)
188 188
189 189 def set_multi(self, mapping):
190 190 with self._dbm_file(True) as dbm:
191 191 for key, value in mapping.items():
192 192 dbm[key] = self._dumps(value)
193 193
194 194
195 195 class BaseRedisBackend(redis_backend.RedisBackend):
196 196 key_prefix = ''
197 197
198 198 def __init__(self, arguments):
199 199 super(BaseRedisBackend, self).__init__(arguments)
200 200 self._lock_timeout = self.lock_timeout
201 201 self._lock_auto_renewal = asbool(arguments.pop("lock_auto_renewal", True))
202 202
203 203 if self._lock_auto_renewal and not self._lock_timeout:
204 204 # set default timeout for auto_renewal
205 205 self._lock_timeout = 30
206 206
207 207 def _create_client(self):
208 208 args = {}
209 209
210 210 if self.url is not None:
211 211 args.update(url=self.url)
212 212
213 213 else:
214 214 args.update(
215 215 host=self.host, password=self.password,
216 216 port=self.port, db=self.db
217 217 )
218 218
219 219 connection_pool = redis.ConnectionPool(**args)
220 220
221 221 return redis.StrictRedis(connection_pool=connection_pool)
222 222
223 223 def list_keys(self, prefix=''):
224 224 prefix = '{}:{}*'.format(self.key_prefix, prefix)
225 225 return self.client.keys(prefix)
226 226
227 227 def get_store(self):
228 228 return self.client.connection_pool
229 229
230 230 def get(self, key):
231 231 value = self.client.get(key)
232 232 if value is None:
233 233 return NO_VALUE
234 234 return self._loads(value)
235 235
236 236 def get_multi(self, keys):
237 237 if not keys:
238 238 return []
239 239 values = self.client.mget(keys)
240 240 loads = self._loads
241 241 return [
242 242 loads(v) if v is not None else NO_VALUE
243 243 for v in values]
244 244
245 245 def set(self, key, value):
246 246 if self.redis_expiration_time:
247 247 self.client.setex(key, self.redis_expiration_time,
248 248 self._dumps(value))
249 249 else:
250 250 self.client.set(key, self._dumps(value))
251 251
252 252 def set_multi(self, mapping):
253 253 dumps = self._dumps
254 254 mapping = dict(
255 255 (k, dumps(v))
256 256 for k, v in mapping.items()
257 257 )
258 258
259 259 if not self.redis_expiration_time:
260 260 self.client.mset(mapping)
261 261 else:
262 262 pipe = self.client.pipeline()
263 263 for key, value in mapping.items():
264 264 pipe.setex(key, self.redis_expiration_time, value)
265 265 pipe.execute()
266 266
267 267 def get_mutex(self, key):
268 268 if self.distributed_lock:
269 269 lock_key = '_lock_{0}'.format(safe_str(key))
270 270 return get_mutex_lock(self.client, lock_key, self._lock_timeout,
271 271 auto_renewal=self._lock_auto_renewal)
272 272 else:
273 273 return None
274 274
275 275
276 276 class RedisPickleBackend(PickleSerializer, BaseRedisBackend):
277 277 key_prefix = 'redis_pickle_backend'
278 278 pass
279 279
280 280
281 281 class RedisMsgPackBackend(MsgPackSerializer, BaseRedisBackend):
282 282 key_prefix = 'redis_msgpack_backend'
283 283 pass
284 284
285 285
286 286 def get_mutex_lock(client, lock_key, lock_timeout, auto_renewal=False):
287 287 import redis_lock
288 288
289 289 class _RedisLockWrapper(object):
290 290 """LockWrapper for redis_lock"""
291 291
292 292 @classmethod
293 293 def get_lock(cls):
294 294 return redis_lock.Lock(
295 295 redis_client=client,
296 296 name=lock_key,
297 297 expire=lock_timeout,
298 298 auto_renewal=auto_renewal,
299 299 strict=True,
300 300 )
301 301
302 302 def __repr__(self):
303 303 return "{}:{}".format(self.__class__.__name__, lock_key)
304 304
305 305 def __str__(self):
306 306 return "{}:{}".format(self.__class__.__name__, lock_key)
307 307
308 308 def __init__(self):
309 309 self.lock = self.get_lock()
310 310 self.lock_key = lock_key
311 311
312 312 def acquire(self, wait=True):
313 313 log.debug('Trying to acquire Redis lock for key %s', self.lock_key)
314 314 try:
315 315 acquired = self.lock.acquire(wait)
316 316 log.debug('Got lock for key %s, %s', self.lock_key, acquired)
317 317 return acquired
318 318 except redis_lock.AlreadyAcquired:
319 319 return False
320 320 except redis_lock.AlreadyStarted:
321 321 # refresh thread exists, but it also means we acquired the lock
322 322 return True
323 323
324 324 def release(self):
325 325 try:
326 326 self.lock.release()
327 327 except redis_lock.NotAcquired:
328 328 pass
329 329
330 330 return _RedisLockWrapper()
@@ -1,207 +1,208 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import time
20 20 import logging
21 21 import functools
22 22 import decorator
23 23
24 24 from dogpile.cache import CacheRegion
25 25
26 from vcsserver.utils import safe_bytes, sha1
26 from vcsserver.str_utils import safe_bytes
27 from vcsserver.utils import sha1
27 28 from vcsserver.lib.rc_cache import region_meta
28 29
29 30 log = logging.getLogger(__name__)
30 31
31 32
32 33 class RhodeCodeCacheRegion(CacheRegion):
33 34
34 35 def conditional_cache_on_arguments(
35 36 self, namespace=None,
36 37 expiration_time=None,
37 38 should_cache_fn=None,
38 39 to_str=str,
39 40 function_key_generator=None,
40 41 condition=True):
41 42 """
42 43 Custom conditional decorator, that will not touch any dogpile internals if
43 44 condition isn't meet. This works a bit different than should_cache_fn
44 45 And it's faster in cases we don't ever want to compute cached values
45 46 """
46 47 expiration_time_is_callable = callable(expiration_time)
47 48
48 49 if function_key_generator is None:
49 50 function_key_generator = self.function_key_generator
50 51
51 52 def get_or_create_for_user_func(key_generator, user_func, *arg, **kw):
52 53
53 54 if not condition:
54 55 log.debug('Calling un-cached method:%s', user_func.__name__)
55 56 start = time.time()
56 57 result = user_func(*arg, **kw)
57 58 total = time.time() - start
58 59 log.debug('un-cached method:%s took %.4fs', user_func.__name__, total)
59 60 return result
60 61
61 62 key = key_generator(*arg, **kw)
62 63
63 64 timeout = expiration_time() if expiration_time_is_callable \
64 65 else expiration_time
65 66
66 67 log.debug('Calling cached method:`%s`', user_func.__name__)
67 68 return self.get_or_create(key, user_func, timeout, should_cache_fn, (arg, kw))
68 69
69 70 def cache_decorator(user_func):
70 71 if to_str is str:
71 72 # backwards compatible
72 73 key_generator = function_key_generator(namespace, user_func)
73 74 else:
74 75 key_generator = function_key_generator(namespace, user_func, to_str=to_str)
75 76
76 77 def refresh(*arg, **kw):
77 78 """
78 79 Like invalidate, but regenerates the value instead
79 80 """
80 81 key = key_generator(*arg, **kw)
81 82 value = user_func(*arg, **kw)
82 83 self.set(key, value)
83 84 return value
84 85
85 86 def invalidate(*arg, **kw):
86 87 key = key_generator(*arg, **kw)
87 88 self.delete(key)
88 89
89 90 def set_(value, *arg, **kw):
90 91 key = key_generator(*arg, **kw)
91 92 self.set(key, value)
92 93
93 94 def get(*arg, **kw):
94 95 key = key_generator(*arg, **kw)
95 96 return self.get(key)
96 97
97 98 user_func.set = set_
98 99 user_func.invalidate = invalidate
99 100 user_func.get = get
100 101 user_func.refresh = refresh
101 102 user_func.key_generator = key_generator
102 103 user_func.original = user_func
103 104
104 105 # Use `decorate` to preserve the signature of :param:`user_func`.
105 106 return decorator.decorate(user_func, functools.partial(
106 107 get_or_create_for_user_func, key_generator))
107 108
108 109 return cache_decorator
109 110
110 111
111 112 def make_region(*arg, **kw):
112 113 return RhodeCodeCacheRegion(*arg, **kw)
113 114
114 115
115 116 def get_default_cache_settings(settings, prefixes=None):
116 117 prefixes = prefixes or []
117 118 cache_settings = {}
118 119 for key in settings.keys():
119 120 for prefix in prefixes:
120 121 if key.startswith(prefix):
121 122 name = key.split(prefix)[1].strip()
122 123 val = settings[key]
123 124 if isinstance(val, str):
124 125 val = val.strip()
125 126 cache_settings[name] = val
126 127 return cache_settings
127 128
128 129
129 130 def compute_key_from_params(*args):
130 131 """
131 132 Helper to compute key from given params to be used in cache manager
132 133 """
133 134 return sha1(safe_bytes("_".join(map(str, args))))
134 135
135 136
136 137 def backend_key_generator(backend):
137 138 """
138 139 Special wrapper that also sends over the backend to the key generator
139 140 """
140 141 def wrapper(namespace, fn):
141 142 return key_generator(backend, namespace, fn)
142 143 return wrapper
143 144
144 145
145 146 def key_generator(backend, namespace, fn):
146 147 fname = fn.__name__
147 148
148 149 def generate_key(*args):
149 150 backend_prefix = getattr(backend, 'key_prefix', None) or 'backend_prefix'
150 151 namespace_pref = namespace or 'default_namespace'
151 152 arg_key = compute_key_from_params(*args)
152 153 final_key = "{}:{}:{}_{}".format(backend_prefix, namespace_pref, fname, arg_key)
153 154
154 155 return final_key
155 156
156 157 return generate_key
157 158
158 159
159 160 def get_or_create_region(region_name, region_namespace=None):
160 161 from vcsserver.lib.rc_cache.backends import FileNamespaceBackend
161 162 region_obj = region_meta.dogpile_cache_regions.get(region_name)
162 163 if not region_obj:
163 164 raise EnvironmentError(
164 165 'Region `{}` not in configured: {}.'.format(
165 166 region_name, list(region_meta.dogpile_cache_regions.keys())))
166 167
167 168 region_uid_name = '{}:{}'.format(region_name, region_namespace)
168 169 if isinstance(region_obj.actual_backend, FileNamespaceBackend):
169 170 region_exist = region_meta.dogpile_cache_regions.get(region_namespace)
170 171 if region_exist:
171 172 log.debug('Using already configured region: %s', region_namespace)
172 173 return region_exist
173 174 cache_dir = region_meta.dogpile_config_defaults['cache_dir']
174 175 expiration_time = region_obj.expiration_time
175 176
176 177 if not os.path.isdir(cache_dir):
177 178 os.makedirs(cache_dir)
178 179 new_region = make_region(
179 180 name=region_uid_name,
180 181 function_key_generator=backend_key_generator(region_obj.actual_backend)
181 182 )
182 183 namespace_filename = os.path.join(
183 184 cache_dir, "{}.cache.dbm".format(region_namespace))
184 185 # special type that allows 1db per namespace
185 186 new_region.configure(
186 187 backend='dogpile.cache.rc.file_namespace',
187 188 expiration_time=expiration_time,
188 189 arguments={"filename": namespace_filename}
189 190 )
190 191
191 192 # create and save in region caches
192 193 log.debug('configuring new region: %s', region_uid_name)
193 194 region_obj = region_meta.dogpile_cache_regions[region_namespace] = new_region
194 195
195 196 return region_obj
196 197
197 198
198 199 def clear_cache_namespace(cache_region, cache_namespace_uid, invalidate=False):
199 200 region = get_or_create_region(cache_region, cache_namespace_uid)
200 201 cache_keys = region.backend.list_keys(prefix=cache_namespace_uid)
201 202 num_delete_keys = len(cache_keys)
202 203 if invalidate:
203 204 region.invalidate(hard=False)
204 205 else:
205 206 if num_delete_keys:
206 207 region.delete_multi(cache_keys)
207 208 return num_delete_keys
@@ -1,160 +1,160 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import tempfile
20 20
21 21 from svn import client
22 22 from svn import core
23 23 from svn import ra
24 24
25 25 from mercurial import error
26 26
27 from vcsserver.utils import safe_bytes
27 from vcsserver.str_utils import safe_bytes
28 28
29 29 core.svn_config_ensure(None)
30 30 svn_config = core.svn_config_get_config(None)
31 31
32 32
33 33 class RaCallbacks(ra.Callbacks):
34 34 @staticmethod
35 35 def open_tmp_file(pool): # pragma: no cover
36 36 (fd, fn) = tempfile.mkstemp()
37 37 os.close(fd)
38 38 return fn
39 39
40 40 @staticmethod
41 41 def get_client_string(pool):
42 42 return b'RhodeCode-subversion-url-checker'
43 43
44 44
45 45 class SubversionException(Exception):
46 46 pass
47 47
48 48
49 49 class SubversionConnectionException(SubversionException):
50 50 """Exception raised when a generic error occurs when connecting to a repository."""
51 51
52 52
53 53 def normalize_url(url):
54 54 if not url:
55 55 return url
56 56 if url.startswith(b'svn+http://') or url.startswith(b'svn+https://'):
57 57 url = url[4:]
58 58 url = url.rstrip(b'/')
59 59 return url
60 60
61 61
62 62 def _create_auth_baton(pool):
63 63 """Create a Subversion authentication baton. """
64 64 # Give the client context baton a suite of authentication
65 65 # providers.h
66 66 platform_specific = [
67 67 'svn_auth_get_gnome_keyring_simple_provider',
68 68 'svn_auth_get_gnome_keyring_ssl_client_cert_pw_provider',
69 69 'svn_auth_get_keychain_simple_provider',
70 70 'svn_auth_get_keychain_ssl_client_cert_pw_provider',
71 71 'svn_auth_get_kwallet_simple_provider',
72 72 'svn_auth_get_kwallet_ssl_client_cert_pw_provider',
73 73 'svn_auth_get_ssl_client_cert_file_provider',
74 74 'svn_auth_get_windows_simple_provider',
75 75 'svn_auth_get_windows_ssl_server_trust_provider',
76 76 ]
77 77
78 78 providers = []
79 79
80 80 for p in platform_specific:
81 81 if getattr(core, p, None) is not None:
82 82 try:
83 83 providers.append(getattr(core, p)())
84 84 except RuntimeError:
85 85 pass
86 86
87 87 providers += [
88 88 client.get_simple_provider(),
89 89 client.get_username_provider(),
90 90 client.get_ssl_client_cert_file_provider(),
91 91 client.get_ssl_client_cert_pw_file_provider(),
92 92 client.get_ssl_server_trust_file_provider(),
93 93 ]
94 94
95 95 return core.svn_auth_open(providers, pool)
96 96
97 97
98 98 class SubversionRepo(object):
99 99 """Wrapper for a Subversion repository.
100 100
101 101 It uses the SWIG Python bindings, see above for requirements.
102 102 """
103 103 def __init__(self, svn_url: bytes = b'', username: bytes = b'', password: bytes = b''):
104 104
105 105 self.username = username
106 106 self.password = password
107 107 self.svn_url = core.svn_path_canonicalize(svn_url)
108 108
109 109 self.auth_baton_pool = core.Pool()
110 110 self.auth_baton = _create_auth_baton(self.auth_baton_pool)
111 111 # self.init_ra_and_client() assumes that a pool already exists
112 112 self.pool = core.Pool()
113 113
114 114 self.ra = self.init_ra_and_client()
115 115 self.uuid = ra.get_uuid(self.ra, self.pool)
116 116
117 117 def init_ra_and_client(self):
118 118 """Initializes the RA and client layers, because sometimes getting
119 119 unified diffs runs the remote server out of open files.
120 120 """
121 121
122 122 if self.username:
123 123 core.svn_auth_set_parameter(self.auth_baton,
124 124 core.SVN_AUTH_PARAM_DEFAULT_USERNAME,
125 125 self.username)
126 126 if self.password:
127 127 core.svn_auth_set_parameter(self.auth_baton,
128 128 core.SVN_AUTH_PARAM_DEFAULT_PASSWORD,
129 129 self.password)
130 130
131 131 callbacks = RaCallbacks()
132 132 callbacks.auth_baton = self.auth_baton
133 133
134 134 try:
135 135 return ra.open2(self.svn_url, callbacks, svn_config, self.pool)
136 136 except SubversionException as e:
137 137 # e.child contains a detailed error messages
138 138 msglist = []
139 139 svn_exc = e
140 140 while svn_exc:
141 141 if svn_exc.args[0]:
142 142 msglist.append(svn_exc.args[0])
143 143 svn_exc = svn_exc.child
144 144 msg = '\n'.join(msglist)
145 145 raise SubversionConnectionException(msg)
146 146
147 147
148 148 class svnremoterepo(object):
149 149 """ the dumb wrapper for actual Subversion repositories """
150 150
151 151 def __init__(self, username: bytes = b'', password: bytes = b'', svn_url: bytes = b''):
152 152 self.username = username or b''
153 153 self.password = password or b''
154 154 self.path = normalize_url(svn_url)
155 155
156 156 def svn(self):
157 157 try:
158 158 return SubversionRepo(self.path, self.username, self.password)
159 159 except SubversionConnectionException as e:
160 160 raise error.Abort(safe_bytes(e))
@@ -1,413 +1,413 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Handles the Git smart protocol."""
19 19
20 20 import os
21 21 import socket
22 22 import logging
23 23
24 24 import dulwich.protocol
25 25 from dulwich.protocol import CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K
26 26 from webob import Request, Response, exc
27 27
28 28 from vcsserver.lib.rc_json import json
29 29 from vcsserver import hooks, subprocessio
30 from vcsserver.utils import ascii_bytes
30 from vcsserver.str_utils import ascii_bytes
31 31
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 36 class FileWrapper(object):
37 37 """File wrapper that ensures how much data is read from it."""
38 38
39 39 def __init__(self, fd, content_length):
40 40 self.fd = fd
41 41 self.content_length = content_length
42 42 self.remain = content_length
43 43
44 44 def read(self, size):
45 45 if size <= self.remain:
46 46 try:
47 47 data = self.fd.read(size)
48 48 except socket.error:
49 49 raise IOError(self)
50 50 self.remain -= size
51 51 elif self.remain:
52 52 data = self.fd.read(self.remain)
53 53 self.remain = 0
54 54 else:
55 55 data = None
56 56 return data
57 57
58 58 def __repr__(self):
59 59 return '<FileWrapper %s len: %s, read: %s>' % (
60 60 self.fd, self.content_length, self.content_length - self.remain
61 61 )
62 62
63 63
64 64 class GitRepository(object):
65 65 """WSGI app for handling Git smart protocol endpoints."""
66 66
67 67 git_folder_signature = frozenset(('config', 'head', 'info', 'objects', 'refs'))
68 68 commands = frozenset(('git-upload-pack', 'git-receive-pack'))
69 69 valid_accepts = frozenset(('application/x-{}-result'.format(c) for c in commands))
70 70
71 71 # The last bytes are the SHA1 of the first 12 bytes.
72 72 EMPTY_PACK = (
73 73 b'PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08' +
74 74 b'\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e'
75 75 )
76 76 FLUSH_PACKET = b"0000"
77 77
78 78 SIDE_BAND_CAPS = frozenset((CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K))
79 79
80 80 def __init__(self, repo_name, content_path, git_path, update_server_info, extras):
81 81 files = frozenset(f.lower() for f in os.listdir(content_path))
82 82 valid_dir_signature = self.git_folder_signature.issubset(files)
83 83
84 84 if not valid_dir_signature:
85 85 raise OSError('%s missing git signature' % content_path)
86 86
87 87 self.content_path = content_path
88 88 self.repo_name = repo_name
89 89 self.extras = extras
90 90 self.git_path = git_path
91 91 self.update_server_info = update_server_info
92 92
93 93 def _get_fixedpath(self, path):
94 94 """
95 95 Small fix for repo_path
96 96
97 97 :param path:
98 98 """
99 99 path = path.split(self.repo_name, 1)[-1]
100 100 if path.startswith('.git'):
101 101 # for bare repos we still get the .git prefix inside, we skip it
102 102 # here, and remove from the service command
103 103 path = path[4:]
104 104
105 105 return path.strip('/')
106 106
107 107 def inforefs(self, request, unused_environ):
108 108 """
109 109 WSGI Response producer for HTTP GET Git Smart
110 110 HTTP /info/refs request.
111 111 """
112 112
113 113 git_command = request.GET.get('service')
114 114 if git_command not in self.commands:
115 115 log.debug('command %s not allowed', git_command)
116 116 return exc.HTTPForbidden()
117 117
118 118 # please, resist the urge to add '\n' to git capture and increment
119 119 # line count by 1.
120 120 # by git docs: Documentation/technical/http-protocol.txt#L214 \n is
121 121 # a part of protocol.
122 122 # The code in Git client not only does NOT need '\n', but actually
123 123 # blows up if you sprinkle "flush" (0000) as "0001\n".
124 124 # It reads binary, per number of bytes specified.
125 125 # if you do add '\n' as part of data, count it.
126 126 server_advert = '# service=%s\n' % git_command
127 127 packet_len = hex(len(server_advert) + 4)[2:].rjust(4, '0').lower()
128 128 try:
129 129 gitenv = dict(os.environ)
130 130 # forget all configs
131 131 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
132 132 command = [self.git_path, git_command[4:], '--stateless-rpc',
133 133 '--advertise-refs', self.content_path]
134 134 out = subprocessio.SubprocessIOChunker(
135 135 command,
136 136 env=gitenv,
137 137 starting_values=[ascii_bytes(packet_len + server_advert) + self.FLUSH_PACKET],
138 138 shell=False
139 139 )
140 140 except OSError:
141 141 log.exception('Error processing command')
142 142 raise exc.HTTPExpectationFailed()
143 143
144 144 resp = Response()
145 145 resp.content_type = f'application/x-{git_command}-advertisement'
146 146 resp.charset = None
147 147 resp.app_iter = out
148 148
149 149 return resp
150 150
151 151 def _get_want_capabilities(self, request):
152 152 """Read the capabilities found in the first want line of the request."""
153 153 pos = request.body_file_seekable.tell()
154 154 first_line = request.body_file_seekable.readline()
155 155 request.body_file_seekable.seek(pos)
156 156
157 157 return frozenset(
158 158 dulwich.protocol.extract_want_line_capabilities(first_line)[1])
159 159
160 160 def _build_failed_pre_pull_response(self, capabilities, pre_pull_messages):
161 161 """
162 162 Construct a response with an empty PACK file.
163 163
164 164 We use an empty PACK file, as that would trigger the failure of the pull
165 165 or clone command.
166 166
167 167 We also print in the error output a message explaining why the command
168 168 was aborted.
169 169
170 170 If additionally, the user is accepting messages we send them the output
171 171 of the pre-pull hook.
172 172
173 173 Note that for clients not supporting side-band we just send them the
174 174 emtpy PACK file.
175 175 """
176 176
177 177 if self.SIDE_BAND_CAPS.intersection(capabilities):
178 178 response = []
179 179 proto = dulwich.protocol.Protocol(None, response.append)
180 180 proto.write_pkt_line(dulwich.protocol.NAK_LINE)
181 181
182 182 self._write_sideband_to_proto(proto, ascii_bytes(pre_pull_messages, allow_bytes=True), capabilities)
183 183 # N.B.(skreft): Do not change the sideband channel to 3, as that
184 184 # produces a fatal error in the client:
185 185 # fatal: error in sideband demultiplexer
186 186 proto.write_sideband(
187 187 dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS,
188 188 ascii_bytes('Pre pull hook failed: aborting\n', allow_bytes=True))
189 189 proto.write_sideband(
190 190 dulwich.protocol.SIDE_BAND_CHANNEL_DATA,
191 191 ascii_bytes(self.EMPTY_PACK, allow_bytes=True))
192 192
193 193 # writes b"0000" as default
194 194 proto.write_pkt_line(None)
195 195
196 196 return response
197 197 else:
198 198 return [ascii_bytes(self.EMPTY_PACK, allow_bytes=True)]
199 199
200 200 def _build_post_pull_response(self, response, capabilities, start_message, end_message):
201 201 """
202 202 Given a list response we inject the post-pull messages.
203 203
204 204 We only inject the messages if the client supports sideband, and the
205 205 response has the format:
206 206 0008NAK\n...0000
207 207
208 208 Note that we do not check the no-progress capability as by default, git
209 209 sends it, which effectively would block all messages.
210 210 """
211 211
212 212 if not self.SIDE_BAND_CAPS.intersection(capabilities):
213 213 return response
214 214
215 215 if not start_message and not end_message:
216 216 return response
217 217
218 218 try:
219 219 iter(response)
220 220 # iterator probably will work, we continue
221 221 except TypeError:
222 222 raise TypeError(f'response must be an iterator: got {type(response)}')
223 223 if isinstance(response, (list, tuple)):
224 224 raise TypeError(f'response must be an iterator: got {type(response)}')
225 225
226 226 def injected_response():
227 227
228 228 do_loop = 1
229 229 header_injected = 0
230 230 next_item = None
231 231 has_item = False
232 232 while do_loop:
233 233
234 234 try:
235 235 next_item = next(response)
236 236 except StopIteration:
237 237 do_loop = 0
238 238
239 239 if has_item:
240 240 # last item ! alter it now
241 241 if do_loop == 0 and item.endswith(self.FLUSH_PACKET):
242 242 new_response = [item[:-4]]
243 243 new_response.extend(self._get_messages(end_message, capabilities))
244 244 new_response.append(self.FLUSH_PACKET)
245 245 item = b''.join(new_response)
246 246
247 247 yield item
248 248 has_item = True
249 249 item = next_item
250 250
251 251 # alter item if it's the initial chunk
252 252 if not header_injected and item.startswith(b'0008NAK\n'):
253 253 new_response = [b'0008NAK\n']
254 254 new_response.extend(self._get_messages(start_message, capabilities))
255 255 new_response.append(item[8:])
256 256 item = b''.join(new_response)
257 257 header_injected = 1
258 258
259 259 return injected_response()
260 260
261 261 def _write_sideband_to_proto(self, proto, data, capabilities):
262 262 """
263 263 Write the data to the proto's sideband number 2 == SIDE_BAND_CHANNEL_PROGRESS
264 264
265 265 We do not use dulwich's write_sideband directly as it only supports
266 266 side-band-64k.
267 267 """
268 268 if not data:
269 269 return
270 270
271 271 # N.B.(skreft): The values below are explained in the pack protocol
272 272 # documentation, section Packfile Data.
273 273 # https://github.com/git/git/blob/master/Documentation/technical/pack-protocol.txt
274 274 if CAPABILITY_SIDE_BAND_64K in capabilities:
275 275 chunk_size = 65515
276 276 elif CAPABILITY_SIDE_BAND in capabilities:
277 277 chunk_size = 995
278 278 else:
279 279 return
280 280
281 281 chunker = (data[i:i + chunk_size] for i in range(0, len(data), chunk_size))
282 282
283 283 for chunk in chunker:
284 284 proto.write_sideband(dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS, ascii_bytes(chunk, allow_bytes=True))
285 285
286 286 def _get_messages(self, data, capabilities):
287 287 """Return a list with packets for sending data in sideband number 2."""
288 288 response = []
289 289 proto = dulwich.protocol.Protocol(None, response.append)
290 290
291 291 self._write_sideband_to_proto(proto, data, capabilities)
292 292
293 293 return response
294 294
295 295 def backend(self, request, environ):
296 296 """
297 297 WSGI Response producer for HTTP POST Git Smart HTTP requests.
298 298 Reads commands and data from HTTP POST's body.
299 299 returns an iterator obj with contents of git command's
300 300 response to stdout
301 301 """
302 302 # TODO(skreft): think how we could detect an HTTPLockedException, as
303 303 # we probably want to have the same mechanism used by mercurial and
304 304 # simplevcs.
305 305 # For that we would need to parse the output of the command looking for
306 306 # some signs of the HTTPLockedError, parse the data and reraise it in
307 307 # pygrack. However, that would interfere with the streaming.
308 308 #
309 309 # Now the output of a blocked push is:
310 310 # Pushing to http://test_regular:test12@127.0.0.1:5001/vcs_test_git
311 311 # POST git-receive-pack (1047 bytes)
312 312 # remote: ERROR: Repository `vcs_test_git` locked by user `test_admin`. Reason:`lock_auto`
313 313 # To http://test_regular:test12@127.0.0.1:5001/vcs_test_git
314 314 # ! [remote rejected] master -> master (pre-receive hook declined)
315 315 # error: failed to push some refs to 'http://test_regular:test12@127.0.0.1:5001/vcs_test_git'
316 316
317 317 git_command = self._get_fixedpath(request.path_info)
318 318 if git_command not in self.commands:
319 319 log.debug('command %s not allowed', git_command)
320 320 return exc.HTTPForbidden()
321 321
322 322 capabilities = None
323 323 if git_command == 'git-upload-pack':
324 324 capabilities = self._get_want_capabilities(request)
325 325
326 326 if 'CONTENT_LENGTH' in environ:
327 327 inputstream = FileWrapper(request.body_file_seekable,
328 328 request.content_length)
329 329 else:
330 330 inputstream = request.body_file_seekable
331 331
332 332 resp = Response()
333 333 resp.content_type = 'application/x-{}-result'.format(git_command)
334 334 resp.charset = None
335 335
336 336 pre_pull_messages = ''
337 337 # Upload-pack == clone
338 338 if git_command == 'git-upload-pack':
339 339 status, pre_pull_messages = hooks.git_pre_pull(self.extras)
340 340 if status != 0:
341 341 resp.app_iter = self._build_failed_pre_pull_response(
342 342 capabilities, pre_pull_messages)
343 343 return resp
344 344
345 345 gitenv = dict(os.environ)
346 346 # forget all configs
347 347 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
348 348 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
349 349 cmd = [self.git_path, git_command[4:], '--stateless-rpc',
350 350 self.content_path]
351 351 log.debug('handling cmd %s', cmd)
352 352
353 353 out = subprocessio.SubprocessIOChunker(
354 354 cmd,
355 355 input_stream=inputstream,
356 356 env=gitenv,
357 357 cwd=self.content_path,
358 358 shell=False,
359 359 fail_on_stderr=False,
360 360 fail_on_return_code=False
361 361 )
362 362
363 363 if self.update_server_info and git_command == 'git-receive-pack':
364 364 # We need to fully consume the iterator here, as the
365 365 # update-server-info command needs to be run after the push.
366 366 out = list(out)
367 367
368 368 # Updating refs manually after each push.
369 369 # This is required as some clients are exposing Git repos internally
370 370 # with the dumb protocol.
371 371 cmd = [self.git_path, 'update-server-info']
372 372 log.debug('handling cmd %s', cmd)
373 373 output = subprocessio.SubprocessIOChunker(
374 374 cmd,
375 375 input_stream=inputstream,
376 376 env=gitenv,
377 377 cwd=self.content_path,
378 378 shell=False,
379 379 fail_on_stderr=False,
380 380 fail_on_return_code=False
381 381 )
382 382 # Consume all the output so the subprocess finishes
383 383 for _ in output:
384 384 pass
385 385
386 386 # Upload-pack == clone
387 387 if git_command == 'git-upload-pack':
388 388 unused_status, post_pull_messages = hooks.git_post_pull(self.extras)
389 389
390 390 resp.app_iter = self._build_post_pull_response(out, capabilities, pre_pull_messages, post_pull_messages)
391 391 else:
392 392 resp.app_iter = out
393 393
394 394 return resp
395 395
396 396 def __call__(self, environ, start_response):
397 397 request = Request(environ)
398 398 _path = self._get_fixedpath(request.path_info)
399 399 if _path.startswith('info/refs'):
400 400 app = self.inforefs
401 401 else:
402 402 app = self.backend
403 403
404 404 try:
405 405 resp = app(request, environ)
406 406 except exc.HTTPException as error:
407 407 log.exception('HTTP Error')
408 408 resp = error
409 409 except Exception:
410 410 log.exception('Unknown error')
411 411 resp = exc.HTTPInternalServerError()
412 412
413 413 return resp(environ, start_response)
@@ -1,1317 +1,1317 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import collections
19 19 import logging
20 20 import os
21 21 import posixpath as vcspath
22 22 import re
23 23 import stat
24 24 import traceback
25 25 import urllib.request, urllib.parse, urllib.error
26 26 import urllib.request, urllib.error, urllib.parse
27 27 from functools import wraps
28 28
29 29 import more_itertools
30 30 import pygit2
31 31 from pygit2 import Repository as LibGit2Repo
32 32 from pygit2 import index as LibGit2Index
33 33 from dulwich import index, objects
34 34 from dulwich.client import HttpGitClient, LocalGitClient
35 35 from dulwich.errors import (
36 36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 37 MissingCommitError, ObjectMissing, HangupException,
38 38 UnexpectedCommandError)
39 39 from dulwich.repo import Repo as DulwichRepo
40 40 from dulwich.server import update_server_info
41 41
42 42 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver.utils import safe_str, safe_int
43 from vcsserver.str_utils import safe_str, safe_int
44 44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, archive_repo
45 45 from vcsserver.hgcompat import (
46 46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 47 from vcsserver.git_lfs.lib import LFSOidStore
48 48 from vcsserver.vcs_base import RemoteBase
49 49
50 50 DIR_STAT = stat.S_IFDIR
51 51 FILE_MODE = stat.S_IFMT
52 52 GIT_LINK = objects.S_IFGITLINK
53 53 PEELED_REF_MARKER = '^{}'
54 54
55 55
56 56 log = logging.getLogger(__name__)
57 57
58 58
59 59 def reraise_safe_exceptions(func):
60 60 """Converts Dulwich exceptions to something neutral."""
61 61
62 62 @wraps(func)
63 63 def wrapper(*args, **kwargs):
64 64 try:
65 65 return func(*args, **kwargs)
66 66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
67 67 exc = exceptions.LookupException(org_exc=e)
68 68 raise exc(safe_str(e))
69 69 except (HangupException, UnexpectedCommandError) as e:
70 70 exc = exceptions.VcsException(org_exc=e)
71 71 raise exc(safe_str(e))
72 72 except Exception as e:
73 73 # NOTE(marcink): becuase of how dulwich handles some exceptions
74 74 # (KeyError on empty repos), we cannot track this and catch all
75 75 # exceptions, it's an exceptions from other handlers
76 76 #if not hasattr(e, '_vcs_kind'):
77 77 #log.exception("Unhandled exception in git remote call")
78 78 #raise_from_original(exceptions.UnhandledException)
79 79 raise
80 80 return wrapper
81 81
82 82
83 83 class Repo(DulwichRepo):
84 84 """
85 85 A wrapper for dulwich Repo class.
86 86
87 87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
88 88 "Too many open files" error. We need to close all opened file descriptors
89 89 once the repo object is destroyed.
90 90 """
91 91 def __del__(self):
92 92 if hasattr(self, 'object_store'):
93 93 self.close()
94 94
95 95
96 96 class Repository(LibGit2Repo):
97 97
98 98 def __enter__(self):
99 99 return self
100 100
101 101 def __exit__(self, exc_type, exc_val, exc_tb):
102 102 self.free()
103 103
104 104
105 105 class GitFactory(RepoFactory):
106 106 repo_type = 'git'
107 107
108 108 def _create_repo(self, wire, create, use_libgit2=False):
109 109 if use_libgit2:
110 110 return Repository(wire['path'])
111 111 else:
112 112 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
113 113 return Repo(repo_path)
114 114
115 115 def repo(self, wire, create=False, use_libgit2=False):
116 116 """
117 117 Get a repository instance for the given path.
118 118 """
119 119 return self._create_repo(wire, create, use_libgit2)
120 120
121 121 def repo_libgit2(self, wire):
122 122 return self.repo(wire, use_libgit2=True)
123 123
124 124
125 125 class GitRemote(RemoteBase):
126 126
127 127 def __init__(self, factory):
128 128 self._factory = factory
129 129 self._bulk_methods = {
130 130 "date": self.date,
131 131 "author": self.author,
132 132 "branch": self.branch,
133 133 "message": self.message,
134 134 "parents": self.parents,
135 135 "_commit": self.revision,
136 136 }
137 137
138 138 def _wire_to_config(self, wire):
139 139 if 'config' in wire:
140 140 return dict([(x[0] + '_' + x[1], x[2]) for x in wire['config']])
141 141 return {}
142 142
143 143 def _remote_conf(self, config):
144 144 params = [
145 145 '-c', 'core.askpass=""',
146 146 ]
147 147 ssl_cert_dir = config.get('vcs_ssl_dir')
148 148 if ssl_cert_dir:
149 149 params.extend(['-c', 'http.sslCAinfo={}'.format(ssl_cert_dir)])
150 150 return params
151 151
152 152 @reraise_safe_exceptions
153 153 def discover_git_version(self):
154 154 stdout, _ = self.run_git_command(
155 155 {}, ['--version'], _bare=True, _safe=True)
156 156 prefix = b'git version'
157 157 if stdout.startswith(prefix):
158 158 stdout = stdout[len(prefix):]
159 159 return stdout.strip()
160 160
161 161 @reraise_safe_exceptions
162 162 def is_empty(self, wire):
163 163 repo_init = self._factory.repo_libgit2(wire)
164 164 with repo_init as repo:
165 165
166 166 try:
167 167 has_head = repo.head.name
168 168 if has_head:
169 169 return False
170 170
171 171 # NOTE(marcink): check again using more expensive method
172 172 return repo.is_empty
173 173 except Exception:
174 174 pass
175 175
176 176 return True
177 177
178 178 @reraise_safe_exceptions
179 179 def assert_correct_path(self, wire):
180 180 cache_on, context_uid, repo_id = self._cache_on(wire)
181 181 region = self._region(wire)
182 182
183 183 @region.conditional_cache_on_arguments(condition=cache_on)
184 184 def _assert_correct_path(_context_uid, _repo_id):
185 185 try:
186 186 repo_init = self._factory.repo_libgit2(wire)
187 187 with repo_init as repo:
188 188 pass
189 189 except pygit2.GitError:
190 190 path = wire.get('path')
191 191 tb = traceback.format_exc()
192 192 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
193 193 return False
194 194
195 195 return True
196 196 return _assert_correct_path(context_uid, repo_id)
197 197
198 198 @reraise_safe_exceptions
199 199 def bare(self, wire):
200 200 repo_init = self._factory.repo_libgit2(wire)
201 201 with repo_init as repo:
202 202 return repo.is_bare
203 203
204 204 @reraise_safe_exceptions
205 205 def blob_as_pretty_string(self, wire, sha):
206 206 repo_init = self._factory.repo_libgit2(wire)
207 207 with repo_init as repo:
208 208 blob_obj = repo[sha]
209 209 blob = blob_obj.data
210 210 return blob
211 211
212 212 @reraise_safe_exceptions
213 213 def blob_raw_length(self, wire, sha):
214 214 cache_on, context_uid, repo_id = self._cache_on(wire)
215 215 region = self._region(wire)
216 216
217 217 @region.conditional_cache_on_arguments(condition=cache_on)
218 218 def _blob_raw_length(_repo_id, _sha):
219 219
220 220 repo_init = self._factory.repo_libgit2(wire)
221 221 with repo_init as repo:
222 222 blob = repo[sha]
223 223 return blob.size
224 224
225 225 return _blob_raw_length(repo_id, sha)
226 226
227 227 def _parse_lfs_pointer(self, raw_content):
228 228 spec_string = b'version https://git-lfs.github.com/spec'
229 229 if raw_content and raw_content.startswith(spec_string):
230 230
231 231 pattern = re.compile(rb"""
232 232 (?:\n)?
233 233 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
234 234 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
235 235 ^size[ ](?P<oid_size>[0-9]+)\n
236 236 (?:\n)?
237 237 """, re.VERBOSE | re.MULTILINE)
238 238 match = pattern.match(raw_content)
239 239 if match:
240 240 return match.groupdict()
241 241
242 242 return {}
243 243
244 244 @reraise_safe_exceptions
245 245 def is_large_file(self, wire, commit_id):
246 246 cache_on, context_uid, repo_id = self._cache_on(wire)
247 247 region = self._region(wire)
248 248
249 249 @region.conditional_cache_on_arguments(condition=cache_on)
250 250 def _is_large_file(_repo_id, _sha):
251 251 repo_init = self._factory.repo_libgit2(wire)
252 252 with repo_init as repo:
253 253 blob = repo[commit_id]
254 254 if blob.is_binary:
255 255 return {}
256 256
257 257 return self._parse_lfs_pointer(blob.data)
258 258
259 259 return _is_large_file(repo_id, commit_id)
260 260
261 261 @reraise_safe_exceptions
262 262 def is_binary(self, wire, tree_id):
263 263 cache_on, context_uid, repo_id = self._cache_on(wire)
264 264 region = self._region(wire)
265 265
266 266 @region.conditional_cache_on_arguments(condition=cache_on)
267 267 def _is_binary(_repo_id, _tree_id):
268 268 repo_init = self._factory.repo_libgit2(wire)
269 269 with repo_init as repo:
270 270 blob_obj = repo[tree_id]
271 271 return blob_obj.is_binary
272 272
273 273 return _is_binary(repo_id, tree_id)
274 274
275 275 @reraise_safe_exceptions
276 276 def in_largefiles_store(self, wire, oid):
277 277 conf = self._wire_to_config(wire)
278 278 repo_init = self._factory.repo_libgit2(wire)
279 279 with repo_init as repo:
280 280 repo_name = repo.path
281 281
282 282 store_location = conf.get('vcs_git_lfs_store_location')
283 283 if store_location:
284 284
285 285 store = LFSOidStore(
286 286 oid=oid, repo=repo_name, store_location=store_location)
287 287 return store.has_oid()
288 288
289 289 return False
290 290
291 291 @reraise_safe_exceptions
292 292 def store_path(self, wire, oid):
293 293 conf = self._wire_to_config(wire)
294 294 repo_init = self._factory.repo_libgit2(wire)
295 295 with repo_init as repo:
296 296 repo_name = repo.path
297 297
298 298 store_location = conf.get('vcs_git_lfs_store_location')
299 299 if store_location:
300 300 store = LFSOidStore(
301 301 oid=oid, repo=repo_name, store_location=store_location)
302 302 return store.oid_path
303 303 raise ValueError('Unable to fetch oid with path {}'.format(oid))
304 304
305 305 @reraise_safe_exceptions
306 306 def bulk_request(self, wire, rev, pre_load):
307 307 cache_on, context_uid, repo_id = self._cache_on(wire)
308 308 region = self._region(wire)
309 309
310 310 @region.conditional_cache_on_arguments(condition=cache_on)
311 311 def _bulk_request(_repo_id, _rev, _pre_load):
312 312 result = {}
313 313 for attr in pre_load:
314 314 try:
315 315 method = self._bulk_methods[attr]
316 316 args = [wire, rev]
317 317 result[attr] = method(*args)
318 318 except KeyError as e:
319 319 raise exceptions.VcsException(e)(
320 320 "Unknown bulk attribute: %s" % attr)
321 321 return result
322 322
323 323 return _bulk_request(repo_id, rev, sorted(pre_load))
324 324
325 325 def _build_opener(self, url):
326 326 handlers = []
327 327 url_obj = url_parser(url)
328 328 _, authinfo = url_obj.authinfo()
329 329
330 330 if authinfo:
331 331 # create a password manager
332 332 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
333 333 passmgr.add_password(*authinfo)
334 334
335 335 handlers.extend((httpbasicauthhandler(passmgr),
336 336 httpdigestauthhandler(passmgr)))
337 337
338 338 return urllib.request.build_opener(*handlers)
339 339
340 340 def _type_id_to_name(self, type_id: int):
341 341 return {
342 342 1: 'commit',
343 343 2: 'tree',
344 344 3: 'blob',
345 345 4: 'tag'
346 346 }[type_id]
347 347
348 348 @reraise_safe_exceptions
349 349 def check_url(self, url, config):
350 350 url_obj = url_parser(url)
351 351 test_uri, _ = url_obj.authinfo()
352 352 url_obj.passwd = '*****' if url_obj.passwd else url_obj.passwd
353 353 url_obj.query = obfuscate_qs(url_obj.query)
354 354 cleaned_uri = str(url_obj)
355 355 log.info("Checking URL for remote cloning/import: %s", cleaned_uri)
356 356
357 357 if not test_uri.endswith('info/refs'):
358 358 test_uri = test_uri.rstrip('/') + '/info/refs'
359 359
360 360 o = self._build_opener(url)
361 361 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
362 362
363 363 q = {"service": 'git-upload-pack'}
364 364 qs = '?%s' % urllib.parse.urlencode(q)
365 365 cu = "%s%s" % (test_uri, qs)
366 366 req = urllib.request.Request(cu, None, {})
367 367
368 368 try:
369 369 log.debug("Trying to open URL %s", cleaned_uri)
370 370 resp = o.open(req)
371 371 if resp.code != 200:
372 372 raise exceptions.URLError()('Return Code is not 200')
373 373 except Exception as e:
374 374 log.warning("URL cannot be opened: %s", cleaned_uri, exc_info=True)
375 375 # means it cannot be cloned
376 376 raise exceptions.URLError(e)("[%s] org_exc: %s" % (cleaned_uri, e))
377 377
378 378 # now detect if it's proper git repo
379 379 gitdata = resp.read()
380 380 if 'service=git-upload-pack' in gitdata:
381 381 pass
382 382 elif re.findall(r'[0-9a-fA-F]{40}\s+refs', gitdata):
383 383 # old style git can return some other format !
384 384 pass
385 385 else:
386 386 raise exceptions.URLError()(
387 387 "url [%s] does not look like an git" % (cleaned_uri,))
388 388
389 389 return True
390 390
391 391 @reraise_safe_exceptions
392 392 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
393 393 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
394 394 remote_refs = self.pull(wire, url, apply_refs=False)
395 395 repo = self._factory.repo(wire)
396 396 if isinstance(valid_refs, list):
397 397 valid_refs = tuple(valid_refs)
398 398
399 399 for k in remote_refs:
400 400 # only parse heads/tags and skip so called deferred tags
401 401 if k.startswith(valid_refs) and not k.endswith(deferred):
402 402 repo[k] = remote_refs[k]
403 403
404 404 if update_after_clone:
405 405 # we want to checkout HEAD
406 406 repo["HEAD"] = remote_refs["HEAD"]
407 407 index.build_index_from_tree(repo.path, repo.index_path(),
408 408 repo.object_store, repo["HEAD"].tree)
409 409
410 410 @reraise_safe_exceptions
411 411 def branch(self, wire, commit_id):
412 412 cache_on, context_uid, repo_id = self._cache_on(wire)
413 413 region = self._region(wire)
414 414 @region.conditional_cache_on_arguments(condition=cache_on)
415 415 def _branch(_context_uid, _repo_id, _commit_id):
416 416 regex = re.compile('^refs/heads')
417 417
418 418 def filter_with(ref):
419 419 return regex.match(ref[0]) and ref[1] == _commit_id
420 420
421 421 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
422 422 return [x[0].split('refs/heads/')[-1] for x in branches]
423 423
424 424 return _branch(context_uid, repo_id, commit_id)
425 425
426 426 @reraise_safe_exceptions
427 427 def commit_branches(self, wire, commit_id):
428 428 cache_on, context_uid, repo_id = self._cache_on(wire)
429 429 region = self._region(wire)
430 430 @region.conditional_cache_on_arguments(condition=cache_on)
431 431 def _commit_branches(_context_uid, _repo_id, _commit_id):
432 432 repo_init = self._factory.repo_libgit2(wire)
433 433 with repo_init as repo:
434 434 branches = [x for x in repo.branches.with_commit(_commit_id)]
435 435 return branches
436 436
437 437 return _commit_branches(context_uid, repo_id, commit_id)
438 438
439 439 @reraise_safe_exceptions
440 440 def add_object(self, wire, content):
441 441 repo_init = self._factory.repo_libgit2(wire)
442 442 with repo_init as repo:
443 443 blob = objects.Blob()
444 444 blob.set_raw_string(content)
445 445 repo.object_store.add_object(blob)
446 446 return blob.id
447 447
448 448 # TODO: this is quite complex, check if that can be simplified
449 449 @reraise_safe_exceptions
450 450 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
451 451 # Defines the root tree
452 452 class _Root(object):
453 453 def __repr__(self):
454 454 return 'ROOT TREE'
455 455 ROOT = _Root()
456 456
457 457 repo = self._factory.repo(wire)
458 458 object_store = repo.object_store
459 459
460 460 # Create tree and populates it with blobs
461 461
462 462 if commit_tree and repo[commit_tree]:
463 463 git_commit = repo[commit_data['parents'][0]]
464 464 commit_tree = repo[git_commit.tree] # root tree
465 465 else:
466 466 commit_tree = objects.Tree()
467 467
468 468 for node in updated:
469 469 # Compute subdirs if needed
470 470 dirpath, nodename = vcspath.split(node['path'])
471 471 dirnames = list(map(safe_str, dirpath and dirpath.split('/') or []))
472 472 parent = commit_tree
473 473 ancestors = [('', parent)]
474 474
475 475 # Tries to dig for the deepest existing tree
476 476 while dirnames:
477 477 curdir = dirnames.pop(0)
478 478 try:
479 479 dir_id = parent[curdir][1]
480 480 except KeyError:
481 481 # put curdir back into dirnames and stops
482 482 dirnames.insert(0, curdir)
483 483 break
484 484 else:
485 485 # If found, updates parent
486 486 parent = repo[dir_id]
487 487 ancestors.append((curdir, parent))
488 488 # Now parent is deepest existing tree and we need to create
489 489 # subtrees for dirnames (in reverse order)
490 490 # [this only applies for nodes from added]
491 491 new_trees = []
492 492
493 493 blob = objects.Blob.from_string(node['content'])
494 494
495 495 if dirnames:
496 496 # If there are trees which should be created we need to build
497 497 # them now (in reverse order)
498 498 reversed_dirnames = list(reversed(dirnames))
499 499 curtree = objects.Tree()
500 500 curtree[node['node_path']] = node['mode'], blob.id
501 501 new_trees.append(curtree)
502 502 for dirname in reversed_dirnames[:-1]:
503 503 newtree = objects.Tree()
504 504 newtree[dirname] = (DIR_STAT, curtree.id)
505 505 new_trees.append(newtree)
506 506 curtree = newtree
507 507 parent[reversed_dirnames[-1]] = (DIR_STAT, curtree.id)
508 508 else:
509 509 parent.add(name=node['node_path'], mode=node['mode'], hexsha=blob.id)
510 510
511 511 new_trees.append(parent)
512 512 # Update ancestors
513 513 reversed_ancestors = reversed(
514 514 [(a[1], b[1], b[0]) for a, b in zip(ancestors, ancestors[1:])])
515 515 for parent, tree, path in reversed_ancestors:
516 516 parent[path] = (DIR_STAT, tree.id)
517 517 object_store.add_object(tree)
518 518
519 519 object_store.add_object(blob)
520 520 for tree in new_trees:
521 521 object_store.add_object(tree)
522 522
523 523 for node_path in removed:
524 524 paths = node_path.split('/')
525 525 tree = commit_tree # start with top-level
526 526 trees = [{'tree': tree, 'path': ROOT}]
527 527 # Traverse deep into the forest...
528 528 # resolve final tree by iterating the path.
529 529 # e.g a/b/c.txt will get
530 530 # - root as tree then
531 531 # - 'a' as tree,
532 532 # - 'b' as tree,
533 533 # - stop at c as blob.
534 534 for path in paths:
535 535 try:
536 536 obj = repo[tree[path][1]]
537 537 if isinstance(obj, objects.Tree):
538 538 trees.append({'tree': obj, 'path': path})
539 539 tree = obj
540 540 except KeyError:
541 541 break
542 542 #PROBLEM:
543 543 """
544 544 We're not editing same reference tree object
545 545 """
546 546 # Cut down the blob and all rotten trees on the way back...
547 547 for path, tree_data in reversed(list(zip(paths, trees))):
548 548 tree = tree_data['tree']
549 549 tree.__delitem__(path)
550 550 # This operation edits the tree, we need to mark new commit back
551 551
552 552 if len(tree) > 0:
553 553 # This tree still has elements - don't remove it or any
554 554 # of it's parents
555 555 break
556 556
557 557 object_store.add_object(commit_tree)
558 558
559 559 # Create commit
560 560 commit = objects.Commit()
561 561 commit.tree = commit_tree.id
562 562 for k, v in commit_data.items():
563 563 setattr(commit, k, v)
564 564 object_store.add_object(commit)
565 565
566 566 self.create_branch(wire, branch, commit.id)
567 567
568 568 # dulwich set-ref
569 569 ref = 'refs/heads/%s' % branch
570 570 repo.refs[ref] = commit.id
571 571
572 572 return commit.id
573 573
574 574 @reraise_safe_exceptions
575 575 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
576 576 if url != 'default' and '://' not in url:
577 577 client = LocalGitClient(url)
578 578 else:
579 579 url_obj = url_parser(url)
580 580 o = self._build_opener(url)
581 581 url, _ = url_obj.authinfo()
582 582 client = HttpGitClient(base_url=url, opener=o)
583 583 repo = self._factory.repo(wire)
584 584
585 585 determine_wants = repo.object_store.determine_wants_all
586 586 if refs:
587 587 def determine_wants_requested(references):
588 588 return [references[r] for r in references if r in refs]
589 589 determine_wants = determine_wants_requested
590 590
591 591 try:
592 592 remote_refs = client.fetch(
593 593 path=url, target=repo, determine_wants=determine_wants)
594 594 except NotGitRepository as e:
595 595 log.warning(
596 596 'Trying to fetch from "%s" failed, not a Git repository.', url)
597 597 # Exception can contain unicode which we convert
598 598 raise exceptions.AbortException(e)(repr(e))
599 599
600 600 # mikhail: client.fetch() returns all the remote refs, but fetches only
601 601 # refs filtered by `determine_wants` function. We need to filter result
602 602 # as well
603 603 if refs:
604 604 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
605 605
606 606 if apply_refs:
607 607 # TODO: johbo: Needs proper test coverage with a git repository
608 608 # that contains a tag object, so that we would end up with
609 609 # a peeled ref at this point.
610 610 for k in remote_refs:
611 611 if k.endswith(PEELED_REF_MARKER):
612 612 log.debug("Skipping peeled reference %s", k)
613 613 continue
614 614 repo[k] = remote_refs[k]
615 615
616 616 if refs and not update_after:
617 617 # mikhail: explicitly set the head to the last ref.
618 618 repo["HEAD"] = remote_refs[refs[-1]]
619 619
620 620 if update_after:
621 621 # we want to checkout HEAD
622 622 repo["HEAD"] = remote_refs["HEAD"]
623 623 index.build_index_from_tree(repo.path, repo.index_path(),
624 624 repo.object_store, repo["HEAD"].tree)
625 625 return remote_refs
626 626
627 627 @reraise_safe_exceptions
628 628 def sync_fetch(self, wire, url, refs=None, all_refs=False):
629 629 repo = self._factory.repo(wire)
630 630 if refs and not isinstance(refs, (list, tuple)):
631 631 refs = [refs]
632 632
633 633 config = self._wire_to_config(wire)
634 634 # get all remote refs we'll use to fetch later
635 635 cmd = ['ls-remote']
636 636 if not all_refs:
637 637 cmd += ['--heads', '--tags']
638 638 cmd += [url]
639 639 output, __ = self.run_git_command(
640 640 wire, cmd, fail_on_stderr=False,
641 641 _copts=self._remote_conf(config),
642 642 extra_env={'GIT_TERMINAL_PROMPT': '0'})
643 643
644 644 remote_refs = collections.OrderedDict()
645 645 fetch_refs = []
646 646
647 647 for ref_line in output.splitlines():
648 648 sha, ref = ref_line.split('\t')
649 649 sha = sha.strip()
650 650 if ref in remote_refs:
651 651 # duplicate, skip
652 652 continue
653 653 if ref.endswith(PEELED_REF_MARKER):
654 654 log.debug("Skipping peeled reference %s", ref)
655 655 continue
656 656 # don't sync HEAD
657 657 if ref in ['HEAD']:
658 658 continue
659 659
660 660 remote_refs[ref] = sha
661 661
662 662 if refs and sha in refs:
663 663 # we filter fetch using our specified refs
664 664 fetch_refs.append('{}:{}'.format(ref, ref))
665 665 elif not refs:
666 666 fetch_refs.append('{}:{}'.format(ref, ref))
667 667 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
668 668
669 669 if fetch_refs:
670 670 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
671 671 fetch_refs_chunks = list(chunk)
672 672 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
673 673 self.run_git_command(
674 674 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
675 675 fail_on_stderr=False,
676 676 _copts=self._remote_conf(config),
677 677 extra_env={'GIT_TERMINAL_PROMPT': '0'})
678 678
679 679 return remote_refs
680 680
681 681 @reraise_safe_exceptions
682 682 def sync_push(self, wire, url, refs=None):
683 683 if not self.check_url(url, wire):
684 684 return
685 685 config = self._wire_to_config(wire)
686 686 self._factory.repo(wire)
687 687 self.run_git_command(
688 688 wire, ['push', url, '--mirror'], fail_on_stderr=False,
689 689 _copts=self._remote_conf(config),
690 690 extra_env={'GIT_TERMINAL_PROMPT': '0'})
691 691
692 692 @reraise_safe_exceptions
693 693 def get_remote_refs(self, wire, url):
694 694 repo = Repo(url)
695 695 return repo.get_refs()
696 696
697 697 @reraise_safe_exceptions
698 698 def get_description(self, wire):
699 699 repo = self._factory.repo(wire)
700 700 return repo.get_description()
701 701
702 702 @reraise_safe_exceptions
703 703 def get_missing_revs(self, wire, rev1, rev2, path2):
704 704 repo = self._factory.repo(wire)
705 705 LocalGitClient(thin_packs=False).fetch(path2, repo)
706 706
707 707 wire_remote = wire.copy()
708 708 wire_remote['path'] = path2
709 709 repo_remote = self._factory.repo(wire_remote)
710 710 LocalGitClient(thin_packs=False).fetch(wire["path"], repo_remote)
711 711
712 712 revs = [
713 713 x.commit.id
714 714 for x in repo_remote.get_walker(include=[rev2], exclude=[rev1])]
715 715 return revs
716 716
717 717 @reraise_safe_exceptions
718 718 def get_object(self, wire, sha, maybe_unreachable=False):
719 719 cache_on, context_uid, repo_id = self._cache_on(wire)
720 720 region = self._region(wire)
721 721
722 722 @region.conditional_cache_on_arguments(condition=cache_on)
723 723 def _get_object(_context_uid, _repo_id, _sha):
724 724 repo_init = self._factory.repo_libgit2(wire)
725 725 with repo_init as repo:
726 726
727 727 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
728 728 try:
729 729 commit = repo.revparse_single(sha)
730 730 except KeyError:
731 731 # NOTE(marcink): KeyError doesn't give us any meaningful information
732 732 # here, we instead give something more explicit
733 733 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
734 734 raise exceptions.LookupException(e)(missing_commit_err)
735 735 except ValueError as e:
736 736 raise exceptions.LookupException(e)(missing_commit_err)
737 737
738 738 is_tag = False
739 739 if isinstance(commit, pygit2.Tag):
740 740 commit = repo.get(commit.target)
741 741 is_tag = True
742 742
743 743 check_dangling = True
744 744 if is_tag:
745 745 check_dangling = False
746 746
747 747 if check_dangling and maybe_unreachable:
748 748 check_dangling = False
749 749
750 750 # we used a reference and it parsed means we're not having a dangling commit
751 751 if sha != commit.hex:
752 752 check_dangling = False
753 753
754 754 if check_dangling:
755 755 # check for dangling commit
756 756 for branch in repo.branches.with_commit(commit.hex):
757 757 if branch:
758 758 break
759 759 else:
760 760 # NOTE(marcink): Empty error doesn't give us any meaningful information
761 761 # here, we instead give something more explicit
762 762 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
763 763 raise exceptions.LookupException(e)(missing_commit_err)
764 764
765 765 commit_id = commit.hex
766 766 type_id = commit.type
767 767
768 768 return {
769 769 'id': commit_id,
770 770 'type': self._type_id_to_name(type_id),
771 771 'commit_id': commit_id,
772 772 'idx': 0
773 773 }
774 774
775 775 return _get_object(context_uid, repo_id, sha)
776 776
777 777 @reraise_safe_exceptions
778 778 def get_refs(self, wire):
779 779 cache_on, context_uid, repo_id = self._cache_on(wire)
780 780 region = self._region(wire)
781 781
782 782 @region.conditional_cache_on_arguments(condition=cache_on)
783 783 def _get_refs(_context_uid, _repo_id):
784 784
785 785 repo_init = self._factory.repo_libgit2(wire)
786 786 with repo_init as repo:
787 787 regex = re.compile('^refs/(heads|tags)/')
788 788 return {x.name: x.target.hex for x in
789 789 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
790 790
791 791 return _get_refs(context_uid, repo_id)
792 792
793 793 @reraise_safe_exceptions
794 794 def get_branch_pointers(self, wire):
795 795 cache_on, context_uid, repo_id = self._cache_on(wire)
796 796 region = self._region(wire)
797 797
798 798 @region.conditional_cache_on_arguments(condition=cache_on)
799 799 def _get_branch_pointers(_context_uid, _repo_id):
800 800
801 801 repo_init = self._factory.repo_libgit2(wire)
802 802 regex = re.compile('^refs/heads')
803 803 with repo_init as repo:
804 804 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
805 805 return {x.target.hex: x.shorthand for x in branches}
806 806
807 807 return _get_branch_pointers(context_uid, repo_id)
808 808
809 809 @reraise_safe_exceptions
810 810 def head(self, wire, show_exc=True):
811 811 cache_on, context_uid, repo_id = self._cache_on(wire)
812 812 region = self._region(wire)
813 813
814 814 @region.conditional_cache_on_arguments(condition=cache_on)
815 815 def _head(_context_uid, _repo_id, _show_exc):
816 816 repo_init = self._factory.repo_libgit2(wire)
817 817 with repo_init as repo:
818 818 try:
819 819 return repo.head.peel().hex
820 820 except Exception:
821 821 if show_exc:
822 822 raise
823 823 return _head(context_uid, repo_id, show_exc)
824 824
825 825 @reraise_safe_exceptions
826 826 def init(self, wire):
827 827 repo_path = str_to_dulwich(wire['path'])
828 828 self.repo = Repo.init(repo_path)
829 829
830 830 @reraise_safe_exceptions
831 831 def init_bare(self, wire):
832 832 repo_path = str_to_dulwich(wire['path'])
833 833 self.repo = Repo.init_bare(repo_path)
834 834
835 835 @reraise_safe_exceptions
836 836 def revision(self, wire, rev):
837 837
838 838 cache_on, context_uid, repo_id = self._cache_on(wire)
839 839 region = self._region(wire)
840 840
841 841 @region.conditional_cache_on_arguments(condition=cache_on)
842 842 def _revision(_context_uid, _repo_id, _rev):
843 843 repo_init = self._factory.repo_libgit2(wire)
844 844 with repo_init as repo:
845 845 commit = repo[rev]
846 846 obj_data = {
847 847 'id': commit.id.hex,
848 848 }
849 849 # tree objects itself don't have tree_id attribute
850 850 if hasattr(commit, 'tree_id'):
851 851 obj_data['tree'] = commit.tree_id.hex
852 852
853 853 return obj_data
854 854 return _revision(context_uid, repo_id, rev)
855 855
856 856 @reraise_safe_exceptions
857 857 def date(self, wire, commit_id):
858 858 cache_on, context_uid, repo_id = self._cache_on(wire)
859 859 region = self._region(wire)
860 860
861 861 @region.conditional_cache_on_arguments(condition=cache_on)
862 862 def _date(_repo_id, _commit_id):
863 863 repo_init = self._factory.repo_libgit2(wire)
864 864 with repo_init as repo:
865 865 commit = repo[commit_id]
866 866
867 867 if hasattr(commit, 'commit_time'):
868 868 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
869 869 else:
870 870 commit = commit.get_object()
871 871 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
872 872
873 873 # TODO(marcink): check dulwich difference of offset vs timezone
874 874 return [commit_time, commit_time_offset]
875 875 return _date(repo_id, commit_id)
876 876
877 877 @reraise_safe_exceptions
878 878 def author(self, wire, commit_id):
879 879 cache_on, context_uid, repo_id = self._cache_on(wire)
880 880 region = self._region(wire)
881 881
882 882 @region.conditional_cache_on_arguments(condition=cache_on)
883 883 def _author(_repo_id, _commit_id):
884 884 repo_init = self._factory.repo_libgit2(wire)
885 885 with repo_init as repo:
886 886 commit = repo[commit_id]
887 887
888 888 if hasattr(commit, 'author'):
889 889 author = commit.author
890 890 else:
891 891 author = commit.get_object().author
892 892
893 893 if author.email:
894 894 return "{} <{}>".format(author.name, author.email)
895 895
896 896 try:
897 897 return "{}".format(author.name)
898 898 except Exception:
899 899 return "{}".format(safe_str(author.raw_name))
900 900
901 901 return _author(repo_id, commit_id)
902 902
903 903 @reraise_safe_exceptions
904 904 def message(self, wire, commit_id):
905 905 cache_on, context_uid, repo_id = self._cache_on(wire)
906 906 region = self._region(wire)
907 907 @region.conditional_cache_on_arguments(condition=cache_on)
908 908 def _message(_repo_id, _commit_id):
909 909 repo_init = self._factory.repo_libgit2(wire)
910 910 with repo_init as repo:
911 911 commit = repo[commit_id]
912 912 return commit.message
913 913 return _message(repo_id, commit_id)
914 914
915 915 @reraise_safe_exceptions
916 916 def parents(self, wire, commit_id):
917 917 cache_on, context_uid, repo_id = self._cache_on(wire)
918 918 region = self._region(wire)
919 919 @region.conditional_cache_on_arguments(condition=cache_on)
920 920 def _parents(_repo_id, _commit_id):
921 921 repo_init = self._factory.repo_libgit2(wire)
922 922 with repo_init as repo:
923 923 commit = repo[commit_id]
924 924 if hasattr(commit, 'parent_ids'):
925 925 parent_ids = commit.parent_ids
926 926 else:
927 927 parent_ids = commit.get_object().parent_ids
928 928
929 929 return [x.hex for x in parent_ids]
930 930 return _parents(repo_id, commit_id)
931 931
932 932 @reraise_safe_exceptions
933 933 def children(self, wire, commit_id):
934 934 cache_on, context_uid, repo_id = self._cache_on(wire)
935 935 region = self._region(wire)
936 936
937 937 @region.conditional_cache_on_arguments(condition=cache_on)
938 938 def _children(_repo_id, _commit_id):
939 939 output, __ = self.run_git_command(
940 940 wire, ['rev-list', '--all', '--children'])
941 941
942 942 child_ids = []
943 943 pat = re.compile(r'^%s' % commit_id)
944 944 for l in output.splitlines():
945 945 if pat.match(l):
946 946 found_ids = l.split(' ')[1:]
947 947 child_ids.extend(found_ids)
948 948
949 949 return child_ids
950 950 return _children(repo_id, commit_id)
951 951
952 952 @reraise_safe_exceptions
953 953 def set_refs(self, wire, key, value):
954 954 repo_init = self._factory.repo_libgit2(wire)
955 955 with repo_init as repo:
956 956 repo.references.create(key, value, force=True)
957 957
958 958 @reraise_safe_exceptions
959 959 def create_branch(self, wire, branch_name, commit_id, force=False):
960 960 repo_init = self._factory.repo_libgit2(wire)
961 961 with repo_init as repo:
962 962 commit = repo[commit_id]
963 963
964 964 if force:
965 965 repo.branches.local.create(branch_name, commit, force=force)
966 966 elif not repo.branches.get(branch_name):
967 967 # create only if that branch isn't existing
968 968 repo.branches.local.create(branch_name, commit, force=force)
969 969
970 970 @reraise_safe_exceptions
971 971 def remove_ref(self, wire, key):
972 972 repo_init = self._factory.repo_libgit2(wire)
973 973 with repo_init as repo:
974 974 repo.references.delete(key)
975 975
976 976 @reraise_safe_exceptions
977 977 def tag_remove(self, wire, tag_name):
978 978 repo_init = self._factory.repo_libgit2(wire)
979 979 with repo_init as repo:
980 980 key = 'refs/tags/{}'.format(tag_name)
981 981 repo.references.delete(key)
982 982
983 983 @reraise_safe_exceptions
984 984 def tree_changes(self, wire, source_id, target_id):
985 985 # TODO(marcink): remove this seems it's only used by tests
986 986 repo = self._factory.repo(wire)
987 987 source = repo[source_id].tree if source_id else None
988 988 target = repo[target_id].tree
989 989 result = repo.object_store.tree_changes(source, target)
990 990 return list(result)
991 991
992 992 @reraise_safe_exceptions
993 993 def tree_and_type_for_path(self, wire, commit_id, path):
994 994
995 995 cache_on, context_uid, repo_id = self._cache_on(wire)
996 996 region = self._region(wire)
997 997
998 998 @region.conditional_cache_on_arguments(condition=cache_on)
999 999 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1000 1000 repo_init = self._factory.repo_libgit2(wire)
1001 1001
1002 1002 with repo_init as repo:
1003 1003 commit = repo[commit_id]
1004 1004 try:
1005 1005 tree = commit.tree[path]
1006 1006 except KeyError:
1007 1007 return None, None, None
1008 1008
1009 1009 return tree.id.hex, tree.type_str, tree.filemode
1010 1010 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1011 1011
1012 1012 @reraise_safe_exceptions
1013 1013 def tree_items(self, wire, tree_id):
1014 1014 cache_on, context_uid, repo_id = self._cache_on(wire)
1015 1015 region = self._region(wire)
1016 1016
1017 1017 @region.conditional_cache_on_arguments(condition=cache_on)
1018 1018 def _tree_items(_repo_id, _tree_id):
1019 1019
1020 1020 repo_init = self._factory.repo_libgit2(wire)
1021 1021 with repo_init as repo:
1022 1022 try:
1023 1023 tree = repo[tree_id]
1024 1024 except KeyError:
1025 1025 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1026 1026
1027 1027 result = []
1028 1028 for item in tree:
1029 1029 item_sha = item.hex
1030 1030 item_mode = item.filemode
1031 1031 item_type = item.type_str
1032 1032
1033 1033 if item_type == 'commit':
1034 1034 # NOTE(marcink): submodules we translate to 'link' for backward compat
1035 1035 item_type = 'link'
1036 1036
1037 1037 result.append((item.name, item_mode, item_sha, item_type))
1038 1038 return result
1039 1039 return _tree_items(repo_id, tree_id)
1040 1040
1041 1041 @reraise_safe_exceptions
1042 1042 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1043 1043 """
1044 1044 Old version that uses subprocess to call diff
1045 1045 """
1046 1046
1047 1047 flags = [
1048 1048 '-U%s' % context, '--patch',
1049 1049 '--binary',
1050 1050 '--find-renames',
1051 1051 '--no-indent-heuristic',
1052 1052 # '--indent-heuristic',
1053 1053 #'--full-index',
1054 1054 #'--abbrev=40'
1055 1055 ]
1056 1056
1057 1057 if opt_ignorews:
1058 1058 flags.append('--ignore-all-space')
1059 1059
1060 1060 if commit_id_1 == self.EMPTY_COMMIT:
1061 1061 cmd = ['show'] + flags + [commit_id_2]
1062 1062 else:
1063 1063 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1064 1064
1065 1065 if file_filter:
1066 1066 cmd.extend(['--', file_filter])
1067 1067
1068 1068 diff, __ = self.run_git_command(wire, cmd)
1069 1069 # If we used 'show' command, strip first few lines (until actual diff
1070 1070 # starts)
1071 1071 if commit_id_1 == self.EMPTY_COMMIT:
1072 1072 lines = diff.splitlines()
1073 1073 x = 0
1074 1074 for line in lines:
1075 1075 if line.startswith(b'diff'):
1076 1076 break
1077 1077 x += 1
1078 1078 # Append new line just like 'diff' command do
1079 1079 diff = '\n'.join(lines[x:]) + '\n'
1080 1080 return diff
1081 1081
1082 1082 @reraise_safe_exceptions
1083 1083 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1084 1084 repo_init = self._factory.repo_libgit2(wire)
1085 1085 with repo_init as repo:
1086 1086 swap = True
1087 1087 flags = 0
1088 1088 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1089 1089
1090 1090 if opt_ignorews:
1091 1091 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1092 1092
1093 1093 if commit_id_1 == self.EMPTY_COMMIT:
1094 1094 comm1 = repo[commit_id_2]
1095 1095 diff_obj = comm1.tree.diff_to_tree(
1096 1096 flags=flags, context_lines=context, swap=swap)
1097 1097
1098 1098 else:
1099 1099 comm1 = repo[commit_id_2]
1100 1100 comm2 = repo[commit_id_1]
1101 1101 diff_obj = comm1.tree.diff_to_tree(
1102 1102 comm2.tree, flags=flags, context_lines=context, swap=swap)
1103 1103 similar_flags = 0
1104 1104 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1105 1105 diff_obj.find_similar(flags=similar_flags)
1106 1106
1107 1107 if file_filter:
1108 1108 for p in diff_obj:
1109 1109 if p.delta.old_file.path == file_filter:
1110 1110 return p.patch or ''
1111 1111 # fo matching path == no diff
1112 1112 return ''
1113 1113 return diff_obj.patch or ''
1114 1114
1115 1115 @reraise_safe_exceptions
1116 1116 def node_history(self, wire, commit_id, path, limit):
1117 1117 cache_on, context_uid, repo_id = self._cache_on(wire)
1118 1118 region = self._region(wire)
1119 1119
1120 1120 @region.conditional_cache_on_arguments(condition=cache_on)
1121 1121 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1122 1122 # optimize for n==1, rev-list is much faster for that use-case
1123 1123 if limit == 1:
1124 1124 cmd = ['rev-list', '-1', commit_id, '--', path]
1125 1125 else:
1126 1126 cmd = ['log']
1127 1127 if limit:
1128 1128 cmd.extend(['-n', str(safe_int(limit, 0))])
1129 1129 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1130 1130
1131 1131 output, __ = self.run_git_command(wire, cmd)
1132 1132 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1133 1133
1134 1134 return [x for x in commit_ids]
1135 1135 return _node_history(context_uid, repo_id, commit_id, path, limit)
1136 1136
1137 1137 @reraise_safe_exceptions
1138 1138 def node_annotate_legacy(self, wire, commit_id, path):
1139 1139 #note: replaced by pygit2 impelementation
1140 1140 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1141 1141 # -l ==> outputs long shas (and we need all 40 characters)
1142 1142 # --root ==> doesn't put '^' character for boundaries
1143 1143 # -r commit_id ==> blames for the given commit
1144 1144 output, __ = self.run_git_command(wire, cmd)
1145 1145
1146 1146 result = []
1147 1147 for i, blame_line in enumerate(output.splitlines()[:-1]):
1148 1148 line_no = i + 1
1149 1149 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1150 1150 result.append((line_no, blame_commit_id, line))
1151 1151
1152 1152 return result
1153 1153
1154 1154 @reraise_safe_exceptions
1155 1155 def node_annotate(self, wire, commit_id, path):
1156 1156
1157 1157 result_libgit = []
1158 1158 repo_init = self._factory.repo_libgit2(wire)
1159 1159 with repo_init as repo:
1160 1160 commit = repo[commit_id]
1161 1161 blame_obj = repo.blame(path, newest_commit=commit_id)
1162 1162 for i, line in enumerate(commit.tree[path].data.splitlines()):
1163 1163 line_no = i + 1
1164 1164 hunk = blame_obj.for_line(line_no)
1165 1165 blame_commit_id = hunk.final_commit_id.hex
1166 1166
1167 1167 result_libgit.append((line_no, blame_commit_id, line))
1168 1168
1169 1169 return result_libgit
1170 1170
1171 1171 @reraise_safe_exceptions
1172 1172 def update_server_info(self, wire):
1173 1173 repo = self._factory.repo(wire)
1174 1174 update_server_info(repo)
1175 1175
1176 1176 @reraise_safe_exceptions
1177 1177 def get_all_commit_ids(self, wire):
1178 1178
1179 1179 cache_on, context_uid, repo_id = self._cache_on(wire)
1180 1180 region = self._region(wire)
1181 1181
1182 1182 @region.conditional_cache_on_arguments(condition=cache_on)
1183 1183 def _get_all_commit_ids(_context_uid, _repo_id):
1184 1184
1185 1185 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1186 1186 try:
1187 1187 output, __ = self.run_git_command(wire, cmd)
1188 1188 return output.splitlines()
1189 1189 except Exception:
1190 1190 # Can be raised for empty repositories
1191 1191 return []
1192 1192
1193 1193 @region.conditional_cache_on_arguments(condition=cache_on)
1194 1194 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1195 1195 repo_init = self._factory.repo_libgit2(wire)
1196 1196 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1197 1197 results = []
1198 1198 with repo_init as repo:
1199 1199 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1200 1200 results.append(commit.id.hex)
1201 1201
1202 1202 return _get_all_commit_ids(context_uid, repo_id)
1203 1203
1204 1204 @reraise_safe_exceptions
1205 1205 def run_git_command(self, wire, cmd, **opts):
1206 1206 path = wire.get('path', None)
1207 1207
1208 1208 if path and os.path.isdir(path):
1209 1209 opts['cwd'] = path
1210 1210
1211 1211 if '_bare' in opts:
1212 1212 _copts = []
1213 1213 del opts['_bare']
1214 1214 else:
1215 1215 _copts = ['-c', 'core.quotepath=false', ]
1216 1216 safe_call = False
1217 1217 if '_safe' in opts:
1218 1218 # no exc on failure
1219 1219 del opts['_safe']
1220 1220 safe_call = True
1221 1221
1222 1222 if '_copts' in opts:
1223 1223 _copts.extend(opts['_copts'] or [])
1224 1224 del opts['_copts']
1225 1225
1226 1226 gitenv = os.environ.copy()
1227 1227 gitenv.update(opts.pop('extra_env', {}))
1228 1228 # need to clean fix GIT_DIR !
1229 1229 if 'GIT_DIR' in gitenv:
1230 1230 del gitenv['GIT_DIR']
1231 1231 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1232 1232 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1233 1233
1234 1234 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1235 1235 _opts = {'env': gitenv, 'shell': False}
1236 1236
1237 1237 proc = None
1238 1238 try:
1239 1239 _opts.update(opts)
1240 1240 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1241 1241
1242 1242 return b''.join(proc), b''.join(proc.stderr)
1243 1243 except OSError as err:
1244 1244 cmd = ' '.join(cmd) # human friendly CMD
1245 1245 tb_err = ("Couldn't run git command (%s).\n"
1246 1246 "Original error was:%s\n"
1247 1247 "Call options:%s\n"
1248 1248 % (cmd, err, _opts))
1249 1249 log.exception(tb_err)
1250 1250 if safe_call:
1251 1251 return '', err
1252 1252 else:
1253 1253 raise exceptions.VcsException()(tb_err)
1254 1254 finally:
1255 1255 if proc:
1256 1256 proc.close()
1257 1257
1258 1258 @reraise_safe_exceptions
1259 1259 def install_hooks(self, wire, force=False):
1260 1260 from vcsserver.hook_utils import install_git_hooks
1261 1261 bare = self.bare(wire)
1262 1262 path = wire['path']
1263 1263 return install_git_hooks(path, bare, force_create=force)
1264 1264
1265 1265 @reraise_safe_exceptions
1266 1266 def get_hooks_info(self, wire):
1267 1267 from vcsserver.hook_utils import (
1268 1268 get_git_pre_hook_version, get_git_post_hook_version)
1269 1269 bare = self.bare(wire)
1270 1270 path = wire['path']
1271 1271 return {
1272 1272 'pre_version': get_git_pre_hook_version(path, bare),
1273 1273 'post_version': get_git_post_hook_version(path, bare),
1274 1274 }
1275 1275
1276 1276 @reraise_safe_exceptions
1277 1277 def set_head_ref(self, wire, head_name):
1278 1278 log.debug('Setting refs/head to `%s`', head_name)
1279 1279 cmd = ['symbolic-ref', '"HEAD"', '"refs/heads/%s"' % head_name]
1280 1280 output, __ = self.run_git_command(wire, cmd)
1281 1281 return [head_name] + output.splitlines()
1282 1282
1283 1283 @reraise_safe_exceptions
1284 1284 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
1285 1285 archive_dir_name, commit_id):
1286 1286
1287 1287 def file_walker(_commit_id, path):
1288 1288 repo_init = self._factory.repo_libgit2(wire)
1289 1289
1290 1290 with repo_init as repo:
1291 1291 commit = repo[commit_id]
1292 1292
1293 1293 if path in ['', '/']:
1294 1294 tree = commit.tree
1295 1295 else:
1296 1296 tree = commit.tree[path.rstrip('/')]
1297 1297 tree_id = tree.id.hex
1298 1298 try:
1299 1299 tree = repo[tree_id]
1300 1300 except KeyError:
1301 1301 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1302 1302
1303 1303 index = LibGit2Index.Index()
1304 1304 index.read_tree(tree)
1305 1305 file_iter = index
1306 1306
1307 1307 for fn in file_iter:
1308 1308 file_path = fn.path
1309 1309 mode = fn.mode
1310 1310 is_link = stat.S_ISLNK(mode)
1311 1311 if mode == pygit2.GIT_FILEMODE_COMMIT:
1312 1312 log.debug('Skipping path %s as a commit node', file_path)
1313 1313 continue
1314 1314 yield ArchiveNode(file_path, mode, is_link, repo[fn.hex].read_raw)
1315 1315
1316 1316 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
1317 1317 archive_dir_name, commit_id)
@@ -1,1062 +1,1062 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import logging
20 20 import stat
21 21 import urllib.request, urllib.parse, urllib.error
22 22 import urllib.request, urllib.error, urllib.parse
23 23 import traceback
24 24
25 25 from hgext import largefiles, rebase, purge
26 26
27 27 from mercurial import commands
28 28 from mercurial import unionrepo
29 29 from mercurial import verify
30 30 from mercurial import repair
31 31
32 32 import vcsserver
33 33 from vcsserver import exceptions
34 34 from vcsserver.base import RepoFactory, obfuscate_qs, raise_from_original, archive_repo, ArchiveNode
35 35 from vcsserver.hgcompat import (
36 36 archival, bin, clone, config as hgconfig, diffopts, hex, get_ctx,
37 37 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler,
38 38 makepeer, instance, match, memctx, exchange, memfilectx, nullrev, hg_merge,
39 39 patch, peer, revrange, ui, hg_tag, Abort, LookupError, RepoError,
40 40 RepoLookupError, InterventionRequired, RequirementError,
41 41 alwaysmatcher, patternmatcher, hgutil, hgext_strip)
42 from vcsserver.utils import ascii_bytes, ascii_str, safe_str, safe_bytes
42 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes
43 43 from vcsserver.vcs_base import RemoteBase
44 44
45 45 log = logging.getLogger(__name__)
46 46
47 47
48 48 def make_ui_from_config(repo_config):
49 49
50 50 class LoggingUI(ui.ui):
51 51
52 52 def status(self, *msg, **opts):
53 53 str_msg = map(safe_str, msg)
54 54 log.info(' '.join(str_msg).rstrip('\n'))
55 55 #super(LoggingUI, self).status(*msg, **opts)
56 56
57 57 def warn(self, *msg, **opts):
58 58 str_msg = map(safe_str, msg)
59 59 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
60 60 #super(LoggingUI, self).warn(*msg, **opts)
61 61
62 62 def error(self, *msg, **opts):
63 63 str_msg = map(safe_str, msg)
64 64 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
65 65 #super(LoggingUI, self).error(*msg, **opts)
66 66
67 67 def note(self, *msg, **opts):
68 68 str_msg = map(safe_str, msg)
69 69 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
70 70 #super(LoggingUI, self).note(*msg, **opts)
71 71
72 72 def debug(self, *msg, **opts):
73 73 str_msg = map(safe_str, msg)
74 74 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
75 75 #super(LoggingUI, self).debug(*msg, **opts)
76 76
77 77 baseui = LoggingUI()
78 78
79 79 # clean the baseui object
80 80 baseui._ocfg = hgconfig.config()
81 81 baseui._ucfg = hgconfig.config()
82 82 baseui._tcfg = hgconfig.config()
83 83
84 84 for section, option, value in repo_config:
85 85 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
86 86
87 87 # make our hgweb quiet so it doesn't print output
88 88 baseui.setconfig(b'ui', b'quiet', b'true')
89 89
90 90 baseui.setconfig(b'ui', b'paginate', b'never')
91 91 # for better Error reporting of Mercurial
92 92 baseui.setconfig(b'ui', b'message-output', b'stderr')
93 93
94 94 # force mercurial to only use 1 thread, otherwise it may try to set a
95 95 # signal in a non-main thread, thus generating a ValueError.
96 96 baseui.setconfig(b'worker', b'numcpus', 1)
97 97
98 98 # If there is no config for the largefiles extension, we explicitly disable
99 99 # it here. This overrides settings from repositories hgrc file. Recent
100 100 # mercurial versions enable largefiles in hgrc on clone from largefile
101 101 # repo.
102 102 if not baseui.hasconfig(b'extensions', b'largefiles'):
103 103 log.debug('Explicitly disable largefiles extension for repo.')
104 104 baseui.setconfig(b'extensions', b'largefiles', b'!')
105 105
106 106 return baseui
107 107
108 108
109 109 def reraise_safe_exceptions(func):
110 110 """Decorator for converting mercurial exceptions to something neutral."""
111 111
112 112 def wrapper(*args, **kwargs):
113 113 try:
114 114 return func(*args, **kwargs)
115 115 except (Abort, InterventionRequired) as e:
116 116 raise_from_original(exceptions.AbortException(e), e)
117 117 except RepoLookupError as e:
118 118 raise_from_original(exceptions.LookupException(e), e)
119 119 except RequirementError as e:
120 120 raise_from_original(exceptions.RequirementException(e), e)
121 121 except RepoError as e:
122 122 raise_from_original(exceptions.VcsException(e), e)
123 123 except LookupError as e:
124 124 raise_from_original(exceptions.LookupException(e), e)
125 125 except Exception as e:
126 126 if not hasattr(e, '_vcs_kind'):
127 127 log.exception("Unhandled exception in hg remote call")
128 128 raise_from_original(exceptions.UnhandledException(e), e)
129 129
130 130 raise
131 131 return wrapper
132 132
133 133
134 134 class MercurialFactory(RepoFactory):
135 135 repo_type = 'hg'
136 136
137 137 def _create_config(self, config, hooks=True):
138 138 if not hooks:
139 139 hooks_to_clean = frozenset((
140 140 'changegroup.repo_size', 'preoutgoing.pre_pull',
141 141 'outgoing.pull_logger', 'prechangegroup.pre_push'))
142 142 new_config = []
143 143 for section, option, value in config:
144 144 if section == 'hooks' and option in hooks_to_clean:
145 145 continue
146 146 new_config.append((section, option, value))
147 147 config = new_config
148 148
149 149 baseui = make_ui_from_config(config)
150 150 return baseui
151 151
152 152 def _create_repo(self, wire, create):
153 153 baseui = self._create_config(wire["config"])
154 154 return instance(baseui, ascii_bytes(wire["path"]), create)
155 155
156 156 def repo(self, wire, create=False):
157 157 """
158 158 Get a repository instance for the given path.
159 159 """
160 160 return self._create_repo(wire, create)
161 161
162 162
163 163 def patch_ui_message_output(baseui):
164 164 baseui.setconfig(b'ui', b'quiet', b'false')
165 165 output = io.BytesIO()
166 166
167 167 def write(data, **unused_kwargs):
168 168 output.write(data)
169 169
170 170 baseui.status = write
171 171 baseui.write = write
172 172 baseui.warn = write
173 173 baseui.debug = write
174 174
175 175 return baseui, output
176 176
177 177
178 178 class HgRemote(RemoteBase):
179 179
180 180 def __init__(self, factory):
181 181 self._factory = factory
182 182 self._bulk_methods = {
183 183 "affected_files": self.ctx_files,
184 184 "author": self.ctx_user,
185 185 "branch": self.ctx_branch,
186 186 "children": self.ctx_children,
187 187 "date": self.ctx_date,
188 188 "message": self.ctx_description,
189 189 "parents": self.ctx_parents,
190 190 "status": self.ctx_status,
191 191 "obsolete": self.ctx_obsolete,
192 192 "phase": self.ctx_phase,
193 193 "hidden": self.ctx_hidden,
194 194 "_file_paths": self.ctx_list,
195 195 }
196 196
197 197 def _get_ctx(self, repo, ref):
198 198 return get_ctx(repo, ref)
199 199
200 200 @reraise_safe_exceptions
201 201 def discover_hg_version(self):
202 202 from mercurial import util
203 203 return util.version()
204 204
205 205 @reraise_safe_exceptions
206 206 def is_empty(self, wire):
207 207 repo = self._factory.repo(wire)
208 208
209 209 try:
210 210 return len(repo) == 0
211 211 except Exception:
212 212 log.exception("failed to read object_store")
213 213 return False
214 214
215 215 @reraise_safe_exceptions
216 216 def bookmarks(self, wire):
217 217 cache_on, context_uid, repo_id = self._cache_on(wire)
218 218 region = self._region(wire)
219 219 @region.conditional_cache_on_arguments(condition=cache_on)
220 220 def _bookmarks(_context_uid, _repo_id):
221 221 repo = self._factory.repo(wire)
222 222 return dict(repo._bookmarks)
223 223
224 224 return _bookmarks(context_uid, repo_id)
225 225
226 226 @reraise_safe_exceptions
227 227 def branches(self, wire, normal, closed):
228 228 cache_on, context_uid, repo_id = self._cache_on(wire)
229 229 region = self._region(wire)
230 230 @region.conditional_cache_on_arguments(condition=cache_on)
231 231 def _branches(_context_uid, _repo_id, _normal, _closed):
232 232 repo = self._factory.repo(wire)
233 233 iter_branches = repo.branchmap().iterbranches()
234 234 bt = {}
235 235 for branch_name, _heads, tip, is_closed in iter_branches:
236 236 if normal and not is_closed:
237 237 bt[branch_name] = tip
238 238 if closed and is_closed:
239 239 bt[branch_name] = tip
240 240
241 241 return bt
242 242
243 243 return _branches(context_uid, repo_id, normal, closed)
244 244
245 245 @reraise_safe_exceptions
246 246 def bulk_request(self, wire, commit_id, pre_load):
247 247 cache_on, context_uid, repo_id = self._cache_on(wire)
248 248 region = self._region(wire)
249 249 @region.conditional_cache_on_arguments(condition=cache_on)
250 250 def _bulk_request(_repo_id, _commit_id, _pre_load):
251 251 result = {}
252 252 for attr in pre_load:
253 253 try:
254 254 method = self._bulk_methods[attr]
255 255 result[attr] = method(wire, commit_id)
256 256 except KeyError as e:
257 257 raise exceptions.VcsException(e)(
258 258 'Unknown bulk attribute: "%s"' % attr)
259 259 return result
260 260
261 261 return _bulk_request(repo_id, commit_id, sorted(pre_load))
262 262
263 263 @reraise_safe_exceptions
264 264 def ctx_branch(self, wire, commit_id):
265 265 cache_on, context_uid, repo_id = self._cache_on(wire)
266 266 region = self._region(wire)
267 267 @region.conditional_cache_on_arguments(condition=cache_on)
268 268 def _ctx_branch(_repo_id, _commit_id):
269 269 repo = self._factory.repo(wire)
270 270 ctx = self._get_ctx(repo, commit_id)
271 271 return ctx.branch()
272 272 return _ctx_branch(repo_id, commit_id)
273 273
274 274 @reraise_safe_exceptions
275 275 def ctx_date(self, wire, commit_id):
276 276 cache_on, context_uid, repo_id = self._cache_on(wire)
277 277 region = self._region(wire)
278 278 @region.conditional_cache_on_arguments(condition=cache_on)
279 279 def _ctx_date(_repo_id, _commit_id):
280 280 repo = self._factory.repo(wire)
281 281 ctx = self._get_ctx(repo, commit_id)
282 282 return ctx.date()
283 283 return _ctx_date(repo_id, commit_id)
284 284
285 285 @reraise_safe_exceptions
286 286 def ctx_description(self, wire, revision):
287 287 repo = self._factory.repo(wire)
288 288 ctx = self._get_ctx(repo, revision)
289 289 return ctx.description()
290 290
291 291 @reraise_safe_exceptions
292 292 def ctx_files(self, wire, commit_id):
293 293 cache_on, context_uid, repo_id = self._cache_on(wire)
294 294 region = self._region(wire)
295 295 @region.conditional_cache_on_arguments(condition=cache_on)
296 296 def _ctx_files(_repo_id, _commit_id):
297 297 repo = self._factory.repo(wire)
298 298 ctx = self._get_ctx(repo, commit_id)
299 299 return ctx.files()
300 300
301 301 return _ctx_files(repo_id, commit_id)
302 302
303 303 @reraise_safe_exceptions
304 304 def ctx_list(self, path, revision):
305 305 repo = self._factory.repo(path)
306 306 ctx = self._get_ctx(repo, revision)
307 307 return list(ctx)
308 308
309 309 @reraise_safe_exceptions
310 310 def ctx_parents(self, wire, commit_id):
311 311 cache_on, context_uid, repo_id = self._cache_on(wire)
312 312 region = self._region(wire)
313 313 @region.conditional_cache_on_arguments(condition=cache_on)
314 314 def _ctx_parents(_repo_id, _commit_id):
315 315 repo = self._factory.repo(wire)
316 316 ctx = self._get_ctx(repo, commit_id)
317 317 return [parent.hex() for parent in ctx.parents()
318 318 if not (parent.hidden() or parent.obsolete())]
319 319
320 320 return _ctx_parents(repo_id, commit_id)
321 321
322 322 @reraise_safe_exceptions
323 323 def ctx_children(self, wire, commit_id):
324 324 cache_on, context_uid, repo_id = self._cache_on(wire)
325 325 region = self._region(wire)
326 326 @region.conditional_cache_on_arguments(condition=cache_on)
327 327 def _ctx_children(_repo_id, _commit_id):
328 328 repo = self._factory.repo(wire)
329 329 ctx = self._get_ctx(repo, commit_id)
330 330 return [child.hex() for child in ctx.children()
331 331 if not (child.hidden() or child.obsolete())]
332 332
333 333 return _ctx_children(repo_id, commit_id)
334 334
335 335 @reraise_safe_exceptions
336 336 def ctx_phase(self, wire, commit_id):
337 337 cache_on, context_uid, repo_id = self._cache_on(wire)
338 338 region = self._region(wire)
339 339 @region.conditional_cache_on_arguments(condition=cache_on)
340 340 def _ctx_phase(_context_uid, _repo_id, _commit_id):
341 341 repo = self._factory.repo(wire)
342 342 ctx = self._get_ctx(repo, commit_id)
343 343 # public=0, draft=1, secret=3
344 344 return ctx.phase()
345 345 return _ctx_phase(context_uid, repo_id, commit_id)
346 346
347 347 @reraise_safe_exceptions
348 348 def ctx_obsolete(self, wire, commit_id):
349 349 cache_on, context_uid, repo_id = self._cache_on(wire)
350 350 region = self._region(wire)
351 351 @region.conditional_cache_on_arguments(condition=cache_on)
352 352 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
353 353 repo = self._factory.repo(wire)
354 354 ctx = self._get_ctx(repo, commit_id)
355 355 return ctx.obsolete()
356 356 return _ctx_obsolete(context_uid, repo_id, commit_id)
357 357
358 358 @reraise_safe_exceptions
359 359 def ctx_hidden(self, wire, commit_id):
360 360 cache_on, context_uid, repo_id = self._cache_on(wire)
361 361 region = self._region(wire)
362 362 @region.conditional_cache_on_arguments(condition=cache_on)
363 363 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
364 364 repo = self._factory.repo(wire)
365 365 ctx = self._get_ctx(repo, commit_id)
366 366 return ctx.hidden()
367 367 return _ctx_hidden(context_uid, repo_id, commit_id)
368 368
369 369 @reraise_safe_exceptions
370 370 def ctx_substate(self, wire, revision):
371 371 repo = self._factory.repo(wire)
372 372 ctx = self._get_ctx(repo, revision)
373 373 return ctx.substate
374 374
375 375 @reraise_safe_exceptions
376 376 def ctx_status(self, wire, revision):
377 377 repo = self._factory.repo(wire)
378 378 ctx = self._get_ctx(repo, revision)
379 379 status = repo[ctx.p1().node()].status(other=ctx.node())
380 380 # object of status (odd, custom named tuple in mercurial) is not
381 381 # correctly serializable, we make it a list, as the underling
382 382 # API expects this to be a list
383 383 return list(status)
384 384
385 385 @reraise_safe_exceptions
386 386 def ctx_user(self, wire, revision):
387 387 repo = self._factory.repo(wire)
388 388 ctx = self._get_ctx(repo, revision)
389 389 return ctx.user()
390 390
391 391 @reraise_safe_exceptions
392 392 def check_url(self, url, config):
393 393 _proto = None
394 394 if '+' in url[:url.find('://')]:
395 395 _proto = url[0:url.find('+')]
396 396 url = url[url.find('+') + 1:]
397 397 handlers = []
398 398 url_obj = url_parser(url)
399 399 test_uri, authinfo = url_obj.authinfo()
400 400 url_obj.passwd = '*****' if url_obj.passwd else url_obj.passwd
401 401 url_obj.query = obfuscate_qs(url_obj.query)
402 402
403 403 cleaned_uri = str(url_obj)
404 404 log.info("Checking URL for remote cloning/import: %s", cleaned_uri)
405 405
406 406 if authinfo:
407 407 # create a password manager
408 408 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
409 409 passmgr.add_password(*authinfo)
410 410
411 411 handlers.extend((httpbasicauthhandler(passmgr),
412 412 httpdigestauthhandler(passmgr)))
413 413
414 414 o = urllib.request.build_opener(*handlers)
415 415 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
416 416 ('Accept', 'application/mercurial-0.1')]
417 417
418 418 q = {"cmd": 'between'}
419 419 q.update({'pairs': "%s-%s" % ('0' * 40, '0' * 40)})
420 420 qs = '?%s' % urllib.parse.urlencode(q)
421 421 cu = "%s%s" % (test_uri, qs)
422 422 req = urllib.request.Request(cu, None, {})
423 423
424 424 try:
425 425 log.debug("Trying to open URL %s", cleaned_uri)
426 426 resp = o.open(req)
427 427 if resp.code != 200:
428 428 raise exceptions.URLError()('Return Code is not 200')
429 429 except Exception as e:
430 430 log.warning("URL cannot be opened: %s", cleaned_uri, exc_info=True)
431 431 # means it cannot be cloned
432 432 raise exceptions.URLError(e)("[%s] org_exc: %s" % (cleaned_uri, e))
433 433
434 434 # now check if it's a proper hg repo, but don't do it for svn
435 435 try:
436 436 if _proto == 'svn':
437 437 pass
438 438 else:
439 439 # check for pure hg repos
440 440 log.debug(
441 441 "Verifying if URL is a Mercurial repository: %s",
442 442 cleaned_uri)
443 443 ui = make_ui_from_config(config)
444 444 peer_checker = makepeer(ui, url)
445 445 peer_checker.lookup('tip')
446 446 except Exception as e:
447 447 log.warning("URL is not a valid Mercurial repository: %s",
448 448 cleaned_uri)
449 449 raise exceptions.URLError(e)(
450 450 "url [%s] does not look like an hg repo org_exc: %s"
451 451 % (cleaned_uri, e))
452 452
453 453 log.info("URL is a valid Mercurial repository: %s", cleaned_uri)
454 454 return True
455 455
456 456 @reraise_safe_exceptions
457 457 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
458 458 repo = self._factory.repo(wire)
459 459
460 460 if file_filter:
461 461 match_filter = match(file_filter[0], '', [file_filter[1]])
462 462 else:
463 463 match_filter = file_filter
464 464 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
465 465
466 466 try:
467 467 return "".join(patch.diff(
468 468 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts))
469 469 except RepoLookupError as e:
470 470 raise exceptions.LookupException(e)()
471 471
472 472 @reraise_safe_exceptions
473 473 def node_history(self, wire, revision, path, limit):
474 474 cache_on, context_uid, repo_id = self._cache_on(wire)
475 475 region = self._region(wire)
476 476
477 477 @region.conditional_cache_on_arguments(condition=cache_on)
478 478 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
479 479 repo = self._factory.repo(wire)
480 480
481 481 ctx = self._get_ctx(repo, revision)
482 482 fctx = ctx.filectx(safe_bytes(path))
483 483
484 484 def history_iter():
485 485 limit_rev = fctx.rev()
486 486 for obj in reversed(list(fctx.filelog())):
487 487 obj = fctx.filectx(obj)
488 488 ctx = obj.changectx()
489 489 if ctx.hidden() or ctx.obsolete():
490 490 continue
491 491
492 492 if limit_rev >= obj.rev():
493 493 yield obj
494 494
495 495 history = []
496 496 for cnt, obj in enumerate(history_iter()):
497 497 if limit and cnt >= limit:
498 498 break
499 499 history.append(hex(obj.node()))
500 500
501 501 return [x for x in history]
502 502 return _node_history(context_uid, repo_id, revision, path, limit)
503 503
504 504 @reraise_safe_exceptions
505 505 def node_history_untill(self, wire, revision, path, limit):
506 506 cache_on, context_uid, repo_id = self._cache_on(wire)
507 507 region = self._region(wire)
508 508
509 509 @region.conditional_cache_on_arguments(condition=cache_on)
510 510 def _node_history_until(_context_uid, _repo_id):
511 511 repo = self._factory.repo(wire)
512 512 ctx = self._get_ctx(repo, revision)
513 513 fctx = ctx.filectx(safe_bytes(path))
514 514
515 515 file_log = list(fctx.filelog())
516 516 if limit:
517 517 # Limit to the last n items
518 518 file_log = file_log[-limit:]
519 519
520 520 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
521 521 return _node_history_until(context_uid, repo_id, revision, path, limit)
522 522
523 523 @reraise_safe_exceptions
524 524 def fctx_annotate(self, wire, revision, path):
525 525 repo = self._factory.repo(wire)
526 526 ctx = self._get_ctx(repo, revision)
527 527 fctx = ctx.filectx(safe_bytes(path))
528 528
529 529 result = []
530 530 for i, annotate_obj in enumerate(fctx.annotate(), 1):
531 531 ln_no = i
532 532 sha = hex(annotate_obj.fctx.node())
533 533 content = annotate_obj.text
534 534 result.append((ln_no, sha, content))
535 535 return result
536 536
537 537 @reraise_safe_exceptions
538 538 def fctx_node_data(self, wire, revision, path):
539 539 repo = self._factory.repo(wire)
540 540 ctx = self._get_ctx(repo, revision)
541 541 fctx = ctx.filectx(safe_bytes(path))
542 542 return fctx.data()
543 543
544 544 @reraise_safe_exceptions
545 545 def fctx_flags(self, wire, commit_id, path):
546 546 cache_on, context_uid, repo_id = self._cache_on(wire)
547 547 region = self._region(wire)
548 548
549 549 @region.conditional_cache_on_arguments(condition=cache_on)
550 550 def _fctx_flags(_repo_id, _commit_id, _path):
551 551 repo = self._factory.repo(wire)
552 552 ctx = self._get_ctx(repo, commit_id)
553 553 fctx = ctx.filectx(safe_bytes(path))
554 554 return fctx.flags()
555 555
556 556 return _fctx_flags(repo_id, commit_id, path)
557 557
558 558 @reraise_safe_exceptions
559 559 def fctx_size(self, wire, commit_id, path):
560 560 cache_on, context_uid, repo_id = self._cache_on(wire)
561 561 region = self._region(wire)
562 562
563 563 @region.conditional_cache_on_arguments(condition=cache_on)
564 564 def _fctx_size(_repo_id, _revision, _path):
565 565 repo = self._factory.repo(wire)
566 566 ctx = self._get_ctx(repo, commit_id)
567 567 fctx = ctx.filectx(safe_bytes(path))
568 568 return fctx.size()
569 569 return _fctx_size(repo_id, commit_id, path)
570 570
571 571 @reraise_safe_exceptions
572 572 def get_all_commit_ids(self, wire, name):
573 573 cache_on, context_uid, repo_id = self._cache_on(wire)
574 574 region = self._region(wire)
575 575
576 576 @region.conditional_cache_on_arguments(condition=cache_on)
577 577 def _get_all_commit_ids(_context_uid, _repo_id, _name):
578 578 repo = self._factory.repo(wire)
579 579 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
580 580 return revs
581 581 return _get_all_commit_ids(context_uid, repo_id, name)
582 582
583 583 @reraise_safe_exceptions
584 584 def get_config_value(self, wire, section, name, untrusted=False):
585 585 repo = self._factory.repo(wire)
586 586 return repo.ui.config(section, name, untrusted=untrusted)
587 587
588 588 @reraise_safe_exceptions
589 589 def is_large_file(self, wire, commit_id, path):
590 590 cache_on, context_uid, repo_id = self._cache_on(wire)
591 591 region = self._region(wire)
592 592
593 593 @region.conditional_cache_on_arguments(condition=cache_on)
594 594 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
595 595 return largefiles.lfutil.isstandin(safe_bytes(path))
596 596
597 597 return _is_large_file(context_uid, repo_id, commit_id, path)
598 598
599 599 @reraise_safe_exceptions
600 600 def is_binary(self, wire, revision, path):
601 601 cache_on, context_uid, repo_id = self._cache_on(wire)
602 602 region = self._region(wire)
603 603
604 604 @region.conditional_cache_on_arguments(condition=cache_on)
605 605 def _is_binary(_repo_id, _sha, _path):
606 606 repo = self._factory.repo(wire)
607 607 ctx = self._get_ctx(repo, revision)
608 608 fctx = ctx.filectx(safe_bytes(path))
609 609 return fctx.isbinary()
610 610
611 611 return _is_binary(repo_id, revision, path)
612 612
613 613 @reraise_safe_exceptions
614 614 def in_largefiles_store(self, wire, sha):
615 615 repo = self._factory.repo(wire)
616 616 return largefiles.lfutil.instore(repo, sha)
617 617
618 618 @reraise_safe_exceptions
619 619 def in_user_cache(self, wire, sha):
620 620 repo = self._factory.repo(wire)
621 621 return largefiles.lfutil.inusercache(repo.ui, sha)
622 622
623 623 @reraise_safe_exceptions
624 624 def store_path(self, wire, sha):
625 625 repo = self._factory.repo(wire)
626 626 return largefiles.lfutil.storepath(repo, sha)
627 627
628 628 @reraise_safe_exceptions
629 629 def link(self, wire, sha, path):
630 630 repo = self._factory.repo(wire)
631 631 largefiles.lfutil.link(
632 632 largefiles.lfutil.usercachepath(repo.ui, sha), path)
633 633
634 634 @reraise_safe_exceptions
635 635 def localrepository(self, wire, create=False):
636 636 self._factory.repo(wire, create=create)
637 637
638 638 @reraise_safe_exceptions
639 639 def lookup(self, wire, revision, both):
640 640 cache_on, context_uid, repo_id = self._cache_on(wire)
641 641 region = self._region(wire)
642 642
643 643 @region.conditional_cache_on_arguments(condition=cache_on)
644 644 def _lookup(_context_uid, _repo_id, _revision, _both):
645 645
646 646 repo = self._factory.repo(wire)
647 647 rev = _revision
648 648 if isinstance(rev, int):
649 649 # NOTE(marcink):
650 650 # since Mercurial doesn't support negative indexes properly
651 651 # we need to shift accordingly by one to get proper index, e.g
652 652 # repo[-1] => repo[-2]
653 653 # repo[0] => repo[-1]
654 654 if rev <= 0:
655 655 rev = rev + -1
656 656 try:
657 657 ctx = self._get_ctx(repo, rev)
658 658 except (TypeError, RepoLookupError) as e:
659 659 e._org_exc_tb = traceback.format_exc()
660 660 raise exceptions.LookupException(e)(rev)
661 661 except LookupError as e:
662 662 e._org_exc_tb = traceback.format_exc()
663 663 raise exceptions.LookupException(e)(e.name)
664 664
665 665 if not both:
666 666 return ctx.hex()
667 667
668 668 ctx = repo[ctx.hex()]
669 669 return ctx.hex(), ctx.rev()
670 670
671 671 return _lookup(context_uid, repo_id, revision, both)
672 672
673 673 @reraise_safe_exceptions
674 674 def sync_push(self, wire, url):
675 675 if not self.check_url(url, wire['config']):
676 676 return
677 677
678 678 repo = self._factory.repo(wire)
679 679
680 680 # Disable any prompts for this repo
681 681 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
682 682
683 683 bookmarks = list(dict(repo._bookmarks).keys())
684 684 remote = peer(repo, {}, url)
685 685 # Disable any prompts for this remote
686 686 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
687 687
688 688 return exchange.push(
689 689 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
690 690
691 691 @reraise_safe_exceptions
692 692 def revision(self, wire, rev):
693 693 repo = self._factory.repo(wire)
694 694 ctx = self._get_ctx(repo, rev)
695 695 return ctx.rev()
696 696
697 697 @reraise_safe_exceptions
698 698 def rev_range(self, wire, commit_filter):
699 699 cache_on, context_uid, repo_id = self._cache_on(wire)
700 700 region = self._region(wire)
701 701
702 702 @region.conditional_cache_on_arguments(condition=cache_on)
703 703 def _rev_range(_context_uid, _repo_id, _filter):
704 704 repo = self._factory.repo(wire)
705 705 revisions = [
706 706 ascii_str(repo[rev].hex())
707 707 for rev in revrange(repo, list(map(ascii_bytes, commit_filter)))
708 708 ]
709 709 return revisions
710 710
711 711 return _rev_range(context_uid, repo_id, sorted(commit_filter))
712 712
713 713 @reraise_safe_exceptions
714 714 def rev_range_hash(self, wire, node):
715 715 repo = self._factory.repo(wire)
716 716
717 717 def get_revs(repo, rev_opt):
718 718 if rev_opt:
719 719 revs = revrange(repo, rev_opt)
720 720 if len(revs) == 0:
721 721 return (nullrev, nullrev)
722 722 return max(revs), min(revs)
723 723 else:
724 724 return len(repo) - 1, 0
725 725
726 726 stop, start = get_revs(repo, [node + ':'])
727 727 revs = [ascii_str(repo[r].hex()) for r in range(start, stop + 1)]
728 728 return revs
729 729
730 730 @reraise_safe_exceptions
731 731 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
732 732 other_path = kwargs.pop('other_path', None)
733 733
734 734 # case when we want to compare two independent repositories
735 735 if other_path and other_path != wire["path"]:
736 736 baseui = self._factory._create_config(wire["config"])
737 737 repo = unionrepo.makeunionrepository(baseui, other_path, wire["path"])
738 738 else:
739 739 repo = self._factory.repo(wire)
740 740 return list(repo.revs(rev_spec, *args))
741 741
742 742 @reraise_safe_exceptions
743 743 def verify(self, wire,):
744 744 repo = self._factory.repo(wire)
745 745 baseui = self._factory._create_config(wire['config'])
746 746
747 747 baseui, output = patch_ui_message_output(baseui)
748 748
749 749 repo.ui = baseui
750 750 verify.verify(repo)
751 751 return output.getvalue()
752 752
753 753 @reraise_safe_exceptions
754 754 def hg_update_cache(self, wire,):
755 755 repo = self._factory.repo(wire)
756 756 baseui = self._factory._create_config(wire['config'])
757 757 baseui, output = patch_ui_message_output(baseui)
758 758
759 759 repo.ui = baseui
760 760 with repo.wlock(), repo.lock():
761 761 repo.updatecaches(full=True)
762 762
763 763 return output.getvalue()
764 764
765 765 @reraise_safe_exceptions
766 766 def hg_rebuild_fn_cache(self, wire,):
767 767 repo = self._factory.repo(wire)
768 768 baseui = self._factory._create_config(wire['config'])
769 769 baseui, output = patch_ui_message_output(baseui)
770 770
771 771 repo.ui = baseui
772 772
773 773 repair.rebuildfncache(baseui, repo)
774 774
775 775 return output.getvalue()
776 776
777 777 @reraise_safe_exceptions
778 778 def tags(self, wire):
779 779 cache_on, context_uid, repo_id = self._cache_on(wire)
780 780 region = self._region(wire)
781 781
782 782 @region.conditional_cache_on_arguments(condition=cache_on)
783 783 def _tags(_context_uid, _repo_id):
784 784 repo = self._factory.repo(wire)
785 785 return repo.tags()
786 786
787 787 return _tags(context_uid, repo_id)
788 788
789 789 @reraise_safe_exceptions
790 790 def update(self, wire, node=None, clean=False):
791 791 repo = self._factory.repo(wire)
792 792 baseui = self._factory._create_config(wire['config'])
793 793 commands.update(baseui, repo, node=node, clean=clean)
794 794
795 795 @reraise_safe_exceptions
796 796 def identify(self, wire):
797 797 repo = self._factory.repo(wire)
798 798 baseui = self._factory._create_config(wire['config'])
799 799 output = io.BytesIO()
800 800 baseui.write = output.write
801 801 # This is required to get a full node id
802 802 baseui.debugflag = True
803 803 commands.identify(baseui, repo, id=True)
804 804
805 805 return output.getvalue()
806 806
807 807 @reraise_safe_exceptions
808 808 def heads(self, wire, branch=None):
809 809 repo = self._factory.repo(wire)
810 810 baseui = self._factory._create_config(wire['config'])
811 811 output = io.BytesIO()
812 812
813 813 def write(data, **unused_kwargs):
814 814 output.write(data)
815 815
816 816 baseui.write = write
817 817 if branch:
818 818 args = [branch]
819 819 else:
820 820 args = []
821 821 commands.heads(baseui, repo, template='{node} ', *args)
822 822
823 823 return output.getvalue()
824 824
825 825 @reraise_safe_exceptions
826 826 def ancestor(self, wire, revision1, revision2):
827 827 repo = self._factory.repo(wire)
828 828 changelog = repo.changelog
829 829 lookup = repo.lookup
830 830 a = changelog.ancestor(lookup(revision1), lookup(revision2))
831 831 return hex(a)
832 832
833 833 @reraise_safe_exceptions
834 834 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
835 835 baseui = self._factory._create_config(wire["config"], hooks=hooks)
836 836 clone(baseui, source, dest, noupdate=not update_after_clone)
837 837
838 838 @reraise_safe_exceptions
839 839 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
840 840
841 841 repo = self._factory.repo(wire)
842 842 baseui = self._factory._create_config(wire['config'])
843 843 publishing = baseui.configbool('phases', 'publish')
844 844 if publishing:
845 845 new_commit = 'public'
846 846 else:
847 847 new_commit = 'draft'
848 848
849 849 def _filectxfn(_repo, ctx, path):
850 850 """
851 851 Marks given path as added/changed/removed in a given _repo. This is
852 852 for internal mercurial commit function.
853 853 """
854 854
855 855 # check if this path is removed
856 856 if path in removed:
857 857 # returning None is a way to mark node for removal
858 858 return None
859 859
860 860 # check if this path is added
861 861 for node in updated:
862 862 if node['path'] == path:
863 863 return memfilectx(
864 864 _repo,
865 865 changectx=ctx,
866 866 path=node['path'],
867 867 data=node['content'],
868 868 islink=False,
869 869 isexec=bool(node['mode'] & stat.S_IXUSR),
870 870 copysource=False)
871 871
872 872 raise exceptions.AbortException()(
873 873 "Given path haven't been marked as added, "
874 874 "changed or removed (%s)" % path)
875 875
876 876 with repo.ui.configoverride({('phases', 'new-commit'): new_commit}):
877 877
878 878 commit_ctx = memctx(
879 879 repo=repo,
880 880 parents=parents,
881 881 text=message,
882 882 files=files,
883 883 filectxfn=_filectxfn,
884 884 user=user,
885 885 date=(commit_time, commit_timezone),
886 886 extra=extra)
887 887
888 888 n = repo.commitctx(commit_ctx)
889 889 new_id = hex(n)
890 890
891 891 return new_id
892 892
893 893 @reraise_safe_exceptions
894 894 def pull(self, wire, url, commit_ids=None):
895 895 repo = self._factory.repo(wire)
896 896 # Disable any prompts for this repo
897 897 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
898 898
899 899 remote = peer(repo, {}, url)
900 900 # Disable any prompts for this remote
901 901 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
902 902
903 903 if commit_ids:
904 904 commit_ids = [bin(commit_id) for commit_id in commit_ids]
905 905
906 906 return exchange.pull(
907 907 repo, remote, heads=commit_ids, force=None).cgresult
908 908
909 909 @reraise_safe_exceptions
910 910 def pull_cmd(self, wire, source, bookmark=None, branch=None, revision=None, hooks=True):
911 911 repo = self._factory.repo(wire)
912 912 baseui = self._factory._create_config(wire['config'], hooks=hooks)
913 913
914 914 # Mercurial internally has a lot of logic that checks ONLY if
915 915 # option is defined, we just pass those if they are defined then
916 916 opts = {}
917 917 if bookmark:
918 918 opts['bookmark'] = bookmark
919 919 if branch:
920 920 opts['branch'] = branch
921 921 if revision:
922 922 opts['rev'] = revision
923 923
924 924 commands.pull(baseui, repo, source, **opts)
925 925
926 926 @reraise_safe_exceptions
927 927 def push(self, wire, revisions, dest_path, hooks=True, push_branches=False):
928 928 repo = self._factory.repo(wire)
929 929 baseui = self._factory._create_config(wire['config'], hooks=hooks)
930 930 commands.push(baseui, repo, dest=dest_path, rev=revisions,
931 931 new_branch=push_branches)
932 932
933 933 @reraise_safe_exceptions
934 934 def strip(self, wire, revision, update, backup):
935 935 repo = self._factory.repo(wire)
936 936 ctx = self._get_ctx(repo, revision)
937 937 hgext_strip(
938 938 repo.baseui, repo, ctx.node(), update=update, backup=backup)
939 939
940 940 @reraise_safe_exceptions
941 941 def get_unresolved_files(self, wire):
942 942 repo = self._factory.repo(wire)
943 943
944 944 log.debug('Calculating unresolved files for repo: %s', repo)
945 945 output = io.BytesIO()
946 946
947 947 def write(data, **unused_kwargs):
948 948 output.write(data)
949 949
950 950 baseui = self._factory._create_config(wire['config'])
951 951 baseui.write = write
952 952
953 953 commands.resolve(baseui, repo, list=True)
954 954 unresolved = output.getvalue().splitlines(0)
955 955 return unresolved
956 956
957 957 @reraise_safe_exceptions
958 958 def merge(self, wire, revision):
959 959 repo = self._factory.repo(wire)
960 960 baseui = self._factory._create_config(wire['config'])
961 961 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
962 962
963 963 # In case of sub repositories are used mercurial prompts the user in
964 964 # case of merge conflicts or different sub repository sources. By
965 965 # setting the interactive flag to `False` mercurial doesn't prompt the
966 966 # used but instead uses a default value.
967 967 repo.ui.setconfig(b'ui', b'interactive', False)
968 968 commands.merge(baseui, repo, rev=revision)
969 969
970 970 @reraise_safe_exceptions
971 971 def merge_state(self, wire):
972 972 repo = self._factory.repo(wire)
973 973 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
974 974
975 975 # In case of sub repositories are used mercurial prompts the user in
976 976 # case of merge conflicts or different sub repository sources. By
977 977 # setting the interactive flag to `False` mercurial doesn't prompt the
978 978 # used but instead uses a default value.
979 979 repo.ui.setconfig(b'ui', b'interactive', False)
980 980 ms = hg_merge.mergestate(repo)
981 981 return [x for x in ms.unresolved()]
982 982
983 983 @reraise_safe_exceptions
984 984 def commit(self, wire, message, username, close_branch=False):
985 985 repo = self._factory.repo(wire)
986 986 baseui = self._factory._create_config(wire['config'])
987 987 repo.ui.setconfig(b'ui', b'username', username)
988 988 commands.commit(baseui, repo, message=message, close_branch=close_branch)
989 989
990 990 @reraise_safe_exceptions
991 991 def rebase(self, wire, source=None, dest=None, abort=False):
992 992 repo = self._factory.repo(wire)
993 993 baseui = self._factory._create_config(wire['config'])
994 994 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
995 995 # In case of sub repositories are used mercurial prompts the user in
996 996 # case of merge conflicts or different sub repository sources. By
997 997 # setting the interactive flag to `False` mercurial doesn't prompt the
998 998 # used but instead uses a default value.
999 999 repo.ui.setconfig(b'ui', b'interactive', False)
1000 1000 rebase.rebase(baseui, repo, base=source, dest=dest, abort=abort, keep=not abort)
1001 1001
1002 1002 @reraise_safe_exceptions
1003 1003 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
1004 1004 repo = self._factory.repo(wire)
1005 1005 ctx = self._get_ctx(repo, revision)
1006 1006 node = ctx.node()
1007 1007
1008 1008 date = (tag_time, tag_timezone)
1009 1009 try:
1010 1010 hg_tag.tag(repo, name, node, message, local, user, date)
1011 1011 except Abort as e:
1012 1012 log.exception("Tag operation aborted")
1013 1013 # Exception can contain unicode which we convert
1014 1014 raise exceptions.AbortException(e)(repr(e))
1015 1015
1016 1016 @reraise_safe_exceptions
1017 1017 def bookmark(self, wire, bookmark, revision=None):
1018 1018 repo = self._factory.repo(wire)
1019 1019 baseui = self._factory._create_config(wire['config'])
1020 1020 commands.bookmark(baseui, repo, bookmark, rev=revision, force=True)
1021 1021
1022 1022 @reraise_safe_exceptions
1023 1023 def install_hooks(self, wire, force=False):
1024 1024 # we don't need any special hooks for Mercurial
1025 1025 pass
1026 1026
1027 1027 @reraise_safe_exceptions
1028 1028 def get_hooks_info(self, wire):
1029 1029 return {
1030 1030 'pre_version': vcsserver.__version__,
1031 1031 'post_version': vcsserver.__version__,
1032 1032 }
1033 1033
1034 1034 @reraise_safe_exceptions
1035 1035 def set_head_ref(self, wire, head_name):
1036 1036 pass
1037 1037
1038 1038 @reraise_safe_exceptions
1039 1039 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
1040 1040 archive_dir_name, commit_id):
1041 1041
1042 1042 def file_walker(_commit_id, path):
1043 1043 repo = self._factory.repo(wire)
1044 1044 ctx = repo[_commit_id]
1045 1045 is_root = path in ['', '/']
1046 1046 if is_root:
1047 1047 matcher = alwaysmatcher(badfn=None)
1048 1048 else:
1049 1049 matcher = patternmatcher('', [(b'glob', path+'/**', b'')], badfn=None)
1050 1050 file_iter = ctx.manifest().walk(matcher)
1051 1051
1052 1052 for fn in file_iter:
1053 1053 file_path = fn
1054 1054 flags = ctx.flags(fn)
1055 1055 mode = b'x' in flags and 0o755 or 0o644
1056 1056 is_link = b'l' in flags
1057 1057
1058 1058 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1059 1059
1060 1060 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
1061 1061 archive_dir_name, commit_id)
1062 1062
@@ -1,864 +1,864 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18
19 19 import os
20 20 import subprocess
21 21 from urllib.error import URLError
22 22 import urllib.parse
23 23 import logging
24 24 import posixpath as vcspath
25 25 import io
26 26 import urllib.request
27 27 import urllib.parse
28 28 import urllib.error
29 29 import traceback
30 30
31 31 import svn.client
32 32 import svn.core
33 33 import svn.delta
34 34 import svn.diff
35 35 import svn.fs
36 36 import svn.repos
37 37
38 38 from vcsserver import svn_diff, exceptions, subprocessio, settings
39 39 from vcsserver.base import RepoFactory, raise_from_original, ArchiveNode, archive_repo
40 40 from vcsserver.exceptions import NoContentException
41 from vcsserver.utils import safe_str
41 from vcsserver.str_utils import safe_str
42 42 from vcsserver.vcs_base import RemoteBase
43 43 from vcsserver.lib.svnremoterepo import svnremoterepo
44 44 log = logging.getLogger(__name__)
45 45
46 46
47 47 svn_compatible_versions_map = {
48 48 'pre-1.4-compatible': '1.3',
49 49 'pre-1.5-compatible': '1.4',
50 50 'pre-1.6-compatible': '1.5',
51 51 'pre-1.8-compatible': '1.7',
52 52 'pre-1.9-compatible': '1.8',
53 53 }
54 54
55 55 current_compatible_version = '1.14'
56 56
57 57
58 58 def reraise_safe_exceptions(func):
59 59 """Decorator for converting svn exceptions to something neutral."""
60 60 def wrapper(*args, **kwargs):
61 61 try:
62 62 return func(*args, **kwargs)
63 63 except Exception as e:
64 64 if not hasattr(e, '_vcs_kind'):
65 65 log.exception("Unhandled exception in svn remote call")
66 66 raise_from_original(exceptions.UnhandledException(e))
67 67 raise
68 68 return wrapper
69 69
70 70
71 71 class SubversionFactory(RepoFactory):
72 72 repo_type = 'svn'
73 73
74 74 def _create_repo(self, wire, create, compatible_version):
75 75 path = svn.core.svn_path_canonicalize(wire['path'])
76 76 if create:
77 77 fs_config = {'compatible-version': current_compatible_version}
78 78 if compatible_version:
79 79
80 80 compatible_version_string = \
81 81 svn_compatible_versions_map.get(compatible_version) \
82 82 or compatible_version
83 83 fs_config['compatible-version'] = compatible_version_string
84 84
85 85 log.debug('Create SVN repo with config "%s"', fs_config)
86 86 repo = svn.repos.create(path, "", "", None, fs_config)
87 87 else:
88 88 repo = svn.repos.open(path)
89 89
90 90 log.debug('Got SVN object: %s', repo)
91 91 return repo
92 92
93 93 def repo(self, wire, create=False, compatible_version=None):
94 94 """
95 95 Get a repository instance for the given path.
96 96 """
97 97 return self._create_repo(wire, create, compatible_version)
98 98
99 99
100 100 NODE_TYPE_MAPPING = {
101 101 svn.core.svn_node_file: 'file',
102 102 svn.core.svn_node_dir: 'dir',
103 103 }
104 104
105 105
106 106 class SvnRemote(RemoteBase):
107 107
108 108 def __init__(self, factory, hg_factory=None):
109 109 self._factory = factory
110 110
111 111 @reraise_safe_exceptions
112 112 def discover_svn_version(self):
113 113 try:
114 114 import svn.core
115 115 svn_ver = svn.core.SVN_VERSION
116 116 except ImportError:
117 117 svn_ver = None
118 118 return svn_ver
119 119
120 120 @reraise_safe_exceptions
121 121 def is_empty(self, wire):
122 122
123 123 try:
124 124 return self.lookup(wire, -1) == 0
125 125 except Exception:
126 126 log.exception("failed to read object_store")
127 127 return False
128 128
129 129 def check_url(self, url):
130 130
131 131 # uuid function get's only valid UUID from proper repo, else
132 132 # throws exception
133 133 username, password, src_url = self.get_url_and_credentials(url)
134 134 try:
135 135 svnremoterepo(username, password, src_url).svn().uuid
136 136 except Exception:
137 137 tb = traceback.format_exc()
138 138 log.debug("Invalid Subversion url: `%s`, tb: %s", url, tb)
139 139 raise URLError(
140 140 '"%s" is not a valid Subversion source url.' % (url, ))
141 141 return True
142 142
143 143 def is_path_valid_repository(self, wire, path):
144 144
145 145 # NOTE(marcink): short circuit the check for SVN repo
146 146 # the repos.open might be expensive to check, but we have one cheap
147 147 # pre condition that we can use, to check for 'format' file
148 148
149 149 if not os.path.isfile(os.path.join(path, 'format')):
150 150 return False
151 151
152 152 try:
153 153 svn.repos.open(path)
154 154 except svn.core.SubversionException:
155 155 tb = traceback.format_exc()
156 156 log.debug("Invalid Subversion path `%s`, tb: %s", path, tb)
157 157 return False
158 158 return True
159 159
160 160 @reraise_safe_exceptions
161 161 def verify(self, wire,):
162 162 repo_path = wire['path']
163 163 if not self.is_path_valid_repository(wire, repo_path):
164 164 raise Exception(
165 165 "Path %s is not a valid Subversion repository." % repo_path)
166 166
167 167 cmd = ['svnadmin', 'info', repo_path]
168 168 stdout, stderr = subprocessio.run_command(cmd)
169 169 return stdout
170 170
171 171 def lookup(self, wire, revision):
172 172 if revision not in [-1, None, 'HEAD']:
173 173 raise NotImplementedError
174 174 repo = self._factory.repo(wire)
175 175 fs_ptr = svn.repos.fs(repo)
176 176 head = svn.fs.youngest_rev(fs_ptr)
177 177 return head
178 178
179 179 def lookup_interval(self, wire, start_ts, end_ts):
180 180 repo = self._factory.repo(wire)
181 181 fsobj = svn.repos.fs(repo)
182 182 start_rev = None
183 183 end_rev = None
184 184 if start_ts:
185 185 start_ts_svn = apr_time_t(start_ts)
186 186 start_rev = svn.repos.dated_revision(repo, start_ts_svn) + 1
187 187 else:
188 188 start_rev = 1
189 189 if end_ts:
190 190 end_ts_svn = apr_time_t(end_ts)
191 191 end_rev = svn.repos.dated_revision(repo, end_ts_svn)
192 192 else:
193 193 end_rev = svn.fs.youngest_rev(fsobj)
194 194 return start_rev, end_rev
195 195
196 196 def revision_properties(self, wire, revision):
197 197
198 198 cache_on, context_uid, repo_id = self._cache_on(wire)
199 199 region = self._region(wire)
200 200 @region.conditional_cache_on_arguments(condition=cache_on)
201 201 def _revision_properties(_repo_id, _revision):
202 202 repo = self._factory.repo(wire)
203 203 fs_ptr = svn.repos.fs(repo)
204 204 return svn.fs.revision_proplist(fs_ptr, revision)
205 205 return _revision_properties(repo_id, revision)
206 206
207 207 def revision_changes(self, wire, revision):
208 208
209 209 repo = self._factory.repo(wire)
210 210 fsobj = svn.repos.fs(repo)
211 211 rev_root = svn.fs.revision_root(fsobj, revision)
212 212
213 213 editor = svn.repos.ChangeCollector(fsobj, rev_root)
214 214 editor_ptr, editor_baton = svn.delta.make_editor(editor)
215 215 base_dir = ""
216 216 send_deltas = False
217 217 svn.repos.replay2(
218 218 rev_root, base_dir, svn.core.SVN_INVALID_REVNUM, send_deltas,
219 219 editor_ptr, editor_baton, None)
220 220
221 221 added = []
222 222 changed = []
223 223 removed = []
224 224
225 225 # TODO: CHANGE_ACTION_REPLACE: Figure out where it belongs
226 226 for path, change in editor.changes.items():
227 227 # TODO: Decide what to do with directory nodes. Subversion can add
228 228 # empty directories.
229 229
230 230 if change.item_kind == svn.core.svn_node_dir:
231 231 continue
232 232 if change.action in [svn.repos.CHANGE_ACTION_ADD]:
233 233 added.append(path)
234 234 elif change.action in [svn.repos.CHANGE_ACTION_MODIFY,
235 235 svn.repos.CHANGE_ACTION_REPLACE]:
236 236 changed.append(path)
237 237 elif change.action in [svn.repos.CHANGE_ACTION_DELETE]:
238 238 removed.append(path)
239 239 else:
240 240 raise NotImplementedError(
241 241 "Action %s not supported on path %s" % (
242 242 change.action, path))
243 243
244 244 changes = {
245 245 'added': added,
246 246 'changed': changed,
247 247 'removed': removed,
248 248 }
249 249 return changes
250 250
251 251 @reraise_safe_exceptions
252 252 def node_history(self, wire, path, revision, limit):
253 253 cache_on, context_uid, repo_id = self._cache_on(wire)
254 254 region = self._region(wire)
255 255 @region.conditional_cache_on_arguments(condition=cache_on)
256 256 def _assert_correct_path(_context_uid, _repo_id, _path, _revision, _limit):
257 257 cross_copies = False
258 258 repo = self._factory.repo(wire)
259 259 fsobj = svn.repos.fs(repo)
260 260 rev_root = svn.fs.revision_root(fsobj, revision)
261 261
262 262 history_revisions = []
263 263 history = svn.fs.node_history(rev_root, path)
264 264 history = svn.fs.history_prev(history, cross_copies)
265 265 while history:
266 266 __, node_revision = svn.fs.history_location(history)
267 267 history_revisions.append(node_revision)
268 268 if limit and len(history_revisions) >= limit:
269 269 break
270 270 history = svn.fs.history_prev(history, cross_copies)
271 271 return history_revisions
272 272 return _assert_correct_path(context_uid, repo_id, path, revision, limit)
273 273
274 274 def node_properties(self, wire, path, revision):
275 275 cache_on, context_uid, repo_id = self._cache_on(wire)
276 276 region = self._region(wire)
277 277 @region.conditional_cache_on_arguments(condition=cache_on)
278 278 def _node_properties(_repo_id, _path, _revision):
279 279 repo = self._factory.repo(wire)
280 280 fsobj = svn.repos.fs(repo)
281 281 rev_root = svn.fs.revision_root(fsobj, revision)
282 282 return svn.fs.node_proplist(rev_root, path)
283 283 return _node_properties(repo_id, path, revision)
284 284
285 285 def file_annotate(self, wire, path, revision):
286 286 abs_path = 'file://' + urllib.request.pathname2url(
287 287 vcspath.join(wire['path'], path))
288 288 file_uri = svn.core.svn_path_canonicalize(abs_path)
289 289
290 290 start_rev = svn_opt_revision_value_t(0)
291 291 peg_rev = svn_opt_revision_value_t(revision)
292 292 end_rev = peg_rev
293 293
294 294 annotations = []
295 295
296 296 def receiver(line_no, revision, author, date, line, pool):
297 297 annotations.append((line_no, revision, line))
298 298
299 299 # TODO: Cannot use blame5, missing typemap function in the swig code
300 300 try:
301 301 svn.client.blame2(
302 302 file_uri, peg_rev, start_rev, end_rev,
303 303 receiver, svn.client.create_context())
304 304 except svn.core.SubversionException as exc:
305 305 log.exception("Error during blame operation.")
306 306 raise Exception(
307 307 "Blame not supported or file does not exist at path %s. "
308 308 "Error %s." % (path, exc))
309 309
310 310 return annotations
311 311
312 312 def get_node_type(self, wire, path, revision=None):
313 313
314 314 cache_on, context_uid, repo_id = self._cache_on(wire)
315 315 region = self._region(wire)
316 316 @region.conditional_cache_on_arguments(condition=cache_on)
317 317 def _get_node_type(_repo_id, _path, _revision):
318 318 repo = self._factory.repo(wire)
319 319 fs_ptr = svn.repos.fs(repo)
320 320 if _revision is None:
321 321 _revision = svn.fs.youngest_rev(fs_ptr)
322 322 root = svn.fs.revision_root(fs_ptr, _revision)
323 323 node = svn.fs.check_path(root, path)
324 324 return NODE_TYPE_MAPPING.get(node, None)
325 325 return _get_node_type(repo_id, path, revision)
326 326
327 327 def get_nodes(self, wire, path, revision=None):
328 328
329 329 cache_on, context_uid, repo_id = self._cache_on(wire)
330 330 region = self._region(wire)
331 331 @region.conditional_cache_on_arguments(condition=cache_on)
332 332 def _get_nodes(_repo_id, _path, _revision):
333 333 repo = self._factory.repo(wire)
334 334 fsobj = svn.repos.fs(repo)
335 335 if _revision is None:
336 336 _revision = svn.fs.youngest_rev(fsobj)
337 337 root = svn.fs.revision_root(fsobj, _revision)
338 338 entries = svn.fs.dir_entries(root, path)
339 339 result = []
340 340 for entry_path, entry_info in entries.items():
341 341 result.append(
342 342 (entry_path, NODE_TYPE_MAPPING.get(entry_info.kind, None)))
343 343 return result
344 344 return _get_nodes(repo_id, path, revision)
345 345
346 346 def get_file_content(self, wire, path, rev=None):
347 347 repo = self._factory.repo(wire)
348 348 fsobj = svn.repos.fs(repo)
349 349 if rev is None:
350 350 rev = svn.fs.youngest_revision(fsobj)
351 351 root = svn.fs.revision_root(fsobj, rev)
352 352 content = svn.core.Stream(svn.fs.file_contents(root, path))
353 353 return content.read()
354 354
355 355 def get_file_size(self, wire, path, revision=None):
356 356
357 357 cache_on, context_uid, repo_id = self._cache_on(wire)
358 358 region = self._region(wire)
359 359
360 360 @region.conditional_cache_on_arguments(condition=cache_on)
361 361 def _get_file_size(_repo_id, _path, _revision):
362 362 repo = self._factory.repo(wire)
363 363 fsobj = svn.repos.fs(repo)
364 364 if _revision is None:
365 365 _revision = svn.fs.youngest_revision(fsobj)
366 366 root = svn.fs.revision_root(fsobj, _revision)
367 367 size = svn.fs.file_length(root, path)
368 368 return size
369 369 return _get_file_size(repo_id, path, revision)
370 370
371 371 def create_repository(self, wire, compatible_version=None):
372 372 log.info('Creating Subversion repository in path "%s"', wire['path'])
373 373 self._factory.repo(wire, create=True,
374 374 compatible_version=compatible_version)
375 375
376 376 def get_url_and_credentials(self, src_url):
377 377 obj = urllib.parse.urlparse(src_url)
378 378 username = obj.username or None
379 379 password = obj.password or None
380 380 return username, password, src_url
381 381
382 382 def import_remote_repository(self, wire, src_url):
383 383 repo_path = wire['path']
384 384 if not self.is_path_valid_repository(wire, repo_path):
385 385 raise Exception(
386 386 "Path %s is not a valid Subversion repository." % repo_path)
387 387
388 388 username, password, src_url = self.get_url_and_credentials(src_url)
389 389 rdump_cmd = ['svnrdump', 'dump', '--non-interactive',
390 390 '--trust-server-cert-failures=unknown-ca']
391 391 if username and password:
392 392 rdump_cmd += ['--username', username, '--password', password]
393 393 rdump_cmd += [src_url]
394 394
395 395 rdump = subprocess.Popen(
396 396 rdump_cmd,
397 397 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
398 398 load = subprocess.Popen(
399 399 ['svnadmin', 'load', repo_path], stdin=rdump.stdout)
400 400
401 401 # TODO: johbo: This can be a very long operation, might be better
402 402 # to track some kind of status and provide an api to check if the
403 403 # import is done.
404 404 rdump.wait()
405 405 load.wait()
406 406
407 407 log.debug('Return process ended with code: %s', rdump.returncode)
408 408 if rdump.returncode != 0:
409 409 errors = rdump.stderr.read()
410 410 log.error('svnrdump dump failed: statuscode %s: message: %s', rdump.returncode, errors)
411 411
412 412 reason = 'UNKNOWN'
413 413 if b'svnrdump: E230001:' in errors:
414 414 reason = 'INVALID_CERTIFICATE'
415 415
416 416 if reason == 'UNKNOWN':
417 417 reason = 'UNKNOWN:{}'.format(safe_str(errors))
418 418
419 419 raise Exception(
420 420 'Failed to dump the remote repository from %s. Reason:%s' % (
421 421 src_url, reason))
422 422 if load.returncode != 0:
423 423 raise Exception(
424 424 'Failed to load the dump of remote repository from %s.' %
425 425 (src_url, ))
426 426
427 427 def commit(self, wire, message, author, timestamp, updated, removed):
428 428 assert isinstance(message, str)
429 429 assert isinstance(author, str)
430 430
431 431 repo = self._factory.repo(wire)
432 432 fsobj = svn.repos.fs(repo)
433 433
434 434 rev = svn.fs.youngest_rev(fsobj)
435 435 txn = svn.repos.fs_begin_txn_for_commit(repo, rev, author, message)
436 436 txn_root = svn.fs.txn_root(txn)
437 437
438 438 for node in updated:
439 439 TxnNodeProcessor(node, txn_root).update()
440 440 for node in removed:
441 441 TxnNodeProcessor(node, txn_root).remove()
442 442
443 443 commit_id = svn.repos.fs_commit_txn(repo, txn)
444 444
445 445 if timestamp:
446 446 apr_time = apr_time_t(timestamp)
447 447 ts_formatted = svn.core.svn_time_to_cstring(apr_time)
448 448 svn.fs.change_rev_prop(fsobj, commit_id, 'svn:date', ts_formatted)
449 449
450 450 log.debug('Committed revision "%s" to "%s".', commit_id, wire['path'])
451 451 return commit_id
452 452
453 453 def diff(self, wire, rev1, rev2, path1=None, path2=None,
454 454 ignore_whitespace=False, context=3):
455 455
456 456 wire.update(cache=False)
457 457 repo = self._factory.repo(wire)
458 458 diff_creator = SvnDiffer(
459 459 repo, rev1, path1, rev2, path2, ignore_whitespace, context)
460 460 try:
461 461 return diff_creator.generate_diff()
462 462 except svn.core.SubversionException as e:
463 463 log.exception(
464 464 "Error during diff operation operation. "
465 465 "Path might not exist %s, %s" % (path1, path2))
466 466 return ""
467 467
468 468 @reraise_safe_exceptions
469 469 def is_large_file(self, wire, path):
470 470 return False
471 471
472 472 @reraise_safe_exceptions
473 473 def is_binary(self, wire, rev, path):
474 474 cache_on, context_uid, repo_id = self._cache_on(wire)
475 475
476 476 region = self._region(wire)
477 477 @region.conditional_cache_on_arguments(condition=cache_on)
478 478 def _is_binary(_repo_id, _rev, _path):
479 479 raw_bytes = self.get_file_content(wire, path, rev)
480 480 return raw_bytes and '\0' in raw_bytes
481 481
482 482 return _is_binary(repo_id, rev, path)
483 483
484 484 @reraise_safe_exceptions
485 485 def run_svn_command(self, wire, cmd, **opts):
486 486 path = wire.get('path', None)
487 487
488 488 if path and os.path.isdir(path):
489 489 opts['cwd'] = path
490 490
491 491 safe_call = opts.pop('_safe', False)
492 492
493 493 svnenv = os.environ.copy()
494 494 svnenv.update(opts.pop('extra_env', {}))
495 495
496 496 _opts = {'env': svnenv, 'shell': False}
497 497
498 498 try:
499 499 _opts.update(opts)
500 500 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
501 501
502 502 return b''.join(proc), b''.join(proc.stderr)
503 503 except OSError as err:
504 504 if safe_call:
505 505 return '', safe_str(err).strip()
506 506 else:
507 507 cmd = ' '.join(cmd) # human friendly CMD
508 508 tb_err = ("Couldn't run svn command (%s).\n"
509 509 "Original error was:%s\n"
510 510 "Call options:%s\n"
511 511 % (cmd, err, _opts))
512 512 log.exception(tb_err)
513 513 raise exceptions.VcsException()(tb_err)
514 514
515 515 @reraise_safe_exceptions
516 516 def install_hooks(self, wire, force=False):
517 517 from vcsserver.hook_utils import install_svn_hooks
518 518 repo_path = wire['path']
519 519 binary_dir = settings.BINARY_DIR
520 520 executable = None
521 521 if binary_dir:
522 522 executable = os.path.join(binary_dir, 'python')
523 523 return install_svn_hooks(
524 524 repo_path, executable=executable, force_create=force)
525 525
526 526 @reraise_safe_exceptions
527 527 def get_hooks_info(self, wire):
528 528 from vcsserver.hook_utils import (
529 529 get_svn_pre_hook_version, get_svn_post_hook_version)
530 530 repo_path = wire['path']
531 531 return {
532 532 'pre_version': get_svn_pre_hook_version(repo_path),
533 533 'post_version': get_svn_post_hook_version(repo_path),
534 534 }
535 535
536 536 @reraise_safe_exceptions
537 537 def set_head_ref(self, wire, head_name):
538 538 pass
539 539
540 540 @reraise_safe_exceptions
541 541 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
542 542 archive_dir_name, commit_id):
543 543
544 544 def walk_tree(root, root_dir, _commit_id):
545 545 """
546 546 Special recursive svn repo walker
547 547 """
548 548
549 549 filemode_default = 0o100644
550 550 filemode_executable = 0o100755
551 551
552 552 file_iter = svn.fs.dir_entries(root, root_dir)
553 553 for f_name in file_iter:
554 554 f_type = NODE_TYPE_MAPPING.get(file_iter[f_name].kind, None)
555 555
556 556 if f_type == 'dir':
557 557 # return only DIR, and then all entries in that dir
558 558 yield os.path.join(root_dir, f_name), {'mode': filemode_default}, f_type
559 559 new_root = os.path.join(root_dir, f_name)
560 560 for _f_name, _f_data, _f_type in walk_tree(root, new_root, _commit_id):
561 561 yield _f_name, _f_data, _f_type
562 562 else:
563 563 f_path = os.path.join(root_dir, f_name).rstrip('/')
564 564 prop_list = svn.fs.node_proplist(root, f_path)
565 565
566 566 f_mode = filemode_default
567 567 if prop_list.get('svn:executable'):
568 568 f_mode = filemode_executable
569 569
570 570 f_is_link = False
571 571 if prop_list.get('svn:special'):
572 572 f_is_link = True
573 573
574 574 data = {
575 575 'is_link': f_is_link,
576 576 'mode': f_mode,
577 577 'content_stream': svn.core.Stream(svn.fs.file_contents(root, f_path)).read
578 578 }
579 579
580 580 yield f_path, data, f_type
581 581
582 582 def file_walker(_commit_id, path):
583 583 repo = self._factory.repo(wire)
584 584 root = svn.fs.revision_root(svn.repos.fs(repo), int(commit_id))
585 585
586 586 def no_content():
587 587 raise NoContentException()
588 588
589 589 for f_name, f_data, f_type in walk_tree(root, path, _commit_id):
590 590 file_path = f_name
591 591
592 592 if f_type == 'dir':
593 593 mode = f_data['mode']
594 594 yield ArchiveNode(file_path, mode, False, no_content)
595 595 else:
596 596 mode = f_data['mode']
597 597 is_link = f_data['is_link']
598 598 data_stream = f_data['content_stream']
599 599 yield ArchiveNode(file_path, mode, is_link, data_stream)
600 600
601 601 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
602 602 archive_dir_name, commit_id)
603 603
604 604
605 605 class SvnDiffer(object):
606 606 """
607 607 Utility to create diffs based on difflib and the Subversion api
608 608 """
609 609
610 610 binary_content = False
611 611
612 612 def __init__(
613 613 self, repo, src_rev, src_path, tgt_rev, tgt_path,
614 614 ignore_whitespace, context):
615 615 self.repo = repo
616 616 self.ignore_whitespace = ignore_whitespace
617 617 self.context = context
618 618
619 619 fsobj = svn.repos.fs(repo)
620 620
621 621 self.tgt_rev = tgt_rev
622 622 self.tgt_path = tgt_path or ''
623 623 self.tgt_root = svn.fs.revision_root(fsobj, tgt_rev)
624 624 self.tgt_kind = svn.fs.check_path(self.tgt_root, self.tgt_path)
625 625
626 626 self.src_rev = src_rev
627 627 self.src_path = src_path or self.tgt_path
628 628 self.src_root = svn.fs.revision_root(fsobj, src_rev)
629 629 self.src_kind = svn.fs.check_path(self.src_root, self.src_path)
630 630
631 631 self._validate()
632 632
633 633 def _validate(self):
634 634 if (self.tgt_kind != svn.core.svn_node_none and
635 635 self.src_kind != svn.core.svn_node_none and
636 636 self.src_kind != self.tgt_kind):
637 637 # TODO: johbo: proper error handling
638 638 raise Exception(
639 639 "Source and target are not compatible for diff generation. "
640 640 "Source type: %s, target type: %s" %
641 641 (self.src_kind, self.tgt_kind))
642 642
643 643 def generate_diff(self):
644 644 buf = io.StringIO()
645 645 if self.tgt_kind == svn.core.svn_node_dir:
646 646 self._generate_dir_diff(buf)
647 647 else:
648 648 self._generate_file_diff(buf)
649 649 return buf.getvalue()
650 650
651 651 def _generate_dir_diff(self, buf):
652 652 editor = DiffChangeEditor()
653 653 editor_ptr, editor_baton = svn.delta.make_editor(editor)
654 654 svn.repos.dir_delta2(
655 655 self.src_root,
656 656 self.src_path,
657 657 '', # src_entry
658 658 self.tgt_root,
659 659 self.tgt_path,
660 660 editor_ptr, editor_baton,
661 661 authorization_callback_allow_all,
662 662 False, # text_deltas
663 663 svn.core.svn_depth_infinity, # depth
664 664 False, # entry_props
665 665 False, # ignore_ancestry
666 666 )
667 667
668 668 for path, __, change in sorted(editor.changes):
669 669 self._generate_node_diff(
670 670 buf, change, path, self.tgt_path, path, self.src_path)
671 671
672 672 def _generate_file_diff(self, buf):
673 673 change = None
674 674 if self.src_kind == svn.core.svn_node_none:
675 675 change = "add"
676 676 elif self.tgt_kind == svn.core.svn_node_none:
677 677 change = "delete"
678 678 tgt_base, tgt_path = vcspath.split(self.tgt_path)
679 679 src_base, src_path = vcspath.split(self.src_path)
680 680 self._generate_node_diff(
681 681 buf, change, tgt_path, tgt_base, src_path, src_base)
682 682
683 683 def _generate_node_diff(
684 684 self, buf, change, tgt_path, tgt_base, src_path, src_base):
685 685
686 686 if self.src_rev == self.tgt_rev and tgt_base == src_base:
687 687 # makes consistent behaviour with git/hg to return empty diff if
688 688 # we compare same revisions
689 689 return
690 690
691 691 tgt_full_path = vcspath.join(tgt_base, tgt_path)
692 692 src_full_path = vcspath.join(src_base, src_path)
693 693
694 694 self.binary_content = False
695 695 mime_type = self._get_mime_type(tgt_full_path)
696 696
697 697 if mime_type and not mime_type.startswith('text'):
698 698 self.binary_content = True
699 699 buf.write("=" * 67 + '\n')
700 700 buf.write("Cannot display: file marked as a binary type.\n")
701 701 buf.write("svn:mime-type = %s\n" % mime_type)
702 702 buf.write("Index: %s\n" % (tgt_path, ))
703 703 buf.write("=" * 67 + '\n')
704 704 buf.write("diff --git a/%(tgt_path)s b/%(tgt_path)s\n" % {
705 705 'tgt_path': tgt_path})
706 706
707 707 if change == 'add':
708 708 # TODO: johbo: SVN is missing a zero here compared to git
709 709 buf.write("new file mode 10644\n")
710 710
711 711 #TODO(marcink): intro to binary detection of svn patches
712 712 # if self.binary_content:
713 713 # buf.write('GIT binary patch\n')
714 714
715 715 buf.write("--- /dev/null\t(revision 0)\n")
716 716 src_lines = []
717 717 else:
718 718 if change == 'delete':
719 719 buf.write("deleted file mode 10644\n")
720 720
721 721 #TODO(marcink): intro to binary detection of svn patches
722 722 # if self.binary_content:
723 723 # buf.write('GIT binary patch\n')
724 724
725 725 buf.write("--- a/%s\t(revision %s)\n" % (
726 726 src_path, self.src_rev))
727 727 src_lines = self._svn_readlines(self.src_root, src_full_path)
728 728
729 729 if change == 'delete':
730 730 buf.write("+++ /dev/null\t(revision %s)\n" % (self.tgt_rev, ))
731 731 tgt_lines = []
732 732 else:
733 733 buf.write("+++ b/%s\t(revision %s)\n" % (
734 734 tgt_path, self.tgt_rev))
735 735 tgt_lines = self._svn_readlines(self.tgt_root, tgt_full_path)
736 736
737 737 if not self.binary_content:
738 738 udiff = svn_diff.unified_diff(
739 739 src_lines, tgt_lines, context=self.context,
740 740 ignore_blank_lines=self.ignore_whitespace,
741 741 ignore_case=False,
742 742 ignore_space_changes=self.ignore_whitespace)
743 743 buf.writelines(udiff)
744 744
745 745 def _get_mime_type(self, path):
746 746 try:
747 747 mime_type = svn.fs.node_prop(
748 748 self.tgt_root, path, svn.core.SVN_PROP_MIME_TYPE)
749 749 except svn.core.SubversionException:
750 750 mime_type = svn.fs.node_prop(
751 751 self.src_root, path, svn.core.SVN_PROP_MIME_TYPE)
752 752 return mime_type
753 753
754 754 def _svn_readlines(self, fs_root, node_path):
755 755 if self.binary_content:
756 756 return []
757 757 node_kind = svn.fs.check_path(fs_root, node_path)
758 758 if node_kind not in (
759 759 svn.core.svn_node_file, svn.core.svn_node_symlink):
760 760 return []
761 761 content = svn.core.Stream(
762 762 svn.fs.file_contents(fs_root, node_path)).read()
763 763 return content.splitlines(True)
764 764
765 765
766 766 class DiffChangeEditor(svn.delta.Editor):
767 767 """
768 768 Records changes between two given revisions
769 769 """
770 770
771 771 def __init__(self):
772 772 self.changes = []
773 773
774 774 def delete_entry(self, path, revision, parent_baton, pool=None):
775 775 self.changes.append((path, None, 'delete'))
776 776
777 777 def add_file(
778 778 self, path, parent_baton, copyfrom_path, copyfrom_revision,
779 779 file_pool=None):
780 780 self.changes.append((path, 'file', 'add'))
781 781
782 782 def open_file(self, path, parent_baton, base_revision, file_pool=None):
783 783 self.changes.append((path, 'file', 'change'))
784 784
785 785
786 786 def authorization_callback_allow_all(root, path, pool):
787 787 return True
788 788
789 789
790 790 class TxnNodeProcessor(object):
791 791 """
792 792 Utility to process the change of one node within a transaction root.
793 793
794 794 It encapsulates the knowledge of how to add, update or remove
795 795 a node for a given transaction root. The purpose is to support the method
796 796 `SvnRemote.commit`.
797 797 """
798 798
799 799 def __init__(self, node, txn_root):
800 800 assert isinstance(node['path'], str)
801 801
802 802 self.node = node
803 803 self.txn_root = txn_root
804 804
805 805 def update(self):
806 806 self._ensure_parent_dirs()
807 807 self._add_file_if_node_does_not_exist()
808 808 self._update_file_content()
809 809 self._update_file_properties()
810 810
811 811 def remove(self):
812 812 svn.fs.delete(self.txn_root, self.node['path'])
813 813 # TODO: Clean up directory if empty
814 814
815 815 def _ensure_parent_dirs(self):
816 816 curdir = vcspath.dirname(self.node['path'])
817 817 dirs_to_create = []
818 818 while not self._svn_path_exists(curdir):
819 819 dirs_to_create.append(curdir)
820 820 curdir = vcspath.dirname(curdir)
821 821
822 822 for curdir in reversed(dirs_to_create):
823 823 log.debug('Creating missing directory "%s"', curdir)
824 824 svn.fs.make_dir(self.txn_root, curdir)
825 825
826 826 def _svn_path_exists(self, path):
827 827 path_status = svn.fs.check_path(self.txn_root, path)
828 828 return path_status != svn.core.svn_node_none
829 829
830 830 def _add_file_if_node_does_not_exist(self):
831 831 kind = svn.fs.check_path(self.txn_root, self.node['path'])
832 832 if kind == svn.core.svn_node_none:
833 833 svn.fs.make_file(self.txn_root, self.node['path'])
834 834
835 835 def _update_file_content(self):
836 836 assert isinstance(self.node['content'], str)
837 837 handler, baton = svn.fs.apply_textdelta(
838 838 self.txn_root, self.node['path'], None, None)
839 839 svn.delta.svn_txdelta_send_string(self.node['content'], handler, baton)
840 840
841 841 def _update_file_properties(self):
842 842 properties = self.node.get('properties', {})
843 843 for key, value in properties.items():
844 844 svn.fs.change_node_prop(
845 845 self.txn_root, self.node['path'], key, value)
846 846
847 847
848 848 def apr_time_t(timestamp):
849 849 """
850 850 Convert a Python timestamp into APR timestamp type apr_time_t
851 851 """
852 852 return timestamp * 1E6
853 853
854 854
855 855 def svn_opt_revision_value_t(num):
856 856 """
857 857 Put `num` into a `svn_opt_revision_value_t` structure.
858 858 """
859 859 value = svn.core.svn_opt_revision_value_t()
860 860 value.number = num
861 861 revision = svn.core.svn_opt_revision_t()
862 862 revision.kind = svn.core.svn_opt_revision_number
863 863 revision.value = value
864 864 return revision
@@ -1,242 +1,242 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import logging
20 20 import itertools
21 21
22 22 import mercurial
23 23 import mercurial.error
24 24 import mercurial.wireprotoserver
25 25 import mercurial.hgweb.common
26 26 import mercurial.hgweb.hgweb_mod
27 27 import webob.exc
28 28
29 29 from vcsserver import pygrack, exceptions, settings, git_lfs
30 from vcsserver.utils import ascii_bytes, safe_bytes
30 from vcsserver.str_utils import ascii_bytes, safe_bytes
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 # propagated from mercurial documentation
36 36 HG_UI_SECTIONS = [
37 37 'alias', 'auth', 'decode/encode', 'defaults', 'diff', 'email', 'extensions',
38 38 'format', 'merge-patterns', 'merge-tools', 'hooks', 'http_proxy', 'smtp',
39 39 'patch', 'paths', 'profiling', 'server', 'trusted', 'ui', 'web',
40 40 ]
41 41
42 42
43 43 class HgWeb(mercurial.hgweb.hgweb_mod.hgweb):
44 44 """Extension of hgweb that simplifies some functions."""
45 45
46 46 def _get_view(self, repo):
47 47 """Views are not supported."""
48 48 return repo
49 49
50 50 def loadsubweb(self):
51 51 """The result is only used in the templater method which is not used."""
52 52 return None
53 53
54 54 def run(self):
55 55 """Unused function so raise an exception if accidentally called."""
56 56 raise NotImplementedError
57 57
58 58 def templater(self, req):
59 59 """Function used in an unreachable code path.
60 60
61 61 This code is unreachable because we guarantee that the HTTP request,
62 62 corresponds to a Mercurial command. See the is_hg method. So, we are
63 63 never going to get a user-visible url.
64 64 """
65 65 raise NotImplementedError
66 66
67 67 def archivelist(self, nodeid):
68 68 """Unused function so raise an exception if accidentally called."""
69 69 raise NotImplementedError
70 70
71 71 def __call__(self, environ, start_response):
72 72 """Run the WSGI application.
73 73
74 74 This may be called by multiple threads.
75 75 """
76 76 from mercurial.hgweb import request as requestmod
77 77 req = requestmod.parserequestfromenv(environ)
78 78 res = requestmod.wsgiresponse(req, start_response)
79 79 gen = self.run_wsgi(req, res)
80 80
81 81 first_chunk = None
82 82
83 83 try:
84 84 data = next(gen)
85 85
86 86 def first_chunk():
87 87 yield data
88 88 except StopIteration:
89 89 pass
90 90
91 91 if first_chunk:
92 92 return itertools.chain(first_chunk(), gen)
93 93 return gen
94 94
95 95 def _runwsgi(self, req, res, repo):
96 96
97 97 cmd = req.qsparams.get(b'cmd', '')
98 98 if not mercurial.wireprotoserver.iscmd(cmd):
99 99 # NOTE(marcink): for unsupported commands, we return bad request
100 100 # internally from HG
101 101 log.warning('cmd: `%s` is not supported by the mercurial wireprotocol v1', cmd)
102 102 from mercurial.hgweb.common import statusmessage
103 103 res.status = statusmessage(mercurial.hgweb.common.HTTP_BAD_REQUEST)
104 104 res.setbodybytes(b'')
105 105 return res.sendresponse()
106 106
107 107 return super(HgWeb, self)._runwsgi(req, res, repo)
108 108
109 109
110 110 def make_hg_ui_from_config(repo_config):
111 111 baseui = mercurial.ui.ui()
112 112
113 113 # clean the baseui object
114 114 baseui._ocfg = mercurial.config.config()
115 115 baseui._ucfg = mercurial.config.config()
116 116 baseui._tcfg = mercurial.config.config()
117 117
118 118 for section, option, value in repo_config:
119 119 baseui.setconfig(
120 120 ascii_bytes(section, allow_bytes=True),
121 121 ascii_bytes(option, allow_bytes=True),
122 122 ascii_bytes(value, allow_bytes=True))
123 123
124 124 # make our hgweb quiet so it doesn't print output
125 125 baseui.setconfig(b'ui', b'quiet', b'true')
126 126
127 127 return baseui
128 128
129 129
130 130 def update_hg_ui_from_hgrc(baseui, repo_path):
131 131 path = os.path.join(repo_path, '.hg', 'hgrc')
132 132
133 133 if not os.path.isfile(path):
134 134 log.debug('hgrc file is not present at %s, skipping...', path)
135 135 return
136 136 log.debug('reading hgrc from %s', path)
137 137 cfg = mercurial.config.config()
138 138 cfg.read(ascii_bytes(path))
139 139 for section in HG_UI_SECTIONS:
140 140 for k, v in cfg.items(section):
141 141 log.debug('settings ui from file: [%s] %s=%s', section, k, v)
142 142 baseui.setconfig(
143 143 ascii_bytes(section, allow_bytes=True),
144 144 ascii_bytes(k, allow_bytes=True),
145 145 ascii_bytes(v, allow_bytes=True))
146 146
147 147
148 148 def create_hg_wsgi_app(repo_path, repo_name, config):
149 149 """
150 150 Prepares a WSGI application to handle Mercurial requests.
151 151
152 152 :param config: is a list of 3-item tuples representing a ConfigObject
153 153 (it is the serialized version of the config object).
154 154 """
155 155 log.debug("Creating Mercurial WSGI application")
156 156
157 157 baseui = make_hg_ui_from_config(config)
158 158 update_hg_ui_from_hgrc(baseui, repo_path)
159 159
160 160 try:
161 161 return HgWeb(safe_bytes(repo_path), name=safe_bytes(repo_name), baseui=baseui)
162 162 except mercurial.error.RequirementError as e:
163 163 raise exceptions.RequirementException(e)(e)
164 164
165 165
166 166 class GitHandler(object):
167 167 """
168 168 Handler for Git operations like push/pull etc
169 169 """
170 170 def __init__(self, repo_location, repo_name, git_path, update_server_info,
171 171 extras):
172 172 if not os.path.isdir(repo_location):
173 173 raise OSError(repo_location)
174 174 self.content_path = repo_location
175 175 self.repo_name = repo_name
176 176 self.repo_location = repo_location
177 177 self.extras = extras
178 178 self.git_path = git_path
179 179 self.update_server_info = update_server_info
180 180
181 181 def __call__(self, environ, start_response):
182 182 app = webob.exc.HTTPNotFound()
183 183 candidate_paths = (
184 184 self.content_path, os.path.join(self.content_path, '.git'))
185 185
186 186 for content_path in candidate_paths:
187 187 try:
188 188 app = pygrack.GitRepository(
189 189 self.repo_name, content_path, self.git_path,
190 190 self.update_server_info, self.extras)
191 191 break
192 192 except OSError:
193 193 continue
194 194
195 195 return app(environ, start_response)
196 196
197 197
198 198 def create_git_wsgi_app(repo_path, repo_name, config):
199 199 """
200 200 Creates a WSGI application to handle Git requests.
201 201
202 202 :param config: is a dictionary holding the extras.
203 203 """
204 204 git_path = settings.GIT_EXECUTABLE
205 205 update_server_info = config.pop('git_update_server_info')
206 206 app = GitHandler(
207 207 repo_path, repo_name, git_path, update_server_info, config)
208 208
209 209 return app
210 210
211 211
212 212 class GitLFSHandler(object):
213 213 """
214 214 Handler for Git LFS operations
215 215 """
216 216
217 217 def __init__(self, repo_location, repo_name, git_path, update_server_info,
218 218 extras):
219 219 if not os.path.isdir(repo_location):
220 220 raise OSError(repo_location)
221 221 self.content_path = repo_location
222 222 self.repo_name = repo_name
223 223 self.repo_location = repo_location
224 224 self.extras = extras
225 225 self.git_path = git_path
226 226 self.update_server_info = update_server_info
227 227
228 228 def get_app(self, git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
229 229 app = git_lfs.create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
230 230 return app
231 231
232 232
233 233 def create_git_lfs_wsgi_app(repo_path, repo_name, config):
234 234 git_path = settings.GIT_EXECUTABLE
235 235 update_server_info = config.pop(b'git_update_server_info')
236 236 git_lfs_enabled = config.pop(b'git_lfs_enabled')
237 237 git_lfs_store_path = config.pop(b'git_lfs_store_path')
238 238 git_lfs_http_scheme = config.pop(b'git_lfs_http_scheme', 'http')
239 239 app = GitLFSHandler(
240 240 repo_path, repo_name, git_path, update_server_info, config)
241 241
242 242 return app.get_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
@@ -1,205 +1,206 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import sys
20 20 import stat
21 21 import pytest
22 22 import vcsserver
23 23 import tempfile
24 24 from vcsserver import hook_utils
25 25 from vcsserver.tests.fixture import no_newline_id_generator
26 from vcsserver.utils import AttributeDict, safe_bytes, safe_str
26 from vcsserver.str_utils import safe_bytes, safe_str
27 from vcsserver.utils import AttributeDict
27 28
28 29
29 30 class TestCheckRhodecodeHook(object):
30 31
31 32 def test_returns_false_when_hook_file_is_wrong_found(self, tmpdir):
32 33 hook = os.path.join(str(tmpdir), 'fake_hook_file.py')
33 34 with open(hook, 'wb') as f:
34 35 f.write(b'dummy test')
35 36 result = hook_utils.check_rhodecode_hook(hook)
36 37 assert result is False
37 38
38 39 def test_returns_true_when_no_hook_file_found(self, tmpdir):
39 40 hook = os.path.join(str(tmpdir), 'fake_hook_file_not_existing.py')
40 41 result = hook_utils.check_rhodecode_hook(hook)
41 42 assert result
42 43
43 44 @pytest.mark.parametrize("file_content, expected_result", [
44 45 ("RC_HOOK_VER = '3.3.3'\n", True),
45 46 ("RC_HOOK = '3.3.3'\n", False),
46 47 ], ids=no_newline_id_generator)
47 48 def test_signatures(self, file_content, expected_result, tmpdir):
48 49 hook = os.path.join(str(tmpdir), 'fake_hook_file_1.py')
49 50 with open(hook, 'wb') as f:
50 51 f.write(safe_bytes(file_content))
51 52
52 53 result = hook_utils.check_rhodecode_hook(hook)
53 54
54 55 assert result is expected_result
55 56
56 57
57 58 class BaseInstallHooks(object):
58 59 HOOK_FILES = ()
59 60
60 61 def _check_hook_file_mode(self, file_path):
61 62 assert os.path.exists(file_path), 'path %s missing' % file_path
62 63 stat_info = os.stat(file_path)
63 64
64 65 file_mode = stat.S_IMODE(stat_info.st_mode)
65 66 expected_mode = int('755', 8)
66 67 assert expected_mode == file_mode
67 68
68 69 def _check_hook_file_content(self, file_path, executable):
69 70 executable = executable or sys.executable
70 71 with open(file_path, 'rt') as hook_file:
71 72 content = hook_file.read()
72 73
73 74 expected_env = '#!{}'.format(executable)
74 75 expected_rc_version = "\nRC_HOOK_VER = '{}'\n".format(safe_str(vcsserver.__version__))
75 76 assert content.strip().startswith(expected_env)
76 77 assert expected_rc_version in content
77 78
78 79 def _create_fake_hook(self, file_path, content):
79 80 with open(file_path, 'w') as hook_file:
80 81 hook_file.write(content)
81 82
82 83 def create_dummy_repo(self, repo_type):
83 84 tmpdir = tempfile.mkdtemp()
84 85 repo = AttributeDict()
85 86 if repo_type == 'git':
86 87 repo.path = os.path.join(tmpdir, 'test_git_hooks_installation_repo')
87 88 os.makedirs(repo.path)
88 89 os.makedirs(os.path.join(repo.path, 'hooks'))
89 90 repo.bare = True
90 91
91 92 elif repo_type == 'svn':
92 93 repo.path = os.path.join(tmpdir, 'test_svn_hooks_installation_repo')
93 94 os.makedirs(repo.path)
94 95 os.makedirs(os.path.join(repo.path, 'hooks'))
95 96
96 97 return repo
97 98
98 99 def check_hooks(self, repo_path, repo_bare=True):
99 100 for file_name in self.HOOK_FILES:
100 101 if repo_bare:
101 102 file_path = os.path.join(repo_path, 'hooks', file_name)
102 103 else:
103 104 file_path = os.path.join(repo_path, '.git', 'hooks', file_name)
104 105 self._check_hook_file_mode(file_path)
105 106 self._check_hook_file_content(file_path, sys.executable)
106 107
107 108
108 109 class TestInstallGitHooks(BaseInstallHooks):
109 110 HOOK_FILES = ('pre-receive', 'post-receive')
110 111
111 112 def test_hooks_are_installed(self):
112 113 repo = self.create_dummy_repo('git')
113 114 result = hook_utils.install_git_hooks(repo.path, repo.bare)
114 115 assert result
115 116 self.check_hooks(repo.path, repo.bare)
116 117
117 118 def test_hooks_are_replaced(self):
118 119 repo = self.create_dummy_repo('git')
119 120 hooks_path = os.path.join(repo.path, 'hooks')
120 121 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
121 122 self._create_fake_hook(
122 123 file_path, content="RC_HOOK_VER = 'abcde'\n")
123 124
124 125 result = hook_utils.install_git_hooks(repo.path, repo.bare)
125 126 assert result
126 127 self.check_hooks(repo.path, repo.bare)
127 128
128 129 def test_non_rc_hooks_are_not_replaced(self):
129 130 repo = self.create_dummy_repo('git')
130 131 hooks_path = os.path.join(repo.path, 'hooks')
131 132 non_rc_content = 'echo "non rc hook"\n'
132 133 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
133 134 self._create_fake_hook(
134 135 file_path, content=non_rc_content)
135 136
136 137 result = hook_utils.install_git_hooks(repo.path, repo.bare)
137 138 assert result
138 139
139 140 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
140 141 with open(file_path, 'rt') as hook_file:
141 142 content = hook_file.read()
142 143 assert content == non_rc_content
143 144
144 145 def test_non_rc_hooks_are_replaced_with_force_flag(self):
145 146 repo = self.create_dummy_repo('git')
146 147 hooks_path = os.path.join(repo.path, 'hooks')
147 148 non_rc_content = 'echo "non rc hook"\n'
148 149 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
149 150 self._create_fake_hook(
150 151 file_path, content=non_rc_content)
151 152
152 153 result = hook_utils.install_git_hooks(
153 154 repo.path, repo.bare, force_create=True)
154 155 assert result
155 156 self.check_hooks(repo.path, repo.bare)
156 157
157 158
158 159 class TestInstallSvnHooks(BaseInstallHooks):
159 160 HOOK_FILES = ('pre-commit', 'post-commit')
160 161
161 162 def test_hooks_are_installed(self):
162 163 repo = self.create_dummy_repo('svn')
163 164 result = hook_utils.install_svn_hooks(repo.path)
164 165 assert result
165 166 self.check_hooks(repo.path)
166 167
167 168 def test_hooks_are_replaced(self):
168 169 repo = self.create_dummy_repo('svn')
169 170 hooks_path = os.path.join(repo.path, 'hooks')
170 171 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
171 172 self._create_fake_hook(
172 173 file_path, content="RC_HOOK_VER = 'abcde'\n")
173 174
174 175 result = hook_utils.install_svn_hooks(repo.path)
175 176 assert result
176 177 self.check_hooks(repo.path)
177 178
178 179 def test_non_rc_hooks_are_not_replaced(self):
179 180 repo = self.create_dummy_repo('svn')
180 181 hooks_path = os.path.join(repo.path, 'hooks')
181 182 non_rc_content = 'echo "non rc hook"\n'
182 183 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
183 184 self._create_fake_hook(
184 185 file_path, content=non_rc_content)
185 186
186 187 result = hook_utils.install_svn_hooks(repo.path)
187 188 assert result
188 189
189 190 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
190 191 with open(file_path, 'rt') as hook_file:
191 192 content = hook_file.read()
192 193 assert content == non_rc_content
193 194
194 195 def test_non_rc_hooks_are_replaced_with_force_flag(self):
195 196 repo = self.create_dummy_repo('svn')
196 197 hooks_path = os.path.join(repo.path, 'hooks')
197 198 non_rc_content = 'echo "non rc hook"\n'
198 199 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
199 200 self._create_fake_hook(
200 201 file_path, content=non_rc_content)
201 202
202 203 result = hook_utils.install_svn_hooks(
203 204 repo.path, force_create=True)
204 205 assert result
205 206 self.check_hooks(repo.path, )
@@ -1,288 +1,287 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import more_itertools
20 20
21 21 import dulwich.protocol
22 22 import mock
23 23 import pytest
24 24 import webob
25 25 import webtest
26 26
27 27 from vcsserver import hooks, pygrack
28 28
29 # pylint: disable=redefined-outer-name,protected-access
30 from vcsserver.utils import ascii_bytes
29 from vcsserver.str_utils import ascii_bytes
31 30
32 31
33 32 @pytest.fixture()
34 33 def pygrack_instance(tmpdir):
35 34 """
36 35 Creates a pygrack app instance.
37 36
38 37 Right now, it does not much helpful regarding the passed directory.
39 38 It just contains the required folders to pass the signature test.
40 39 """
41 40 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
42 41 tmpdir.mkdir(dir_name)
43 42
44 43 return pygrack.GitRepository('repo_name', str(tmpdir), 'git', False, {})
45 44
46 45
47 46 @pytest.fixture()
48 47 def pygrack_app(pygrack_instance):
49 48 """
50 49 Creates a pygrack app wrapped in webtest.TestApp.
51 50 """
52 51 return webtest.TestApp(pygrack_instance)
53 52
54 53
55 54 def test_invalid_service_info_refs_returns_403(pygrack_app):
56 55 response = pygrack_app.get('/info/refs?service=git-upload-packs',
57 56 expect_errors=True)
58 57
59 58 assert response.status_int == 403
60 59
61 60
62 61 def test_invalid_endpoint_returns_403(pygrack_app):
63 62 response = pygrack_app.post('/git-upload-packs', expect_errors=True)
64 63
65 64 assert response.status_int == 403
66 65
67 66
68 67 @pytest.mark.parametrize('sideband', [
69 68 'side-band-64k',
70 69 'side-band',
71 70 'side-band no-progress',
72 71 ])
73 72 def test_pre_pull_hook_fails_with_sideband(pygrack_app, sideband):
74 73 request = ''.join([
75 74 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ',
76 75 'multi_ack %s ofs-delta\n' % sideband,
77 76 '0000',
78 77 '0009done\n',
79 78 ])
80 79 with mock.patch('vcsserver.hooks.git_pre_pull', return_value=hooks.HookResponse(1, 'foo')):
81 80 response = pygrack_app.post(
82 81 '/git-upload-pack', params=request,
83 82 content_type='application/x-git-upload-pack')
84 83
85 84 data = io.BytesIO(response.body)
86 85 proto = dulwich.protocol.Protocol(data.read, None)
87 86 packets = list(proto.read_pkt_seq())
88 87
89 88 expected_packets = [
90 89 b'NAK\n', b'\x02foo', b'\x02Pre pull hook failed: aborting\n',
91 90 b'\x01' + pygrack.GitRepository.EMPTY_PACK,
92 91 ]
93 92 assert packets == expected_packets
94 93
95 94
96 95 def test_pre_pull_hook_fails_no_sideband(pygrack_app):
97 96 request = ''.join([
98 97 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
99 98 'multi_ack ofs-delta\n'
100 99 '0000',
101 100 '0009done\n',
102 101 ])
103 102 with mock.patch('vcsserver.hooks.git_pre_pull',
104 103 return_value=hooks.HookResponse(1, 'foo')):
105 104 response = pygrack_app.post(
106 105 '/git-upload-pack', params=request,
107 106 content_type='application/x-git-upload-pack')
108 107
109 108 assert response.body == pygrack.GitRepository.EMPTY_PACK
110 109
111 110
112 111 def test_pull_has_hook_messages(pygrack_app):
113 112 request = ''.join([
114 113 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
115 114 'multi_ack side-band-64k ofs-delta\n'
116 115 '0000',
117 116 '0009done\n',
118 117 ])
119 118 with mock.patch('vcsserver.hooks.git_pre_pull',
120 119 return_value=hooks.HookResponse(0, 'foo')):
121 120 with mock.patch('vcsserver.hooks.git_post_pull',
122 121 return_value=hooks.HookResponse(1, 'bar')):
123 122 with mock.patch('vcsserver.subprocessio.SubprocessIOChunker',
124 123 return_value=more_itertools.always_iterable([b'0008NAK\n0009subp\n0000'])):
125 124 response = pygrack_app.post(
126 125 '/git-upload-pack', params=request,
127 126 content_type='application/x-git-upload-pack')
128 127
129 128 data = io.BytesIO(response.body)
130 129 proto = dulwich.protocol.Protocol(data.read, None)
131 130 packets = list(proto.read_pkt_seq())
132 131
133 132 assert packets == [b'NAK\n', b'\x02foo', b'subp\n', b'\x02bar']
134 133
135 134
136 135 def test_get_want_capabilities(pygrack_instance):
137 136 data = io.BytesIO(
138 137 b'0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
139 138 b'multi_ack side-band-64k ofs-delta\n00000009done\n')
140 139
141 140 request = webob.Request({
142 141 'wsgi.input': data,
143 142 'REQUEST_METHOD': 'POST',
144 143 'webob.is_body_seekable': True
145 144 })
146 145
147 146 capabilities = pygrack_instance._get_want_capabilities(request)
148 147
149 148 assert capabilities == frozenset(
150 149 (b'ofs-delta', b'multi_ack', b'side-band-64k'))
151 150 assert data.tell() == 0
152 151
153 152
154 153 @pytest.mark.parametrize('data,capabilities,expected', [
155 154 ('foo', [], []),
156 155 ('', [pygrack.CAPABILITY_SIDE_BAND_64K], []),
157 156 ('', [pygrack.CAPABILITY_SIDE_BAND], []),
158 157 ('foo', [pygrack.CAPABILITY_SIDE_BAND_64K], [b'0008\x02foo']),
159 158 ('foo', [pygrack.CAPABILITY_SIDE_BAND], [b'0008\x02foo']),
160 159 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'03ed\x02' + b'f' * 1000]),
161 160 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995, b'000a\x02fffff']),
162 161 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'fff0\x02' + b'f' * 65515, b'000a\x02fffff']),
163 162 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995] * 65 + [b'0352\x02' + b'f' * 845]),
164 163 ], ids=[
165 164 'foo-empty',
166 165 'empty-64k', 'empty',
167 166 'foo-64k', 'foo',
168 167 'f-1000-64k', 'f-1000',
169 168 'f-65520-64k', 'f-65520'])
170 169 def test_get_messages(pygrack_instance, data, capabilities, expected):
171 170 messages = pygrack_instance._get_messages(data, capabilities)
172 171
173 172 assert messages == expected
174 173
175 174
176 175 @pytest.mark.parametrize('response,capabilities,pre_pull_messages,post_pull_messages', [
177 176 # Unexpected response
178 177 ([b'unexpected_response[no_initial_header]'], [pygrack.CAPABILITY_SIDE_BAND_64K], 'foo', 'bar'),
179 178 # No sideband
180 179 ([b'no-sideband'], [], 'foo', 'bar'),
181 180 # No messages
182 181 ([b'no-messages'], [pygrack.CAPABILITY_SIDE_BAND_64K], '', ''),
183 182 ])
184 183 def test_inject_messages_to_response_nothing_to_do(
185 184 pygrack_instance, response, capabilities, pre_pull_messages, post_pull_messages):
186 185
187 186 new_response = pygrack_instance._build_post_pull_response(
188 187 more_itertools.always_iterable(response), capabilities, pre_pull_messages, post_pull_messages)
189 188
190 189 assert list(new_response) == response
191 190
192 191
193 192 @pytest.mark.parametrize('capabilities', [
194 193 [pygrack.CAPABILITY_SIDE_BAND],
195 194 [pygrack.CAPABILITY_SIDE_BAND_64K],
196 195 ])
197 196 def test_inject_messages_to_response_single_element(pygrack_instance, capabilities):
198 197 response = [b'0008NAK\n0009subp\n0000']
199 198 new_response = pygrack_instance._build_post_pull_response(
200 199 more_itertools.always_iterable(response), capabilities, 'foo', 'bar')
201 200
202 201 expected_response = b''.join([
203 202 b'0008NAK\n',
204 203 b'0008\x02foo',
205 204 b'0009subp\n',
206 205 b'0008\x02bar',
207 206 b'0000'])
208 207
209 208 assert b''.join(new_response) == expected_response
210 209
211 210
212 211 @pytest.mark.parametrize('capabilities', [
213 212 [pygrack.CAPABILITY_SIDE_BAND],
214 213 [pygrack.CAPABILITY_SIDE_BAND_64K],
215 214 ])
216 215 def test_inject_messages_to_response_multi_element(pygrack_instance, capabilities):
217 216 response = more_itertools.always_iterable([
218 217 b'0008NAK\n000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n0000'
219 218 ])
220 219 new_response = pygrack_instance._build_post_pull_response(response, capabilities, 'foo', 'bar')
221 220
222 221 expected_response = b''.join([
223 222 b'0008NAK\n',
224 223 b'0008\x02foo',
225 224 b'000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n',
226 225 b'0008\x02bar',
227 226 b'0000'
228 227 ])
229 228
230 229 assert b''.join(new_response) == expected_response
231 230
232 231
233 232 def test_build_failed_pre_pull_response_no_sideband(pygrack_instance):
234 233 response = pygrack_instance._build_failed_pre_pull_response([], 'foo')
235 234
236 235 assert response == [pygrack.GitRepository.EMPTY_PACK]
237 236
238 237
239 238 @pytest.mark.parametrize('capabilities', [
240 239 [pygrack.CAPABILITY_SIDE_BAND],
241 240 [pygrack.CAPABILITY_SIDE_BAND_64K],
242 241 [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'],
243 242 ])
244 243 def test_build_failed_pre_pull_response(pygrack_instance, capabilities):
245 244 response = pygrack_instance._build_failed_pre_pull_response(capabilities, 'foo')
246 245
247 246 expected_response = [
248 247 b'0008NAK\n', b'0008\x02foo', b'0024\x02Pre pull hook failed: aborting\n',
249 248 b'%04x\x01%s' % (len(pygrack.GitRepository.EMPTY_PACK) + 5, pygrack.GitRepository.EMPTY_PACK),
250 249 pygrack.GitRepository.FLUSH_PACKET,
251 250 ]
252 251
253 252 assert response == expected_response
254 253
255 254
256 255 def test_inject_messages_to_response_generator(pygrack_instance):
257 256
258 257 def response_generator():
259 258 response = [
260 259 # protocol start
261 260 b'0008NAK\n',
262 261 ]
263 262 response += [ascii_bytes(f'000asubp{x}\n') for x in range(1000)]
264 263 response += [
265 264 # protocol end
266 265 pygrack.GitRepository.FLUSH_PACKET
267 266 ]
268 267 for elem in response:
269 268 yield elem
270 269
271 270 new_response = pygrack_instance._build_post_pull_response(
272 271 response_generator(), [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'], 'PRE_PULL_MSG\n', 'POST_PULL_MSG\n')
273 272
274 273 assert iter(new_response)
275 274
276 275 expected_response = b''.join([
277 276 # start
278 277 b'0008NAK\n0012\x02PRE_PULL_MSG\n',
279 278 ] + [
280 279 # ... rest
281 280 ascii_bytes(f'000asubp{x}\n') for x in range(1000)
282 281 ] + [
283 282 # final message,
284 283 b'0013\x02POST_PULL_MSG\n0000',
285 284
286 285 ])
287 286
288 287 assert b''.join(new_response) == expected_response
@@ -1,87 +1,87 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19
20 20 import mercurial.hg
21 21 import mercurial.ui
22 22 import mercurial.error
23 23 import mock
24 24 import pytest
25 25 import webtest
26 26
27 27 from vcsserver import scm_app
28 from vcsserver.utils import ascii_bytes
28 from vcsserver.str_utils import ascii_bytes
29 29
30 30
31 31 def test_hg_does_not_accept_invalid_cmd(tmpdir):
32 32 repo = mercurial.hg.repository(mercurial.ui.ui(), ascii_bytes(str(tmpdir)), create=True)
33 33 app = webtest.TestApp(scm_app.HgWeb(repo))
34 34
35 35 response = app.get('/repo?cmd=invalidcmd', expect_errors=True)
36 36
37 37 assert response.status_int == 400
38 38
39 39
40 40 def test_create_hg_wsgi_app_requirement_error(tmpdir):
41 41 repo = mercurial.hg.repository(mercurial.ui.ui(), ascii_bytes(str(tmpdir)), create=True)
42 42 config = (
43 43 ('paths', 'default', ''),
44 44 )
45 45 with mock.patch('vcsserver.scm_app.HgWeb') as hgweb_mock:
46 46 hgweb_mock.side_effect = mercurial.error.RequirementError()
47 47 with pytest.raises(Exception):
48 48 scm_app.create_hg_wsgi_app(str(tmpdir), repo, config)
49 49
50 50
51 51 def test_git_returns_not_found(tmpdir):
52 52 app = webtest.TestApp(
53 53 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
54 54
55 55 response = app.get('/repo_name/inforefs?service=git-upload-pack',
56 56 expect_errors=True)
57 57
58 58 assert response.status_int == 404
59 59
60 60
61 61 def test_git(tmpdir):
62 62 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
63 63 tmpdir.mkdir(dir_name)
64 64
65 65 app = webtest.TestApp(
66 66 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
67 67
68 68 # We set service to git-upload-packs to trigger a 403
69 69 response = app.get('/repo_name/inforefs?service=git-upload-packs',
70 70 expect_errors=True)
71 71
72 72 assert response.status_int == 403
73 73
74 74
75 75 def test_git_fallbacks_to_git_folder(tmpdir):
76 76 tmpdir.mkdir('.git')
77 77 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
78 78 tmpdir.mkdir(os.path.join('.git', dir_name))
79 79
80 80 app = webtest.TestApp(
81 81 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
82 82
83 83 # We set service to git-upload-packs to trigger a 403
84 84 response = app.get('/repo_name/inforefs?service=git-upload-packs',
85 85 expect_errors=True)
86 86
87 87 assert response.status_int == 403
@@ -1,155 +1,155 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21
22 22 import pytest
23 23
24 24 from vcsserver import subprocessio
25 from vcsserver.utils import ascii_bytes
25 from vcsserver.str_utils import ascii_bytes
26 26
27 27
28 28 class FileLikeObj(object): # pragma: no cover
29 29
30 30 def __init__(self, data: bytes, size):
31 31 chunks = size // len(data)
32 32
33 33 self.stream = self._get_stream(data, chunks)
34 34
35 35 def _get_stream(self, data, chunks):
36 36 for x in range(chunks):
37 37 yield data
38 38
39 39 def read(self, n):
40 40
41 41 buffer_stream = b''
42 42 for chunk in self.stream:
43 43 buffer_stream += chunk
44 44 if len(buffer_stream) >= n:
45 45 break
46 46
47 47 # self.stream = self.bytes[n:]
48 48 return buffer_stream
49 49
50 50
51 51 @pytest.fixture(scope='module')
52 52 def environ():
53 53 """Delete coverage variables, as they make the tests fail."""
54 54 env = dict(os.environ)
55 55 for key in env.keys():
56 56 if key.startswith('COV_CORE_'):
57 57 del env[key]
58 58
59 59 return env
60 60
61 61
62 62 def _get_python_args(script):
63 63 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
64 64
65 65
66 66 def test_raise_exception_on_non_zero_return_code(environ):
67 67 call_args = _get_python_args('raise ValueError("fail")')
68 68 with pytest.raises(OSError):
69 69 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
70 70
71 71
72 72 def test_does_not_fail_on_non_zero_return_code(environ):
73 73 call_args = _get_python_args('sys.stdout.write("hello"); sys.exit(1)')
74 74 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
75 75 output = b''.join(proc)
76 76
77 77 assert output == b'hello'
78 78
79 79
80 80 def test_raise_exception_on_stderr(environ):
81 81 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); time.sleep(1);')
82 82
83 83 with pytest.raises(OSError) as excinfo:
84 84 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
85 85
86 86 assert 'exited due to an error:\nWRITE_TO_STDERR' in str(excinfo.value)
87 87
88 88
89 89 def test_does_not_fail_on_stderr(environ):
90 90 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); sys.stderr.flush; time.sleep(2);')
91 91 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_stderr=False, env=environ)
92 92 output = b''.join(proc)
93 93
94 94 assert output == b''
95 95
96 96
97 97 @pytest.mark.parametrize('size', [
98 98 1,
99 99 10 ** 5
100 100 ])
101 101 def test_output_with_no_input(size, environ):
102 102 call_args = _get_python_args(f'sys.stdout.write("X" * {size});')
103 103 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ)
104 104 output = b''.join(proc)
105 105
106 106 assert output == ascii_bytes("X" * size)
107 107
108 108
109 109 @pytest.mark.parametrize('size', [
110 110 1,
111 111 10 ** 5
112 112 ])
113 113 def test_output_with_no_input_does_not_fail(size, environ):
114 114
115 115 call_args = _get_python_args(f'sys.stdout.write("X" * {size}); sys.exit(1)')
116 116 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
117 117 output = b''.join(proc)
118 118
119 119 assert output == ascii_bytes("X" * size)
120 120
121 121
122 122 @pytest.mark.parametrize('size', [
123 123 1,
124 124 10 ** 5
125 125 ])
126 126 def test_output_with_input(size, environ):
127 127 data_len = size
128 128 inputstream = FileLikeObj(b'X', size)
129 129
130 130 # This acts like the cat command.
131 131 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 132 # note: in this tests we explicitly don't assign chunker to a variable and let it stream directly
133 133 output = b''.join(
134 134 subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
135 135 )
136 136
137 137 assert len(output) == data_len
138 138
139 139
140 140 @pytest.mark.parametrize('size', [
141 141 1,
142 142 10 ** 5
143 143 ])
144 144 def test_output_with_input_skipping_iterator(size, environ):
145 145 data_len = size
146 146 inputstream = FileLikeObj(b'X', size)
147 147
148 148 # This acts like the cat command.
149 149 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
150 150
151 151 # Note: assigning the chunker makes sure that it is not deleted too early
152 152 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
153 153 output = b''.join(proc.stdout)
154 154
155 155 assert len(output) == data_len
@@ -1,103 +1,103 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import mock
20 20 import pytest
21 21 import sys
22 22
23 from vcsserver.utils import ascii_bytes
23 from vcsserver.str_utils import ascii_bytes
24 24
25 25
26 26 class MockPopen(object):
27 27 def __init__(self, stderr):
28 28 self.stdout = io.BytesIO(b'')
29 29 self.stderr = io.BytesIO(stderr)
30 30 self.returncode = 1
31 31
32 32 def wait(self):
33 33 pass
34 34
35 35
36 36 INVALID_CERTIFICATE_STDERR = '\n'.join([
37 37 'svnrdump: E230001: Unable to connect to a repository at URL url',
38 38 'svnrdump: E230001: Server SSL certificate verification failed: issuer is not trusted',
39 39 ])
40 40
41 41
42 42 @pytest.mark.parametrize('stderr,expected_reason', [
43 43 (INVALID_CERTIFICATE_STDERR, 'INVALID_CERTIFICATE'),
44 44 ('svnrdump: E123456', 'UNKNOWN:svnrdump: E123456'),
45 45 ], ids=['invalid-cert-stderr', 'svnrdump-err-123456'])
46 46 @pytest.mark.xfail(sys.platform == "cygwin",
47 47 reason="SVN not packaged for Cygwin")
48 48 def test_import_remote_repository_certificate_error(stderr, expected_reason):
49 49 from vcsserver.remote import svn
50 50 factory = mock.Mock()
51 51 factory.repo = mock.Mock(return_value=mock.Mock())
52 52
53 53 remote = svn.SvnRemote(factory)
54 54 remote.is_path_valid_repository = lambda wire, path: True
55 55
56 56 with mock.patch('subprocess.Popen',
57 57 return_value=MockPopen(ascii_bytes(stderr))):
58 58 with pytest.raises(Exception) as excinfo:
59 59 remote.import_remote_repository({'path': 'path'}, 'url')
60 60
61 61 expected_error_args = 'Failed to dump the remote repository from url. Reason:{}'.format(expected_reason)
62 62
63 63 assert excinfo.value.args[0] == expected_error_args
64 64
65 65
66 66 def test_svn_libraries_can_be_imported():
67 67 import svn.client
68 68 assert svn.client is not None
69 69
70 70
71 71 @pytest.mark.parametrize('example_url, parts', [
72 72 ('http://server.com', (None, None, 'http://server.com')),
73 73 ('http://user@server.com', ('user', None, 'http://user@server.com')),
74 74 ('http://user:pass@server.com', ('user', 'pass', 'http://user:pass@server.com')),
75 75 ('<script>', (None, None, '<script>')),
76 76 ('http://', (None, None, 'http://')),
77 77 ])
78 78 def test_username_password_extraction_from_url(example_url, parts):
79 79 from vcsserver.remote import svn
80 80
81 81 factory = mock.Mock()
82 82 factory.repo = mock.Mock(return_value=mock.Mock())
83 83
84 84 remote = svn.SvnRemote(factory)
85 85 remote.is_path_valid_repository = lambda wire, path: True
86 86
87 87 assert remote.get_url_and_credentials(example_url) == parts
88 88
89 89
90 90 @pytest.mark.parametrize('call_url', [
91 91 b'https://svn.code.sf.net/p/svnbook/source/trunk/',
92 92 b'https://marcink@svn.code.sf.net/p/svnbook/source/trunk/',
93 93 b'https://marcink:qweqwe@svn.code.sf.net/p/svnbook/source/trunk/',
94 94 ])
95 95 def test_check_url(call_url):
96 96 from vcsserver.remote import svn
97 97 factory = mock.Mock()
98 98 factory.repo = mock.Mock(return_value=mock.Mock())
99 99
100 100 remote = svn.SvnRemote(factory)
101 101 remote.is_path_valid_repository = lambda wire, path: True
102 102 assert remote.check_url(call_url)
103 103
@@ -1,53 +1,53 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import pytest
19 from vcsserver.utils import ascii_bytes, ascii_str
19 from vcsserver.str_utils import ascii_bytes, ascii_str
20 20
21 21
22 22 @pytest.mark.parametrize('given, expected', [
23 23 ('a', b'a'),
24 24 ('a', b'a'),
25 25 ])
26 26 def test_ascii_bytes(given, expected):
27 27 assert ascii_bytes(given) == expected
28 28
29 29
30 30 @pytest.mark.parametrize('given', [
31 31 'Γ₯',
32 32 'Γ₯'.encode('utf8')
33 33 ])
34 34 def test_ascii_bytes_raises(given):
35 35 with pytest.raises(ValueError):
36 36 ascii_bytes(given)
37 37
38 38
39 39 @pytest.mark.parametrize('given, expected', [
40 40 (b'a', 'a'),
41 41 ])
42 42 def test_ascii_str(given, expected):
43 43 assert ascii_str(given) == expected
44 44
45 45
46 46 @pytest.mark.parametrize('given', [
47 47 'a',
48 48 'Γ₯'.encode('utf8'),
49 49 'Γ₯'
50 50 ])
51 51 def test_ascii_str_raises(given):
52 52 with pytest.raises(ValueError):
53 53 ascii_str(given)
@@ -1,98 +1,98 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import wsgiref.simple_server
19 19 import wsgiref.validate
20 20
21 21 from vcsserver import wsgi_app_caller
22 from vcsserver.utils import ascii_bytes, safe_str
22 from vcsserver.str_utils import ascii_bytes, safe_str
23 23
24 24
25 25 @wsgiref.validate.validator
26 26 def demo_app(environ, start_response):
27 27 """WSGI app used for testing."""
28 28
29 29 input_data = safe_str(environ['wsgi.input'].read(1024))
30 30
31 31 data = [
32 32 f'Hello World!\n',
33 33 f'input_data={input_data}\n',
34 34 ]
35 35 for key, value in sorted(environ.items()):
36 36 data.append(f'{key}={value}\n')
37 37
38 38 write = start_response("200 OK", [('Content-Type', 'text/plain')])
39 39 write(b'Old school write method\n')
40 40 write(b'***********************\n')
41 41 return list(map(ascii_bytes, data))
42 42
43 43
44 44 BASE_ENVIRON = {
45 45 'REQUEST_METHOD': 'GET',
46 46 'SERVER_NAME': 'localhost',
47 47 'SERVER_PORT': '80',
48 48 'SCRIPT_NAME': '',
49 49 'PATH_INFO': '/',
50 50 'QUERY_STRING': '',
51 51 'foo.var': 'bla',
52 52 }
53 53
54 54
55 55 def test_complete_environ():
56 56 environ = dict(BASE_ENVIRON)
57 57 data = b"data"
58 58 wsgi_app_caller._complete_environ(environ, data)
59 59 wsgiref.validate.check_environ(environ)
60 60
61 61 assert data == environ['wsgi.input'].read(1024)
62 62
63 63
64 64 def test_start_response():
65 65 start_response = wsgi_app_caller._StartResponse()
66 66 status = '200 OK'
67 67 headers = [('Content-Type', 'text/plain')]
68 68 start_response(status, headers)
69 69
70 70 assert status == start_response.status
71 71 assert headers == start_response.headers
72 72
73 73
74 74 def test_start_response_with_error():
75 75 start_response = wsgi_app_caller._StartResponse()
76 76 status = '500 Internal Server Error'
77 77 headers = [('Content-Type', 'text/plain')]
78 78 start_response(status, headers, (None, None, None))
79 79
80 80 assert status == start_response.status
81 81 assert headers == start_response.headers
82 82
83 83
84 84 def test_wsgi_app_caller():
85 85 environ = dict(BASE_ENVIRON)
86 86 input_data = 'some text'
87 87
88 88 caller = wsgi_app_caller.WSGIAppCaller(demo_app)
89 89 responses, status, headers = caller.handle(environ, input_data)
90 90 response = b''.join(responses)
91 91
92 92 assert status == '200 OK'
93 93 assert headers == [('Content-Type', 'text/plain')]
94 94 assert response.startswith(b'Old school write method\n***********************\n')
95 95 assert b'Hello World!\n' in response
96 96 assert b'foo.var=bla\n' in response
97 97
98 98 assert ascii_bytes(f'input_data={input_data}\n') in response
@@ -1,106 +1,106 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import time
19 19 import logging
20 20
21 21 import vcsserver
22 from vcsserver.utils import safe_str, ascii_str
22 from vcsserver.str_utils import safe_str, ascii_str
23 23
24 24 log = logging.getLogger(__name__)
25 25
26 26
27 27 def get_access_path(environ):
28 28 path = environ.get('PATH_INFO')
29 29 return path
30 30
31 31
32 32 def get_user_agent(environ):
33 33 return environ.get('HTTP_USER_AGENT')
34 34
35 35
36 36 def get_vcs_method(environ):
37 37 return environ.get('HTTP_X_RC_METHOD')
38 38
39 39
40 40 def get_vcs_repo(environ):
41 41 return environ.get('HTTP_X_RC_REPO_NAME')
42 42
43 43
44 44 class RequestWrapperTween(object):
45 45 def __init__(self, handler, registry):
46 46 self.handler = handler
47 47 self.registry = registry
48 48
49 49 # one-time configuration code goes here
50 50
51 51 def __call__(self, request):
52 52 start = time.time()
53 53 log.debug('Starting request time measurement')
54 54 response = None
55 55
56 56 ua = get_user_agent(request.environ)
57 57 vcs_method = get_vcs_method(request.environ)
58 58 repo_name = get_vcs_repo(request.environ)
59 59
60 60 try:
61 61 response = self.handler(request)
62 62 finally:
63 63 count = request.request_count()
64 64 _ver_ = ascii_str(vcsserver.__version__)
65 65 _path = safe_str(get_access_path(request.environ))
66 66 ip = '127.0.0.1'
67 67 match_route = request.matched_route.name if request.matched_route else "NOT_FOUND"
68 68 resp_code = getattr(response, 'status_code', 'UNDEFINED')
69 69
70 70 total = time.time() - start
71 71
72 72 _view_path = f"{repo_name}@{_path}/{vcs_method}"
73 73 log.info(
74 74 'Req[%4s] IP: %s %s Request to %s time: %.4fs [%s], VCSServer %s',
75 75 count, ip, request.environ.get('REQUEST_METHOD'),
76 76 _view_path, total, ua, _ver_,
77 77 extra={"time": total, "ver": _ver_, "code": resp_code,
78 78 "path": _path, "view_name": match_route, "user_agent": ua,
79 79 "vcs_method": vcs_method, "repo_name": repo_name}
80 80 )
81 81
82 82 statsd = request.registry.statsd
83 83 if statsd:
84 84 match_route = request.matched_route.name if request.matched_route else _path
85 85 elapsed_time_ms = round(1000.0 * total) # use ms only
86 86 statsd.timing(
87 87 "vcsserver_req_timing.histogram", elapsed_time_ms,
88 88 tags=[
89 89 "view_name:{}".format(match_route),
90 90 "code:{}".format(resp_code)
91 91 ],
92 92 use_decimals=False
93 93 )
94 94 statsd.incr(
95 95 "vcsserver_req_total", tags=[
96 96 "view_name:{}".format(match_route),
97 97 "code:{}".format(resp_code)
98 98 ])
99 99
100 100 return response
101 101
102 102
103 103 def includeme(config):
104 104 config.add_tween(
105 105 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
106 106 )
@@ -1,137 +1,54 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import logging
18 18 import hashlib
19 19
20 20 log = logging.getLogger(__name__)
21 21
22 22
23 def safe_int(val, default=None):
24 """
25 Returns int() of val if val is not convertable to int use default
26 instead
27
28 :param val:
29 :param default:
30 """
31
32 try:
33 val = int(val)
34 except (ValueError, TypeError):
35 val = default
36
37 return val
38
39
40 def safe_str(str_, to_encoding=None) -> str:
41 """
42 safe str function. Does few trick to turn unicode_ into string
43
44 :param str_: str to encode
45 :param to_encoding: encode to this type UTF8 default
46 :rtype: str
47 :returns: str object
48 """
49 if isinstance(str_, str):
50 return str_
51
52 # if it's bytes cast to str
53 if not isinstance(str_, bytes):
54 return str(str_)
55
56 to_encoding = to_encoding or ['utf8']
57 if not isinstance(to_encoding, (list, tuple)):
58 to_encoding = [to_encoding]
59
60 for enc in to_encoding:
61 try:
62 return str(str_, enc)
63 except UnicodeDecodeError:
64 pass
65
66 return str(str_, to_encoding[0], 'replace')
67
68
69 def safe_bytes(str_, from_encoding=None) -> bytes:
70 """
71 safe bytes function. Does few trick to turn str_ into bytes string:
23 class AttributeDictBase(dict):
24 def __getstate__(self):
25 odict = self.__dict__ # get attribute dictionary
26 return odict
72 27
73 :param str_: string to decode
74 :param from_encoding: encode from this type UTF8 default
75 :rtype: unicode
76 :returns: unicode object
77 """
78 if isinstance(str_, bytes):
79 return str_
80
81 if not isinstance(str_, str):
82 raise ValueError('safe_bytes cannot convert other types than str: got: {}'.format(type(str_)))
83
84 from_encoding = from_encoding or ['utf8']
85 if not isinstance(from_encoding, (list, tuple)):
86 from_encoding = [from_encoding]
87
88 for enc in from_encoding:
89 try:
90 return str_.encode(enc)
91 except UnicodeDecodeError:
92 pass
93
94 return str_.encode(from_encoding[0], 'replace')
95
28 def __setstate__(self, dict):
29 self.__dict__ = dict
96 30
97 def ascii_bytes(str_, allow_bytes=False) -> bytes:
98 """
99 Simple conversion from str to bytes, with assumption that str_ is pure ASCII.
100 Fails with UnicodeError on invalid input.
101 This should be used where encoding and "safe" ambiguity should be avoided.
102 Where strings already have been encoded in other ways but still are unicode
103 string - for example to hex, base64, json, urlencoding, or are known to be
104 identifiers.
105 """
106 if allow_bytes and isinstance(str_, bytes):
107 return str_
108
109 if not isinstance(str_, str):
110 raise ValueError('ascii_bytes cannot convert other types than str: got: {}'.format(type(str_)))
111 return str_.encode('ascii')
112
113
114 def ascii_str(str_):
115 """
116 Simple conversion from bytes to str, with assumption that str_ is pure ASCII.
117 Fails with UnicodeError on invalid input.
118 This should be used where encoding and "safe" ambiguity should be avoided.
119 Where strings are encoded but also in other ways are known to be ASCII, and
120 where a unicode string is wanted without caring about encoding. For example
121 to hex, base64, urlencoding, or are known to be identifiers.
122 """
123
124 if not isinstance(str_, bytes):
125 raise ValueError('ascii_str cannot convert other types than bytes: got: {}'.format(type(str_)))
126 return str_.decode('ascii')
127
128
129 class AttributeDict(dict):
130 def __getattr__(self, attr):
131 return self.get(attr, None)
132 31 __setattr__ = dict.__setitem__
133 32 __delattr__ = dict.__delitem__
134 33
135 34
35 class StrictAttributeDict(AttributeDictBase):
36 """
37 Strict Version of Attribute dict which raises an Attribute error when
38 requested attribute is not set
39 """
40 def __getattr__(self, attr):
41 try:
42 return self[attr]
43 except KeyError:
44 raise AttributeError('%s object has no attribute %s' % (
45 self.__class__, attr))
46
47
48 class AttributeDict(AttributeDictBase):
49 def __getattr__(self, attr):
50 return self.get(attr, None)
51
52
136 53 def sha1(val):
137 54 return hashlib.sha1(val).hexdigest()
@@ -1,116 +1,116 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Extract the responses of a WSGI app."""
19 19
20 20 __all__ = ('WSGIAppCaller',)
21 21
22 22 import io
23 23 import logging
24 24 import os
25 25
26 from vcsserver.utils import ascii_bytes
26 from vcsserver.str_utils import ascii_bytes
27 27
28 28 log = logging.getLogger(__name__)
29 29
30 30 DEV_NULL = open(os.devnull)
31 31
32 32
33 33 def _complete_environ(environ, input_data: bytes):
34 34 """Update the missing wsgi.* variables of a WSGI environment.
35 35
36 36 :param environ: WSGI environment to update
37 37 :type environ: dict
38 38 :param input_data: data to be read by the app
39 39 :type input_data: bytes
40 40 """
41 41 environ.update({
42 42 'wsgi.version': (1, 0),
43 43 'wsgi.url_scheme': 'http',
44 44 'wsgi.multithread': True,
45 45 'wsgi.multiprocess': True,
46 46 'wsgi.run_once': False,
47 47 'wsgi.input': io.BytesIO(input_data),
48 48 'wsgi.errors': DEV_NULL,
49 49 })
50 50
51 51
52 52 # pylint: disable=too-few-public-methods
53 53 class _StartResponse(object):
54 54 """Save the arguments of a start_response call."""
55 55
56 56 __slots__ = ['status', 'headers', 'content']
57 57
58 58 def __init__(self):
59 59 self.status = None
60 60 self.headers = None
61 61 self.content = []
62 62
63 63 def __call__(self, status, headers, exc_info=None):
64 64 # TODO(skreft): do something meaningful with the exc_info
65 65 exc_info = None # avoid dangling circular reference
66 66 self.status = status
67 67 self.headers = headers
68 68
69 69 return self.write
70 70
71 71 def write(self, content):
72 72 """Write method returning when calling this object.
73 73
74 74 All the data written is then available in content.
75 75 """
76 76 self.content.append(content)
77 77
78 78
79 79 class WSGIAppCaller(object):
80 80 """Calls a WSGI app."""
81 81
82 82 def __init__(self, app):
83 83 """
84 84 :param app: WSGI app to call
85 85 """
86 86 self.app = app
87 87
88 88 def handle(self, environ, input_data):
89 89 """Process a request with the WSGI app.
90 90
91 91 The returned data of the app is fully consumed into a list.
92 92
93 93 :param environ: WSGI environment to update
94 94 :type environ: dict
95 95 :param input_data: data to be read by the app
96 96 :type input_data: str/bytes
97 97
98 98 :returns: a tuple with the contents, status and headers
99 99 :rtype: (list<str>, str, list<(str, str)>)
100 100 """
101 101 _complete_environ(environ, ascii_bytes(input_data, allow_bytes=True))
102 102 start_response = _StartResponse()
103 103 log.debug("Calling wrapped WSGI application")
104 104 responses = self.app(environ, start_response)
105 105 responses_list = list(responses)
106 106 existing_responses = start_response.content
107 107 if existing_responses:
108 108 log.debug("Adding returned response to response written via write()")
109 109 existing_responses.extend(responses_list)
110 110 responses_list = existing_responses
111 111 if hasattr(responses, 'close'):
112 112 log.debug("Closing iterator from WSGI application")
113 113 responses.close()
114 114
115 115 log.debug("Handling of WSGI request done, returning response")
116 116 return responses_list, start_response.status, start_response.headers
General Comments 0
You need to be logged in to leave comments. Login now