##// END OF EJS Templates
core: moved str_utils and type_utils to lib so it's consistent with ce, and easier to sync up codebases
super-admin -
r1249:5745b11f default
parent child Browse files
Show More
@@ -1,187 +1,187 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import os
18 18 import sys
19 19 import tempfile
20 20 import logging
21 21 import urllib.parse
22 22
23 23 from vcsserver.lib.rc_cache.archive_cache import get_archival_cache_store
24 24
25 25 from vcsserver import exceptions
26 26 from vcsserver.exceptions import NoContentException
27 27 from vcsserver.hgcompat import archival
28 from vcsserver.str_utils import safe_bytes
28 from vcsserver.lib.str_utils import safe_bytes
29 29 from vcsserver.lib.exc_tracking import format_exc
30 30 log = logging.getLogger(__name__)
31 31
32 32
33 33 class RepoFactory:
34 34 """
35 35 Utility to create instances of repository
36 36
37 37 It provides internal caching of the `repo` object based on
38 38 the :term:`call context`.
39 39 """
40 40 repo_type = None
41 41
42 42 def __init__(self):
43 43 pass
44 44
45 45 def _create_config(self, path, config):
46 46 config = {}
47 47 return config
48 48
49 49 def _create_repo(self, wire, create):
50 50 raise NotImplementedError()
51 51
52 52 def repo(self, wire, create=False):
53 53 raise NotImplementedError()
54 54
55 55
56 56 def obfuscate_qs(query_string):
57 57 if query_string is None:
58 58 return None
59 59
60 60 parsed = []
61 61 for k, v in urllib.parse.parse_qsl(query_string, keep_blank_values=True):
62 62 if k in ['auth_token', 'api_key']:
63 63 v = "*****"
64 64 parsed.append((k, v))
65 65
66 66 return '&'.join('{}{}'.format(
67 67 k, f'={v}' if v else '') for k, v in parsed)
68 68
69 69
70 70 def raise_from_original(new_type, org_exc: Exception):
71 71 """
72 72 Raise a new exception type with original args and traceback.
73 73 """
74 74 exc_info = sys.exc_info()
75 75 exc_type, exc_value, exc_traceback = exc_info
76 76 new_exc = new_type(*exc_value.args)
77 77
78 78 # store the original traceback into the new exc
79 79 new_exc._org_exc_tb = format_exc(exc_info)
80 80
81 81 try:
82 82 raise new_exc.with_traceback(exc_traceback)
83 83 finally:
84 84 del exc_traceback
85 85
86 86
87 87 class ArchiveNode:
88 88 def __init__(self, path, mode, is_link, raw_bytes):
89 89 self.path = path
90 90 self.mode = mode
91 91 self.is_link = is_link
92 92 self.raw_bytes = raw_bytes
93 93
94 94
95 95 def store_archive_in_cache(node_walker, archive_key, kind, mtime, archive_at_path, archive_dir_name,
96 96 commit_id, write_metadata=True, extra_metadata=None, cache_config=None):
97 97 """
98 98 Function that would store generate archive and send it to a dedicated backend store
99 99 In here we use diskcache
100 100
101 101 :param node_walker: a generator returning nodes to add to archive
102 102 :param archive_key: key used to store the path
103 103 :param kind: archive kind
104 104 :param mtime: time of creation
105 105 :param archive_at_path: default '/' the path at archive was started.
106 106 If this is not '/' it means it's a partial archive
107 107 :param archive_dir_name: inside dir name when creating an archive
108 108 :param commit_id: commit sha of revision archive was created at
109 109 :param write_metadata:
110 110 :param extra_metadata:
111 111 :param cache_config:
112 112
113 113 walker should be a file walker, for example,
114 114 def node_walker():
115 115 for file_info in files:
116 116 yield ArchiveNode(fn, mode, is_link, ctx[fn].data)
117 117 """
118 118 extra_metadata = extra_metadata or {}
119 119
120 120 d_cache = get_archival_cache_store(config=cache_config)
121 121
122 122 if archive_key in d_cache:
123 123 reader, metadata = d_cache.fetch(archive_key)
124 124 return reader.name
125 125
126 126 archive_tmp_path = safe_bytes(tempfile.mkstemp()[1])
127 127 log.debug('Creating new temp archive in %s', archive_tmp_path)
128 128
129 129 if kind == "tgz":
130 130 archiver = archival.tarit(archive_tmp_path, mtime, b"gz")
131 131 elif kind == "tbz2":
132 132 archiver = archival.tarit(archive_tmp_path, mtime, b"bz2")
133 133 elif kind == 'zip':
134 134 archiver = archival.zipit(archive_tmp_path, mtime)
135 135 else:
136 136 raise exceptions.ArchiveException()(
137 137 f'Remote does not support: "{kind}" archive type.')
138 138
139 139 for f in node_walker(commit_id, archive_at_path):
140 140 f_path = os.path.join(safe_bytes(archive_dir_name), safe_bytes(f.path).lstrip(b'/'))
141 141
142 142 try:
143 143 archiver.addfile(f_path, f.mode, f.is_link, f.raw_bytes())
144 144 except NoContentException:
145 145 # NOTE(marcink): this is a special case for SVN so we can create "empty"
146 146 # directories which are not supported by archiver
147 147 archiver.addfile(os.path.join(f_path, b'.dir'), f.mode, f.is_link, b'')
148 148
149 149 metadata = dict([
150 150 ('commit_id', commit_id),
151 151 ('mtime', mtime),
152 152 ])
153 153 metadata.update(extra_metadata)
154 154 if write_metadata:
155 155 meta = [safe_bytes(f"{f_name}:{value}") for f_name, value in metadata.items()]
156 156 f_path = os.path.join(safe_bytes(archive_dir_name), b'.archival.txt')
157 157 archiver.addfile(f_path, 0o644, False, b'\n'.join(meta))
158 158
159 159 archiver.done()
160 160
161 161 with open(archive_tmp_path, 'rb') as archive_file:
162 162 add_result = d_cache.store(archive_key, archive_file, metadata=metadata)
163 163 if not add_result:
164 164 log.error('Failed to store cache for key=%s', archive_key)
165 165
166 166 os.remove(archive_tmp_path)
167 167
168 168 reader, metadata = d_cache.fetch(archive_key)
169 169
170 170 return reader.name
171 171
172 172
173 173 class BinaryEnvelope:
174 174 def __init__(self, val):
175 175 self.val = val
176 176
177 177
178 178 class BytesEnvelope(bytes):
179 179 def __new__(cls, content):
180 180 if isinstance(content, bytes):
181 181 return super().__new__(cls, content)
182 182 else:
183 183 raise TypeError('BytesEnvelope content= param must be bytes. Use BinaryEnvelope to wrap other types')
184 184
185 185
186 186 class BinaryBytesEnvelope(BytesEnvelope):
187 187 pass
@@ -1,185 +1,185 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import textwrap
21 21 import string
22 22 import functools
23 23 import logging
24 24 import tempfile
25 25 import logging.config
26 26
27 from vcsserver.type_utils import str2bool, aslist
27 from vcsserver.lib.type_utils import str2bool, aslist
28 28
29 29 log = logging.getLogger(__name__)
30 30
31 31
32 32 # skip keys, that are set here, so we don't double process those
33 33 set_keys = {
34 34 '__file__': ''
35 35 }
36 36
37 37
38 38 class SettingsMaker:
39 39
40 40 def __init__(self, app_settings):
41 41 self.settings = app_settings
42 42
43 43 @classmethod
44 44 def _bool_func(cls, input_val):
45 45 if isinstance(input_val, bytes):
46 46 # decode to str
47 47 input_val = input_val.decode('utf8')
48 48 return str2bool(input_val)
49 49
50 50 @classmethod
51 51 def _int_func(cls, input_val):
52 52 return int(input_val)
53 53
54 54 @classmethod
55 55 def _float_func(cls, input_val):
56 56 return float(input_val)
57 57
58 58 @classmethod
59 59 def _list_func(cls, input_val, sep=','):
60 60 return aslist(input_val, sep=sep)
61 61
62 62 @classmethod
63 63 def _string_func(cls, input_val, lower=True):
64 64 if lower:
65 65 input_val = input_val.lower()
66 66 return input_val
67 67
68 68 @classmethod
69 69 def _string_no_quote_func(cls, input_val, lower=True):
70 70 """
71 71 Special case string function that detects if value is set to empty quote string
72 72 e.g.
73 73
74 74 core.binary_dir = ""
75 75 """
76 76
77 77 input_val = cls._string_func(input_val, lower=lower)
78 78 if input_val in ['""', "''"]:
79 79 return ''
80 80 return input_val
81 81
82 82 @classmethod
83 83 def _dir_func(cls, input_val, ensure_dir=False, mode=0o755):
84 84
85 85 # ensure we have our dir created
86 86 if not os.path.isdir(input_val) and ensure_dir:
87 87 os.makedirs(input_val, mode=mode, exist_ok=True)
88 88
89 89 if not os.path.isdir(input_val):
90 90 raise Exception(f'Dir at {input_val} does not exist')
91 91 return input_val
92 92
93 93 @classmethod
94 94 def _file_path_func(cls, input_val, ensure_dir=False, mode=0o755):
95 95 dirname = os.path.dirname(input_val)
96 96 cls._dir_func(dirname, ensure_dir=ensure_dir)
97 97 return input_val
98 98
99 99 @classmethod
100 100 def _key_transformator(cls, key):
101 101 return "{}_{}".format('RC'.upper(), key.upper().replace('.', '_').replace('-', '_'))
102 102
103 103 def maybe_env_key(self, key):
104 104 # now maybe we have this KEY in env, search and use the value with higher priority.
105 105 transformed_key = self._key_transformator(key)
106 106 envvar_value = os.environ.get(transformed_key)
107 107 if envvar_value:
108 108 log.debug('using `%s` key instead of `%s` key for config', transformed_key, key)
109 109
110 110 return envvar_value
111 111
112 112 def env_expand(self):
113 113 replaced = {}
114 114 for k, v in self.settings.items():
115 115 if k not in set_keys:
116 116 envvar_value = self.maybe_env_key(k)
117 117 if envvar_value:
118 118 replaced[k] = envvar_value
119 119 set_keys[k] = envvar_value
120 120
121 121 # replace ALL keys updated
122 122 self.settings.update(replaced)
123 123
124 124 def enable_logging(self, logging_conf=None, level='INFO', formatter='generic'):
125 125 """
126 126 Helper to enable debug on running instance
127 127 :return:
128 128 """
129 129
130 130 if not str2bool(self.settings.get('logging.autoconfigure')):
131 131 log.info('logging configuration based on main .ini file')
132 132 return
133 133
134 134 if logging_conf is None:
135 135 logging_conf = self.settings.get('logging.logging_conf_file') or ''
136 136
137 137 if not os.path.isfile(logging_conf):
138 138 log.error('Unable to setup logging based on %s, '
139 139 'file does not exist.... specify path using logging.logging_conf_file= config setting. ', logging_conf)
140 140 return
141 141
142 142 with open(logging_conf, 'rt') as f:
143 143 ini_template = textwrap.dedent(f.read())
144 144 ini_template = string.Template(ini_template).safe_substitute(
145 145 RC_LOGGING_LEVEL=os.environ.get('RC_LOGGING_LEVEL', '') or level,
146 146 RC_LOGGING_FORMATTER=os.environ.get('RC_LOGGING_FORMATTER', '') or formatter
147 147 )
148 148
149 149 with tempfile.NamedTemporaryFile(prefix='rc_logging_', suffix='.ini', delete=False) as f:
150 150 log.info('Saved Temporary LOGGING config at %s', f.name)
151 151 f.write(ini_template)
152 152
153 153 logging.config.fileConfig(f.name)
154 154 os.remove(f.name)
155 155
156 156 def make_setting(self, key, default, lower=False, default_when_empty=False, parser=None):
157 157 input_val = self.settings.get(key, default)
158 158
159 159 if default_when_empty and not input_val:
160 160 # use default value when value is set in the config but it is empty
161 161 input_val = default
162 162
163 163 parser_func = {
164 164 'bool': self._bool_func,
165 165 'int': self._int_func,
166 166 'float': self._float_func,
167 167 'list': self._list_func,
168 168 'list:newline': functools.partial(self._list_func, sep='/n'),
169 169 'list:spacesep': functools.partial(self._list_func, sep=' '),
170 170 'string': functools.partial(self._string_func, lower=lower),
171 171 'string:noquote': functools.partial(self._string_no_quote_func, lower=lower),
172 172 'dir': self._dir_func,
173 173 'dir:ensured': functools.partial(self._dir_func, ensure_dir=True),
174 174 'file': self._file_path_func,
175 175 'file:ensured': functools.partial(self._file_path_func, ensure_dir=True),
176 176 None: lambda i: i
177 177 }[parser]
178 178
179 179 envvar_value = self.maybe_env_key(key)
180 180 if envvar_value:
181 181 input_val = envvar_value
182 182 set_keys[key] = input_val
183 183
184 184 self.settings[key] = parser_func(input_val)
185 185 return self.settings[key]
@@ -1,296 +1,296 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import re
19 19 import logging
20 20
21 21 from pyramid.config import Configurator
22 22 from pyramid.response import Response, FileIter
23 23 from pyramid.httpexceptions import (
24 24 HTTPBadRequest, HTTPNotImplemented, HTTPNotFound, HTTPForbidden,
25 25 HTTPUnprocessableEntity)
26 26
27 27 from vcsserver.lib.ext_json import json
28 28 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
29 29 from vcsserver.git_lfs.utils import safe_result, get_cython_compat_decorator
30 from vcsserver.str_utils import safe_int
30 from vcsserver.lib.str_utils import safe_int
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 GIT_LFS_CONTENT_TYPE = 'application/vnd.git-lfs' # +json ?
36 36 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
37 37
38 38
39 39 def write_response_error(http_exception, text=None):
40 40 content_type = GIT_LFS_CONTENT_TYPE + '+json'
41 41 _exception = http_exception(content_type=content_type)
42 42 _exception.content_type = content_type
43 43 if text:
44 44 _exception.body = json.dumps({'message': text})
45 45 log.debug('LFS: writing response of type %s to client with text:%s',
46 46 http_exception, text)
47 47 return _exception
48 48
49 49
50 50 class AuthHeaderRequired:
51 51 """
52 52 Decorator to check if request has proper auth-header
53 53 """
54 54
55 55 def __call__(self, func):
56 56 return get_cython_compat_decorator(self.__wrapper, func)
57 57
58 58 def __wrapper(self, func, *fargs, **fkwargs):
59 59 request = fargs[1]
60 60 auth = request.authorization
61 61 if not auth:
62 62 return write_response_error(HTTPForbidden)
63 63 return func(*fargs[1:], **fkwargs)
64 64
65 65
66 66 # views
67 67
68 68 def lfs_objects(request):
69 69 # indicate not supported, V1 API
70 70 log.warning('LFS: v1 api not supported, reporting it back to client')
71 71 return write_response_error(HTTPNotImplemented, 'LFS: v1 api not supported')
72 72
73 73
74 74 @AuthHeaderRequired()
75 75 def lfs_objects_batch(request):
76 76 """
77 77 The client sends the following information to the Batch endpoint to transfer some objects:
78 78
79 79 operation - Should be download or upload.
80 80 transfers - An optional Array of String identifiers for transfer
81 81 adapters that the client has configured. If omitted, the basic
82 82 transfer adapter MUST be assumed by the server.
83 83 objects - An Array of objects to download.
84 84 oid - String OID of the LFS object.
85 85 size - Integer byte size of the LFS object. Must be at least zero.
86 86 """
87 87 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
88 88 auth = request.authorization
89 89 repo = request.matchdict.get('repo')
90 90 data = request.json
91 91 operation = data.get('operation')
92 92 http_scheme = request.registry.git_lfs_http_scheme
93 93
94 94 if operation not in ('download', 'upload'):
95 95 log.debug('LFS: unsupported operation:%s', operation)
96 96 return write_response_error(
97 97 HTTPBadRequest, f'unsupported operation mode: `{operation}`')
98 98
99 99 if 'objects' not in data:
100 100 log.debug('LFS: missing objects data')
101 101 return write_response_error(
102 102 HTTPBadRequest, 'missing objects data')
103 103
104 104 log.debug('LFS: handling operation of type: %s', operation)
105 105
106 106 objects = []
107 107 for o in data['objects']:
108 108 try:
109 109 oid = o['oid']
110 110 obj_size = o['size']
111 111 except KeyError:
112 112 log.exception('LFS, failed to extract data')
113 113 return write_response_error(
114 114 HTTPBadRequest, 'unsupported data in objects')
115 115
116 116 obj_data = {'oid': oid}
117 117 if http_scheme == 'http':
118 118 # Note(marcink): when using http, we might have a custom port
119 119 # so we skip setting it to http, url dispatch then wont generate a port in URL
120 120 # for development we need this
121 121 http_scheme = None
122 122
123 123 obj_href = request.route_url('lfs_objects_oid', repo=repo, oid=oid,
124 124 _scheme=http_scheme)
125 125 obj_verify_href = request.route_url('lfs_objects_verify', repo=repo,
126 126 _scheme=http_scheme)
127 127 store = LFSOidStore(
128 128 oid, repo, store_location=request.registry.git_lfs_store_path)
129 129 handler = OidHandler(
130 130 store, repo, auth, oid, obj_size, obj_data,
131 131 obj_href, obj_verify_href)
132 132
133 133 # this verifies also OIDs
134 134 actions, errors = handler.exec_operation(operation)
135 135 if errors:
136 136 log.warning('LFS: got following errors: %s', errors)
137 137 obj_data['errors'] = errors
138 138
139 139 if actions:
140 140 obj_data['actions'] = actions
141 141
142 142 obj_data['size'] = obj_size
143 143 obj_data['authenticated'] = True
144 144 objects.append(obj_data)
145 145
146 146 result = {'objects': objects, 'transfer': 'basic'}
147 147 log.debug('LFS Response %s', safe_result(result))
148 148
149 149 return result
150 150
151 151
152 152 def lfs_objects_oid_upload(request):
153 153 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
154 154 repo = request.matchdict.get('repo')
155 155 oid = request.matchdict.get('oid')
156 156 store = LFSOidStore(
157 157 oid, repo, store_location=request.registry.git_lfs_store_path)
158 158 engine = store.get_engine(mode='wb')
159 159 log.debug('LFS: starting chunked write of LFS oid: %s to storage', oid)
160 160
161 161 body = request.environ['wsgi.input']
162 162
163 163 with engine as f:
164 164 blksize = 64 * 1024 # 64kb
165 165 while True:
166 166 # read in chunks as stream comes in from Gunicorn
167 167 # this is a specific Gunicorn support function.
168 168 # might work differently on waitress
169 169 chunk = body.read(blksize)
170 170 if not chunk:
171 171 break
172 172 f.write(chunk)
173 173
174 174 return {'upload': 'ok'}
175 175
176 176
177 177 def lfs_objects_oid_download(request):
178 178 repo = request.matchdict.get('repo')
179 179 oid = request.matchdict.get('oid')
180 180
181 181 store = LFSOidStore(
182 182 oid, repo, store_location=request.registry.git_lfs_store_path)
183 183 if not store.has_oid():
184 184 log.debug('LFS: oid %s does not exists in store', oid)
185 185 return write_response_error(
186 186 HTTPNotFound, f'requested file with oid `{oid}` not found in store')
187 187
188 188 # TODO(marcink): support range header ?
189 189 # Range: bytes=0-, `bytes=(\d+)\-.*`
190 190
191 191 f = open(store.oid_path, 'rb')
192 192 response = Response(
193 193 content_type='application/octet-stream', app_iter=FileIter(f))
194 194 response.headers.add('X-RC-LFS-Response-Oid', str(oid))
195 195 return response
196 196
197 197
198 198 def lfs_objects_verify(request):
199 199 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
200 200 repo = request.matchdict.get('repo')
201 201
202 202 data = request.json
203 203 oid = data.get('oid')
204 204 size = safe_int(data.get('size'))
205 205
206 206 if not (oid and size):
207 207 return write_response_error(
208 208 HTTPBadRequest, 'missing oid and size in request data')
209 209
210 210 store = LFSOidStore(
211 211 oid, repo, store_location=request.registry.git_lfs_store_path)
212 212 if not store.has_oid():
213 213 log.debug('LFS: oid %s does not exists in store', oid)
214 214 return write_response_error(
215 215 HTTPNotFound, f'oid `{oid}` does not exists in store')
216 216
217 217 store_size = store.size_oid()
218 218 if store_size != size:
219 219 msg = 'requested file size mismatch store size:{} requested:{}'.format(
220 220 store_size, size)
221 221 return write_response_error(
222 222 HTTPUnprocessableEntity, msg)
223 223
224 224 return {'message': {'size': 'ok', 'in_store': 'ok'}}
225 225
226 226
227 227 def lfs_objects_lock(request):
228 228 return write_response_error(
229 229 HTTPNotImplemented, 'GIT LFS locking api not supported')
230 230
231 231
232 232 def not_found(request):
233 233 return write_response_error(
234 234 HTTPNotFound, 'request path not found')
235 235
236 236
237 237 def lfs_disabled(request):
238 238 return write_response_error(
239 239 HTTPNotImplemented, 'GIT LFS disabled for this repo')
240 240
241 241
242 242 def git_lfs_app(config):
243 243
244 244 # v1 API deprecation endpoint
245 245 config.add_route('lfs_objects',
246 246 '/{repo:.*?[^/]}/info/lfs/objects')
247 247 config.add_view(lfs_objects, route_name='lfs_objects',
248 248 request_method='POST', renderer='json')
249 249
250 250 # locking API
251 251 config.add_route('lfs_objects_lock',
252 252 '/{repo:.*?[^/]}/info/lfs/locks')
253 253 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock',
254 254 request_method=('POST', 'GET'), renderer='json')
255 255
256 256 config.add_route('lfs_objects_lock_verify',
257 257 '/{repo:.*?[^/]}/info/lfs/locks/verify')
258 258 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock_verify',
259 259 request_method=('POST', 'GET'), renderer='json')
260 260
261 261 # batch API
262 262 config.add_route('lfs_objects_batch',
263 263 '/{repo:.*?[^/]}/info/lfs/objects/batch')
264 264 config.add_view(lfs_objects_batch, route_name='lfs_objects_batch',
265 265 request_method='POST', renderer='json')
266 266
267 267 # oid upload/download API
268 268 config.add_route('lfs_objects_oid',
269 269 '/{repo:.*?[^/]}/info/lfs/objects/{oid}')
270 270 config.add_view(lfs_objects_oid_upload, route_name='lfs_objects_oid',
271 271 request_method='PUT', renderer='json')
272 272 config.add_view(lfs_objects_oid_download, route_name='lfs_objects_oid',
273 273 request_method='GET', renderer='json')
274 274
275 275 # verification API
276 276 config.add_route('lfs_objects_verify',
277 277 '/{repo:.*?[^/]}/info/lfs/verify')
278 278 config.add_view(lfs_objects_verify, route_name='lfs_objects_verify',
279 279 request_method='POST', renderer='json')
280 280
281 281 # not found handler for API
282 282 config.add_notfound_view(not_found, renderer='json')
283 283
284 284
285 285 def create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
286 286 config = Configurator()
287 287 if git_lfs_enabled:
288 288 config.include(git_lfs_app)
289 289 config.registry.git_lfs_store_path = git_lfs_store_path
290 290 config.registry.git_lfs_http_scheme = git_lfs_http_scheme
291 291 else:
292 292 # not found handler for API, reporting disabled LFS support
293 293 config.add_notfound_view(lfs_disabled, renderer='json')
294 294
295 295 app = config.make_wsgi_app()
296 296 return app
@@ -1,274 +1,274 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 20 from webtest.app import TestApp as WebObTestApp
21 21
22 22 from vcsserver.lib.ext_json import json
23 from vcsserver.str_utils import safe_bytes
23 from vcsserver.lib.str_utils import safe_bytes
24 24 from vcsserver.git_lfs.app import create_app
25 25 from vcsserver.git_lfs.lib import LFSOidStore
26 26
27 27
28 28 @pytest.fixture(scope='function')
29 29 def git_lfs_app(tmpdir):
30 30 custom_app = WebObTestApp(create_app(
31 31 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
32 32 git_lfs_http_scheme='http'))
33 33 custom_app._store = str(tmpdir)
34 34 return custom_app
35 35
36 36
37 37 @pytest.fixture(scope='function')
38 38 def git_lfs_https_app(tmpdir):
39 39 custom_app = WebObTestApp(create_app(
40 40 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
41 41 git_lfs_http_scheme='https'))
42 42 custom_app._store = str(tmpdir)
43 43 return custom_app
44 44
45 45
46 46 @pytest.fixture()
47 47 def http_auth():
48 48 return {'HTTP_AUTHORIZATION': "Basic XXXXX"}
49 49
50 50
51 51 class TestLFSApplication:
52 52
53 53 def test_app_wrong_path(self, git_lfs_app):
54 54 git_lfs_app.get('/repo/info/lfs/xxx', status=404)
55 55
56 56 def test_app_deprecated_endpoint(self, git_lfs_app):
57 57 response = git_lfs_app.post('/repo/info/lfs/objects', status=501)
58 58 assert response.status_code == 501
59 59 assert json.loads(response.text) == {'message': 'LFS: v1 api not supported'}
60 60
61 61 def test_app_lock_verify_api_not_available(self, git_lfs_app):
62 62 response = git_lfs_app.post('/repo/info/lfs/locks/verify', status=501)
63 63 assert response.status_code == 501
64 64 assert json.loads(response.text) == {
65 65 'message': 'GIT LFS locking api not supported'}
66 66
67 67 def test_app_lock_api_not_available(self, git_lfs_app):
68 68 response = git_lfs_app.post('/repo/info/lfs/locks', status=501)
69 69 assert response.status_code == 501
70 70 assert json.loads(response.text) == {
71 71 'message': 'GIT LFS locking api not supported'}
72 72
73 73 def test_app_batch_api_missing_auth(self, git_lfs_app):
74 74 git_lfs_app.post_json(
75 75 '/repo/info/lfs/objects/batch', params={}, status=403)
76 76
77 77 def test_app_batch_api_unsupported_operation(self, git_lfs_app, http_auth):
78 78 response = git_lfs_app.post_json(
79 79 '/repo/info/lfs/objects/batch', params={}, status=400,
80 80 extra_environ=http_auth)
81 81 assert json.loads(response.text) == {
82 82 'message': 'unsupported operation mode: `None`'}
83 83
84 84 def test_app_batch_api_missing_objects(self, git_lfs_app, http_auth):
85 85 response = git_lfs_app.post_json(
86 86 '/repo/info/lfs/objects/batch', params={'operation': 'download'},
87 87 status=400, extra_environ=http_auth)
88 88 assert json.loads(response.text) == {
89 89 'message': 'missing objects data'}
90 90
91 91 def test_app_batch_api_unsupported_data_in_objects(
92 92 self, git_lfs_app, http_auth):
93 93 params = {'operation': 'download',
94 94 'objects': [{}]}
95 95 response = git_lfs_app.post_json(
96 96 '/repo/info/lfs/objects/batch', params=params, status=400,
97 97 extra_environ=http_auth)
98 98 assert json.loads(response.text) == {
99 99 'message': 'unsupported data in objects'}
100 100
101 101 def test_app_batch_api_download_missing_object(
102 102 self, git_lfs_app, http_auth):
103 103 params = {'operation': 'download',
104 104 'objects': [{'oid': '123', 'size': '1024'}]}
105 105 response = git_lfs_app.post_json(
106 106 '/repo/info/lfs/objects/batch', params=params,
107 107 extra_environ=http_auth)
108 108
109 109 expected_objects = [
110 110 {'authenticated': True,
111 111 'errors': {'error': {
112 112 'code': 404,
113 113 'message': 'object: 123 does not exist in store'}},
114 114 'oid': '123',
115 115 'size': '1024'}
116 116 ]
117 117 assert json.loads(response.text) == {
118 118 'objects': expected_objects, 'transfer': 'basic'}
119 119
120 120 def test_app_batch_api_download(self, git_lfs_app, http_auth):
121 121 oid = '456'
122 122 oid_path = LFSOidStore(oid=oid, repo=None, store_location=git_lfs_app._store).oid_path
123 123 if not os.path.isdir(os.path.dirname(oid_path)):
124 124 os.makedirs(os.path.dirname(oid_path))
125 125 with open(oid_path, 'wb') as f:
126 126 f.write(safe_bytes('OID_CONTENT'))
127 127
128 128 params = {'operation': 'download',
129 129 'objects': [{'oid': oid, 'size': '1024'}]}
130 130 response = git_lfs_app.post_json(
131 131 '/repo/info/lfs/objects/batch', params=params,
132 132 extra_environ=http_auth)
133 133
134 134 expected_objects = [
135 135 {'authenticated': True,
136 136 'actions': {
137 137 'download': {
138 138 'header': {'Authorization': 'Basic XXXXX'},
139 139 'href': 'http://localhost/repo/info/lfs/objects/456'},
140 140 },
141 141 'oid': '456',
142 142 'size': '1024'}
143 143 ]
144 144 assert json.loads(response.text) == {
145 145 'objects': expected_objects, 'transfer': 'basic'}
146 146
147 147 def test_app_batch_api_upload(self, git_lfs_app, http_auth):
148 148 params = {'operation': 'upload',
149 149 'objects': [{'oid': '123', 'size': '1024'}]}
150 150 response = git_lfs_app.post_json(
151 151 '/repo/info/lfs/objects/batch', params=params,
152 152 extra_environ=http_auth)
153 153 expected_objects = [
154 154 {'authenticated': True,
155 155 'actions': {
156 156 'upload': {
157 157 'header': {'Authorization': 'Basic XXXXX',
158 158 'Transfer-Encoding': 'chunked'},
159 159 'href': 'http://localhost/repo/info/lfs/objects/123'},
160 160 'verify': {
161 161 'header': {'Authorization': 'Basic XXXXX'},
162 162 'href': 'http://localhost/repo/info/lfs/verify'}
163 163 },
164 164 'oid': '123',
165 165 'size': '1024'}
166 166 ]
167 167 assert json.loads(response.text) == {
168 168 'objects': expected_objects, 'transfer': 'basic'}
169 169
170 170 def test_app_batch_api_upload_for_https(self, git_lfs_https_app, http_auth):
171 171 params = {'operation': 'upload',
172 172 'objects': [{'oid': '123', 'size': '1024'}]}
173 173 response = git_lfs_https_app.post_json(
174 174 '/repo/info/lfs/objects/batch', params=params,
175 175 extra_environ=http_auth)
176 176 expected_objects = [
177 177 {'authenticated': True,
178 178 'actions': {
179 179 'upload': {
180 180 'header': {'Authorization': 'Basic XXXXX',
181 181 'Transfer-Encoding': 'chunked'},
182 182 'href': 'https://localhost/repo/info/lfs/objects/123'},
183 183 'verify': {
184 184 'header': {'Authorization': 'Basic XXXXX'},
185 185 'href': 'https://localhost/repo/info/lfs/verify'}
186 186 },
187 187 'oid': '123',
188 188 'size': '1024'}
189 189 ]
190 190 assert json.loads(response.text) == {
191 191 'objects': expected_objects, 'transfer': 'basic'}
192 192
193 193 def test_app_verify_api_missing_data(self, git_lfs_app):
194 194 params = {'oid': 'missing'}
195 195 response = git_lfs_app.post_json(
196 196 '/repo/info/lfs/verify', params=params,
197 197 status=400)
198 198
199 199 assert json.loads(response.text) == {
200 200 'message': 'missing oid and size in request data'}
201 201
202 202 def test_app_verify_api_missing_obj(self, git_lfs_app):
203 203 params = {'oid': 'missing', 'size': '1024'}
204 204 response = git_lfs_app.post_json(
205 205 '/repo/info/lfs/verify', params=params,
206 206 status=404)
207 207
208 208 assert json.loads(response.text) == {
209 209 'message': 'oid `missing` does not exists in store'}
210 210
211 211 def test_app_verify_api_size_mismatch(self, git_lfs_app):
212 212 oid = 'existing'
213 213 oid_path = LFSOidStore(oid=oid, repo=None, store_location=git_lfs_app._store).oid_path
214 214 if not os.path.isdir(os.path.dirname(oid_path)):
215 215 os.makedirs(os.path.dirname(oid_path))
216 216 with open(oid_path, 'wb') as f:
217 217 f.write(safe_bytes('OID_CONTENT'))
218 218
219 219 params = {'oid': oid, 'size': '1024'}
220 220 response = git_lfs_app.post_json(
221 221 '/repo/info/lfs/verify', params=params, status=422)
222 222
223 223 assert json.loads(response.text) == {
224 224 'message': 'requested file size mismatch '
225 225 'store size:11 requested:1024'}
226 226
227 227 def test_app_verify_api(self, git_lfs_app):
228 228 oid = 'existing'
229 229 oid_path = LFSOidStore(oid=oid, repo=None, store_location=git_lfs_app._store).oid_path
230 230 if not os.path.isdir(os.path.dirname(oid_path)):
231 231 os.makedirs(os.path.dirname(oid_path))
232 232 with open(oid_path, 'wb') as f:
233 233 f.write(safe_bytes('OID_CONTENT'))
234 234
235 235 params = {'oid': oid, 'size': 11}
236 236 response = git_lfs_app.post_json(
237 237 '/repo/info/lfs/verify', params=params)
238 238
239 239 assert json.loads(response.text) == {
240 240 'message': {'size': 'ok', 'in_store': 'ok'}}
241 241
242 242 def test_app_download_api_oid_not_existing(self, git_lfs_app):
243 243 oid = 'missing'
244 244
245 245 response = git_lfs_app.get(
246 246 '/repo/info/lfs/objects/{oid}'.format(oid=oid), status=404)
247 247
248 248 assert json.loads(response.text) == {
249 249 'message': 'requested file with oid `missing` not found in store'}
250 250
251 251 def test_app_download_api(self, git_lfs_app):
252 252 oid = 'existing'
253 253 oid_path = LFSOidStore(oid=oid, repo=None, store_location=git_lfs_app._store).oid_path
254 254 if not os.path.isdir(os.path.dirname(oid_path)):
255 255 os.makedirs(os.path.dirname(oid_path))
256 256 with open(oid_path, 'wb') as f:
257 257 f.write(safe_bytes('OID_CONTENT'))
258 258
259 259 response = git_lfs_app.get(
260 260 '/repo/info/lfs/objects/{oid}'.format(oid=oid))
261 261 assert response
262 262
263 263 def test_app_upload(self, git_lfs_app):
264 264 oid = 'uploaded'
265 265
266 266 response = git_lfs_app.put(
267 267 '/repo/info/lfs/objects/{oid}'.format(oid=oid), params='CONTENT')
268 268
269 269 assert json.loads(response.text) == {'upload': 'ok'}
270 270
271 271 # verify that we actually wrote that OID
272 272 oid_path = LFSOidStore(oid=oid, repo=None, store_location=git_lfs_app._store).oid_path
273 273 assert os.path.isfile(oid_path)
274 274 assert 'CONTENT' == open(oid_path).read()
@@ -1,142 +1,142 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 from vcsserver.str_utils import safe_bytes
20 from vcsserver.lib.str_utils import safe_bytes
21 21 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
22 22
23 23
24 24 @pytest.fixture()
25 25 def lfs_store(tmpdir):
26 26 repo = 'test'
27 27 oid = '123456789'
28 28 store = LFSOidStore(oid=oid, repo=repo, store_location=str(tmpdir))
29 29 return store
30 30
31 31
32 32 @pytest.fixture()
33 33 def oid_handler(lfs_store):
34 34 store = lfs_store
35 35 repo = store.repo
36 36 oid = store.oid
37 37
38 38 oid_handler = OidHandler(
39 39 store=store, repo_name=repo, auth=('basic', 'xxxx'),
40 40 oid=oid,
41 41 obj_size='1024', obj_data={}, obj_href='http://localhost/handle_oid',
42 42 obj_verify_href='http://localhost/verify')
43 43 return oid_handler
44 44
45 45
46 46 class TestOidHandler:
47 47
48 48 @pytest.mark.parametrize('exec_action', [
49 49 'download',
50 50 'upload',
51 51 ])
52 52 def test_exec_action(self, exec_action, oid_handler):
53 53 handler = oid_handler.exec_operation(exec_action)
54 54 assert handler
55 55
56 56 def test_exec_action_undefined(self, oid_handler):
57 57 with pytest.raises(AttributeError):
58 58 oid_handler.exec_operation('wrong')
59 59
60 60 def test_download_oid_not_existing(self, oid_handler):
61 61 response, has_errors = oid_handler.exec_operation('download')
62 62
63 63 assert response is None
64 64 assert has_errors['error'] == {
65 65 'code': 404,
66 66 'message': 'object: 123456789 does not exist in store'}
67 67
68 68 def test_download_oid(self, oid_handler):
69 69 store = oid_handler.get_store()
70 70 if not os.path.isdir(os.path.dirname(store.oid_path)):
71 71 os.makedirs(os.path.dirname(store.oid_path))
72 72
73 73 with open(store.oid_path, 'wb') as f:
74 74 f.write(safe_bytes('CONTENT'))
75 75
76 76 response, has_errors = oid_handler.exec_operation('download')
77 77
78 78 assert has_errors is None
79 79 assert response['download'] == {
80 80 'header': {'Authorization': 'basic xxxx'},
81 81 'href': 'http://localhost/handle_oid'
82 82 }
83 83
84 84 def test_upload_oid_that_exists(self, oid_handler):
85 85 store = oid_handler.get_store()
86 86 if not os.path.isdir(os.path.dirname(store.oid_path)):
87 87 os.makedirs(os.path.dirname(store.oid_path))
88 88
89 89 with open(store.oid_path, 'wb') as f:
90 90 f.write(safe_bytes('CONTENT'))
91 91 oid_handler.obj_size = 7
92 92 response, has_errors = oid_handler.exec_operation('upload')
93 93 assert has_errors is None
94 94 assert response is None
95 95
96 96 def test_upload_oid_that_exists_but_has_wrong_size(self, oid_handler):
97 97 store = oid_handler.get_store()
98 98 if not os.path.isdir(os.path.dirname(store.oid_path)):
99 99 os.makedirs(os.path.dirname(store.oid_path))
100 100
101 101 with open(store.oid_path, 'wb') as f:
102 102 f.write(safe_bytes('CONTENT'))
103 103
104 104 oid_handler.obj_size = 10240
105 105 response, has_errors = oid_handler.exec_operation('upload')
106 106 assert has_errors is None
107 107 assert response['upload'] == {
108 108 'header': {'Authorization': 'basic xxxx',
109 109 'Transfer-Encoding': 'chunked'},
110 110 'href': 'http://localhost/handle_oid',
111 111 }
112 112
113 113 def test_upload_oid(self, oid_handler):
114 114 response, has_errors = oid_handler.exec_operation('upload')
115 115 assert has_errors is None
116 116 assert response['upload'] == {
117 117 'header': {'Authorization': 'basic xxxx',
118 118 'Transfer-Encoding': 'chunked'},
119 119 'href': 'http://localhost/handle_oid'
120 120 }
121 121
122 122
123 123 class TestLFSStore:
124 124 def test_write_oid(self, lfs_store):
125 125 oid_location = lfs_store.oid_path
126 126
127 127 assert not os.path.isfile(oid_location)
128 128
129 129 engine = lfs_store.get_engine(mode='wb')
130 130 with engine as f:
131 131 f.write(safe_bytes('CONTENT'))
132 132
133 133 assert os.path.isfile(oid_location)
134 134
135 135 def test_detect_has_oid(self, lfs_store):
136 136
137 137 assert lfs_store.has_oid() is False
138 138 engine = lfs_store.get_engine(mode='wb')
139 139 with engine as f:
140 140 f.write(safe_bytes('CONTENT'))
141 141
142 142 assert lfs_store.has_oid() is True
@@ -1,92 +1,92 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """
19 19 Mercurial libs compatibility
20 20 """
21 21
22 22 import mercurial
23 23 from mercurial import demandimport
24 24
25 25 # patch demandimport, due to bug in mercurial when it always triggers
26 26 # demandimport.enable()
27 from vcsserver.str_utils import safe_bytes
27 from vcsserver.lib.str_utils import safe_bytes
28 28
29 29 demandimport.enable = lambda *args, **kwargs: 1
30 30
31 31 from mercurial import ui
32 32 from mercurial import patch
33 33 from mercurial import config
34 34 from mercurial import extensions
35 35 from mercurial import scmutil
36 36 from mercurial import archival
37 37 from mercurial import discovery
38 38 from mercurial import unionrepo
39 39 from mercurial import localrepo
40 40 from mercurial import merge as hg_merge
41 41 from mercurial import subrepo
42 42 from mercurial import subrepoutil
43 43 from mercurial import tags as hg_tag
44 44 from mercurial import util as hgutil
45 45 from mercurial.commands import clone, pull
46 46 from mercurial.node import nullid
47 47 from mercurial.context import memctx, memfilectx
48 48 from mercurial.error import (
49 49 LookupError, RepoError, RepoLookupError, Abort, InterventionRequired,
50 50 RequirementError, ProgrammingError)
51 51 from mercurial.hgweb import hgweb_mod
52 52 from mercurial.localrepo import instance
53 53 from mercurial.match import match, alwaysmatcher, patternmatcher
54 54 from mercurial.mdiff import diffopts
55 55 from mercurial.node import bin, hex
56 56 from mercurial.encoding import tolocal
57 57 from mercurial.discovery import findcommonoutgoing
58 58 from mercurial.hg import peer
59 59 from mercurial.httppeer import make_peer
60 60 from mercurial.utils.urlutil import url as hg_url
61 61 from mercurial.scmutil import revrange, revsymbol
62 62 from mercurial.node import nullrev
63 63 from mercurial import exchange
64 64 from hgext import largefiles
65 65
66 66 # those authnadlers are patched for python 2.6.5 bug an
67 67 # infinit looping when given invalid resources
68 68 from mercurial.url import httpbasicauthhandler, httpdigestauthhandler
69 69
70 70 # hg strip is in core now
71 71 from mercurial import strip as hgext_strip
72 72
73 73
74 74 def get_ctx(repo, ref):
75 75 if not isinstance(ref, int):
76 76 ref = safe_bytes(ref)
77 77
78 78 try:
79 79 ctx = repo[ref]
80 80 return ctx
81 81 except (ProgrammingError, TypeError):
82 82 # we're unable to find the rev using a regular lookup, we fallback
83 83 # to slower, but backward compat revsymbol usage
84 84 pass
85 85 except (LookupError, RepoLookupError):
86 86 # Similar case as above but only for refs that are not numeric
87 87 if isinstance(ref, int):
88 88 raise
89 89
90 90 ctx = revsymbol(repo, ref)
91 91
92 92 return ctx
@@ -1,230 +1,230 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18
19 19 import re
20 20 import os
21 21 import sys
22 22 import datetime
23 23 import logging
24 24 import pkg_resources
25 25
26 26 import vcsserver
27 27 import vcsserver.settings
28 from vcsserver.str_utils import safe_bytes
28 from vcsserver.lib.str_utils import safe_bytes
29 29
30 30 log = logging.getLogger(__name__)
31 31
32 32 HOOKS_DIR_MODE = 0o755
33 33 HOOKS_FILE_MODE = 0o755
34 34
35 35
36 36 def set_permissions_if_needed(path_to_check, perms: oct):
37 37 # Get current permissions
38 38 current_permissions = os.stat(path_to_check).st_mode & 0o777 # Extract permission bits
39 39
40 40 # Check if current permissions are lower than required
41 41 if current_permissions < int(perms):
42 42 # Change the permissions if they are lower than required
43 43 os.chmod(path_to_check, perms)
44 44
45 45
46 46 def get_git_hooks_path(repo_path, bare):
47 47 hooks_path = os.path.join(repo_path, 'hooks')
48 48 if not bare:
49 49 hooks_path = os.path.join(repo_path, '.git', 'hooks')
50 50
51 51 return hooks_path
52 52
53 53
54 54 def install_git_hooks(repo_path, bare, executable=None, force_create=False):
55 55 """
56 56 Creates a RhodeCode hook inside a git repository
57 57
58 58 :param repo_path: path to repository
59 59 :param bare: defines if repository is considered a bare git repo
60 60 :param executable: binary executable to put in the hooks
61 61 :param force_create: Creates even if the same name hook exists
62 62 """
63 63 executable = executable or sys.executable
64 64 hooks_path = get_git_hooks_path(repo_path, bare)
65 65
66 66 # we always call it to ensure dir exists and it has a proper mode
67 67 if not os.path.exists(hooks_path):
68 68 # If it doesn't exist, create a new directory with the specified mode
69 69 os.makedirs(hooks_path, mode=HOOKS_DIR_MODE, exist_ok=True)
70 70 # If it exists, change the directory's mode to the specified mode
71 71 set_permissions_if_needed(hooks_path, perms=HOOKS_DIR_MODE)
72 72
73 73 tmpl_post = pkg_resources.resource_string(
74 74 'vcsserver', '/'.join(
75 75 ('hook_utils', 'hook_templates', 'git_post_receive.py.tmpl')))
76 76 tmpl_pre = pkg_resources.resource_string(
77 77 'vcsserver', '/'.join(
78 78 ('hook_utils', 'hook_templates', 'git_pre_receive.py.tmpl')))
79 79
80 80 path = '' # not used for now
81 81 timestamp = datetime.datetime.utcnow().isoformat()
82 82
83 83 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
84 84 log.debug('Installing git hook in repo %s', repo_path)
85 85 _hook_file = os.path.join(hooks_path, f'{h_type}-receive')
86 86 _rhodecode_hook = check_rhodecode_hook(_hook_file)
87 87
88 88 if _rhodecode_hook or force_create:
89 89 log.debug('writing git %s hook file at %s !', h_type, _hook_file)
90 90 try:
91 91 with open(_hook_file, 'wb') as f:
92 92 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.get_version()))
93 93 template = template.replace(b'_DATE_', safe_bytes(timestamp))
94 94 template = template.replace(b'_ENV_', safe_bytes(executable))
95 95 template = template.replace(b'_PATH_', safe_bytes(path))
96 96 f.write(template)
97 97 set_permissions_if_needed(_hook_file, perms=HOOKS_FILE_MODE)
98 98 except OSError:
99 99 log.exception('error writing hook file %s', _hook_file)
100 100 else:
101 101 log.debug('skipping writing hook file')
102 102
103 103 return True
104 104
105 105
106 106 def get_svn_hooks_path(repo_path):
107 107 hooks_path = os.path.join(repo_path, 'hooks')
108 108
109 109 return hooks_path
110 110
111 111
112 112 def install_svn_hooks(repo_path, executable=None, force_create=False):
113 113 """
114 114 Creates RhodeCode hooks inside a svn repository
115 115
116 116 :param repo_path: path to repository
117 117 :param executable: binary executable to put in the hooks
118 118 :param force_create: Create even if same name hook exists
119 119 """
120 120 executable = executable or sys.executable
121 121 hooks_path = get_svn_hooks_path(repo_path)
122 122 if not os.path.isdir(hooks_path):
123 123 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
124 124
125 125 tmpl_post = pkg_resources.resource_string(
126 126 'vcsserver', '/'.join(
127 127 ('hook_utils', 'hook_templates', 'svn_post_commit_hook.py.tmpl')))
128 128 tmpl_pre = pkg_resources.resource_string(
129 129 'vcsserver', '/'.join(
130 130 ('hook_utils', 'hook_templates', 'svn_pre_commit_hook.py.tmpl')))
131 131
132 132 path = '' # not used for now
133 133 timestamp = datetime.datetime.utcnow().isoformat()
134 134
135 135 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
136 136 log.debug('Installing svn hook in repo %s', repo_path)
137 137 _hook_file = os.path.join(hooks_path, f'{h_type}-commit')
138 138 _rhodecode_hook = check_rhodecode_hook(_hook_file)
139 139
140 140 if _rhodecode_hook or force_create:
141 141 log.debug('writing svn %s hook file at %s !', h_type, _hook_file)
142 142
143 143 env_expand = str([
144 144 ('RC_CORE_BINARY_DIR', vcsserver.settings.BINARY_DIR),
145 145 ('RC_GIT_EXECUTABLE', vcsserver.settings.GIT_EXECUTABLE()),
146 146 ('RC_SVN_EXECUTABLE', vcsserver.settings.SVN_EXECUTABLE()),
147 147 ('RC_SVNLOOK_EXECUTABLE', vcsserver.settings.SVNLOOK_EXECUTABLE()),
148 148
149 149 ])
150 150 try:
151 151 with open(_hook_file, 'wb') as f:
152 152 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.get_version()))
153 153 template = template.replace(b'_DATE_', safe_bytes(timestamp))
154 154 template = template.replace(b'_OS_EXPAND_', safe_bytes(env_expand))
155 155 template = template.replace(b'_ENV_', safe_bytes(executable))
156 156 template = template.replace(b'_PATH_', safe_bytes(path))
157 157
158 158 f.write(template)
159 159 os.chmod(_hook_file, 0o755)
160 160 except OSError:
161 161 log.exception('error writing hook file %s', _hook_file)
162 162 else:
163 163 log.debug('skipping writing hook file')
164 164
165 165 return True
166 166
167 167
168 168 def get_version_from_hook(hook_path):
169 169 version = b''
170 170 hook_content = read_hook_content(hook_path)
171 171 matches = re.search(rb'RC_HOOK_VER\s*=\s*(.*)', hook_content)
172 172 if matches:
173 173 try:
174 174 version = matches.groups()[0]
175 175 log.debug('got version %s from hooks.', version)
176 176 except Exception:
177 177 log.exception("Exception while reading the hook version.")
178 178 return version.replace(b"'", b"")
179 179
180 180
181 181 def check_rhodecode_hook(hook_path):
182 182 """
183 183 Check if the hook was created by RhodeCode
184 184 """
185 185 if not os.path.exists(hook_path):
186 186 return True
187 187
188 188 log.debug('hook exists, checking if it is from RhodeCode')
189 189
190 190 version = get_version_from_hook(hook_path)
191 191 if version:
192 192 return True
193 193
194 194 return False
195 195
196 196
197 197 def read_hook_content(hook_path) -> bytes:
198 198 content = b''
199 199 if os.path.isfile(hook_path):
200 200 with open(hook_path, 'rb') as f:
201 201 content = f.read()
202 202 return content
203 203
204 204
205 205 def get_git_pre_hook_version(repo_path, bare):
206 206 hooks_path = get_git_hooks_path(repo_path, bare)
207 207 _hook_file = os.path.join(hooks_path, 'pre-receive')
208 208 version = get_version_from_hook(_hook_file)
209 209 return version
210 210
211 211
212 212 def get_git_post_hook_version(repo_path, bare):
213 213 hooks_path = get_git_hooks_path(repo_path, bare)
214 214 _hook_file = os.path.join(hooks_path, 'post-receive')
215 215 version = get_version_from_hook(_hook_file)
216 216 return version
217 217
218 218
219 219 def get_svn_pre_hook_version(repo_path):
220 220 hooks_path = get_svn_hooks_path(repo_path)
221 221 _hook_file = os.path.join(hooks_path, 'pre-commit')
222 222 version = get_version_from_hook(_hook_file)
223 223 return version
224 224
225 225
226 226 def get_svn_post_hook_version(repo_path):
227 227 hooks_path = get_svn_hooks_path(repo_path)
228 228 _hook_file = os.path.join(hooks_path, 'post-commit')
229 229 version = get_version_from_hook(_hook_file)
230 230 return version
@@ -1,832 +1,832 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21 import logging
22 22 import collections
23 23 import base64
24 24 import msgpack
25 25 import dataclasses
26 26 import pygit2
27 27
28 28 import http.client
29 29 from celery import Celery
30 30
31 31 import mercurial.scmutil
32 32 import mercurial.node
33 33
34 34 from vcsserver.lib.ext_json import json
35 35 from vcsserver import exceptions, subprocessio, settings
36 from vcsserver.str_utils import ascii_str, safe_str
36 from vcsserver.lib.str_utils import ascii_str, safe_str
37 37 from vcsserver.remote.git_remote import Repository
38 38
39 39 celery_app = Celery('__vcsserver__')
40 40 log = logging.getLogger(__name__)
41 41
42 42
43 43 class HooksHttpClient:
44 44 proto = 'msgpack.v1'
45 45 connection = None
46 46
47 47 def __init__(self, hooks_uri):
48 48 self.hooks_uri = hooks_uri
49 49
50 50 def __repr__(self):
51 51 return f'{self.__class__}(hook_uri={self.hooks_uri}, proto={self.proto})'
52 52
53 53 def __call__(self, method, extras):
54 54 connection = http.client.HTTPConnection(self.hooks_uri)
55 55 # binary msgpack body
56 56 headers, body = self._serialize(method, extras)
57 57 log.debug('Doing a new hooks call using HTTPConnection to %s', self.hooks_uri)
58 58
59 59 try:
60 60 try:
61 61 connection.request('POST', '/', body, headers)
62 62 except Exception as error:
63 63 log.error('Hooks calling Connection failed on %s, org error: %s', connection.__dict__, error)
64 64 raise
65 65
66 66 response = connection.getresponse()
67 67 try:
68 68 return msgpack.load(response)
69 69 except Exception:
70 70 response_data = response.read()
71 71 log.exception('Failed to decode hook response json data. '
72 72 'response_code:%s, raw_data:%s',
73 73 response.status, response_data)
74 74 raise
75 75 finally:
76 76 connection.close()
77 77
78 78 @classmethod
79 79 def _serialize(cls, hook_name, extras):
80 80 data = {
81 81 'method': hook_name,
82 82 'extras': extras
83 83 }
84 84 headers = {
85 85 "rc-hooks-protocol": cls.proto,
86 86 "Connection": "keep-alive"
87 87 }
88 88 return headers, msgpack.packb(data)
89 89
90 90
91 91 class HooksCeleryClient:
92 92 TASK_TIMEOUT = 60 # time in seconds
93 93
94 94 def __init__(self, queue, backend):
95 95 celery_app.config_from_object({
96 96 'broker_url': queue, 'result_backend': backend,
97 97 'broker_connection_retry_on_startup': True,
98 98 'task_serializer': 'json',
99 99 'accept_content': ['json', 'msgpack'],
100 100 'result_serializer': 'json',
101 101 'result_accept_content': ['json', 'msgpack']
102 102 })
103 103 self.celery_app = celery_app
104 104
105 105 def __call__(self, method, extras):
106 106 inquired_task = self.celery_app.signature(
107 107 f'rhodecode.lib.celerylib.tasks.{method}'
108 108 )
109 109 return inquired_task.delay(extras).get(timeout=self.TASK_TIMEOUT)
110 110
111 111
112 112 class HooksShadowRepoClient:
113 113
114 114 def __call__(self, hook_name, extras):
115 115 return {'output': '', 'status': 0}
116 116
117 117
118 118 class RemoteMessageWriter:
119 119 """Writer base class."""
120 120 def write(self, message):
121 121 raise NotImplementedError()
122 122
123 123
124 124 class HgMessageWriter(RemoteMessageWriter):
125 125 """Writer that knows how to send messages to mercurial clients."""
126 126
127 127 def __init__(self, ui):
128 128 self.ui = ui
129 129
130 130 def write(self, message: str):
131 131 # TODO: Check why the quiet flag is set by default.
132 132 old = self.ui.quiet
133 133 self.ui.quiet = False
134 134 self.ui.status(message.encode('utf-8'))
135 135 self.ui.quiet = old
136 136
137 137
138 138 class GitMessageWriter(RemoteMessageWriter):
139 139 """Writer that knows how to send messages to git clients."""
140 140
141 141 def __init__(self, stdout=None):
142 142 self.stdout = stdout or sys.stdout
143 143
144 144 def write(self, message: str):
145 145 self.stdout.write(message)
146 146
147 147
148 148 class SvnMessageWriter(RemoteMessageWriter):
149 149 """Writer that knows how to send messages to svn clients."""
150 150
151 151 def __init__(self, stderr=None):
152 152 # SVN needs data sent to stderr for back-to-client messaging
153 153 self.stderr = stderr or sys.stderr
154 154
155 155 def write(self, message):
156 156 self.stderr.write(message)
157 157
158 158
159 159 def _handle_exception(result):
160 160 exception_class = result.get('exception')
161 161 exception_traceback = result.get('exception_traceback')
162 162 log.debug('Handling hook-call exception: %s', exception_class)
163 163
164 164 if exception_traceback:
165 165 log.error('Got traceback from remote call:%s', exception_traceback)
166 166
167 167 if exception_class == 'HTTPLockedRC':
168 168 raise exceptions.RepositoryLockedException()(*result['exception_args'])
169 169 elif exception_class == 'HTTPBranchProtected':
170 170 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
171 171 elif exception_class == 'RepositoryError':
172 172 raise exceptions.VcsException()(*result['exception_args'])
173 173 elif exception_class:
174 174 raise Exception(
175 175 f"""Got remote exception "{exception_class}" with args "{result['exception_args']}" """
176 176 )
177 177
178 178
179 179 def _get_hooks_client(extras):
180 180 hooks_uri = extras.get('hooks_uri')
181 181 task_queue = extras.get('task_queue')
182 182 task_backend = extras.get('task_backend')
183 183 is_shadow_repo = extras.get('is_shadow_repo')
184 184
185 185 if hooks_uri:
186 186 return HooksHttpClient(hooks_uri)
187 187 elif task_queue and task_backend:
188 188 return HooksCeleryClient(task_queue, task_backend)
189 189 elif is_shadow_repo:
190 190 return HooksShadowRepoClient()
191 191 else:
192 192 raise Exception("Hooks client not found!")
193 193
194 194
195 195 def _call_hook(hook_name, extras, writer):
196 196 hooks_client = _get_hooks_client(extras)
197 197 log.debug('Hooks, using client:%s', hooks_client)
198 198 result = hooks_client(hook_name, extras)
199 199 log.debug('Hooks got result: %s', result)
200 200 _handle_exception(result)
201 201 writer.write(result['output'])
202 202
203 203 return result['status']
204 204
205 205
206 206 def _extras_from_ui(ui):
207 207 hook_data = ui.config(b'rhodecode', b'RC_SCM_DATA')
208 208 if not hook_data:
209 209 # maybe it's inside environ ?
210 210 env_hook_data = os.environ.get('RC_SCM_DATA')
211 211 if env_hook_data:
212 212 hook_data = env_hook_data
213 213
214 214 extras = {}
215 215 if hook_data:
216 216 extras = json.loads(hook_data)
217 217 return extras
218 218
219 219
220 220 def _rev_range_hash(repo, node, check_heads=False):
221 221 from vcsserver.hgcompat import get_ctx
222 222
223 223 commits = []
224 224 revs = []
225 225 start = get_ctx(repo, node).rev()
226 226 end = len(repo)
227 227 for rev in range(start, end):
228 228 revs.append(rev)
229 229 ctx = get_ctx(repo, rev)
230 230 commit_id = ascii_str(mercurial.node.hex(ctx.node()))
231 231 branch = safe_str(ctx.branch())
232 232 commits.append((commit_id, branch))
233 233
234 234 parent_heads = []
235 235 if check_heads:
236 236 parent_heads = _check_heads(repo, start, end, revs)
237 237 return commits, parent_heads
238 238
239 239
240 240 def _check_heads(repo, start, end, commits):
241 241 from vcsserver.hgcompat import get_ctx
242 242 changelog = repo.changelog
243 243 parents = set()
244 244
245 245 for new_rev in commits:
246 246 for p in changelog.parentrevs(new_rev):
247 247 if p == mercurial.node.nullrev:
248 248 continue
249 249 if p < start:
250 250 parents.add(p)
251 251
252 252 for p in parents:
253 253 branch = get_ctx(repo, p).branch()
254 254 # The heads descending from that parent, on the same branch
255 255 parent_heads = {p}
256 256 reachable = {p}
257 257 for x in range(p + 1, end):
258 258 if get_ctx(repo, x).branch() != branch:
259 259 continue
260 260 for pp in changelog.parentrevs(x):
261 261 if pp in reachable:
262 262 reachable.add(x)
263 263 parent_heads.discard(pp)
264 264 parent_heads.add(x)
265 265 # More than one head? Suggest merging
266 266 if len(parent_heads) > 1:
267 267 return list(parent_heads)
268 268
269 269 return []
270 270
271 271
272 272 def _get_git_env():
273 273 env = {}
274 274 for k, v in os.environ.items():
275 275 if k.startswith('GIT'):
276 276 env[k] = v
277 277
278 278 # serialized version
279 279 return [(k, v) for k, v in env.items()]
280 280
281 281
282 282 def _get_hg_env(old_rev, new_rev, txnid, repo_path):
283 283 env = {}
284 284 for k, v in os.environ.items():
285 285 if k.startswith('HG'):
286 286 env[k] = v
287 287
288 288 env['HG_NODE'] = old_rev
289 289 env['HG_NODE_LAST'] = new_rev
290 290 env['HG_TXNID'] = txnid
291 291 env['HG_PENDING'] = repo_path
292 292
293 293 return [(k, v) for k, v in env.items()]
294 294
295 295
296 296 def _fix_hooks_executables(ini_path=''):
297 297 """
298 298 This is a trick to set proper settings.EXECUTABLE paths for certain execution patterns
299 299 especially for subversion where hooks strip entire env, and calling just 'svn' command will most likely fail
300 300 because svn is not on PATH
301 301 """
302 302 from vcsserver.http_main import sanitize_settings_and_apply_defaults
303 303 from vcsserver.lib.config_utils import get_app_config_lightweight
304 304
305 305 core_binary_dir = settings.BINARY_DIR or '/usr/local/bin/rhodecode_bin/vcs_bin'
306 306 if ini_path:
307 307
308 308 ini_settings = get_app_config_lightweight(ini_path)
309 309 ini_settings = sanitize_settings_and_apply_defaults({'__file__': ini_path}, ini_settings)
310 310 core_binary_dir = ini_settings['core.binary_dir']
311 311
312 312 settings.BINARY_DIR = core_binary_dir
313 313
314 314
315 315 def repo_size(ui, repo, **kwargs):
316 316 extras = _extras_from_ui(ui)
317 317 return _call_hook('repo_size', extras, HgMessageWriter(ui))
318 318
319 319
320 320 def pre_pull(ui, repo, **kwargs):
321 321 extras = _extras_from_ui(ui)
322 322 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
323 323
324 324
325 325 def pre_pull_ssh(ui, repo, **kwargs):
326 326 extras = _extras_from_ui(ui)
327 327 if extras and extras.get('SSH'):
328 328 return pre_pull(ui, repo, **kwargs)
329 329 return 0
330 330
331 331
332 332 def post_pull(ui, repo, **kwargs):
333 333 extras = _extras_from_ui(ui)
334 334 return _call_hook('post_pull', extras, HgMessageWriter(ui))
335 335
336 336
337 337 def post_pull_ssh(ui, repo, **kwargs):
338 338 extras = _extras_from_ui(ui)
339 339 if extras and extras.get('SSH'):
340 340 return post_pull(ui, repo, **kwargs)
341 341 return 0
342 342
343 343
344 344 def pre_push(ui, repo, node=None, **kwargs):
345 345 """
346 346 Mercurial pre_push hook
347 347 """
348 348 extras = _extras_from_ui(ui)
349 349 detect_force_push = extras.get('detect_force_push')
350 350
351 351 rev_data = []
352 352 hook_type: str = safe_str(kwargs.get('hooktype'))
353 353
354 354 if node and hook_type == 'pretxnchangegroup':
355 355 branches = collections.defaultdict(list)
356 356 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
357 357 for commit_id, branch in commits:
358 358 branches[branch].append(commit_id)
359 359
360 360 for branch, commits in branches.items():
361 361 old_rev = ascii_str(kwargs.get('node_last')) or commits[0]
362 362 rev_data.append({
363 363 'total_commits': len(commits),
364 364 'old_rev': old_rev,
365 365 'new_rev': commits[-1],
366 366 'ref': '',
367 367 'type': 'branch',
368 368 'name': branch,
369 369 })
370 370
371 371 for push_ref in rev_data:
372 372 push_ref['multiple_heads'] = _heads
373 373
374 374 repo_path = os.path.join(
375 375 extras.get('repo_store', ''), extras.get('repository', ''))
376 376 push_ref['hg_env'] = _get_hg_env(
377 377 old_rev=push_ref['old_rev'],
378 378 new_rev=push_ref['new_rev'], txnid=ascii_str(kwargs.get('txnid')),
379 379 repo_path=repo_path)
380 380
381 381 extras['hook_type'] = hook_type or 'pre_push'
382 382 extras['commit_ids'] = rev_data
383 383
384 384 return _call_hook('pre_push', extras, HgMessageWriter(ui))
385 385
386 386
387 387 def pre_push_ssh(ui, repo, node=None, **kwargs):
388 388 extras = _extras_from_ui(ui)
389 389 if extras.get('SSH'):
390 390 return pre_push(ui, repo, node, **kwargs)
391 391
392 392 return 0
393 393
394 394
395 395 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
396 396 """
397 397 Mercurial pre_push hook for SSH
398 398 """
399 399 extras = _extras_from_ui(ui)
400 400 if extras.get('SSH'):
401 401 permission = extras['SSH_PERMISSIONS']
402 402
403 403 if 'repository.write' == permission or 'repository.admin' == permission:
404 404 return 0
405 405
406 406 # non-zero ret code
407 407 return 1
408 408
409 409 return 0
410 410
411 411
412 412 def post_push(ui, repo, node, **kwargs):
413 413 """
414 414 Mercurial post_push hook
415 415 """
416 416 extras = _extras_from_ui(ui)
417 417
418 418 commit_ids = []
419 419 branches = []
420 420 bookmarks = []
421 421 tags = []
422 422 hook_type: str = safe_str(kwargs.get('hooktype'))
423 423
424 424 commits, _heads = _rev_range_hash(repo, node)
425 425 for commit_id, branch in commits:
426 426 commit_ids.append(commit_id)
427 427 if branch not in branches:
428 428 branches.append(branch)
429 429
430 430 if hasattr(ui, '_rc_pushkey_bookmarks'):
431 431 bookmarks = ui._rc_pushkey_bookmarks
432 432
433 433 extras['hook_type'] = hook_type or 'post_push'
434 434 extras['commit_ids'] = commit_ids
435 435
436 436 extras['new_refs'] = {
437 437 'branches': branches,
438 438 'bookmarks': bookmarks,
439 439 'tags': tags
440 440 }
441 441
442 442 return _call_hook('post_push', extras, HgMessageWriter(ui))
443 443
444 444
445 445 def post_push_ssh(ui, repo, node, **kwargs):
446 446 """
447 447 Mercurial post_push hook for SSH
448 448 """
449 449 if _extras_from_ui(ui).get('SSH'):
450 450 return post_push(ui, repo, node, **kwargs)
451 451 return 0
452 452
453 453
454 454 def key_push(ui, repo, **kwargs):
455 455 from vcsserver.hgcompat import get_ctx
456 456
457 457 if kwargs['new'] != b'0' and kwargs['namespace'] == b'bookmarks':
458 458 # store new bookmarks in our UI object propagated later to post_push
459 459 ui._rc_pushkey_bookmarks = get_ctx(repo, kwargs['key']).bookmarks()
460 460 return
461 461
462 462
463 463 # backward compat
464 464 log_pull_action = post_pull
465 465
466 466 # backward compat
467 467 log_push_action = post_push
468 468
469 469
470 470 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
471 471 """
472 472 Old hook name: keep here for backward compatibility.
473 473
474 474 This is only required when the installed git hooks are not upgraded.
475 475 """
476 476 pass
477 477
478 478
479 479 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
480 480 """
481 481 Old hook name: keep here for backward compatibility.
482 482
483 483 This is only required when the installed git hooks are not upgraded.
484 484 """
485 485 pass
486 486
487 487
488 488 @dataclasses.dataclass
489 489 class HookResponse:
490 490 status: int
491 491 output: str
492 492
493 493
494 494 def git_pre_pull(extras) -> HookResponse:
495 495 """
496 496 Pre pull hook.
497 497
498 498 :param extras: dictionary containing the keys defined in simplevcs
499 499 :type extras: dict
500 500
501 501 :return: status code of the hook. 0 for success.
502 502 :rtype: int
503 503 """
504 504
505 505 if 'pull' not in extras['hooks']:
506 506 return HookResponse(0, '')
507 507
508 508 stdout = io.StringIO()
509 509 try:
510 510 status_code = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
511 511
512 512 except Exception as error:
513 513 log.exception('Failed to call pre_pull hook')
514 514 status_code = 128
515 515 stdout.write(f'ERROR: {error}\n')
516 516
517 517 return HookResponse(status_code, stdout.getvalue())
518 518
519 519
520 520 def git_post_pull(extras) -> HookResponse:
521 521 """
522 522 Post pull hook.
523 523
524 524 :param extras: dictionary containing the keys defined in simplevcs
525 525 :type extras: dict
526 526
527 527 :return: status code of the hook. 0 for success.
528 528 :rtype: int
529 529 """
530 530 if 'pull' not in extras['hooks']:
531 531 return HookResponse(0, '')
532 532
533 533 stdout = io.StringIO()
534 534 try:
535 535 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
536 536 except Exception as error:
537 537 status = 128
538 538 stdout.write(f'ERROR: {error}\n')
539 539
540 540 return HookResponse(status, stdout.getvalue())
541 541
542 542
543 543 def _parse_git_ref_lines(revision_lines):
544 544 rev_data = []
545 545 for revision_line in revision_lines or []:
546 546 old_rev, new_rev, ref = revision_line.strip().split(' ')
547 547 ref_data = ref.split('/', 2)
548 548 if ref_data[1] in ('tags', 'heads'):
549 549 rev_data.append({
550 550 # NOTE(marcink):
551 551 # we're unable to tell total_commits for git at this point
552 552 # but we set the variable for consistency with GIT
553 553 'total_commits': -1,
554 554 'old_rev': old_rev,
555 555 'new_rev': new_rev,
556 556 'ref': ref,
557 557 'type': ref_data[1],
558 558 'name': ref_data[2],
559 559 })
560 560 return rev_data
561 561
562 562
563 563 def git_pre_receive(unused_repo_path, revision_lines, env) -> int:
564 564 """
565 565 Pre push hook.
566 566
567 567 :return: status code of the hook. 0 for success.
568 568 """
569 569 extras = json.loads(env['RC_SCM_DATA'])
570 570 rev_data = _parse_git_ref_lines(revision_lines)
571 571 if 'push' not in extras['hooks']:
572 572 return 0
573 573 _fix_hooks_executables()
574 574
575 575 empty_commit_id = '0' * 40
576 576
577 577 detect_force_push = extras.get('detect_force_push')
578 578
579 579 for push_ref in rev_data:
580 580 # store our git-env which holds the temp store
581 581 push_ref['git_env'] = _get_git_env()
582 582 push_ref['pruned_sha'] = ''
583 583 if not detect_force_push:
584 584 # don't check for forced-push when we don't need to
585 585 continue
586 586
587 587 type_ = push_ref['type']
588 588 new_branch = push_ref['old_rev'] == empty_commit_id
589 589 delete_branch = push_ref['new_rev'] == empty_commit_id
590 590 if type_ == 'heads' and not (new_branch or delete_branch):
591 591 old_rev = push_ref['old_rev']
592 592 new_rev = push_ref['new_rev']
593 593 cmd = [settings.GIT_EXECUTABLE(), 'rev-list', old_rev, f'^{new_rev}']
594 594 stdout, stderr = subprocessio.run_command(
595 595 cmd, env=os.environ.copy())
596 596 # means we're having some non-reachable objects, this forced push was used
597 597 if stdout:
598 598 push_ref['pruned_sha'] = stdout.splitlines()
599 599
600 600 extras['hook_type'] = 'pre_receive'
601 601 extras['commit_ids'] = rev_data
602 602
603 603 stdout = sys.stdout
604 604 status_code = _call_hook('pre_push', extras, GitMessageWriter(stdout))
605 605
606 606 return status_code
607 607
608 608
609 609 def git_post_receive(unused_repo_path, revision_lines, env) -> int:
610 610 """
611 611 Post push hook.
612 612
613 613 :return: status code of the hook. 0 for success.
614 614 """
615 615 extras = json.loads(env['RC_SCM_DATA'])
616 616 if 'push' not in extras['hooks']:
617 617 return 0
618 618
619 619 _fix_hooks_executables()
620 620
621 621 rev_data = _parse_git_ref_lines(revision_lines)
622 622
623 623 git_revs = []
624 624
625 625 # N.B.(skreft): it is ok to just call git, as git before calling a
626 626 # subcommand sets the PATH environment variable so that it point to the
627 627 # correct version of the git executable.
628 628 empty_commit_id = '0' * 40
629 629 branches = []
630 630 tags = []
631 631 for push_ref in rev_data:
632 632 type_ = push_ref['type']
633 633
634 634 if type_ == 'heads':
635 635 # starting new branch case
636 636 if push_ref['old_rev'] == empty_commit_id:
637 637 push_ref_name = push_ref['name']
638 638
639 639 if push_ref_name not in branches:
640 640 branches.append(push_ref_name)
641 641
642 642 need_head_set = ''
643 643 with Repository(os.getcwd()) as repo:
644 644 try:
645 645 repo.head
646 646 except pygit2.GitError:
647 647 need_head_set = f'refs/heads/{push_ref_name}'
648 648
649 649 if need_head_set:
650 650 repo.set_head(need_head_set)
651 651 print(f"Setting default branch to {push_ref_name}")
652 652
653 653 cmd = [settings.GIT_EXECUTABLE(), 'for-each-ref', '--format=%(refname)', 'refs/heads/*']
654 654 stdout, stderr = subprocessio.run_command(
655 655 cmd, env=os.environ.copy())
656 656 heads = safe_str(stdout)
657 657 heads = heads.replace(push_ref['ref'], '')
658 658 heads = ' '.join(head for head
659 659 in heads.splitlines() if head) or '.'
660 660 cmd = [settings.GIT_EXECUTABLE(), 'log', '--reverse',
661 661 '--pretty=format:%H', '--', push_ref['new_rev'],
662 662 '--not', heads]
663 663 stdout, stderr = subprocessio.run_command(
664 664 cmd, env=os.environ.copy())
665 665 git_revs.extend(list(map(ascii_str, stdout.splitlines())))
666 666
667 667 # delete branch case
668 668 elif push_ref['new_rev'] == empty_commit_id:
669 669 git_revs.append(f'delete_branch=>{push_ref["name"]}')
670 670 else:
671 671 if push_ref['name'] not in branches:
672 672 branches.append(push_ref['name'])
673 673
674 674 cmd = [settings.GIT_EXECUTABLE(), 'log',
675 675 f'{push_ref["old_rev"]}..{push_ref["new_rev"]}',
676 676 '--reverse', '--pretty=format:%H']
677 677 stdout, stderr = subprocessio.run_command(
678 678 cmd, env=os.environ.copy())
679 679 # we get bytes from stdout, we need str to be consistent
680 680 log_revs = list(map(ascii_str, stdout.splitlines()))
681 681 git_revs.extend(log_revs)
682 682
683 683 # Pure pygit2 impl. but still 2-3x slower :/
684 684 # results = []
685 685 #
686 686 # with Repository(os.getcwd()) as repo:
687 687 # repo_new_rev = repo[push_ref['new_rev']]
688 688 # repo_old_rev = repo[push_ref['old_rev']]
689 689 # walker = repo.walk(repo_new_rev.id, pygit2.GIT_SORT_TOPOLOGICAL)
690 690 #
691 691 # for commit in walker:
692 692 # if commit.id == repo_old_rev.id:
693 693 # break
694 694 # results.append(commit.id.hex)
695 695 # # reverse the order, can't use GIT_SORT_REVERSE
696 696 # log_revs = results[::-1]
697 697
698 698 elif type_ == 'tags':
699 699 if push_ref['name'] not in tags:
700 700 tags.append(push_ref['name'])
701 701 git_revs.append(f'tag=>{push_ref["name"]}')
702 702
703 703 extras['hook_type'] = 'post_receive'
704 704 extras['commit_ids'] = git_revs
705 705 extras['new_refs'] = {
706 706 'branches': branches,
707 707 'bookmarks': [],
708 708 'tags': tags,
709 709 }
710 710
711 711 stdout = sys.stdout
712 712
713 713 if 'repo_size' in extras['hooks']:
714 714 try:
715 715 _call_hook('repo_size', extras, GitMessageWriter(stdout))
716 716 except Exception:
717 717 pass
718 718
719 719 status_code = _call_hook('post_push', extras, GitMessageWriter(stdout))
720 720 return status_code
721 721
722 722
723 723 def _get_extras_from_txn_id(path, txn_id):
724 724 _fix_hooks_executables()
725 725
726 726 extras = {}
727 727 try:
728 728 cmd = [settings.SVNLOOK_EXECUTABLE(), 'pget',
729 729 '-t', txn_id,
730 730 '--revprop', path, 'rc-scm-extras']
731 731 stdout, stderr = subprocessio.run_command(
732 732 cmd, env=os.environ.copy())
733 733 extras = json.loads(base64.urlsafe_b64decode(stdout))
734 734 except Exception:
735 735 log.exception('Failed to extract extras info from txn_id')
736 736
737 737 return extras
738 738
739 739
740 740 def _get_extras_from_commit_id(commit_id, path):
741 741 _fix_hooks_executables()
742 742
743 743 extras = {}
744 744 try:
745 745 cmd = [settings.SVNLOOK_EXECUTABLE(), 'pget',
746 746 '-r', commit_id,
747 747 '--revprop', path, 'rc-scm-extras']
748 748 stdout, stderr = subprocessio.run_command(
749 749 cmd, env=os.environ.copy())
750 750 extras = json.loads(base64.urlsafe_b64decode(stdout))
751 751 except Exception:
752 752 log.exception('Failed to extract extras info from commit_id')
753 753
754 754 return extras
755 755
756 756
757 757 def svn_pre_commit(repo_path, commit_data, env):
758 758
759 759 path, txn_id = commit_data
760 760 branches = []
761 761 tags = []
762 762
763 763 if env.get('RC_SCM_DATA'):
764 764 extras = json.loads(env['RC_SCM_DATA'])
765 765 else:
766 766 # fallback method to read from TXN-ID stored data
767 767 extras = _get_extras_from_txn_id(path, txn_id)
768 768
769 769 if not extras:
770 770 #TODO: temporary fix until svn txn-id changes are merged
771 771 return 0
772 772 raise ValueError('Failed to extract context data called extras for hook execution')
773 773
774 774 extras['hook_type'] = 'pre_commit'
775 775 extras['commit_ids'] = [txn_id]
776 776 extras['txn_id'] = txn_id
777 777 extras['new_refs'] = {
778 778 'total_commits': 1,
779 779 'branches': branches,
780 780 'bookmarks': [],
781 781 'tags': tags,
782 782 }
783 783
784 784 return _call_hook('pre_push', extras, SvnMessageWriter())
785 785
786 786
787 787 def svn_post_commit(repo_path, commit_data, env):
788 788 """
789 789 commit_data is path, rev, txn_id
790 790 """
791 791
792 792 if len(commit_data) == 3:
793 793 path, commit_id, txn_id = commit_data
794 794 elif len(commit_data) == 2:
795 795 log.error('Failed to extract txn_id from commit_data using legacy method. '
796 796 'Some functionality might be limited')
797 797 path, commit_id = commit_data
798 798 txn_id = None
799 799 else:
800 800 return 0
801 801
802 802 branches = []
803 803 tags = []
804 804
805 805 if env.get('RC_SCM_DATA'):
806 806 extras = json.loads(env['RC_SCM_DATA'])
807 807 else:
808 808 # fallback method to read from TXN-ID stored data
809 809 extras = _get_extras_from_commit_id(commit_id, path)
810 810
811 811 if not extras:
812 812 #TODO: temporary fix until svn txn-id changes are merged
813 813 return 0
814 814 raise ValueError('Failed to extract context data called extras for hook execution')
815 815
816 816 extras['hook_type'] = 'post_commit'
817 817 extras['commit_ids'] = [commit_id]
818 818 extras['txn_id'] = txn_id
819 819 extras['new_refs'] = {
820 820 'branches': branches,
821 821 'bookmarks': [],
822 822 'tags': tags,
823 823 'total_commits': 1,
824 824 }
825 825
826 826 if 'repo_size' in extras['hooks']:
827 827 try:
828 828 _call_hook('repo_size', extras, SvnMessageWriter())
829 829 except Exception:
830 830 pass
831 831
832 832 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,774 +1,774 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import platform
21 21 import sys
22 22 import locale
23 23 import logging
24 24 import uuid
25 25 import time
26 26 import wsgiref.util
27 27 import tempfile
28 28 import psutil
29 29
30 30 from itertools import chain
31 31
32 32 import msgpack
33 33 import configparser
34 34
35 35 from pyramid.config import Configurator
36 36 from pyramid.wsgi import wsgiapp
37 37 from pyramid.response import Response
38 38
39 39 from vcsserver.base import BytesEnvelope, BinaryEnvelope
40 40 from vcsserver.lib.ext_json import json
41 41 from vcsserver.config.settings_maker import SettingsMaker
42 from vcsserver.str_utils import safe_int
42 from vcsserver.lib.str_utils import safe_int
43 43 from vcsserver.lib.statsd_client import StatsdClient
44 44 from vcsserver.tweens.request_wrapper import get_headers_call_context
45 45
46 46 import vcsserver
47 47 from vcsserver import remote_wsgi, scm_app, settings, hgpatches
48 48 from vcsserver.git_lfs.app import GIT_LFS_CONTENT_TYPE, GIT_LFS_PROTO_PAT
49 49 from vcsserver.echo_stub import remote_wsgi as remote_wsgi_stub
50 50 from vcsserver.echo_stub.echo_app import EchoApp
51 51 from vcsserver.exceptions import HTTPRepoLocked, HTTPRepoBranchProtected
52 52 from vcsserver.lib.exc_tracking import store_exception, format_exc
53 53 from vcsserver.server import VcsServer
54 54
55 55 strict_vcs = True
56 56
57 57 git_import_err = None
58 58 try:
59 59 from vcsserver.remote.git_remote import GitFactory, GitRemote
60 60 except ImportError as e:
61 61 GitFactory = None
62 62 GitRemote = None
63 63 git_import_err = e
64 64 if strict_vcs:
65 65 raise
66 66
67 67
68 68 hg_import_err = None
69 69 try:
70 70 from vcsserver.remote.hg_remote import MercurialFactory, HgRemote
71 71 except ImportError as e:
72 72 MercurialFactory = None
73 73 HgRemote = None
74 74 hg_import_err = e
75 75 if strict_vcs:
76 76 raise
77 77
78 78
79 79 svn_import_err = None
80 80 try:
81 81 from vcsserver.remote.svn_remote import SubversionFactory, SvnRemote
82 82 except ImportError as e:
83 83 SubversionFactory = None
84 84 SvnRemote = None
85 85 svn_import_err = e
86 86 if strict_vcs:
87 87 raise
88 88
89 89 log = logging.getLogger(__name__)
90 90
91 91 # due to Mercurial/glibc2.27 problems we need to detect if locale settings are
92 92 # causing problems and "fix" it in case they do and fallback to LC_ALL = C
93 93
94 94 try:
95 95 locale.setlocale(locale.LC_ALL, '')
96 96 except locale.Error as e:
97 97 log.error(
98 98 'LOCALE ERROR: failed to set LC_ALL, fallback to LC_ALL=C, org error: %s', e)
99 99 os.environ['LC_ALL'] = 'C'
100 100
101 101
102 102 def _is_request_chunked(environ):
103 103 stream = environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked'
104 104 return stream
105 105
106 106
107 107 def log_max_fd():
108 108 try:
109 109 maxfd = psutil.Process().rlimit(psutil.RLIMIT_NOFILE)[1]
110 110 log.info('Max file descriptors value: %s', maxfd)
111 111 except Exception:
112 112 pass
113 113
114 114
115 115 class VCS:
116 116 def __init__(self, locale_conf=None, cache_config=None):
117 117 self.locale = locale_conf
118 118 self.cache_config = cache_config
119 119 self._configure_locale()
120 120
121 121 log_max_fd()
122 122
123 123 if GitFactory and GitRemote:
124 124 git_factory = GitFactory()
125 125 self._git_remote = GitRemote(git_factory)
126 126 else:
127 127 log.error("Git client import failed: %s", git_import_err)
128 128
129 129 if MercurialFactory and HgRemote:
130 130 hg_factory = MercurialFactory()
131 131 self._hg_remote = HgRemote(hg_factory)
132 132 else:
133 133 log.error("Mercurial client import failed: %s", hg_import_err)
134 134
135 135 if SubversionFactory and SvnRemote:
136 136 svn_factory = SubversionFactory()
137 137
138 138 # hg factory is used for svn url validation
139 139 hg_factory = MercurialFactory()
140 140 self._svn_remote = SvnRemote(svn_factory, hg_factory=hg_factory)
141 141 else:
142 142 log.error("Subversion client import failed: %s", svn_import_err)
143 143
144 144 self._vcsserver = VcsServer()
145 145
146 146 def _configure_locale(self):
147 147 if self.locale:
148 148 log.info('Settings locale: `LC_ALL` to %s', self.locale)
149 149 else:
150 150 log.info('Configuring locale subsystem based on environment variables')
151 151 try:
152 152 # If self.locale is the empty string, then the locale
153 153 # module will use the environment variables. See the
154 154 # documentation of the package `locale`.
155 155 locale.setlocale(locale.LC_ALL, self.locale)
156 156
157 157 language_code, encoding = locale.getlocale()
158 158 log.info(
159 159 'Locale set to language code "%s" with encoding "%s".',
160 160 language_code, encoding)
161 161 except locale.Error:
162 162 log.exception('Cannot set locale, not configuring the locale system')
163 163
164 164
165 165 class WsgiProxy:
166 166 def __init__(self, wsgi):
167 167 self.wsgi = wsgi
168 168
169 169 def __call__(self, environ, start_response):
170 170 input_data = environ['wsgi.input'].read()
171 171 input_data = msgpack.unpackb(input_data)
172 172
173 173 error = None
174 174 try:
175 175 data, status, headers = self.wsgi.handle(
176 176 input_data['environment'], input_data['input_data'],
177 177 *input_data['args'], **input_data['kwargs'])
178 178 except Exception as e:
179 179 data, status, headers = [], None, None
180 180 error = {
181 181 'message': str(e),
182 182 '_vcs_kind': getattr(e, '_vcs_kind', None)
183 183 }
184 184
185 185 start_response(200, {})
186 186 return self._iterator(error, status, headers, data)
187 187
188 188 def _iterator(self, error, status, headers, data):
189 189 initial_data = [
190 190 error,
191 191 status,
192 192 headers,
193 193 ]
194 194
195 195 for d in chain(initial_data, data):
196 196 yield msgpack.packb(d)
197 197
198 198
199 199 def not_found(request):
200 200 return {'status': '404 NOT FOUND'}
201 201
202 202
203 203 class VCSViewPredicate:
204 204 def __init__(self, val, config):
205 205 self.remotes = val
206 206
207 207 def text(self):
208 208 return f'vcs view method = {list(self.remotes.keys())}'
209 209
210 210 phash = text
211 211
212 212 def __call__(self, context, request):
213 213 """
214 214 View predicate that returns true if given backend is supported by
215 215 defined remotes.
216 216 """
217 217 backend = request.matchdict.get('backend')
218 218 return backend in self.remotes
219 219
220 220
221 221 class HTTPApplication:
222 222 ALLOWED_EXCEPTIONS = ('KeyError', 'URLError')
223 223
224 224 remote_wsgi = remote_wsgi
225 225 _use_echo_app = False
226 226
227 227 def __init__(self, settings=None, global_config=None):
228 228
229 229 self.config = Configurator(settings=settings)
230 230 # Init our statsd at very start
231 231 self.config.registry.statsd = StatsdClient.statsd
232 232 self.config.registry.vcs_call_context = {}
233 233
234 234 self.global_config = global_config
235 235 self.config.include('vcsserver.lib.rc_cache')
236 236 self.config.include('vcsserver.lib.rc_cache.archive_cache')
237 237
238 238 settings_locale = settings.get('locale', '') or 'en_US.UTF-8'
239 239 vcs = VCS(locale_conf=settings_locale, cache_config=settings)
240 240 self._remotes = {
241 241 'hg': vcs._hg_remote,
242 242 'git': vcs._git_remote,
243 243 'svn': vcs._svn_remote,
244 244 'server': vcs._vcsserver,
245 245 }
246 246 if settings.get('dev.use_echo_app', 'false').lower() == 'true':
247 247 self._use_echo_app = True
248 248 log.warning("Using EchoApp for VCS operations.")
249 249 self.remote_wsgi = remote_wsgi_stub
250 250
251 251 self._configure_settings(global_config, settings)
252 252
253 253 self._configure()
254 254
255 255 def _configure_settings(self, global_config, app_settings):
256 256 """
257 257 Configure the settings module.
258 258 """
259 259 settings_merged = global_config.copy()
260 260 settings_merged.update(app_settings)
261 261
262 262 binary_dir = app_settings['core.binary_dir']
263 263
264 264 settings.BINARY_DIR = binary_dir
265 265
266 266 # Store the settings to make them available to other modules.
267 267 vcsserver.PYRAMID_SETTINGS = settings_merged
268 268 vcsserver.CONFIG = settings_merged
269 269
270 270 def _configure(self):
271 271 self.config.add_renderer(name='msgpack', factory=self._msgpack_renderer_factory)
272 272
273 273 self.config.add_route('service', '/_service')
274 274 self.config.add_route('status', '/status')
275 275 self.config.add_route('hg_proxy', '/proxy/hg')
276 276 self.config.add_route('git_proxy', '/proxy/git')
277 277
278 278 # rpc methods
279 279 self.config.add_route('vcs', '/{backend}')
280 280
281 281 # streaming rpc remote methods
282 282 self.config.add_route('vcs_stream', '/{backend}/stream')
283 283
284 284 # vcs operations clone/push as streaming
285 285 self.config.add_route('stream_git', '/stream/git/*repo_name')
286 286 self.config.add_route('stream_hg', '/stream/hg/*repo_name')
287 287
288 288 self.config.add_view(self.status_view, route_name='status', renderer='json')
289 289 self.config.add_view(self.service_view, route_name='service', renderer='msgpack')
290 290
291 291 self.config.add_view(self.hg_proxy(), route_name='hg_proxy')
292 292 self.config.add_view(self.git_proxy(), route_name='git_proxy')
293 293 self.config.add_view(self.vcs_view, route_name='vcs', renderer='msgpack',
294 294 vcs_view=self._remotes)
295 295 self.config.add_view(self.vcs_stream_view, route_name='vcs_stream',
296 296 vcs_view=self._remotes)
297 297
298 298 self.config.add_view(self.hg_stream(), route_name='stream_hg')
299 299 self.config.add_view(self.git_stream(), route_name='stream_git')
300 300
301 301 self.config.add_view_predicate('vcs_view', VCSViewPredicate)
302 302
303 303 self.config.add_notfound_view(not_found, renderer='json')
304 304
305 305 self.config.add_view(self.handle_vcs_exception, context=Exception)
306 306
307 307 self.config.add_tween(
308 308 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
309 309 )
310 310 self.config.add_request_method(
311 311 'vcsserver.lib.request_counter.get_request_counter',
312 312 'request_count')
313 313
314 314 def wsgi_app(self):
315 315 return self.config.make_wsgi_app()
316 316
317 317 def _vcs_view_params(self, request):
318 318 remote = self._remotes[request.matchdict['backend']]
319 319 payload = msgpack.unpackb(request.body, use_list=True)
320 320
321 321 method = payload.get('method')
322 322 params = payload['params']
323 323 wire = params.get('wire')
324 324 args = params.get('args')
325 325 kwargs = params.get('kwargs')
326 326 context_uid = None
327 327
328 328 request.registry.vcs_call_context = {
329 329 'method': method,
330 330 'repo_name': payload.get('_repo_name'),
331 331 }
332 332
333 333 if wire:
334 334 try:
335 335 wire['context'] = context_uid = uuid.UUID(wire['context'])
336 336 except KeyError:
337 337 pass
338 338 args.insert(0, wire)
339 339 repo_state_uid = wire.get('repo_state_uid') if wire else None
340 340
341 341 # NOTE(marcink): trading complexity for slight performance
342 342 if log.isEnabledFor(logging.DEBUG):
343 343 # also we SKIP printing out any of those methods args since they maybe excessive
344 344 just_args_methods = {
345 345 'commitctx': ('content', 'removed', 'updated'),
346 346 'commit': ('content', 'removed', 'updated')
347 347 }
348 348 if method in just_args_methods:
349 349 skip_args = just_args_methods[method]
350 350 call_args = ''
351 351 call_kwargs = {}
352 352 for k in kwargs:
353 353 if k in skip_args:
354 354 # replace our skip key with dummy
355 355 call_kwargs[k] = f'RemovedParam({k})'
356 356 else:
357 357 call_kwargs[k] = kwargs[k]
358 358 else:
359 359 call_args = args[1:]
360 360 call_kwargs = kwargs
361 361
362 362 log.debug('Method requested:`%s` with args:%s kwargs:%s context_uid: %s, repo_state_uid:%s',
363 363 method, call_args, call_kwargs, context_uid, repo_state_uid)
364 364
365 365 statsd = request.registry.statsd
366 366 if statsd:
367 367 statsd.incr(
368 368 'vcsserver_method_total', tags=[
369 369 f"method:{method}",
370 370 ])
371 371 return payload, remote, method, args, kwargs
372 372
373 373 def vcs_view(self, request):
374 374
375 375 payload, remote, method, args, kwargs = self._vcs_view_params(request)
376 376 payload_id = payload.get('id')
377 377
378 378 try:
379 379 resp = getattr(remote, method)(*args, **kwargs)
380 380 except Exception as e:
381 381 exc_info = list(sys.exc_info())
382 382 exc_type, exc_value, exc_traceback = exc_info
383 383
384 384 org_exc = getattr(e, '_org_exc', None)
385 385 org_exc_name = None
386 386 org_exc_tb = ''
387 387 if org_exc:
388 388 org_exc_name = org_exc.__class__.__name__
389 389 org_exc_tb = getattr(e, '_org_exc_tb', '')
390 390 # replace our "faked" exception with our org
391 391 exc_info[0] = org_exc.__class__
392 392 exc_info[1] = org_exc
393 393
394 394 should_store_exc = True
395 395 if org_exc:
396 396 def get_exc_fqn(_exc_obj):
397 397 module_name = getattr(org_exc.__class__, '__module__', 'UNKNOWN')
398 398 return module_name + '.' + org_exc_name
399 399
400 400 exc_fqn = get_exc_fqn(org_exc)
401 401
402 402 if exc_fqn in ['mercurial.error.RepoLookupError',
403 403 'vcsserver.exceptions.RefNotFoundException']:
404 404 should_store_exc = False
405 405
406 406 if should_store_exc:
407 407 store_exception(id(exc_info), exc_info, request_path=request.path)
408 408
409 409 tb_info = format_exc(exc_info)
410 410
411 411 type_ = e.__class__.__name__
412 412 if type_ not in self.ALLOWED_EXCEPTIONS:
413 413 type_ = None
414 414
415 415 resp = {
416 416 'id': payload_id,
417 417 'error': {
418 418 'message': str(e),
419 419 'traceback': tb_info,
420 420 'org_exc': org_exc_name,
421 421 'org_exc_tb': org_exc_tb,
422 422 'type': type_
423 423 }
424 424 }
425 425
426 426 try:
427 427 resp['error']['_vcs_kind'] = getattr(e, '_vcs_kind', None)
428 428 except AttributeError:
429 429 pass
430 430 else:
431 431 resp = {
432 432 'id': payload_id,
433 433 'result': resp
434 434 }
435 435 log.debug('Serving data for method %s', method)
436 436 return resp
437 437
438 438 def vcs_stream_view(self, request):
439 439 payload, remote, method, args, kwargs = self._vcs_view_params(request)
440 440 # this method has a stream: marker we remove it here
441 441 method = method.split('stream:')[-1]
442 442 chunk_size = safe_int(payload.get('chunk_size')) or 4096
443 443
444 444 resp = getattr(remote, method)(*args, **kwargs)
445 445
446 446 def get_chunked_data(method_resp):
447 447 stream = io.BytesIO(method_resp)
448 448 while 1:
449 449 chunk = stream.read(chunk_size)
450 450 if not chunk:
451 451 break
452 452 yield chunk
453 453
454 454 response = Response(app_iter=get_chunked_data(resp))
455 455 response.content_type = 'application/octet-stream'
456 456
457 457 return response
458 458
459 459 def status_view(self, request):
460 460 import vcsserver
461 461 _platform_id = platform.uname()[1] or 'instance'
462 462
463 463 return {
464 464 "status": "OK",
465 465 "vcsserver_version": vcsserver.get_version(),
466 466 "platform": _platform_id,
467 467 "pid": os.getpid(),
468 468 }
469 469
470 470 def service_view(self, request):
471 471 import vcsserver
472 472
473 473 payload = msgpack.unpackb(request.body, use_list=True)
474 474 server_config, app_config = {}, {}
475 475
476 476 try:
477 477 path = self.global_config['__file__']
478 478 config = configparser.RawConfigParser()
479 479
480 480 config.read(path)
481 481
482 482 if config.has_section('server:main'):
483 483 server_config = dict(config.items('server:main'))
484 484 if config.has_section('app:main'):
485 485 app_config = dict(config.items('app:main'))
486 486
487 487 except Exception:
488 488 log.exception('Failed to read .ini file for display')
489 489
490 490 environ = list(os.environ.items())
491 491
492 492 resp = {
493 493 'id': payload.get('id'),
494 494 'result': dict(
495 495 version=vcsserver.get_version(),
496 496 config=server_config,
497 497 app_config=app_config,
498 498 environ=environ,
499 499 payload=payload,
500 500 )
501 501 }
502 502 return resp
503 503
504 504 def _msgpack_renderer_factory(self, info):
505 505
506 506 def _render(value, system):
507 507 bin_type = False
508 508 res = value.get('result')
509 509 if isinstance(res, BytesEnvelope):
510 510 log.debug('Result is wrapped in BytesEnvelope type')
511 511 bin_type = True
512 512 elif isinstance(res, BinaryEnvelope):
513 513 log.debug('Result is wrapped in BinaryEnvelope type')
514 514 value['result'] = res.val
515 515 bin_type = True
516 516
517 517 request = system.get('request')
518 518 if request is not None:
519 519 response = request.response
520 520 ct = response.content_type
521 521 if ct == response.default_content_type:
522 522 response.content_type = 'application/x-msgpack'
523 523 if bin_type:
524 524 response.content_type = 'application/x-msgpack-bin'
525 525
526 526 return msgpack.packb(value, use_bin_type=bin_type)
527 527 return _render
528 528
529 529 def set_env_from_config(self, environ, config):
530 530 dict_conf = {}
531 531 try:
532 532 for elem in config:
533 533 if elem[0] == 'rhodecode':
534 534 dict_conf = json.loads(elem[2])
535 535 break
536 536 except Exception:
537 537 log.exception('Failed to fetch SCM CONFIG')
538 538 return
539 539
540 540 username = dict_conf.get('username')
541 541 if username:
542 542 environ['REMOTE_USER'] = username
543 543 # mercurial specific, some extension api rely on this
544 544 environ['HGUSER'] = username
545 545
546 546 ip = dict_conf.get('ip')
547 547 if ip:
548 548 environ['REMOTE_HOST'] = ip
549 549
550 550 if _is_request_chunked(environ):
551 551 # set the compatibility flag for webob
552 552 environ['wsgi.input_terminated'] = True
553 553
554 554 def hg_proxy(self):
555 555 @wsgiapp
556 556 def _hg_proxy(environ, start_response):
557 557 app = WsgiProxy(self.remote_wsgi.HgRemoteWsgi())
558 558 return app(environ, start_response)
559 559 return _hg_proxy
560 560
561 561 def git_proxy(self):
562 562 @wsgiapp
563 563 def _git_proxy(environ, start_response):
564 564 app = WsgiProxy(self.remote_wsgi.GitRemoteWsgi())
565 565 return app(environ, start_response)
566 566 return _git_proxy
567 567
568 568 def hg_stream(self):
569 569 if self._use_echo_app:
570 570 @wsgiapp
571 571 def _hg_stream(environ, start_response):
572 572 app = EchoApp('fake_path', 'fake_name', None)
573 573 return app(environ, start_response)
574 574 return _hg_stream
575 575 else:
576 576 @wsgiapp
577 577 def _hg_stream(environ, start_response):
578 578 log.debug('http-app: handling hg stream')
579 579 call_context = get_headers_call_context(environ)
580 580
581 581 repo_path = call_context['repo_path']
582 582 repo_name = call_context['repo_name']
583 583 config = call_context['repo_config']
584 584
585 585 app = scm_app.create_hg_wsgi_app(
586 586 repo_path, repo_name, config)
587 587
588 588 # Consistent path information for hgweb
589 589 environ['PATH_INFO'] = call_context['path_info']
590 590 environ['REPO_NAME'] = repo_name
591 591 self.set_env_from_config(environ, config)
592 592
593 593 log.debug('http-app: starting app handler '
594 594 'with %s and process request', app)
595 595 return app(environ, ResponseFilter(start_response))
596 596 return _hg_stream
597 597
598 598 def git_stream(self):
599 599 if self._use_echo_app:
600 600 @wsgiapp
601 601 def _git_stream(environ, start_response):
602 602 app = EchoApp('fake_path', 'fake_name', None)
603 603 return app(environ, start_response)
604 604 return _git_stream
605 605 else:
606 606 @wsgiapp
607 607 def _git_stream(environ, start_response):
608 608 log.debug('http-app: handling git stream')
609 609
610 610 call_context = get_headers_call_context(environ)
611 611
612 612 repo_path = call_context['repo_path']
613 613 repo_name = call_context['repo_name']
614 614 config = call_context['repo_config']
615 615
616 616 environ['PATH_INFO'] = call_context['path_info']
617 617 self.set_env_from_config(environ, config)
618 618
619 619 content_type = environ.get('CONTENT_TYPE', '')
620 620
621 621 path = environ['PATH_INFO']
622 622 is_lfs_request = GIT_LFS_CONTENT_TYPE in content_type
623 623 log.debug(
624 624 'LFS: Detecting if request `%s` is LFS server path based '
625 625 'on content type:`%s`, is_lfs:%s',
626 626 path, content_type, is_lfs_request)
627 627
628 628 if not is_lfs_request:
629 629 # fallback detection by path
630 630 if GIT_LFS_PROTO_PAT.match(path):
631 631 is_lfs_request = True
632 632 log.debug(
633 633 'LFS: fallback detection by path of: `%s`, is_lfs:%s',
634 634 path, is_lfs_request)
635 635
636 636 if is_lfs_request:
637 637 app = scm_app.create_git_lfs_wsgi_app(
638 638 repo_path, repo_name, config)
639 639 else:
640 640 app = scm_app.create_git_wsgi_app(
641 641 repo_path, repo_name, config)
642 642
643 643 log.debug('http-app: starting app handler '
644 644 'with %s and process request', app)
645 645
646 646 return app(environ, start_response)
647 647
648 648 return _git_stream
649 649
650 650 def handle_vcs_exception(self, exception, request):
651 651 _vcs_kind = getattr(exception, '_vcs_kind', '')
652 652
653 653 if _vcs_kind == 'repo_locked':
654 654 headers_call_context = get_headers_call_context(request.environ)
655 655 status_code = safe_int(headers_call_context['locked_status_code'])
656 656
657 657 return HTTPRepoLocked(
658 658 title=str(exception), status_code=status_code, headers=[('X-Rc-Locked', '1')])
659 659
660 660 elif _vcs_kind == 'repo_branch_protected':
661 661 # Get custom repo-branch-protected status code if present.
662 662 return HTTPRepoBranchProtected(
663 663 title=str(exception), headers=[('X-Rc-Branch-Protection', '1')])
664 664
665 665 exc_info = request.exc_info
666 666 store_exception(id(exc_info), exc_info)
667 667
668 668 traceback_info = 'unavailable'
669 669 if request.exc_info:
670 670 traceback_info = format_exc(request.exc_info)
671 671
672 672 log.error(
673 673 'error occurred handling this request for path: %s, \n%s',
674 674 request.path, traceback_info)
675 675
676 676 statsd = request.registry.statsd
677 677 if statsd:
678 678 exc_type = f"{exception.__class__.__module__}.{exception.__class__.__name__}"
679 679 statsd.incr('vcsserver_exception_total',
680 680 tags=[f"type:{exc_type}"])
681 681 raise exception
682 682
683 683
684 684 class ResponseFilter:
685 685
686 686 def __init__(self, start_response):
687 687 self._start_response = start_response
688 688
689 689 def __call__(self, status, response_headers, exc_info=None):
690 690 headers = tuple(
691 691 (h, v) for h, v in response_headers
692 692 if not wsgiref.util.is_hop_by_hop(h))
693 693 return self._start_response(status, headers, exc_info)
694 694
695 695
696 696 def sanitize_settings_and_apply_defaults(global_config, settings):
697 697 _global_settings_maker = SettingsMaker(global_config)
698 698 settings_maker = SettingsMaker(settings)
699 699
700 700 settings_maker.make_setting('logging.autoconfigure', False, parser='bool')
701 701
702 702 logging_conf = os.path.join(os.path.dirname(global_config.get('__file__')), 'logging.ini')
703 703 settings_maker.enable_logging(logging_conf)
704 704
705 705 # Default includes, possible to change as a user
706 706 pyramid_includes = settings_maker.make_setting('pyramid.includes', [], parser='list:newline')
707 707 log.debug("Using the following pyramid.includes: %s", pyramid_includes)
708 708
709 709 settings_maker.make_setting('__file__', global_config.get('__file__'))
710 710
711 711 settings_maker.make_setting('pyramid.default_locale_name', 'en')
712 712 settings_maker.make_setting('locale', 'en_US.UTF-8')
713 713
714 714 settings_maker.make_setting(
715 715 'core.binary_dir', '/usr/local/bin/rhodecode_bin/vcs_bin',
716 716 default_when_empty=True, parser='string:noquote')
717 717
718 718 temp_store = tempfile.gettempdir()
719 719 default_cache_dir = os.path.join(temp_store, 'rc_cache')
720 720 # save default, cache dir, and use it for all backends later.
721 721 default_cache_dir = settings_maker.make_setting(
722 722 'cache_dir',
723 723 default=default_cache_dir, default_when_empty=True,
724 724 parser='dir:ensured')
725 725
726 726 # exception store cache
727 727 settings_maker.make_setting(
728 728 'exception_tracker.store_path',
729 729 default=os.path.join(default_cache_dir, 'exc_store'), default_when_empty=True,
730 730 parser='dir:ensured'
731 731 )
732 732
733 733 # repo_object cache defaults
734 734 settings_maker.make_setting(
735 735 'rc_cache.repo_object.backend',
736 736 default='dogpile.cache.rc.file_namespace',
737 737 parser='string')
738 738 settings_maker.make_setting(
739 739 'rc_cache.repo_object.expiration_time',
740 740 default=30 * 24 * 60 * 60, # 30days
741 741 parser='int')
742 742 settings_maker.make_setting(
743 743 'rc_cache.repo_object.arguments.filename',
744 744 default=os.path.join(default_cache_dir, 'vcsserver_cache_repo_object.db'),
745 745 parser='string')
746 746
747 747 # statsd
748 748 settings_maker.make_setting('statsd.enabled', False, parser='bool')
749 749 settings_maker.make_setting('statsd.statsd_host', 'statsd-exporter', parser='string')
750 750 settings_maker.make_setting('statsd.statsd_port', 9125, parser='int')
751 751 settings_maker.make_setting('statsd.statsd_prefix', '')
752 752 settings_maker.make_setting('statsd.statsd_ipv6', False, parser='bool')
753 753
754 754 settings_maker.env_expand()
755 755
756 756
757 757 def main(global_config, **settings):
758 758 start_time = time.time()
759 759 log.info('Pyramid app config starting')
760 760
761 761 if MercurialFactory:
762 762 hgpatches.patch_largefiles_capabilities()
763 763 hgpatches.patch_subrepo_type_mapping()
764 764
765 765 # Fill in and sanitize the defaults & do ENV expansion
766 766 sanitize_settings_and_apply_defaults(global_config, settings)
767 767
768 768 # init and bootstrap StatsdClient
769 769 StatsdClient.setup(settings)
770 770
771 771 pyramid_app = HTTPApplication(settings=settings, global_config=global_config).wsgi_app()
772 772 total_time = time.time() - start_time
773 773 log.info('Pyramid app created and configured in %.2fs', total_time)
774 774 return pyramid_app
@@ -1,63 +1,63 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18
19 19 import logging
20 20
21 21 from repoze.lru import LRUCache
22 22
23 from vcsserver.str_utils import safe_str
23 from vcsserver.lib.str_utils import safe_str
24 24
25 25 log = logging.getLogger(__name__)
26 26
27 27
28 28 class LRUDict(LRUCache):
29 29 """
30 30 Wrapper to provide partial dict access
31 31 """
32 32
33 33 def __setitem__(self, key, value):
34 34 return self.put(key, value)
35 35
36 36 def __getitem__(self, key):
37 37 return self.get(key)
38 38
39 39 def __contains__(self, key):
40 40 return bool(self.get(key))
41 41
42 42 def __delitem__(self, key):
43 43 del self.data[key]
44 44
45 45 def keys(self):
46 46 return list(self.data.keys())
47 47
48 48
49 49 class LRUDictDebug(LRUDict):
50 50 """
51 51 Wrapper to provide some debug options
52 52 """
53 53 def _report_keys(self):
54 54 elems_cnt = f'{len(list(self.keys()))}/{self.size}'
55 55 # trick for pformat print it more nicely
56 56 fmt = '\n'
57 57 for cnt, elem in enumerate(self.keys()):
58 58 fmt += f'{cnt+1} - {safe_str(elem)}\n'
59 59 log.debug('current LRU keys (%s):%s', elems_cnt, fmt)
60 60
61 61 def __getitem__(self, key):
62 62 self._report_keys()
63 63 return self.get(key)
@@ -1,303 +1,303 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 #import errno
19 19 import fcntl
20 20 import functools
21 21 import logging
22 22 import os
23 23 import pickle
24 24 #import time
25 25
26 26 #import gevent
27 27 import msgpack
28 28 import redis
29 29
30 30 flock_org = fcntl.flock
31 31 from typing import Union
32 32
33 33 from dogpile.cache.api import Deserializer, Serializer
34 34 from dogpile.cache.backends import file as file_backend
35 35 from dogpile.cache.backends import memory as memory_backend
36 36 from dogpile.cache.backends import redis as redis_backend
37 37 from dogpile.cache.backends.file import FileLock
38 38 from dogpile.cache.util import memoized_property
39 39
40 40 from vcsserver.lib.memory_lru_dict import LRUDict, LRUDictDebug
41 from vcsserver.str_utils import safe_bytes, safe_str
42 from vcsserver.type_utils import str2bool
41 from vcsserver.lib.str_utils import safe_bytes, safe_str
42 from vcsserver.lib.type_utils import str2bool
43 43
44 44 _default_max_size = 1024
45 45
46 46 log = logging.getLogger(__name__)
47 47
48 48
49 49 class LRUMemoryBackend(memory_backend.MemoryBackend):
50 50 key_prefix = 'lru_mem_backend'
51 51 pickle_values = False
52 52
53 53 def __init__(self, arguments):
54 54 self.max_size = arguments.pop('max_size', _default_max_size)
55 55
56 56 LRUDictClass = LRUDict
57 57 if arguments.pop('log_key_count', None):
58 58 LRUDictClass = LRUDictDebug
59 59
60 60 arguments['cache_dict'] = LRUDictClass(self.max_size)
61 61 super().__init__(arguments)
62 62
63 63 def __repr__(self):
64 64 return f'{self.__class__}(maxsize=`{self.max_size}`)'
65 65
66 66 def __str__(self):
67 67 return self.__repr__()
68 68
69 69 def delete(self, key):
70 70 try:
71 71 del self._cache[key]
72 72 except KeyError:
73 73 # we don't care if key isn't there at deletion
74 74 pass
75 75
76 76 def list_keys(self, prefix):
77 77 return list(self._cache.keys())
78 78
79 79 def delete_multi(self, keys):
80 80 for key in keys:
81 81 self.delete(key)
82 82
83 83 def delete_multi_by_prefix(self, prefix):
84 84 cache_keys = self.list_keys(prefix=prefix)
85 85 num_affected_keys = len(cache_keys)
86 86 if num_affected_keys:
87 87 self.delete_multi(cache_keys)
88 88 return num_affected_keys
89 89
90 90
91 91 class PickleSerializer:
92 92 serializer: None | Serializer = staticmethod( # type: ignore
93 93 functools.partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL)
94 94 )
95 95 deserializer: None | Deserializer = staticmethod( # type: ignore
96 96 functools.partial(pickle.loads)
97 97 )
98 98
99 99
100 100 class MsgPackSerializer:
101 101 serializer: None | Serializer = staticmethod( # type: ignore
102 102 msgpack.packb
103 103 )
104 104 deserializer: None | Deserializer = staticmethod( # type: ignore
105 105 functools.partial(msgpack.unpackb, use_list=False)
106 106 )
107 107
108 108
109 109 class CustomLockFactory(FileLock):
110 110
111 111 pass
112 112
113 113
114 114 class FileNamespaceBackend(PickleSerializer, file_backend.DBMBackend):
115 115 key_prefix = 'file_backend'
116 116
117 117 def __init__(self, arguments):
118 118 arguments['lock_factory'] = CustomLockFactory
119 119 db_file = arguments.get('filename')
120 120
121 121 log.debug('initialing cache-backend=%s db in %s', self.__class__.__name__, db_file)
122 122 db_file_dir = os.path.dirname(db_file)
123 123 if not os.path.isdir(db_file_dir):
124 124 os.makedirs(db_file_dir)
125 125
126 126 try:
127 127 super().__init__(arguments)
128 128 except Exception:
129 129 log.exception('Failed to initialize db at: %s', db_file)
130 130 raise
131 131
132 132 def __repr__(self):
133 133 return f'{self.__class__}(file=`{self.filename}`)'
134 134
135 135 def __str__(self):
136 136 return self.__repr__()
137 137
138 138 def _get_keys_pattern(self, prefix: bytes = b''):
139 139 return b'%b:%b' % (safe_bytes(self.key_prefix), safe_bytes(prefix))
140 140
141 141 def list_keys(self, prefix: bytes = b''):
142 142 prefix = self._get_keys_pattern(prefix)
143 143
144 144 def cond(dbm_key: bytes):
145 145 if not prefix:
146 146 return True
147 147
148 148 if dbm_key.startswith(prefix):
149 149 return True
150 150 return False
151 151
152 152 with self._dbm_file(True) as dbm:
153 153 try:
154 154 return list(filter(cond, dbm.keys()))
155 155 except Exception:
156 156 log.error('Failed to fetch DBM keys from DB: %s', self.get_store())
157 157 raise
158 158
159 159 def delete_multi_by_prefix(self, prefix):
160 160 cache_keys = self.list_keys(prefix=prefix)
161 161 num_affected_keys = len(cache_keys)
162 162 if num_affected_keys:
163 163 self.delete_multi(cache_keys)
164 164 return num_affected_keys
165 165
166 166 def get_store(self):
167 167 return self.filename
168 168
169 169
170 170 class BaseRedisBackend(redis_backend.RedisBackend):
171 171 key_prefix = ''
172 172
173 173 def __init__(self, arguments):
174 174 self.db_conn = arguments.get('host', '') or arguments.get('url', '') or 'redis-host'
175 175 super().__init__(arguments)
176 176
177 177 self._lock_timeout = self.lock_timeout
178 178 self._lock_auto_renewal = str2bool(arguments.pop("lock_auto_renewal", True))
179 179
180 180 if self._lock_auto_renewal and not self._lock_timeout:
181 181 # set default timeout for auto_renewal
182 182 self._lock_timeout = 30
183 183
184 184 def __repr__(self):
185 185 return f'{self.__class__}(conn=`{self.db_conn}`)'
186 186
187 187 def __str__(self):
188 188 return self.__repr__()
189 189
190 190 def _create_client(self):
191 191 args = {}
192 192
193 193 if self.url is not None:
194 194 args.update(url=self.url)
195 195
196 196 else:
197 197 args.update(
198 198 host=self.host, password=self.password,
199 199 port=self.port, db=self.db
200 200 )
201 201
202 202 connection_pool = redis.ConnectionPool(**args)
203 203 self.writer_client = redis.StrictRedis(
204 204 connection_pool=connection_pool
205 205 )
206 206 self.reader_client = self.writer_client
207 207
208 208 def _get_keys_pattern(self, prefix: bytes = b''):
209 209 return b'%b:%b*' % (safe_bytes(self.key_prefix), safe_bytes(prefix))
210 210
211 211 def list_keys(self, prefix: bytes = b''):
212 212 prefix = self._get_keys_pattern(prefix)
213 213 return self.reader_client.keys(prefix)
214 214
215 215 def delete_multi_by_prefix(self, prefix, use_lua=False):
216 216 if use_lua:
217 217 # high efficient LUA script to delete ALL keys by prefix...
218 218 lua = """local keys = redis.call('keys', ARGV[1])
219 219 for i=1,#keys,5000 do
220 220 redis.call('del', unpack(keys, i, math.min(i+(5000-1), #keys)))
221 221 end
222 222 return #keys"""
223 223 num_affected_keys = self.writer_client.eval(
224 224 lua,
225 225 0,
226 226 f"{prefix}*")
227 227 else:
228 228 cache_keys = self.list_keys(prefix=prefix)
229 229 num_affected_keys = len(cache_keys)
230 230 if num_affected_keys:
231 231 self.delete_multi(cache_keys)
232 232 return num_affected_keys
233 233
234 234 def get_store(self):
235 235 return self.reader_client.connection_pool
236 236
237 237 def get_mutex(self, key):
238 238 if self.distributed_lock:
239 239 lock_key = f'_lock_{safe_str(key)}'
240 240 return get_mutex_lock(
241 241 self.writer_client, lock_key,
242 242 self._lock_timeout,
243 243 auto_renewal=self._lock_auto_renewal
244 244 )
245 245 else:
246 246 return None
247 247
248 248
249 249 class RedisPickleBackend(PickleSerializer, BaseRedisBackend):
250 250 key_prefix = 'redis_pickle_backend'
251 251 pass
252 252
253 253
254 254 class RedisMsgPackBackend(MsgPackSerializer, BaseRedisBackend):
255 255 key_prefix = 'redis_msgpack_backend'
256 256 pass
257 257
258 258
259 259 def get_mutex_lock(client, lock_key, lock_timeout, auto_renewal=False):
260 260 from vcsserver.lib._vendor import redis_lock
261 261
262 262 class _RedisLockWrapper:
263 263 """LockWrapper for redis_lock"""
264 264
265 265 @classmethod
266 266 def get_lock(cls):
267 267 return redis_lock.Lock(
268 268 redis_client=client,
269 269 name=lock_key,
270 270 expire=lock_timeout,
271 271 auto_renewal=auto_renewal,
272 272 strict=True,
273 273 )
274 274
275 275 def __repr__(self):
276 276 return f"{self.__class__.__name__}:{lock_key}"
277 277
278 278 def __str__(self):
279 279 return f"{self.__class__.__name__}:{lock_key}"
280 280
281 281 def __init__(self):
282 282 self.lock = self.get_lock()
283 283 self.lock_key = lock_key
284 284
285 285 def acquire(self, wait=True):
286 286 log.debug('Trying to acquire Redis lock for key %s', self.lock_key)
287 287 try:
288 288 acquired = self.lock.acquire(wait)
289 289 log.debug('Got lock for key %s, %s', self.lock_key, acquired)
290 290 return acquired
291 291 except redis_lock.AlreadyAcquired:
292 292 return False
293 293 except redis_lock.AlreadyStarted:
294 294 # refresh thread exists, but it also means we acquired the lock
295 295 return True
296 296
297 297 def release(self):
298 298 try:
299 299 self.lock.release()
300 300 except redis_lock.NotAcquired:
301 301 pass
302 302
303 303 return _RedisLockWrapper()
@@ -1,245 +1,245 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import functools
19 19 import logging
20 20 import os
21 21 import threading
22 22 import time
23 23
24 24 import decorator
25 25 from dogpile.cache import CacheRegion
26 26
27 27
28 28 from vcsserver.utils import sha1
29 from vcsserver.str_utils import safe_bytes
30 from vcsserver.type_utils import str2bool # noqa :required by imports from .utils
29 from vcsserver.lib.str_utils import safe_bytes
30 from vcsserver.lib.type_utils import str2bool # noqa :required by imports from .utils
31 31
32 32 from . import region_meta
33 33
34 34 log = logging.getLogger(__name__)
35 35
36 36
37 37 class RhodeCodeCacheRegion(CacheRegion):
38 38
39 39 def __repr__(self):
40 40 return f'`{self.__class__.__name__}(name={self.name}, backend={self.backend.__class__})`'
41 41
42 42 def conditional_cache_on_arguments(
43 43 self, namespace=None,
44 44 expiration_time=None,
45 45 should_cache_fn=None,
46 46 to_str=str,
47 47 function_key_generator=None,
48 48 condition=True):
49 49 """
50 50 Custom conditional decorator, that will not touch any dogpile internals if
51 51 condition isn't meet. This works a bit different from should_cache_fn
52 52 And it's faster in cases we don't ever want to compute cached values
53 53 """
54 54 expiration_time_is_callable = callable(expiration_time)
55 55 if not namespace:
56 56 namespace = getattr(self, '_default_namespace', None)
57 57
58 58 if function_key_generator is None:
59 59 function_key_generator = self.function_key_generator
60 60
61 61 def get_or_create_for_user_func(func_key_generator, user_func, *arg, **kw):
62 62
63 63 if not condition:
64 64 log.debug('Calling un-cached method:%s', user_func.__name__)
65 65 start = time.time()
66 66 result = user_func(*arg, **kw)
67 67 total = time.time() - start
68 68 log.debug('un-cached method:%s took %.4fs', user_func.__name__, total)
69 69 return result
70 70
71 71 key = func_key_generator(*arg, **kw)
72 72
73 73 timeout = expiration_time() if expiration_time_is_callable \
74 74 else expiration_time
75 75
76 76 log.debug('Calling cached method:`%s`', user_func.__name__)
77 77 return self.get_or_create(key, user_func, timeout, should_cache_fn, (arg, kw))
78 78
79 79 def cache_decorator(user_func):
80 80 if to_str is str:
81 81 # backwards compatible
82 82 key_generator = function_key_generator(namespace, user_func)
83 83 else:
84 84 key_generator = function_key_generator(namespace, user_func, to_str=to_str)
85 85
86 86 def refresh(*arg, **kw):
87 87 """
88 88 Like invalidate, but regenerates the value instead
89 89 """
90 90 key = key_generator(*arg, **kw)
91 91 value = user_func(*arg, **kw)
92 92 self.set(key, value)
93 93 return value
94 94
95 95 def invalidate(*arg, **kw):
96 96 key = key_generator(*arg, **kw)
97 97 self.delete(key)
98 98
99 99 def set_(value, *arg, **kw):
100 100 key = key_generator(*arg, **kw)
101 101 self.set(key, value)
102 102
103 103 def get(*arg, **kw):
104 104 key = key_generator(*arg, **kw)
105 105 return self.get(key)
106 106
107 107 user_func.set = set_
108 108 user_func.invalidate = invalidate
109 109 user_func.get = get
110 110 user_func.refresh = refresh
111 111 user_func.key_generator = key_generator
112 112 user_func.original = user_func
113 113
114 114 # Use `decorate` to preserve the signature of :param:`user_func`.
115 115 return decorator.decorate(user_func, functools.partial(
116 116 get_or_create_for_user_func, key_generator))
117 117
118 118 return cache_decorator
119 119
120 120
121 121 def make_region(*arg, **kw):
122 122 return RhodeCodeCacheRegion(*arg, **kw)
123 123
124 124
125 125 def get_default_cache_settings(settings, prefixes=None):
126 126 prefixes = prefixes or []
127 127 cache_settings = {}
128 128 for key in settings.keys():
129 129 for prefix in prefixes:
130 130 if key.startswith(prefix):
131 131 name = key.split(prefix)[1].strip()
132 132 val = settings[key]
133 133 if isinstance(val, str):
134 134 val = val.strip()
135 135 cache_settings[name] = val
136 136 return cache_settings
137 137
138 138
139 139 def compute_key_from_params(*args):
140 140 """
141 141 Helper to compute key from given params to be used in cache manager
142 142 """
143 143 return sha1(safe_bytes("_".join(map(str, args))))
144 144
145 145
146 146 def custom_key_generator(backend, namespace, fn):
147 147 func_name = fn.__name__
148 148
149 149 def generate_key(*args):
150 150 backend_pref = getattr(backend, 'key_prefix', None) or 'backend_prefix'
151 151 namespace_pref = namespace or 'default_namespace'
152 152 arg_key = compute_key_from_params(*args)
153 153 final_key = f"{backend_pref}:{namespace_pref}:{func_name}_{arg_key}"
154 154
155 155 return final_key
156 156
157 157 return generate_key
158 158
159 159
160 160 def backend_key_generator(backend):
161 161 """
162 162 Special wrapper that also sends over the backend to the key generator
163 163 """
164 164 def wrapper(namespace, fn):
165 165 return custom_key_generator(backend, namespace, fn)
166 166 return wrapper
167 167
168 168
169 169 def get_or_create_region(region_name, region_namespace: str = None, use_async_runner=False):
170 170 from .backends import FileNamespaceBackend
171 171 from . import async_creation_runner
172 172
173 173 region_obj = region_meta.dogpile_cache_regions.get(region_name)
174 174 if not region_obj:
175 175 reg_keys = list(region_meta.dogpile_cache_regions.keys())
176 176 raise OSError(f'Region `{region_name}` not in configured: {reg_keys}.')
177 177
178 178 region_uid_name = f'{region_name}:{region_namespace}'
179 179
180 180 # Special case for ONLY the FileNamespaceBackend backend. We register one-file-per-region
181 181 if isinstance(region_obj.actual_backend, FileNamespaceBackend):
182 182 if not region_namespace:
183 183 raise ValueError(f'{FileNamespaceBackend} used requires to specify region_namespace param')
184 184
185 185 region_exist = region_meta.dogpile_cache_regions.get(region_namespace)
186 186 if region_exist:
187 187 log.debug('Using already configured region: %s', region_namespace)
188 188 return region_exist
189 189
190 190 expiration_time = region_obj.expiration_time
191 191
192 192 cache_dir = region_meta.dogpile_config_defaults['cache_dir']
193 193 namespace_cache_dir = cache_dir
194 194
195 195 # we default the namespace_cache_dir to our default cache dir.
196 196 # however, if this backend is configured with filename= param, we prioritize that
197 197 # so all caches within that particular region, even those namespaced end up in the same path
198 198 if region_obj.actual_backend.filename:
199 199 namespace_cache_dir = os.path.dirname(region_obj.actual_backend.filename)
200 200
201 201 if not os.path.isdir(namespace_cache_dir):
202 202 os.makedirs(namespace_cache_dir)
203 203 new_region = make_region(
204 204 name=region_uid_name,
205 205 function_key_generator=backend_key_generator(region_obj.actual_backend)
206 206 )
207 207
208 208 namespace_filename = os.path.join(
209 209 namespace_cache_dir, f"{region_name}_{region_namespace}.cache_db")
210 210 # special type that allows 1db per namespace
211 211 new_region.configure(
212 212 backend='dogpile.cache.rc.file_namespace',
213 213 expiration_time=expiration_time,
214 214 arguments={"filename": namespace_filename}
215 215 )
216 216
217 217 # create and save in region caches
218 218 log.debug('configuring new region: %s', region_uid_name)
219 219 region_obj = region_meta.dogpile_cache_regions[region_namespace] = new_region
220 220
221 221 region_obj._default_namespace = region_namespace
222 222 if use_async_runner:
223 223 region_obj.async_creation_runner = async_creation_runner
224 224 return region_obj
225 225
226 226
227 227 def clear_cache_namespace(cache_region: str | RhodeCodeCacheRegion, cache_namespace_uid: str, method: str) -> int:
228 228 from . import CLEAR_DELETE, CLEAR_INVALIDATE
229 229
230 230 if not isinstance(cache_region, RhodeCodeCacheRegion):
231 231 cache_region = get_or_create_region(cache_region, cache_namespace_uid)
232 232 log.debug('clearing cache region: %s [prefix:%s] with method=%s',
233 233 cache_region, cache_namespace_uid, method)
234 234
235 235 num_affected_keys = 0
236 236
237 237 if method == CLEAR_INVALIDATE:
238 238 # NOTE: The CacheRegion.invalidate() method’s default mode of
239 239 # operation is to set a timestamp local to this CacheRegion in this Python process only.
240 240 # It does not impact other Python processes or regions as the timestamp is only stored locally in memory.
241 241 cache_region.invalidate(hard=True)
242 242
243 243 if method == CLEAR_DELETE:
244 244 num_affected_keys = cache_region.backend.delete_multi_by_prefix(prefix=cache_namespace_uid)
245 245 return num_affected_keys
1 NO CONTENT: file renamed from vcsserver/str_utils.py to vcsserver/lib/str_utils.py
@@ -1,160 +1,160 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import tempfile
20 20
21 21 from svn import client
22 22 from svn import core
23 23 from svn import ra
24 24
25 25 from mercurial import error
26 26
27 from vcsserver.str_utils import safe_bytes
27 from vcsserver.lib.str_utils import safe_bytes
28 28
29 29 core.svn_config_ensure(None)
30 30 svn_config = core.svn_config_get_config(None)
31 31
32 32
33 33 class RaCallbacks(ra.Callbacks):
34 34 @staticmethod
35 35 def open_tmp_file(pool): # pragma: no cover
36 36 (fd, fn) = tempfile.mkstemp()
37 37 os.close(fd)
38 38 return fn
39 39
40 40 @staticmethod
41 41 def get_client_string(pool):
42 42 return b'RhodeCode-subversion-url-checker'
43 43
44 44
45 45 class SubversionException(Exception):
46 46 pass
47 47
48 48
49 49 class SubversionConnectionException(SubversionException):
50 50 """Exception raised when a generic error occurs when connecting to a repository."""
51 51
52 52
53 53 def normalize_url(url):
54 54 if not url:
55 55 return url
56 56 if url.startswith(b'svn+http://') or url.startswith(b'svn+https://'):
57 57 url = url[4:]
58 58 url = url.rstrip(b'/')
59 59 return url
60 60
61 61
62 62 def _create_auth_baton(pool):
63 63 """Create a Subversion authentication baton. """
64 64 # Give the client context baton a suite of authentication
65 65 # providers.h
66 66 platform_specific = [
67 67 'svn_auth_get_gnome_keyring_simple_provider',
68 68 'svn_auth_get_gnome_keyring_ssl_client_cert_pw_provider',
69 69 'svn_auth_get_keychain_simple_provider',
70 70 'svn_auth_get_keychain_ssl_client_cert_pw_provider',
71 71 'svn_auth_get_kwallet_simple_provider',
72 72 'svn_auth_get_kwallet_ssl_client_cert_pw_provider',
73 73 'svn_auth_get_ssl_client_cert_file_provider',
74 74 'svn_auth_get_windows_simple_provider',
75 75 'svn_auth_get_windows_ssl_server_trust_provider',
76 76 ]
77 77
78 78 providers = []
79 79
80 80 for p in platform_specific:
81 81 if getattr(core, p, None) is not None:
82 82 try:
83 83 providers.append(getattr(core, p)())
84 84 except RuntimeError:
85 85 pass
86 86
87 87 providers += [
88 88 client.get_simple_provider(),
89 89 client.get_username_provider(),
90 90 client.get_ssl_client_cert_file_provider(),
91 91 client.get_ssl_client_cert_pw_file_provider(),
92 92 client.get_ssl_server_trust_file_provider(),
93 93 ]
94 94
95 95 return core.svn_auth_open(providers, pool)
96 96
97 97
98 98 class SubversionRepo:
99 99 """Wrapper for a Subversion repository.
100 100
101 101 It uses the SWIG Python bindings, see above for requirements.
102 102 """
103 103 def __init__(self, svn_url: bytes = b'', username: bytes = b'', password: bytes = b''):
104 104
105 105 self.username = username
106 106 self.password = password
107 107 self.svn_url = core.svn_path_canonicalize(svn_url)
108 108
109 109 self.auth_baton_pool = core.Pool()
110 110 self.auth_baton = _create_auth_baton(self.auth_baton_pool)
111 111 # self.init_ra_and_client() assumes that a pool already exists
112 112 self.pool = core.Pool()
113 113
114 114 self.ra = self.init_ra_and_client()
115 115 self.uuid = ra.get_uuid(self.ra, self.pool)
116 116
117 117 def init_ra_and_client(self):
118 118 """Initializes the RA and client layers, because sometimes getting
119 119 unified diffs runs the remote server out of open files.
120 120 """
121 121
122 122 if self.username:
123 123 core.svn_auth_set_parameter(self.auth_baton,
124 124 core.SVN_AUTH_PARAM_DEFAULT_USERNAME,
125 125 self.username)
126 126 if self.password:
127 127 core.svn_auth_set_parameter(self.auth_baton,
128 128 core.SVN_AUTH_PARAM_DEFAULT_PASSWORD,
129 129 self.password)
130 130
131 131 callbacks = RaCallbacks()
132 132 callbacks.auth_baton = self.auth_baton
133 133
134 134 try:
135 135 return ra.open2(self.svn_url, callbacks, svn_config, self.pool)
136 136 except SubversionException as e:
137 137 # e.child contains a detailed error messages
138 138 msglist = []
139 139 svn_exc = e
140 140 while svn_exc:
141 141 if svn_exc.args[0]:
142 142 msglist.append(svn_exc.args[0])
143 143 svn_exc = svn_exc.child
144 144 msg = '\n'.join(msglist)
145 145 raise SubversionConnectionException(msg)
146 146
147 147
148 148 class svnremoterepo:
149 149 """ the dumb wrapper for actual Subversion repositories """
150 150
151 151 def __init__(self, username: bytes = b'', password: bytes = b'', svn_url: bytes = b''):
152 152 self.username = username or b''
153 153 self.password = password or b''
154 154 self.path = normalize_url(svn_url)
155 155
156 156 def svn(self):
157 157 try:
158 158 return SubversionRepo(self.path, self.username, self.password)
159 159 except SubversionConnectionException as e:
160 160 raise error.Abort(safe_bytes(e))
1 NO CONTENT: file renamed from vcsserver/type_utils.py to vcsserver/lib/type_utils.py
@@ -1,417 +1,417 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Handles the Git smart protocol."""
19 19
20 20 import os
21 21 import socket
22 22 import logging
23 23
24 24 import dulwich.protocol
25 25 from dulwich.protocol import CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K
26 26 from webob import Request, Response, exc
27 27
28 28 from vcsserver.lib.ext_json import json
29 29 from vcsserver import hooks, subprocessio
30 from vcsserver.str_utils import ascii_bytes
30 from vcsserver.lib.str_utils import ascii_bytes
31 31
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 36 class FileWrapper:
37 37 """File wrapper that ensures how much data is read from it."""
38 38
39 39 def __init__(self, fd, content_length):
40 40 self.fd = fd
41 41 self.content_length = content_length
42 42 self.remain = content_length
43 43
44 44 def read(self, size):
45 45 if size <= self.remain:
46 46 try:
47 47 data = self.fd.read(size)
48 48 except socket.error:
49 49 raise IOError(self)
50 50 self.remain -= size
51 51 elif self.remain:
52 52 data = self.fd.read(self.remain)
53 53 self.remain = 0
54 54 else:
55 55 data = None
56 56 return data
57 57
58 58 def __repr__(self):
59 59 return '<FileWrapper {} len: {}, read: {}>'.format(
60 60 self.fd, self.content_length, self.content_length - self.remain
61 61 )
62 62
63 63
64 64 class GitRepository:
65 65 """WSGI app for handling Git smart protocol endpoints."""
66 66
67 67 git_folder_signature = frozenset(('config', 'head', 'info', 'objects', 'refs'))
68 68 commands = frozenset(('git-upload-pack', 'git-receive-pack'))
69 69 valid_accepts = frozenset(f'application/x-{c}-result' for c in commands)
70 70
71 71 # The last bytes are the SHA1 of the first 12 bytes.
72 72 EMPTY_PACK = (
73 73 b'PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08' +
74 74 b'\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e'
75 75 )
76 76 FLUSH_PACKET = b"0000"
77 77
78 78 SIDE_BAND_CAPS = frozenset((CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K))
79 79
80 80 def __init__(self, repo_name, content_path, git_path, update_server_info, extras):
81 81 files = frozenset(f.lower() for f in os.listdir(content_path))
82 82 valid_dir_signature = self.git_folder_signature.issubset(files)
83 83
84 84 if not valid_dir_signature:
85 85 raise OSError(f'{content_path} missing git signature')
86 86
87 87 self.content_path = content_path
88 88 self.repo_name = repo_name
89 89 self.extras = extras
90 90 self.git_path = git_path
91 91 self.update_server_info = update_server_info
92 92
93 93 def _get_fixedpath(self, path):
94 94 """
95 95 Small fix for repo_path
96 96
97 97 :param path:
98 98 """
99 99 path = path.split(self.repo_name, 1)[-1]
100 100 if path.startswith('.git'):
101 101 # for bare repos we still get the .git prefix inside, we skip it
102 102 # here, and remove from the service command
103 103 path = path[4:]
104 104
105 105 return path.strip('/')
106 106
107 107 def inforefs(self, request, unused_environ):
108 108 """
109 109 WSGI Response producer for HTTP GET Git Smart
110 110 HTTP /info/refs request.
111 111 """
112 112
113 113 git_command = request.GET.get('service')
114 114 if git_command not in self.commands:
115 115 log.debug('command %s not allowed', git_command)
116 116 return exc.HTTPForbidden()
117 117
118 118 # please, resist the urge to add '\n' to git capture and increment
119 119 # line count by 1.
120 120 # by git docs: Documentation/technical/http-protocol.txt#L214 \n is
121 121 # a part of protocol.
122 122 # The code in Git client not only does NOT need '\n', but actually
123 123 # blows up if you sprinkle "flush" (0000) as "0001\n".
124 124 # It reads binary, per number of bytes specified.
125 125 # if you do add '\n' as part of data, count it.
126 126 server_advert = f'# service={git_command}\n'
127 127 packet_len = hex(len(server_advert) + 4)[2:].rjust(4, '0').lower()
128 128 try:
129 129 gitenv = dict(os.environ)
130 130 # forget all configs
131 131 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
132 132 command = [self.git_path, git_command[4:], '--stateless-rpc',
133 133 '--advertise-refs', self.content_path]
134 134 out = subprocessio.SubprocessIOChunker(
135 135 command,
136 136 env=gitenv,
137 137 starting_values=[ascii_bytes(packet_len + server_advert) + self.FLUSH_PACKET],
138 138 shell=False
139 139 )
140 140 except OSError:
141 141 log.exception('Error processing command')
142 142 raise exc.HTTPExpectationFailed()
143 143
144 144 resp = Response()
145 145 resp.content_type = f'application/x-{git_command}-advertisement'
146 146 resp.charset = None
147 147 resp.app_iter = out
148 148
149 149 return resp
150 150
151 151 def _get_want_capabilities(self, request):
152 152 """Read the capabilities found in the first want line of the request."""
153 153 pos = request.body_file_seekable.tell()
154 154 first_line = request.body_file_seekable.readline()
155 155 request.body_file_seekable.seek(pos)
156 156
157 157 return frozenset(
158 158 dulwich.protocol.extract_want_line_capabilities(first_line)[1])
159 159
160 160 def _build_failed_pre_pull_response(self, capabilities, pre_pull_messages):
161 161 """
162 162 Construct a response with an empty PACK file.
163 163
164 164 We use an empty PACK file, as that would trigger the failure of the pull
165 165 or clone command.
166 166
167 167 We also print in the error output a message explaining why the command
168 168 was aborted.
169 169
170 170 If additionally, the user is accepting messages we send them the output
171 171 of the pre-pull hook.
172 172
173 173 Note that for clients not supporting side-band we just send them the
174 174 emtpy PACK file.
175 175 """
176 176
177 177 if self.SIDE_BAND_CAPS.intersection(capabilities):
178 178 response = []
179 179 proto = dulwich.protocol.Protocol(None, response.append)
180 180 proto.write_pkt_line(dulwich.protocol.NAK_LINE)
181 181
182 182 self._write_sideband_to_proto(proto, ascii_bytes(pre_pull_messages, allow_bytes=True), capabilities)
183 183 # N.B.(skreft): Do not change the sideband channel to 3, as that
184 184 # produces a fatal error in the client:
185 185 # fatal: error in sideband demultiplexer
186 186 proto.write_sideband(
187 187 dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS,
188 188 ascii_bytes('Pre pull hook failed: aborting\n', allow_bytes=True))
189 189 proto.write_sideband(
190 190 dulwich.protocol.SIDE_BAND_CHANNEL_DATA,
191 191 ascii_bytes(self.EMPTY_PACK, allow_bytes=True))
192 192
193 193 # writes b"0000" as default
194 194 proto.write_pkt_line(None)
195 195
196 196 return response
197 197 else:
198 198 return [ascii_bytes(self.EMPTY_PACK, allow_bytes=True)]
199 199
200 200 def _build_post_pull_response(self, response, capabilities, start_message, end_message):
201 201 """
202 202 Given a list response we inject the post-pull messages.
203 203
204 204 We only inject the messages if the client supports sideband, and the
205 205 response has the format:
206 206 0008NAK\n...0000
207 207
208 208 Note that we do not check the no-progress capability as by default, git
209 209 sends it, which effectively would block all messages.
210 210 """
211 211
212 212 if not self.SIDE_BAND_CAPS.intersection(capabilities):
213 213 return response
214 214
215 215 if not start_message and not end_message:
216 216 return response
217 217
218 218 try:
219 219 iter(response)
220 220 # iterator probably will work, we continue
221 221 except TypeError:
222 222 raise TypeError(f'response must be an iterator: got {type(response)}')
223 223 if isinstance(response, (list, tuple)):
224 224 raise TypeError(f'response must be an iterator: got {type(response)}')
225 225
226 226 def injected_response():
227 227
228 228 do_loop = 1
229 229 header_injected = 0
230 230 next_item = None
231 231 has_item = False
232 232 item = b''
233 233
234 234 while do_loop:
235 235
236 236 try:
237 237 next_item = next(response)
238 238 except StopIteration:
239 239 do_loop = 0
240 240
241 241 if has_item:
242 242 # last item ! alter it now
243 243 if do_loop == 0 and item.endswith(self.FLUSH_PACKET):
244 244 new_response = [item[:-4]]
245 245 new_response.extend(self._get_messages(end_message, capabilities))
246 246 new_response.append(self.FLUSH_PACKET)
247 247 item = b''.join(new_response)
248 248
249 249 yield item
250 250
251 251 has_item = True
252 252 item = next_item
253 253
254 254 # alter item if it's the initial chunk
255 255 if not header_injected and item.startswith(b'0008NAK\n'):
256 256 new_response = [b'0008NAK\n']
257 257 new_response.extend(self._get_messages(start_message, capabilities))
258 258 new_response.append(item[8:])
259 259 item = b''.join(new_response)
260 260 header_injected = 1
261 261
262 262 return injected_response()
263 263
264 264 def _write_sideband_to_proto(self, proto, data, capabilities):
265 265 """
266 266 Write the data to the proto's sideband number 2 == SIDE_BAND_CHANNEL_PROGRESS
267 267
268 268 We do not use dulwich's write_sideband directly as it only supports
269 269 side-band-64k.
270 270 """
271 271 if not data:
272 272 return
273 273
274 274 # N.B.(skreft): The values below are explained in the pack protocol
275 275 # documentation, section Packfile Data.
276 276 # https://github.com/git/git/blob/master/Documentation/technical/pack-protocol.txt
277 277 if CAPABILITY_SIDE_BAND_64K in capabilities:
278 278 chunk_size = 65515
279 279 elif CAPABILITY_SIDE_BAND in capabilities:
280 280 chunk_size = 995
281 281 else:
282 282 return
283 283
284 284 chunker = (data[i:i + chunk_size] for i in range(0, len(data), chunk_size))
285 285
286 286 for chunk in chunker:
287 287 proto.write_sideband(dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS, ascii_bytes(chunk, allow_bytes=True))
288 288
289 289 def _get_messages(self, data, capabilities):
290 290 """Return a list with packets for sending data in sideband number 2."""
291 291 response = []
292 292 proto = dulwich.protocol.Protocol(None, response.append)
293 293
294 294 self._write_sideband_to_proto(proto, data, capabilities)
295 295
296 296 return response
297 297
298 298 def backend(self, request, environ):
299 299 """
300 300 WSGI Response producer for HTTP POST Git Smart HTTP requests.
301 301 Reads commands and data from HTTP POST's body.
302 302 returns an iterator obj with contents of git command's
303 303 response to stdout
304 304 """
305 305 # TODO(skreft): think how we could detect an HTTPLockedException, as
306 306 # we probably want to have the same mechanism used by mercurial and
307 307 # simplevcs.
308 308 # For that we would need to parse the output of the command looking for
309 309 # some signs of the HTTPLockedError, parse the data and reraise it in
310 310 # pygrack. However, that would interfere with the streaming.
311 311 #
312 312 # Now the output of a blocked push is:
313 313 # Pushing to http://test_regular:test12@127.0.0.1:5001/vcs_test_git
314 314 # POST git-receive-pack (1047 bytes)
315 315 # remote: ERROR: Repository `vcs_test_git` locked by user `test_admin`. Reason:`lock_auto`
316 316 # To http://test_regular:test12@127.0.0.1:5001/vcs_test_git
317 317 # ! [remote rejected] master -> master (pre-receive hook declined)
318 318 # error: failed to push some refs to 'http://test_regular:test12@127.0.0.1:5001/vcs_test_git'
319 319
320 320 git_command = self._get_fixedpath(request.path_info)
321 321 if git_command not in self.commands:
322 322 log.debug('command %s not allowed', git_command)
323 323 return exc.HTTPForbidden()
324 324
325 325 capabilities = None
326 326 if git_command == 'git-upload-pack':
327 327 capabilities = self._get_want_capabilities(request)
328 328
329 329 if 'CONTENT_LENGTH' in environ:
330 330 inputstream = FileWrapper(request.body_file_seekable,
331 331 request.content_length)
332 332 else:
333 333 inputstream = request.body_file_seekable
334 334
335 335 resp = Response()
336 336 resp.content_type = f'application/x-{git_command}-result'
337 337 resp.charset = None
338 338
339 339 pre_pull_messages = ''
340 340 # Upload-pack == clone
341 341 if git_command == 'git-upload-pack':
342 342 hook_response = hooks.git_pre_pull(self.extras)
343 343 if hook_response.status != 0:
344 344 pre_pull_messages = hook_response.output
345 345 resp.app_iter = self._build_failed_pre_pull_response(
346 346 capabilities, pre_pull_messages)
347 347 return resp
348 348
349 349 gitenv = dict(os.environ)
350 350 # forget all configs
351 351 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
352 352 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
353 353 cmd = [self.git_path, git_command[4:], '--stateless-rpc',
354 354 self.content_path]
355 355 log.debug('handling cmd %s', cmd)
356 356
357 357 out = subprocessio.SubprocessIOChunker(
358 358 cmd,
359 359 input_stream=inputstream,
360 360 env=gitenv,
361 361 cwd=self.content_path,
362 362 shell=False,
363 363 fail_on_stderr=False,
364 364 fail_on_return_code=False
365 365 )
366 366
367 367 if self.update_server_info and git_command == 'git-receive-pack':
368 368 # We need to fully consume the iterator here, as the
369 369 # update-server-info command needs to be run after the push.
370 370 out = list(out)
371 371
372 372 # Updating refs manually after each push.
373 373 # This is required as some clients are exposing Git repos internally
374 374 # with the dumb protocol.
375 375 cmd = [self.git_path, 'update-server-info']
376 376 log.debug('handling cmd %s', cmd)
377 377 output = subprocessio.SubprocessIOChunker(
378 378 cmd,
379 379 input_stream=inputstream,
380 380 env=gitenv,
381 381 cwd=self.content_path,
382 382 shell=False,
383 383 fail_on_stderr=False,
384 384 fail_on_return_code=False
385 385 )
386 386 # Consume all the output so the subprocess finishes
387 387 for _ in output:
388 388 pass
389 389
390 390 # Upload-pack == clone
391 391 if git_command == 'git-upload-pack':
392 392 hook_response = hooks.git_post_pull(self.extras)
393 393 post_pull_messages = hook_response.output
394 394 resp.app_iter = self._build_post_pull_response(out, capabilities, pre_pull_messages, post_pull_messages)
395 395 else:
396 396 resp.app_iter = out
397 397
398 398 return resp
399 399
400 400 def __call__(self, environ, start_response):
401 401 request = Request(environ)
402 402 _path = self._get_fixedpath(request.path_info)
403 403 if _path.startswith('info/refs'):
404 404 app = self.inforefs
405 405 else:
406 406 app = self.backend
407 407
408 408 try:
409 409 resp = app(request, environ)
410 410 except exc.HTTPException as error:
411 411 log.exception('HTTP Error')
412 412 resp = error
413 413 except Exception:
414 414 log.exception('Unknown error')
415 415 resp = exc.HTTPInternalServerError()
416 416
417 417 return resp(environ, start_response)
@@ -1,1526 +1,1526 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import collections
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import traceback
24 24 import urllib.request
25 25 import urllib.parse
26 26 import urllib.error
27 27 from functools import wraps
28 28
29 29 import more_itertools
30 30 import pygit2
31 31 from pygit2 import Repository as LibGit2Repo
32 32 from pygit2 import index as LibGit2Index
33 33 from dulwich import index, objects
34 34 from dulwich.client import HttpGitClient, LocalGitClient, FetchPackResult
35 35 from dulwich.errors import (
36 36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 37 MissingCommitError, ObjectMissing, HangupException,
38 38 UnexpectedCommandError)
39 39 from dulwich.repo import Repo as DulwichRepo
40 40
41 41 import rhodecode
42 42 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes, convert_to_str, splitnewlines
43 from vcsserver.lib.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes, convert_to_str, splitnewlines
44 44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, store_archive_in_cache, BytesEnvelope, BinaryEnvelope
45 45 from vcsserver.hgcompat import (
46 46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 47 from vcsserver.git_lfs.lib import LFSOidStore
48 48 from vcsserver.vcs_base import RemoteBase
49 49
50 50 DIR_STAT = stat.S_IFDIR
51 51 FILE_MODE = stat.S_IFMT
52 52 GIT_LINK = objects.S_IFGITLINK
53 53 PEELED_REF_MARKER = b'^{}'
54 54 HEAD_MARKER = b'HEAD'
55 55
56 56 log = logging.getLogger(__name__)
57 57
58 58
59 59 def reraise_safe_exceptions(func):
60 60 """Converts Dulwich exceptions to something neutral."""
61 61
62 62 @wraps(func)
63 63 def wrapper(*args, **kwargs):
64 64 try:
65 65 return func(*args, **kwargs)
66 66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
67 67 exc = exceptions.LookupException(org_exc=e)
68 68 raise exc(safe_str(e))
69 69 except (HangupException, UnexpectedCommandError) as e:
70 70 exc = exceptions.VcsException(org_exc=e)
71 71 raise exc(safe_str(e))
72 72 except Exception:
73 73 # NOTE(marcink): because of how dulwich handles some exceptions
74 74 # (KeyError on empty repos), we cannot track this and catch all
75 75 # exceptions, it's an exceptions from other handlers
76 76 #if not hasattr(e, '_vcs_kind'):
77 77 #log.exception("Unhandled exception in git remote call")
78 78 #raise_from_original(exceptions.UnhandledException)
79 79 raise
80 80 return wrapper
81 81
82 82
83 83 class Repo(DulwichRepo):
84 84 """
85 85 A wrapper for dulwich Repo class.
86 86
87 87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
88 88 "Too many open files" error. We need to close all opened file descriptors
89 89 once the repo object is destroyed.
90 90 """
91 91 def __del__(self):
92 92 if hasattr(self, 'object_store'):
93 93 self.close()
94 94
95 95
96 96 class Repository(LibGit2Repo):
97 97
98 98 def __enter__(self):
99 99 return self
100 100
101 101 def __exit__(self, exc_type, exc_val, exc_tb):
102 102 self.free()
103 103
104 104
105 105 class GitFactory(RepoFactory):
106 106 repo_type = 'git'
107 107
108 108 def _create_repo(self, wire, create, use_libgit2=False):
109 109 if use_libgit2:
110 110 repo = Repository(safe_bytes(wire['path']))
111 111 else:
112 112 # dulwich mode
113 113 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
114 114 repo = Repo(repo_path)
115 115
116 116 log.debug('repository created: got GIT object: %s', repo)
117 117 return repo
118 118
119 119 def repo(self, wire, create=False, use_libgit2=False):
120 120 """
121 121 Get a repository instance for the given path.
122 122 """
123 123 return self._create_repo(wire, create, use_libgit2)
124 124
125 125 def repo_libgit2(self, wire):
126 126 return self.repo(wire, use_libgit2=True)
127 127
128 128
129 129 def create_signature_from_string(author_str, **kwargs):
130 130 """
131 131 Creates a pygit2.Signature object from a string of the format 'Name <email>'.
132 132
133 133 :param author_str: String of the format 'Name <email>'
134 134 :return: pygit2.Signature object
135 135 """
136 136 match = re.match(r'^(.+) <(.+)>$', author_str)
137 137 if match is None:
138 138 raise ValueError(f"Invalid format: {author_str}")
139 139
140 140 name, email = match.groups()
141 141 return pygit2.Signature(name, email, **kwargs)
142 142
143 143
144 144 def get_obfuscated_url(url_obj):
145 145 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
146 146 url_obj.query = obfuscate_qs(url_obj.query)
147 147 obfuscated_uri = str(url_obj)
148 148 return obfuscated_uri
149 149
150 150
151 151 class GitRemote(RemoteBase):
152 152
153 153 def __init__(self, factory):
154 154 self._factory = factory
155 155 self._bulk_methods = {
156 156 "date": self.date,
157 157 "author": self.author,
158 158 "branch": self.branch,
159 159 "message": self.message,
160 160 "parents": self.parents,
161 161 "_commit": self.revision,
162 162 }
163 163 self._bulk_file_methods = {
164 164 "size": self.get_node_size,
165 165 "data": self.get_node_data,
166 166 "flags": self.get_node_flags,
167 167 "is_binary": self.get_node_is_binary,
168 168 "md5": self.md5_hash
169 169 }
170 170
171 171 def _wire_to_config(self, wire):
172 172 if 'config' in wire:
173 173 return {x[0] + '_' + x[1]: x[2] for x in wire['config']}
174 174 return {}
175 175
176 176 def _remote_conf(self, config):
177 177 params = [
178 178 '-c', 'core.askpass=""',
179 179 ]
180 180 config_attrs = {
181 181 'vcs_ssl_dir': 'http.sslCAinfo={}',
182 182 'vcs_git_lfs_store_location': 'lfs.storage={}'
183 183 }
184 184 for key, param in config_attrs.items():
185 185 if value := config.get(key):
186 186 params.extend(['-c', param.format(value)])
187 187 return params
188 188
189 189 @reraise_safe_exceptions
190 190 def discover_git_version(self):
191 191 stdout, _ = self.run_git_command(
192 192 {}, ['--version'], _bare=True, _safe=True)
193 193 prefix = b'git version'
194 194 if stdout.startswith(prefix):
195 195 stdout = stdout[len(prefix):]
196 196 return safe_str(stdout.strip())
197 197
198 198 @reraise_safe_exceptions
199 199 def is_empty(self, wire):
200 200 repo_init = self._factory.repo_libgit2(wire)
201 201 with repo_init as repo:
202 202 try:
203 203 has_head = repo.head.name
204 204 if has_head:
205 205 return False
206 206
207 207 # NOTE(marcink): check again using more expensive method
208 208 return repo.is_empty
209 209 except Exception:
210 210 pass
211 211
212 212 return True
213 213
214 214 @reraise_safe_exceptions
215 215 def assert_correct_path(self, wire):
216 216 cache_on, context_uid, repo_id = self._cache_on(wire)
217 217 region = self._region(wire)
218 218
219 219 @region.conditional_cache_on_arguments(condition=cache_on)
220 220 def _assert_correct_path(_context_uid, _repo_id, fast_check):
221 221 if fast_check:
222 222 path = safe_str(wire['path'])
223 223 if pygit2.discover_repository(path):
224 224 return True
225 225 return False
226 226 else:
227 227 try:
228 228 repo_init = self._factory.repo_libgit2(wire)
229 229 with repo_init:
230 230 pass
231 231 except pygit2.GitError:
232 232 path = wire.get('path')
233 233 tb = traceback.format_exc()
234 234 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
235 235 return False
236 236 return True
237 237
238 238 return _assert_correct_path(context_uid, repo_id, True)
239 239
240 240 @reraise_safe_exceptions
241 241 def bare(self, wire):
242 242 repo_init = self._factory.repo_libgit2(wire)
243 243 with repo_init as repo:
244 244 return repo.is_bare
245 245
246 246 @reraise_safe_exceptions
247 247 def get_node_data(self, wire, commit_id, path):
248 248 repo_init = self._factory.repo_libgit2(wire)
249 249 with repo_init as repo:
250 250 commit = repo[commit_id]
251 251 blob_obj = commit.tree[path]
252 252
253 253 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
254 254 raise exceptions.LookupException()(
255 255 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
256 256
257 257 return BytesEnvelope(blob_obj.data)
258 258
259 259 @reraise_safe_exceptions
260 260 def get_node_size(self, wire, commit_id, path):
261 261 repo_init = self._factory.repo_libgit2(wire)
262 262 with repo_init as repo:
263 263 commit = repo[commit_id]
264 264 blob_obj = commit.tree[path]
265 265
266 266 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
267 267 raise exceptions.LookupException()(
268 268 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
269 269
270 270 return blob_obj.size
271 271
272 272 @reraise_safe_exceptions
273 273 def get_node_flags(self, wire, commit_id, path):
274 274 repo_init = self._factory.repo_libgit2(wire)
275 275 with repo_init as repo:
276 276 commit = repo[commit_id]
277 277 blob_obj = commit.tree[path]
278 278
279 279 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
280 280 raise exceptions.LookupException()(
281 281 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
282 282
283 283 return blob_obj.filemode
284 284
285 285 @reraise_safe_exceptions
286 286 def get_node_is_binary(self, wire, commit_id, path):
287 287 repo_init = self._factory.repo_libgit2(wire)
288 288 with repo_init as repo:
289 289 commit = repo[commit_id]
290 290 blob_obj = commit.tree[path]
291 291
292 292 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
293 293 raise exceptions.LookupException()(
294 294 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
295 295
296 296 return blob_obj.is_binary
297 297
298 298 @reraise_safe_exceptions
299 299 def blob_as_pretty_string(self, wire, sha):
300 300 repo_init = self._factory.repo_libgit2(wire)
301 301 with repo_init as repo:
302 302 blob_obj = repo[sha]
303 303 return BytesEnvelope(blob_obj.data)
304 304
305 305 @reraise_safe_exceptions
306 306 def blob_raw_length(self, wire, sha):
307 307 cache_on, context_uid, repo_id = self._cache_on(wire)
308 308 region = self._region(wire)
309 309
310 310 @region.conditional_cache_on_arguments(condition=cache_on)
311 311 def _blob_raw_length(_repo_id, _sha):
312 312
313 313 repo_init = self._factory.repo_libgit2(wire)
314 314 with repo_init as repo:
315 315 blob = repo[sha]
316 316 return blob.size
317 317
318 318 return _blob_raw_length(repo_id, sha)
319 319
320 320 def _parse_lfs_pointer(self, raw_content):
321 321 spec_string = b'version https://git-lfs.github.com/spec'
322 322 if raw_content and raw_content.startswith(spec_string):
323 323
324 324 pattern = re.compile(rb"""
325 325 (?:\n)?
326 326 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
327 327 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
328 328 ^size[ ](?P<oid_size>[0-9]+)\n
329 329 (?:\n)?
330 330 """, re.VERBOSE | re.MULTILINE)
331 331 match = pattern.match(raw_content)
332 332 if match:
333 333 return match.groupdict()
334 334
335 335 return {}
336 336
337 337 @reraise_safe_exceptions
338 338 def is_large_file(self, wire, commit_id):
339 339 cache_on, context_uid, repo_id = self._cache_on(wire)
340 340 region = self._region(wire)
341 341
342 342 @region.conditional_cache_on_arguments(condition=cache_on)
343 343 def _is_large_file(_repo_id, _sha):
344 344 repo_init = self._factory.repo_libgit2(wire)
345 345 with repo_init as repo:
346 346 blob = repo[commit_id]
347 347 if blob.is_binary:
348 348 return {}
349 349
350 350 return self._parse_lfs_pointer(blob.data)
351 351
352 352 return _is_large_file(repo_id, commit_id)
353 353
354 354 @reraise_safe_exceptions
355 355 def is_binary(self, wire, tree_id):
356 356 cache_on, context_uid, repo_id = self._cache_on(wire)
357 357 region = self._region(wire)
358 358
359 359 @region.conditional_cache_on_arguments(condition=cache_on)
360 360 def _is_binary(_repo_id, _tree_id):
361 361 repo_init = self._factory.repo_libgit2(wire)
362 362 with repo_init as repo:
363 363 blob_obj = repo[tree_id]
364 364 return blob_obj.is_binary
365 365
366 366 return _is_binary(repo_id, tree_id)
367 367
368 368 @reraise_safe_exceptions
369 369 def md5_hash(self, wire, commit_id, path):
370 370 cache_on, context_uid, repo_id = self._cache_on(wire)
371 371 region = self._region(wire)
372 372
373 373 @region.conditional_cache_on_arguments(condition=cache_on)
374 374 def _md5_hash(_repo_id, _commit_id, _path):
375 375 repo_init = self._factory.repo_libgit2(wire)
376 376 with repo_init as repo:
377 377 commit = repo[_commit_id]
378 378 blob_obj = commit.tree[_path]
379 379
380 380 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
381 381 raise exceptions.LookupException()(
382 382 f'Tree for commit_id:{_commit_id} is not a blob: {blob_obj.type_str}')
383 383
384 384 return ''
385 385
386 386 return _md5_hash(repo_id, commit_id, path)
387 387
388 388 @reraise_safe_exceptions
389 389 def in_largefiles_store(self, wire, oid):
390 390 conf = self._wire_to_config(wire)
391 391 repo_init = self._factory.repo_libgit2(wire)
392 392 with repo_init as repo:
393 393 repo_name = repo.path
394 394
395 395 store_location = conf.get('vcs_git_lfs_store_location')
396 396 if store_location:
397 397
398 398 store = LFSOidStore(
399 399 oid=oid, repo=repo_name, store_location=store_location)
400 400 return store.has_oid()
401 401
402 402 return False
403 403
404 404 @reraise_safe_exceptions
405 405 def store_path(self, wire, oid):
406 406 conf = self._wire_to_config(wire)
407 407 repo_init = self._factory.repo_libgit2(wire)
408 408 with repo_init as repo:
409 409 repo_name = repo.path
410 410
411 411 store_location = conf.get('vcs_git_lfs_store_location')
412 412 if store_location:
413 413 store = LFSOidStore(
414 414 oid=oid, repo=repo_name, store_location=store_location)
415 415 return store.oid_path
416 416 raise ValueError(f'Unable to fetch oid with path {oid}')
417 417
418 418 @reraise_safe_exceptions
419 419 def bulk_request(self, wire, rev, pre_load):
420 420 cache_on, context_uid, repo_id = self._cache_on(wire)
421 421 region = self._region(wire)
422 422
423 423 @region.conditional_cache_on_arguments(condition=cache_on)
424 424 def _bulk_request(_repo_id, _rev, _pre_load):
425 425 result = {}
426 426 for attr in pre_load:
427 427 try:
428 428 method = self._bulk_methods[attr]
429 429 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
430 430 args = [wire, rev]
431 431 result[attr] = method(*args)
432 432 except KeyError as e:
433 433 raise exceptions.VcsException(e)(f"Unknown bulk attribute: {attr}")
434 434 return result
435 435
436 436 return _bulk_request(repo_id, rev, sorted(pre_load))
437 437
438 438 @reraise_safe_exceptions
439 439 def bulk_file_request(self, wire, commit_id, path, pre_load):
440 440 cache_on, context_uid, repo_id = self._cache_on(wire)
441 441 region = self._region(wire)
442 442
443 443 @region.conditional_cache_on_arguments(condition=cache_on)
444 444 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
445 445 result = {}
446 446 for attr in pre_load:
447 447 try:
448 448 method = self._bulk_file_methods[attr]
449 449 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
450 450 result[attr] = method(wire, _commit_id, _path)
451 451 except KeyError as e:
452 452 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
453 453 return result
454 454
455 455 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
456 456
457 457 def _build_opener(self, url: str):
458 458 handlers = []
459 459 url_obj = url_parser(safe_bytes(url))
460 460 authinfo = url_obj.authinfo()[1]
461 461
462 462 if authinfo:
463 463 # create a password manager
464 464 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
465 465 passmgr.add_password(*convert_to_str(authinfo))
466 466
467 467 handlers.extend((httpbasicauthhandler(passmgr),
468 468 httpdigestauthhandler(passmgr)))
469 469
470 470 return urllib.request.build_opener(*handlers)
471 471
472 472 @reraise_safe_exceptions
473 473 def check_url(self, url, config):
474 474 url_obj = url_parser(safe_bytes(url))
475 475
476 476 test_uri = safe_str(url_obj.authinfo()[0])
477 477 obfuscated_uri = get_obfuscated_url(url_obj)
478 478
479 479 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
480 480
481 481 if not test_uri.endswith('info/refs'):
482 482 test_uri = test_uri.rstrip('/') + '/info/refs'
483 483
484 484 o = self._build_opener(url=url)
485 485 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
486 486
487 487 q = {"service": 'git-upload-pack'}
488 488 qs = f'?{urllib.parse.urlencode(q)}'
489 489 cu = f"{test_uri}{qs}"
490 490
491 491 try:
492 492 req = urllib.request.Request(cu, None, {})
493 493 log.debug("Trying to open URL %s", obfuscated_uri)
494 494 resp = o.open(req)
495 495 if resp.code != 200:
496 496 raise exceptions.URLError()('Return Code is not 200')
497 497 except Exception as e:
498 498 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
499 499 # means it cannot be cloned
500 500 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
501 501
502 502 # now detect if it's proper git repo
503 503 gitdata: bytes = resp.read()
504 504
505 505 if b'service=git-upload-pack' in gitdata:
506 506 pass
507 507 elif re.findall(br'[0-9a-fA-F]{40}\s+refs', gitdata):
508 508 # old style git can return some other format!
509 509 pass
510 510 else:
511 511 e = None
512 512 raise exceptions.URLError(e)(
513 513 f"url [{obfuscated_uri}] does not look like an hg repo org_exc: {e}")
514 514
515 515 return True
516 516
517 517 @reraise_safe_exceptions
518 518 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
519 519 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
520 520 remote_refs = self.pull(wire, url, apply_refs=False)
521 521 repo = self._factory.repo(wire)
522 522 if isinstance(valid_refs, list):
523 523 valid_refs = tuple(valid_refs)
524 524
525 525 for k in remote_refs:
526 526 # only parse heads/tags and skip so called deferred tags
527 527 if k.startswith(valid_refs) and not k.endswith(deferred):
528 528 repo[k] = remote_refs[k]
529 529
530 530 if update_after_clone:
531 531 # we want to checkout HEAD
532 532 repo["HEAD"] = remote_refs["HEAD"]
533 533 index.build_index_from_tree(repo.path, repo.index_path(),
534 534 repo.object_store, repo["HEAD"].tree)
535 535
536 536 @reraise_safe_exceptions
537 537 def branch(self, wire, commit_id):
538 538 cache_on, context_uid, repo_id = self._cache_on(wire)
539 539 region = self._region(wire)
540 540
541 541 @region.conditional_cache_on_arguments(condition=cache_on)
542 542 def _branch(_context_uid, _repo_id, _commit_id):
543 543 regex = re.compile('^refs/heads')
544 544
545 545 def filter_with(ref):
546 546 return regex.match(ref[0]) and ref[1] == _commit_id
547 547
548 548 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
549 549 return [x[0].split('refs/heads/')[-1] for x in branches]
550 550
551 551 return _branch(context_uid, repo_id, commit_id)
552 552
553 553 @reraise_safe_exceptions
554 554 def delete_branch(self, wire, branch_name):
555 555 repo_init = self._factory.repo_libgit2(wire)
556 556 with repo_init as repo:
557 557 if branch := repo.lookup_branch(branch_name):
558 558 branch.delete()
559 559
560 560 @reraise_safe_exceptions
561 561 def commit_branches(self, wire, commit_id):
562 562 cache_on, context_uid, repo_id = self._cache_on(wire)
563 563 region = self._region(wire)
564 564
565 565 @region.conditional_cache_on_arguments(condition=cache_on)
566 566 def _commit_branches(_context_uid, _repo_id, _commit_id):
567 567 repo_init = self._factory.repo_libgit2(wire)
568 568 with repo_init as repo:
569 569 branches = [x for x in repo.branches.with_commit(_commit_id)]
570 570 return branches
571 571
572 572 return _commit_branches(context_uid, repo_id, commit_id)
573 573
574 574 @reraise_safe_exceptions
575 575 def add_object(self, wire, content):
576 576 repo_init = self._factory.repo_libgit2(wire)
577 577 with repo_init as repo:
578 578 blob = objects.Blob()
579 579 blob.set_raw_string(content)
580 580 repo.object_store.add_object(blob)
581 581 return blob.id
582 582
583 583 @reraise_safe_exceptions
584 584 def create_commit(self, wire, author, committer, message, branch, new_tree_id,
585 585 date_args: list[int, int] = None,
586 586 parents: list | None = None):
587 587
588 588 repo_init = self._factory.repo_libgit2(wire)
589 589 with repo_init as repo:
590 590
591 591 if date_args:
592 592 current_time, offset = date_args
593 593
594 594 kw = {
595 595 'time': current_time,
596 596 'offset': offset
597 597 }
598 598 author = create_signature_from_string(author, **kw)
599 599 committer = create_signature_from_string(committer, **kw)
600 600
601 601 tree = new_tree_id
602 602 if isinstance(tree, (bytes, str)):
603 603 # validate this tree is in the repo...
604 604 tree = repo[safe_str(tree)].id
605 605
606 606 if parents:
607 607 # run via sha's and validate them in repo
608 608 parents = [repo[c].id for c in parents]
609 609 else:
610 610 parents = []
611 611 # ensure we COMMIT on top of given branch head
612 612 # check if this repo has ANY branches, otherwise it's a new branch case we need to make
613 613 if branch in repo.branches.local:
614 614 parents += [repo.branches[branch].target]
615 615 elif [x for x in repo.branches.local]:
616 616 parents += [repo.head.target]
617 617 #else:
618 618 # in case we want to commit on new branch we create it on top of HEAD
619 619 #repo.branches.local.create(branch, repo.revparse_single('HEAD'))
620 620
621 621 # # Create a new commit
622 622 commit_oid = repo.create_commit(
623 623 f'refs/heads/{branch}', # the name of the reference to update
624 624 author, # the author of the commit
625 625 committer, # the committer of the commit
626 626 message, # the commit message
627 627 tree, # the tree produced by the index
628 628 parents # list of parents for the new commit, usually just one,
629 629 )
630 630
631 631 new_commit_id = safe_str(commit_oid)
632 632
633 633 return new_commit_id
634 634
635 635 @reraise_safe_exceptions
636 636 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
637 637
638 638 def mode2pygit(mode):
639 639 """
640 640 git only supports two filemode 644 and 755
641 641
642 642 0o100755 -> 33261
643 643 0o100644 -> 33188
644 644 """
645 645 return {
646 646 0o100644: pygit2.GIT_FILEMODE_BLOB,
647 647 0o100755: pygit2.GIT_FILEMODE_BLOB_EXECUTABLE,
648 648 0o120000: pygit2.GIT_FILEMODE_LINK
649 649 }.get(mode) or pygit2.GIT_FILEMODE_BLOB
650 650
651 651 repo_init = self._factory.repo_libgit2(wire)
652 652 with repo_init as repo:
653 653 repo_index = repo.index
654 654
655 655 commit_parents = None
656 656 if commit_tree and commit_data['parents']:
657 657 commit_parents = commit_data['parents']
658 658 parent_commit = repo[commit_parents[0]]
659 659 repo_index.read_tree(parent_commit.tree)
660 660
661 661 for pathspec in updated:
662 662 blob_id = repo.create_blob(pathspec['content'])
663 663 ie = pygit2.IndexEntry(pathspec['path'], blob_id, mode2pygit(pathspec['mode']))
664 664 repo_index.add(ie)
665 665
666 666 for pathspec in removed:
667 667 repo_index.remove(pathspec)
668 668
669 669 # Write changes to the index
670 670 repo_index.write()
671 671
672 672 # Create a tree from the updated index
673 673 written_commit_tree = repo_index.write_tree()
674 674
675 675 new_tree_id = written_commit_tree
676 676
677 677 author = commit_data['author']
678 678 committer = commit_data['committer']
679 679 message = commit_data['message']
680 680
681 681 date_args = [int(commit_data['commit_time']), int(commit_data['commit_timezone'])]
682 682
683 683 new_commit_id = self.create_commit(wire, author, committer, message, branch,
684 684 new_tree_id, date_args=date_args, parents=commit_parents)
685 685
686 686 # libgit2, ensure the branch is there and exists
687 687 self.create_branch(wire, branch, new_commit_id)
688 688
689 689 # libgit2, set new ref to this created commit
690 690 self.set_refs(wire, f'refs/heads/{branch}', new_commit_id)
691 691
692 692 return new_commit_id
693 693
694 694 @reraise_safe_exceptions
695 695 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
696 696 if url != 'default' and '://' not in url:
697 697 client = LocalGitClient(url)
698 698 else:
699 699 url_obj = url_parser(safe_bytes(url))
700 700 o = self._build_opener(url)
701 701 url = url_obj.authinfo()[0]
702 702 client = HttpGitClient(base_url=url, opener=o)
703 703 repo = self._factory.repo(wire)
704 704
705 705 determine_wants = repo.object_store.determine_wants_all
706 706
707 707 if refs:
708 708 refs: list[bytes] = [ascii_bytes(x) for x in refs]
709 709
710 710 def determine_wants_requested(_remote_refs):
711 711 determined = []
712 712 for ref_name, ref_hash in _remote_refs.items():
713 713 bytes_ref_name = safe_bytes(ref_name)
714 714
715 715 if bytes_ref_name in refs:
716 716 bytes_ref_hash = safe_bytes(ref_hash)
717 717 determined.append(bytes_ref_hash)
718 718 return determined
719 719
720 720 # swap with our custom requested wants
721 721 determine_wants = determine_wants_requested
722 722
723 723 try:
724 724 remote_refs = client.fetch(
725 725 path=url, target=repo, determine_wants=determine_wants)
726 726
727 727 except NotGitRepository as e:
728 728 log.warning(
729 729 'Trying to fetch from "%s" failed, not a Git repository.', url)
730 730 # Exception can contain unicode which we convert
731 731 raise exceptions.AbortException(e)(repr(e))
732 732
733 733 # mikhail: client.fetch() returns all the remote refs, but fetches only
734 734 # refs filtered by `determine_wants` function. We need to filter result
735 735 # as well
736 736 if refs:
737 737 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
738 738
739 739 if apply_refs:
740 740 # TODO: johbo: Needs proper test coverage with a git repository
741 741 # that contains a tag object, so that we would end up with
742 742 # a peeled ref at this point.
743 743 for k in remote_refs:
744 744 if k.endswith(PEELED_REF_MARKER):
745 745 log.debug("Skipping peeled reference %s", k)
746 746 continue
747 747 repo[k] = remote_refs[k]
748 748
749 749 if refs and not update_after:
750 750 # update to ref
751 751 # mikhail: explicitly set the head to the last ref.
752 752 update_to_ref = refs[-1]
753 753 if isinstance(update_after, str):
754 754 update_to_ref = update_after
755 755
756 756 repo[HEAD_MARKER] = remote_refs[update_to_ref]
757 757
758 758 if update_after:
759 759 # we want to check out HEAD
760 760 repo[HEAD_MARKER] = remote_refs[HEAD_MARKER]
761 761 index.build_index_from_tree(repo.path, repo.index_path(),
762 762 repo.object_store, repo[HEAD_MARKER].tree)
763 763
764 764 if isinstance(remote_refs, FetchPackResult):
765 765 return remote_refs.refs
766 766 return remote_refs
767 767
768 768 @reraise_safe_exceptions
769 769 def sync_fetch(self, wire, url, refs=None, all_refs=False, **kwargs):
770 770 self._factory.repo(wire)
771 771 if refs and not isinstance(refs, (list, tuple)):
772 772 refs = [refs]
773 773
774 774 config = self._wire_to_config(wire)
775 775 # get all remote refs we'll use to fetch later
776 776 cmd = ['ls-remote']
777 777 if not all_refs:
778 778 cmd += ['--heads', '--tags']
779 779 cmd += [url]
780 780 output, __ = self.run_git_command(
781 781 wire, cmd, fail_on_stderr=False,
782 782 _copts=self._remote_conf(config),
783 783 extra_env={'GIT_TERMINAL_PROMPT': '0'})
784 784
785 785 remote_refs = collections.OrderedDict()
786 786 fetch_refs = []
787 787
788 788 for ref_line in output.splitlines():
789 789 sha, ref = ref_line.split(b'\t')
790 790 sha = sha.strip()
791 791 if ref in remote_refs:
792 792 # duplicate, skip
793 793 continue
794 794 if ref.endswith(PEELED_REF_MARKER):
795 795 log.debug("Skipping peeled reference %s", ref)
796 796 continue
797 797 # don't sync HEAD
798 798 if ref in [HEAD_MARKER]:
799 799 continue
800 800
801 801 remote_refs[ref] = sha
802 802
803 803 if refs and sha in refs:
804 804 # we filter fetch using our specified refs
805 805 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
806 806 elif not refs:
807 807 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
808 808 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
809 809
810 810 if fetch_refs:
811 811 for chunk in more_itertools.chunked(fetch_refs, 128):
812 812 fetch_refs_chunks = list(chunk)
813 813 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
814 814 self.run_git_command(
815 815 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
816 816 fail_on_stderr=False,
817 817 _copts=self._remote_conf(config),
818 818 extra_env={'GIT_TERMINAL_PROMPT': '0'})
819 819 if kwargs.get('sync_large_objects'):
820 820 self.run_git_command(
821 821 wire, ['lfs', 'fetch', url, '--all'],
822 822 fail_on_stderr=False,
823 823 _copts=self._remote_conf(config),
824 824 )
825 825
826 826 return remote_refs
827 827
828 828 @reraise_safe_exceptions
829 829 def sync_push(self, wire, url, refs=None, **kwargs):
830 830 if not self.check_url(url, wire):
831 831 return
832 832 config = self._wire_to_config(wire)
833 833 self._factory.repo(wire)
834 834 self.run_git_command(
835 835 wire, ['push', url, '--mirror'], fail_on_stderr=False,
836 836 _copts=self._remote_conf(config),
837 837 extra_env={'GIT_TERMINAL_PROMPT': '0'})
838 838 if kwargs.get('sync_large_objects'):
839 839 self.run_git_command(
840 840 wire, ['lfs', 'push', url, '--all'],
841 841 fail_on_stderr=False,
842 842 _copts=self._remote_conf(config),
843 843 )
844 844
845 845 @reraise_safe_exceptions
846 846 def get_remote_refs(self, wire, url):
847 847 repo = Repo(url)
848 848 return repo.get_refs()
849 849
850 850 @reraise_safe_exceptions
851 851 def get_description(self, wire):
852 852 repo = self._factory.repo(wire)
853 853 return repo.get_description()
854 854
855 855 @reraise_safe_exceptions
856 856 def get_missing_revs(self, wire, rev1, rev2, other_repo_path):
857 857 origin_repo_path = wire['path']
858 858 repo = self._factory.repo(wire)
859 859 # fetch from other_repo_path to our origin repo
860 860 LocalGitClient(thin_packs=False).fetch(other_repo_path, repo)
861 861
862 862 wire_remote = wire.copy()
863 863 wire_remote['path'] = other_repo_path
864 864 repo_remote = self._factory.repo(wire_remote)
865 865
866 866 # fetch from origin_repo_path to our remote repo
867 867 LocalGitClient(thin_packs=False).fetch(origin_repo_path, repo_remote)
868 868
869 869 revs = [
870 870 x.commit.id
871 871 for x in repo_remote.get_walker(include=[safe_bytes(rev2)], exclude=[safe_bytes(rev1)])]
872 872 return revs
873 873
874 874 @reraise_safe_exceptions
875 875 def get_object(self, wire, sha, maybe_unreachable=False):
876 876 cache_on, context_uid, repo_id = self._cache_on(wire)
877 877 region = self._region(wire)
878 878
879 879 @region.conditional_cache_on_arguments(condition=cache_on)
880 880 def _get_object(_context_uid, _repo_id, _sha):
881 881 repo_init = self._factory.repo_libgit2(wire)
882 882 with repo_init as repo:
883 883
884 884 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
885 885 try:
886 886 commit = repo.revparse_single(sha)
887 887 except KeyError:
888 888 # NOTE(marcink): KeyError doesn't give us any meaningful information
889 889 # here, we instead give something more explicit
890 890 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
891 891 raise exceptions.LookupException(e)(missing_commit_err)
892 892 except ValueError as e:
893 893 raise exceptions.LookupException(e)(missing_commit_err)
894 894
895 895 is_tag = False
896 896 if isinstance(commit, pygit2.Tag):
897 897 commit = repo.get(commit.target)
898 898 is_tag = True
899 899
900 900 check_dangling = True
901 901 if is_tag:
902 902 check_dangling = False
903 903
904 904 if check_dangling and maybe_unreachable:
905 905 check_dangling = False
906 906
907 907 # we used a reference and it parsed means we're not having a dangling commit
908 908 if sha != commit.hex:
909 909 check_dangling = False
910 910
911 911 if check_dangling:
912 912 # check for dangling commit
913 913 for branch in repo.branches.with_commit(commit.hex):
914 914 if branch:
915 915 break
916 916 else:
917 917 # NOTE(marcink): Empty error doesn't give us any meaningful information
918 918 # here, we instead give something more explicit
919 919 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
920 920 raise exceptions.LookupException(e)(missing_commit_err)
921 921
922 922 commit_id = commit.hex
923 923 type_str = commit.type_str
924 924
925 925 return {
926 926 'id': commit_id,
927 927 'type': type_str,
928 928 'commit_id': commit_id,
929 929 'idx': 0
930 930 }
931 931
932 932 return _get_object(context_uid, repo_id, sha)
933 933
934 934 @reraise_safe_exceptions
935 935 def get_refs(self, wire):
936 936 cache_on, context_uid, repo_id = self._cache_on(wire)
937 937 region = self._region(wire)
938 938
939 939 @region.conditional_cache_on_arguments(condition=cache_on)
940 940 def _get_refs(_context_uid, _repo_id):
941 941
942 942 repo_init = self._factory.repo_libgit2(wire)
943 943 with repo_init as repo:
944 944 regex = re.compile('^refs/(heads|tags)/')
945 945 return {x.name: x.target.hex for x in
946 946 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
947 947
948 948 return _get_refs(context_uid, repo_id)
949 949
950 950 @reraise_safe_exceptions
951 951 def get_branch_pointers(self, wire):
952 952 cache_on, context_uid, repo_id = self._cache_on(wire)
953 953 region = self._region(wire)
954 954
955 955 @region.conditional_cache_on_arguments(condition=cache_on)
956 956 def _get_branch_pointers(_context_uid, _repo_id):
957 957
958 958 repo_init = self._factory.repo_libgit2(wire)
959 959 regex = re.compile('^refs/heads')
960 960 with repo_init as repo:
961 961 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
962 962 return {x.target.hex: x.shorthand for x in branches}
963 963
964 964 return _get_branch_pointers(context_uid, repo_id)
965 965
966 966 @reraise_safe_exceptions
967 967 def head(self, wire, show_exc=True):
968 968 cache_on, context_uid, repo_id = self._cache_on(wire)
969 969 region = self._region(wire)
970 970
971 971 @region.conditional_cache_on_arguments(condition=cache_on)
972 972 def _head(_context_uid, _repo_id, _show_exc):
973 973 repo_init = self._factory.repo_libgit2(wire)
974 974 with repo_init as repo:
975 975 try:
976 976 return repo.head.peel().hex
977 977 except Exception:
978 978 if show_exc:
979 979 raise
980 980 return _head(context_uid, repo_id, show_exc)
981 981
982 982 @reraise_safe_exceptions
983 983 def init(self, wire):
984 984 repo_path = safe_str(wire['path'])
985 985 os.makedirs(repo_path, mode=0o755)
986 986 pygit2.init_repository(repo_path, bare=False)
987 987
988 988 @reraise_safe_exceptions
989 989 def init_bare(self, wire):
990 990 repo_path = safe_str(wire['path'])
991 991 os.makedirs(repo_path, mode=0o755)
992 992 pygit2.init_repository(repo_path, bare=True)
993 993
994 994 @reraise_safe_exceptions
995 995 def revision(self, wire, rev):
996 996
997 997 cache_on, context_uid, repo_id = self._cache_on(wire)
998 998 region = self._region(wire)
999 999
1000 1000 @region.conditional_cache_on_arguments(condition=cache_on)
1001 1001 def _revision(_context_uid, _repo_id, _rev):
1002 1002 repo_init = self._factory.repo_libgit2(wire)
1003 1003 with repo_init as repo:
1004 1004 commit = repo[rev]
1005 1005 obj_data = {
1006 1006 'id': commit.id.hex,
1007 1007 }
1008 1008 # tree objects itself don't have tree_id attribute
1009 1009 if hasattr(commit, 'tree_id'):
1010 1010 obj_data['tree'] = commit.tree_id.hex
1011 1011
1012 1012 return obj_data
1013 1013 return _revision(context_uid, repo_id, rev)
1014 1014
1015 1015 @reraise_safe_exceptions
1016 1016 def date(self, wire, commit_id):
1017 1017 cache_on, context_uid, repo_id = self._cache_on(wire)
1018 1018 region = self._region(wire)
1019 1019
1020 1020 @region.conditional_cache_on_arguments(condition=cache_on)
1021 1021 def _date(_repo_id, _commit_id):
1022 1022 repo_init = self._factory.repo_libgit2(wire)
1023 1023 with repo_init as repo:
1024 1024 commit = repo[commit_id]
1025 1025
1026 1026 if hasattr(commit, 'commit_time'):
1027 1027 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
1028 1028 else:
1029 1029 commit = commit.get_object()
1030 1030 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
1031 1031
1032 1032 # TODO(marcink): check dulwich difference of offset vs timezone
1033 1033 return [commit_time, commit_time_offset]
1034 1034 return _date(repo_id, commit_id)
1035 1035
1036 1036 @reraise_safe_exceptions
1037 1037 def author(self, wire, commit_id):
1038 1038 cache_on, context_uid, repo_id = self._cache_on(wire)
1039 1039 region = self._region(wire)
1040 1040
1041 1041 @region.conditional_cache_on_arguments(condition=cache_on)
1042 1042 def _author(_repo_id, _commit_id):
1043 1043 repo_init = self._factory.repo_libgit2(wire)
1044 1044 with repo_init as repo:
1045 1045 commit = repo[commit_id]
1046 1046
1047 1047 if hasattr(commit, 'author'):
1048 1048 author = commit.author
1049 1049 else:
1050 1050 author = commit.get_object().author
1051 1051
1052 1052 if author.email:
1053 1053 return f"{author.name} <{author.email}>"
1054 1054
1055 1055 try:
1056 1056 return f"{author.name}"
1057 1057 except Exception:
1058 1058 return f"{safe_str(author.raw_name)}"
1059 1059
1060 1060 return _author(repo_id, commit_id)
1061 1061
1062 1062 @reraise_safe_exceptions
1063 1063 def message(self, wire, commit_id):
1064 1064 cache_on, context_uid, repo_id = self._cache_on(wire)
1065 1065 region = self._region(wire)
1066 1066
1067 1067 @region.conditional_cache_on_arguments(condition=cache_on)
1068 1068 def _message(_repo_id, _commit_id):
1069 1069 repo_init = self._factory.repo_libgit2(wire)
1070 1070 with repo_init as repo:
1071 1071 commit = repo[commit_id]
1072 1072 return commit.message
1073 1073 return _message(repo_id, commit_id)
1074 1074
1075 1075 @reraise_safe_exceptions
1076 1076 def parents(self, wire, commit_id):
1077 1077 cache_on, context_uid, repo_id = self._cache_on(wire)
1078 1078 region = self._region(wire)
1079 1079
1080 1080 @region.conditional_cache_on_arguments(condition=cache_on)
1081 1081 def _parents(_repo_id, _commit_id):
1082 1082 repo_init = self._factory.repo_libgit2(wire)
1083 1083 with repo_init as repo:
1084 1084 commit = repo[commit_id]
1085 1085 if hasattr(commit, 'parent_ids'):
1086 1086 parent_ids = commit.parent_ids
1087 1087 else:
1088 1088 parent_ids = commit.get_object().parent_ids
1089 1089
1090 1090 return [x.hex for x in parent_ids]
1091 1091 return _parents(repo_id, commit_id)
1092 1092
1093 1093 @reraise_safe_exceptions
1094 1094 def children(self, wire, commit_id):
1095 1095 cache_on, context_uid, repo_id = self._cache_on(wire)
1096 1096 region = self._region(wire)
1097 1097
1098 1098 head = self.head(wire)
1099 1099
1100 1100 @region.conditional_cache_on_arguments(condition=cache_on)
1101 1101 def _children(_repo_id, _commit_id):
1102 1102
1103 1103 output, __ = self.run_git_command(
1104 1104 wire, ['rev-list', '--all', '--children', f'{commit_id}^..{head}'])
1105 1105
1106 1106 child_ids = []
1107 1107 pat = re.compile(fr'^{commit_id}')
1108 1108 for line in output.splitlines():
1109 1109 line = safe_str(line)
1110 1110 if pat.match(line):
1111 1111 found_ids = line.split(' ')[1:]
1112 1112 child_ids.extend(found_ids)
1113 1113 break
1114 1114
1115 1115 return child_ids
1116 1116 return _children(repo_id, commit_id)
1117 1117
1118 1118 @reraise_safe_exceptions
1119 1119 def set_refs(self, wire, key, value):
1120 1120 repo_init = self._factory.repo_libgit2(wire)
1121 1121 with repo_init as repo:
1122 1122 repo.references.create(key, value, force=True)
1123 1123
1124 1124 @reraise_safe_exceptions
1125 1125 def update_refs(self, wire, key, value):
1126 1126 repo_init = self._factory.repo_libgit2(wire)
1127 1127 with repo_init as repo:
1128 1128 if key not in repo.references:
1129 1129 raise ValueError(f'Reference {key} not found in the repository')
1130 1130 repo.references.create(key, value, force=True)
1131 1131
1132 1132 @reraise_safe_exceptions
1133 1133 def create_branch(self, wire, branch_name, commit_id, force=False):
1134 1134 repo_init = self._factory.repo_libgit2(wire)
1135 1135 with repo_init as repo:
1136 1136 if commit_id:
1137 1137 commit = repo[commit_id]
1138 1138 else:
1139 1139 # if commit is not given just use the HEAD
1140 1140 commit = repo.head()
1141 1141
1142 1142 if force:
1143 1143 repo.branches.local.create(branch_name, commit, force=force)
1144 1144 elif not repo.branches.get(branch_name):
1145 1145 # create only if that branch isn't existing
1146 1146 repo.branches.local.create(branch_name, commit, force=force)
1147 1147
1148 1148 @reraise_safe_exceptions
1149 1149 def remove_ref(self, wire, key):
1150 1150 repo_init = self._factory.repo_libgit2(wire)
1151 1151 with repo_init as repo:
1152 1152 repo.references.delete(key)
1153 1153
1154 1154 @reraise_safe_exceptions
1155 1155 def tag_remove(self, wire, tag_name):
1156 1156 repo_init = self._factory.repo_libgit2(wire)
1157 1157 with repo_init as repo:
1158 1158 key = f'refs/tags/{tag_name}'
1159 1159 repo.references.delete(key)
1160 1160
1161 1161 @reraise_safe_exceptions
1162 1162 def tree_changes(self, wire, source_id, target_id):
1163 1163 repo = self._factory.repo(wire)
1164 1164 # source can be empty
1165 1165 source_id = safe_bytes(source_id if source_id else b'')
1166 1166 target_id = safe_bytes(target_id)
1167 1167
1168 1168 source = repo[source_id].tree if source_id else None
1169 1169 target = repo[target_id].tree
1170 1170 result = repo.object_store.tree_changes(source, target)
1171 1171
1172 1172 added = set()
1173 1173 modified = set()
1174 1174 deleted = set()
1175 1175 for (old_path, new_path), (_, _), (_, _) in list(result):
1176 1176 if new_path and old_path:
1177 1177 modified.add(new_path)
1178 1178 elif new_path and not old_path:
1179 1179 added.add(new_path)
1180 1180 elif not new_path and old_path:
1181 1181 deleted.add(old_path)
1182 1182
1183 1183 return list(added), list(modified), list(deleted)
1184 1184
1185 1185 @reraise_safe_exceptions
1186 1186 def tree_and_type_for_path(self, wire, commit_id, path):
1187 1187
1188 1188 cache_on, context_uid, repo_id = self._cache_on(wire)
1189 1189 region = self._region(wire)
1190 1190
1191 1191 @region.conditional_cache_on_arguments(condition=cache_on)
1192 1192 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1193 1193 repo_init = self._factory.repo_libgit2(wire)
1194 1194
1195 1195 with repo_init as repo:
1196 1196 commit = repo[commit_id]
1197 1197 try:
1198 1198 tree = commit.tree[path]
1199 1199 except KeyError:
1200 1200 return None, None, None
1201 1201
1202 1202 return tree.id.hex, tree.type_str, tree.filemode
1203 1203 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1204 1204
1205 1205 @reraise_safe_exceptions
1206 1206 def tree_items(self, wire, tree_id):
1207 1207 cache_on, context_uid, repo_id = self._cache_on(wire)
1208 1208 region = self._region(wire)
1209 1209
1210 1210 @region.conditional_cache_on_arguments(condition=cache_on)
1211 1211 def _tree_items(_repo_id, _tree_id):
1212 1212
1213 1213 repo_init = self._factory.repo_libgit2(wire)
1214 1214 with repo_init as repo:
1215 1215 try:
1216 1216 tree = repo[tree_id]
1217 1217 except KeyError:
1218 1218 raise ObjectMissing(f'No tree with id: {tree_id}')
1219 1219
1220 1220 result = []
1221 1221 for item in tree:
1222 1222 item_sha = item.hex
1223 1223 item_mode = item.filemode
1224 1224 item_type = item.type_str
1225 1225
1226 1226 if item_type == 'commit':
1227 1227 # NOTE(marcink): submodules we translate to 'link' for backward compat
1228 1228 item_type = 'link'
1229 1229
1230 1230 result.append((item.name, item_mode, item_sha, item_type))
1231 1231 return result
1232 1232 return _tree_items(repo_id, tree_id)
1233 1233
1234 1234 @reraise_safe_exceptions
1235 1235 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1236 1236 """
1237 1237 Old version that uses subprocess to call diff
1238 1238 """
1239 1239
1240 1240 flags = [
1241 1241 f'-U{context}', '--patch',
1242 1242 '--binary',
1243 1243 '--find-renames',
1244 1244 '--no-indent-heuristic',
1245 1245 # '--indent-heuristic',
1246 1246 #'--full-index',
1247 1247 #'--abbrev=40'
1248 1248 ]
1249 1249
1250 1250 if opt_ignorews:
1251 1251 flags.append('--ignore-all-space')
1252 1252
1253 1253 if commit_id_1 == self.EMPTY_COMMIT:
1254 1254 cmd = ['show'] + flags + [commit_id_2]
1255 1255 else:
1256 1256 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1257 1257
1258 1258 if file_filter:
1259 1259 cmd.extend(['--', file_filter])
1260 1260
1261 1261 diff, __ = self.run_git_command(wire, cmd)
1262 1262 # If we used 'show' command, strip first few lines (until actual diff
1263 1263 # starts)
1264 1264 if commit_id_1 == self.EMPTY_COMMIT:
1265 1265 lines = diff.splitlines()
1266 1266 x = 0
1267 1267 for line in lines:
1268 1268 if line.startswith(b'diff'):
1269 1269 break
1270 1270 x += 1
1271 1271 # Append new line just like 'diff' command do
1272 1272 diff = '\n'.join(lines[x:]) + '\n'
1273 1273 return diff
1274 1274
1275 1275 @reraise_safe_exceptions
1276 1276 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1277 1277 repo_init = self._factory.repo_libgit2(wire)
1278 1278
1279 1279 with repo_init as repo:
1280 1280 swap = True
1281 1281 flags = 0
1282 1282 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1283 1283
1284 1284 if opt_ignorews:
1285 1285 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1286 1286
1287 1287 if commit_id_1 == self.EMPTY_COMMIT:
1288 1288 comm1 = repo[commit_id_2]
1289 1289 diff_obj = comm1.tree.diff_to_tree(
1290 1290 flags=flags, context_lines=context, swap=swap)
1291 1291
1292 1292 else:
1293 1293 comm1 = repo[commit_id_2]
1294 1294 comm2 = repo[commit_id_1]
1295 1295 diff_obj = comm1.tree.diff_to_tree(
1296 1296 comm2.tree, flags=flags, context_lines=context, swap=swap)
1297 1297 similar_flags = 0
1298 1298 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1299 1299 diff_obj.find_similar(flags=similar_flags)
1300 1300
1301 1301 if file_filter:
1302 1302 for p in diff_obj:
1303 1303 if p.delta.old_file.path == file_filter:
1304 1304 return BytesEnvelope(p.data) or BytesEnvelope(b'')
1305 1305 # fo matching path == no diff
1306 1306 return BytesEnvelope(b'')
1307 1307
1308 1308 return BytesEnvelope(safe_bytes(diff_obj.patch)) or BytesEnvelope(b'')
1309 1309
1310 1310 @reraise_safe_exceptions
1311 1311 def node_history(self, wire, commit_id, path, limit):
1312 1312 cache_on, context_uid, repo_id = self._cache_on(wire)
1313 1313 region = self._region(wire)
1314 1314
1315 1315 @region.conditional_cache_on_arguments(condition=cache_on)
1316 1316 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1317 1317 # optimize for n==1, rev-list is much faster for that use-case
1318 1318 if limit == 1:
1319 1319 cmd = ['rev-list', '-1', commit_id, '--', path]
1320 1320 else:
1321 1321 cmd = ['log']
1322 1322 if limit:
1323 1323 cmd.extend(['-n', str(safe_int(limit, 0))])
1324 1324 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1325 1325
1326 1326 output, __ = self.run_git_command(wire, cmd)
1327 1327 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1328 1328
1329 1329 return [x for x in commit_ids]
1330 1330 return _node_history(context_uid, repo_id, commit_id, path, limit)
1331 1331
1332 1332 @reraise_safe_exceptions
1333 1333 def node_annotate_legacy(self, wire, commit_id, path):
1334 1334 # note: replaced by pygit2 implementation
1335 1335 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1336 1336 # -l ==> outputs long shas (and we need all 40 characters)
1337 1337 # --root ==> doesn't put '^' character for boundaries
1338 1338 # -r commit_id ==> blames for the given commit
1339 1339 output, __ = self.run_git_command(wire, cmd)
1340 1340
1341 1341 result = []
1342 1342 for i, blame_line in enumerate(output.splitlines()[:-1]):
1343 1343 line_no = i + 1
1344 1344 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1345 1345 result.append((line_no, blame_commit_id, line))
1346 1346
1347 1347 return result
1348 1348
1349 1349 @reraise_safe_exceptions
1350 1350 def node_annotate(self, wire, commit_id, path):
1351 1351
1352 1352 result_libgit = []
1353 1353 repo_init = self._factory.repo_libgit2(wire)
1354 1354 with repo_init as repo:
1355 1355 commit = repo[commit_id]
1356 1356 blame_obj = repo.blame(path, newest_commit=commit_id)
1357 1357 file_content = commit.tree[path].data
1358 1358 for i, line in enumerate(splitnewlines(file_content)):
1359 1359 line_no = i + 1
1360 1360 hunk = blame_obj.for_line(line_no)
1361 1361 blame_commit_id = hunk.final_commit_id.hex
1362 1362
1363 1363 result_libgit.append((line_no, blame_commit_id, line))
1364 1364
1365 1365 return BinaryEnvelope(result_libgit)
1366 1366
1367 1367 @reraise_safe_exceptions
1368 1368 def update_server_info(self, wire, force=False):
1369 1369 cmd = ['update-server-info']
1370 1370 if force:
1371 1371 cmd += ['--force']
1372 1372 output, __ = self.run_git_command(wire, cmd)
1373 1373 return output.splitlines()
1374 1374
1375 1375 @reraise_safe_exceptions
1376 1376 def get_all_commit_ids(self, wire):
1377 1377
1378 1378 cache_on, context_uid, repo_id = self._cache_on(wire)
1379 1379 region = self._region(wire)
1380 1380
1381 1381 @region.conditional_cache_on_arguments(condition=cache_on)
1382 1382 def _get_all_commit_ids(_context_uid, _repo_id):
1383 1383
1384 1384 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1385 1385 try:
1386 1386 output, __ = self.run_git_command(wire, cmd)
1387 1387 return output.splitlines()
1388 1388 except Exception:
1389 1389 # Can be raised for empty repositories
1390 1390 return []
1391 1391
1392 1392 @region.conditional_cache_on_arguments(condition=cache_on)
1393 1393 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1394 1394 repo_init = self._factory.repo_libgit2(wire)
1395 1395 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1396 1396 results = []
1397 1397 with repo_init as repo:
1398 1398 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1399 1399 results.append(commit.id.hex)
1400 1400
1401 1401 return _get_all_commit_ids(context_uid, repo_id)
1402 1402
1403 1403 @reraise_safe_exceptions
1404 1404 def run_git_command(self, wire, cmd, **opts):
1405 1405 path = wire.get('path', None)
1406 1406 debug_mode = rhodecode.ConfigGet().get_bool('debug')
1407 1407
1408 1408 if path and os.path.isdir(path):
1409 1409 opts['cwd'] = path
1410 1410
1411 1411 if '_bare' in opts:
1412 1412 _copts = []
1413 1413 del opts['_bare']
1414 1414 else:
1415 1415 _copts = ['-c', 'core.quotepath=false', '-c', 'advice.diverging=false']
1416 1416 safe_call = False
1417 1417 if '_safe' in opts:
1418 1418 # no exc on failure
1419 1419 del opts['_safe']
1420 1420 safe_call = True
1421 1421
1422 1422 if '_copts' in opts:
1423 1423 _copts.extend(opts['_copts'] or [])
1424 1424 del opts['_copts']
1425 1425
1426 1426 gitenv = os.environ.copy()
1427 1427 gitenv.update(opts.pop('extra_env', {}))
1428 1428 # need to clean fix GIT_DIR !
1429 1429 if 'GIT_DIR' in gitenv:
1430 1430 del gitenv['GIT_DIR']
1431 1431 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1432 1432 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1433 1433
1434 1434 cmd = [settings.GIT_EXECUTABLE()] + _copts + cmd
1435 1435 _opts = {'env': gitenv, 'shell': False}
1436 1436
1437 1437 proc = None
1438 1438 try:
1439 1439 _opts.update(opts)
1440 1440 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1441 1441
1442 1442 return b''.join(proc), b''.join(proc.stderr)
1443 1443 except OSError as err:
1444 1444 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
1445 1445 call_opts = {}
1446 1446 if debug_mode:
1447 1447 call_opts = _opts
1448 1448
1449 1449 tb_err = ("Couldn't run git command ({}).\n"
1450 1450 "Original error was:{}\n"
1451 1451 "Call options:{}\n"
1452 1452 .format(cmd, err, call_opts))
1453 1453 log.exception(tb_err)
1454 1454 if safe_call:
1455 1455 return '', err
1456 1456 else:
1457 1457 raise exceptions.VcsException()(tb_err)
1458 1458 finally:
1459 1459 if proc:
1460 1460 proc.close()
1461 1461
1462 1462 @reraise_safe_exceptions
1463 1463 def install_hooks(self, wire, force=False):
1464 1464 from vcsserver.hook_utils import install_git_hooks
1465 1465 bare = self.bare(wire)
1466 1466 path = wire['path']
1467 1467 binary_dir = settings.BINARY_DIR
1468 1468 if binary_dir:
1469 1469 os.path.join(binary_dir, 'python3')
1470 1470 return install_git_hooks(path, bare, force_create=force)
1471 1471
1472 1472 @reraise_safe_exceptions
1473 1473 def get_hooks_info(self, wire):
1474 1474 from vcsserver.hook_utils import (
1475 1475 get_git_pre_hook_version, get_git_post_hook_version)
1476 1476 bare = self.bare(wire)
1477 1477 path = wire['path']
1478 1478 return {
1479 1479 'pre_version': get_git_pre_hook_version(path, bare),
1480 1480 'post_version': get_git_post_hook_version(path, bare),
1481 1481 }
1482 1482
1483 1483 @reraise_safe_exceptions
1484 1484 def set_head_ref(self, wire, head_name):
1485 1485 log.debug('Setting refs/head to `%s`', head_name)
1486 1486 repo_init = self._factory.repo_libgit2(wire)
1487 1487 with repo_init as repo:
1488 1488 repo.set_head(f'refs/heads/{head_name}')
1489 1489
1490 1490 return [head_name] + [f'set HEAD to refs/heads/{head_name}']
1491 1491
1492 1492 @reraise_safe_exceptions
1493 1493 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1494 1494 archive_dir_name, commit_id, cache_config):
1495 1495
1496 1496 def file_walker(_commit_id, path):
1497 1497 repo_init = self._factory.repo_libgit2(wire)
1498 1498
1499 1499 with repo_init as repo:
1500 1500 commit = repo[commit_id]
1501 1501
1502 1502 if path in ['', '/']:
1503 1503 tree = commit.tree
1504 1504 else:
1505 1505 tree = commit.tree[path.rstrip('/')]
1506 1506 tree_id = tree.id.hex
1507 1507 try:
1508 1508 tree = repo[tree_id]
1509 1509 except KeyError:
1510 1510 raise ObjectMissing(f'No tree with id: {tree_id}')
1511 1511
1512 1512 index = LibGit2Index.Index()
1513 1513 index.read_tree(tree)
1514 1514 file_iter = index
1515 1515
1516 1516 for file_node in file_iter:
1517 1517 file_path = file_node.path
1518 1518 mode = file_node.mode
1519 1519 is_link = stat.S_ISLNK(mode)
1520 1520 if mode == pygit2.GIT_FILEMODE_COMMIT:
1521 1521 log.debug('Skipping path %s as a commit node', file_path)
1522 1522 continue
1523 1523 yield ArchiveNode(file_path, mode, is_link, repo[file_node.hex].read_raw)
1524 1524
1525 1525 return store_archive_in_cache(
1526 1526 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
@@ -1,1216 +1,1216 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import binascii
19 19 import io
20 20 import logging
21 21 import stat
22 22 import sys
23 23 import urllib.request
24 24 import urllib.parse
25 25 import hashlib
26 26
27 27 from hgext import largefiles, rebase
28 28
29 29 from mercurial import commands
30 30 from mercurial import unionrepo
31 31 from mercurial import verify
32 32 from mercurial import repair
33 33 from mercurial.error import AmbiguousPrefixLookupError
34 34 from mercurial.utils.urlutil import path as hg_path
35 35
36 36 import vcsserver
37 37 from vcsserver import exceptions
38 38 from vcsserver.base import (
39 39 RepoFactory,
40 40 obfuscate_qs,
41 41 raise_from_original,
42 42 store_archive_in_cache,
43 43 ArchiveNode,
44 44 BytesEnvelope,
45 45 BinaryEnvelope,
46 46 )
47 47 from vcsserver.hgcompat import (
48 48 archival,
49 49 bin,
50 50 clone,
51 51 config as hgconfig,
52 52 diffopts,
53 53 hex,
54 54 get_ctx,
55 55 hg_url as url_parser,
56 56 httpbasicauthhandler,
57 57 httpdigestauthhandler,
58 58 make_peer,
59 59 instance,
60 60 match,
61 61 memctx,
62 62 exchange,
63 63 memfilectx,
64 64 nullrev,
65 65 hg_merge,
66 66 patch,
67 67 peer,
68 68 revrange,
69 69 ui,
70 70 hg_tag,
71 71 Abort,
72 72 LookupError,
73 73 RepoError,
74 74 RepoLookupError,
75 75 InterventionRequired,
76 76 RequirementError,
77 77 alwaysmatcher,
78 78 patternmatcher,
79 79 hgext_strip,
80 80 )
81 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes, convert_to_str
81 from vcsserver.lib.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes, convert_to_str
82 82 from vcsserver.vcs_base import RemoteBase
83 83 from vcsserver.config import hooks as hooks_config
84 84 from vcsserver.lib.exc_tracking import format_exc
85 85
86 86 log = logging.getLogger(__name__)
87 87
88 88
89 89 def make_ui_from_config(repo_config):
90 90
91 91 class LoggingUI(ui.ui):
92 92
93 93 def status(self, *msg, **opts):
94 94 str_msg = map(safe_str, msg)
95 95 log.info(' '.join(str_msg).rstrip('\n'))
96 96 #super(LoggingUI, self).status(*msg, **opts)
97 97
98 98 def warn(self, *msg, **opts):
99 99 str_msg = map(safe_str, msg)
100 100 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
101 101 #super(LoggingUI, self).warn(*msg, **opts)
102 102
103 103 def error(self, *msg, **opts):
104 104 str_msg = map(safe_str, msg)
105 105 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
106 106 #super(LoggingUI, self).error(*msg, **opts)
107 107
108 108 def note(self, *msg, **opts):
109 109 str_msg = map(safe_str, msg)
110 110 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
111 111 #super(LoggingUI, self).note(*msg, **opts)
112 112
113 113 def debug(self, *msg, **opts):
114 114 str_msg = map(safe_str, msg)
115 115 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
116 116 #super(LoggingUI, self).debug(*msg, **opts)
117 117
118 118 baseui = LoggingUI()
119 119
120 120 # clean the baseui object
121 121 baseui._ocfg = hgconfig.config()
122 122 baseui._ucfg = hgconfig.config()
123 123 baseui._tcfg = hgconfig.config()
124 124
125 125 for section, option, value in repo_config:
126 126 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
127 127
128 128 # make our hgweb quiet so it doesn't print output
129 129 baseui.setconfig(b'ui', b'quiet', b'true')
130 130
131 131 baseui.setconfig(b'ui', b'paginate', b'never')
132 132 # for better Error reporting of Mercurial
133 133 baseui.setconfig(b'ui', b'message-output', b'stderr')
134 134
135 135 # force mercurial to only use 1 thread, otherwise it may try to set a
136 136 # signal in a non-main thread, thus generating a ValueError.
137 137 baseui.setconfig(b'worker', b'numcpus', 1)
138 138
139 139 # If there is no config for the largefiles extension, we explicitly disable
140 140 # it here. This overrides settings from repositories hgrc file. Recent
141 141 # mercurial versions enable largefiles in hgrc on clone from largefile
142 142 # repo.
143 143 if not baseui.hasconfig(b'extensions', b'largefiles'):
144 144 log.debug('Explicitly disable largefiles extension for repo.')
145 145 baseui.setconfig(b'extensions', b'largefiles', b'!')
146 146
147 147 return baseui
148 148
149 149
150 150 def reraise_safe_exceptions(func):
151 151 """Decorator for converting mercurial exceptions to something neutral."""
152 152
153 153 def wrapper(*args, **kwargs):
154 154 try:
155 155 return func(*args, **kwargs)
156 156 except (Abort, InterventionRequired) as e:
157 157 raise_from_original(exceptions.AbortException(e), e)
158 158 except RepoLookupError as e:
159 159 raise_from_original(exceptions.LookupException(e), e)
160 160 except RequirementError as e:
161 161 raise_from_original(exceptions.RequirementException(e), e)
162 162 except RepoError as e:
163 163 raise_from_original(exceptions.VcsException(e), e)
164 164 except LookupError as e:
165 165 raise_from_original(exceptions.LookupException(e), e)
166 166 except Exception as e:
167 167 if not hasattr(e, '_vcs_kind'):
168 168 log.exception("Unhandled exception in hg remote call")
169 169 raise_from_original(exceptions.UnhandledException(e), e)
170 170
171 171 raise
172 172 return wrapper
173 173
174 174
175 175 class MercurialFactory(RepoFactory):
176 176 repo_type = 'hg'
177 177
178 178 def _create_config(self, config, hooks=True):
179 179 if not hooks:
180 180
181 181 hooks_to_clean = {
182 182
183 183 hooks_config.HOOK_REPO_SIZE,
184 184 hooks_config.HOOK_PRE_PULL,
185 185 hooks_config.HOOK_PULL,
186 186
187 187 hooks_config.HOOK_PRE_PUSH,
188 188 # TODO: what about PRETXT, this was disabled in pre 5.0.0
189 189 hooks_config.HOOK_PRETX_PUSH,
190 190
191 191 }
192 192 new_config = []
193 193 for section, option, value in config:
194 194 if section == 'hooks' and option in hooks_to_clean:
195 195 continue
196 196 new_config.append((section, option, value))
197 197 config = new_config
198 198
199 199 baseui = make_ui_from_config(config)
200 200 return baseui
201 201
202 202 def _create_repo(self, wire, create):
203 203 baseui = self._create_config(wire["config"])
204 204 repo = instance(baseui, safe_bytes(wire["path"]), create)
205 205 log.debug('repository created: got HG object: %s', repo)
206 206 return repo
207 207
208 208 def repo(self, wire, create=False):
209 209 """
210 210 Get a repository instance for the given path.
211 211 """
212 212 return self._create_repo(wire, create)
213 213
214 214
215 215 def patch_ui_message_output(baseui):
216 216 baseui.setconfig(b'ui', b'quiet', b'false')
217 217 output = io.BytesIO()
218 218
219 219 def write(data, **unused_kwargs):
220 220 output.write(data)
221 221
222 222 baseui.status = write
223 223 baseui.write = write
224 224 baseui.warn = write
225 225 baseui.debug = write
226 226
227 227 return baseui, output
228 228
229 229
230 230 def get_obfuscated_url(url_obj):
231 231 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
232 232 url_obj.query = obfuscate_qs(url_obj.query)
233 233 obfuscated_uri = str(url_obj)
234 234 return obfuscated_uri
235 235
236 236
237 237 def normalize_url_for_hg(url: str):
238 238 _proto = None
239 239
240 240 if '+' in url[:url.find('://')]:
241 241 _proto = url[0:url.find('+')]
242 242 url = url[url.find('+') + 1:]
243 243 return url, _proto
244 244
245 245
246 246 class HgRemote(RemoteBase):
247 247
248 248 def __init__(self, factory):
249 249 self._factory = factory
250 250 self._bulk_methods = {
251 251 "affected_files": self.ctx_files,
252 252 "author": self.ctx_user,
253 253 "branch": self.ctx_branch,
254 254 "children": self.ctx_children,
255 255 "date": self.ctx_date,
256 256 "message": self.ctx_description,
257 257 "parents": self.ctx_parents,
258 258 "status": self.ctx_status,
259 259 "obsolete": self.ctx_obsolete,
260 260 "phase": self.ctx_phase,
261 261 "hidden": self.ctx_hidden,
262 262 "_file_paths": self.ctx_list,
263 263 }
264 264 self._bulk_file_methods = {
265 265 "size": self.fctx_size,
266 266 "data": self.fctx_node_data,
267 267 "flags": self.fctx_flags,
268 268 "is_binary": self.is_binary,
269 269 "md5": self.md5_hash,
270 270 }
271 271
272 272 def _get_ctx(self, repo, ref):
273 273 return get_ctx(repo, ref)
274 274
275 275 @reraise_safe_exceptions
276 276 def discover_hg_version(self):
277 277 from mercurial import util
278 278 return safe_str(util.version())
279 279
280 280 @reraise_safe_exceptions
281 281 def is_empty(self, wire):
282 282 repo = self._factory.repo(wire)
283 283
284 284 try:
285 285 return len(repo) == 0
286 286 except Exception:
287 287 log.exception("failed to read object_store")
288 288 return False
289 289
290 290 @reraise_safe_exceptions
291 291 def bookmarks(self, wire):
292 292 cache_on, context_uid, repo_id = self._cache_on(wire)
293 293 region = self._region(wire)
294 294
295 295 @region.conditional_cache_on_arguments(condition=cache_on)
296 296 def _bookmarks(_context_uid, _repo_id):
297 297 repo = self._factory.repo(wire)
298 298 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo._bookmarks.items()}
299 299
300 300 return _bookmarks(context_uid, repo_id)
301 301
302 302 @reraise_safe_exceptions
303 303 def branches(self, wire, normal, closed):
304 304 cache_on, context_uid, repo_id = self._cache_on(wire)
305 305 region = self._region(wire)
306 306
307 307 @region.conditional_cache_on_arguments(condition=cache_on)
308 308 def _branches(_context_uid, _repo_id, _normal, _closed):
309 309 repo = self._factory.repo(wire)
310 310 iter_branches = repo.branchmap().iterbranches()
311 311 bt = {}
312 312 for branch_name, _heads, tip_node, is_closed in iter_branches:
313 313 if normal and not is_closed:
314 314 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
315 315 if closed and is_closed:
316 316 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
317 317
318 318 return bt
319 319
320 320 return _branches(context_uid, repo_id, normal, closed)
321 321
322 322 @reraise_safe_exceptions
323 323 def bulk_request(self, wire, commit_id, pre_load):
324 324 cache_on, context_uid, repo_id = self._cache_on(wire)
325 325 region = self._region(wire)
326 326
327 327 @region.conditional_cache_on_arguments(condition=cache_on)
328 328 def _bulk_request(_repo_id, _commit_id, _pre_load):
329 329 result = {}
330 330 for attr in pre_load:
331 331 try:
332 332 method = self._bulk_methods[attr]
333 333 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
334 334 result[attr] = method(wire, commit_id)
335 335 except KeyError as e:
336 336 raise exceptions.VcsException(e)(
337 337 f'Unknown bulk attribute: "{attr}"')
338 338 return result
339 339
340 340 return _bulk_request(repo_id, commit_id, sorted(pre_load))
341 341
342 342 @reraise_safe_exceptions
343 343 def ctx_branch(self, wire, commit_id):
344 344 cache_on, context_uid, repo_id = self._cache_on(wire)
345 345 region = self._region(wire)
346 346
347 347 @region.conditional_cache_on_arguments(condition=cache_on)
348 348 def _ctx_branch(_repo_id, _commit_id):
349 349 repo = self._factory.repo(wire)
350 350 ctx = self._get_ctx(repo, commit_id)
351 351 return ctx.branch()
352 352 return _ctx_branch(repo_id, commit_id)
353 353
354 354 @reraise_safe_exceptions
355 355 def ctx_date(self, wire, commit_id):
356 356 cache_on, context_uid, repo_id = self._cache_on(wire)
357 357 region = self._region(wire)
358 358
359 359 @region.conditional_cache_on_arguments(condition=cache_on)
360 360 def _ctx_date(_repo_id, _commit_id):
361 361 repo = self._factory.repo(wire)
362 362 ctx = self._get_ctx(repo, commit_id)
363 363 return ctx.date()
364 364 return _ctx_date(repo_id, commit_id)
365 365
366 366 @reraise_safe_exceptions
367 367 def ctx_description(self, wire, revision):
368 368 repo = self._factory.repo(wire)
369 369 ctx = self._get_ctx(repo, revision)
370 370 return ctx.description()
371 371
372 372 @reraise_safe_exceptions
373 373 def ctx_files(self, wire, commit_id):
374 374 cache_on, context_uid, repo_id = self._cache_on(wire)
375 375 region = self._region(wire)
376 376
377 377 @region.conditional_cache_on_arguments(condition=cache_on)
378 378 def _ctx_files(_repo_id, _commit_id):
379 379 repo = self._factory.repo(wire)
380 380 ctx = self._get_ctx(repo, commit_id)
381 381 return ctx.files()
382 382
383 383 return _ctx_files(repo_id, commit_id)
384 384
385 385 @reraise_safe_exceptions
386 386 def ctx_list(self, path, revision):
387 387 repo = self._factory.repo(path)
388 388 ctx = self._get_ctx(repo, revision)
389 389 return list(ctx)
390 390
391 391 @reraise_safe_exceptions
392 392 def ctx_parents(self, wire, commit_id):
393 393 cache_on, context_uid, repo_id = self._cache_on(wire)
394 394 region = self._region(wire)
395 395
396 396 @region.conditional_cache_on_arguments(condition=cache_on)
397 397 def _ctx_parents(_repo_id, _commit_id):
398 398 repo = self._factory.repo(wire)
399 399 ctx = self._get_ctx(repo, commit_id)
400 400 return [parent.hex() for parent in ctx.parents()
401 401 if not (parent.hidden() or parent.obsolete())]
402 402
403 403 return _ctx_parents(repo_id, commit_id)
404 404
405 405 @reraise_safe_exceptions
406 406 def ctx_children(self, wire, commit_id):
407 407 cache_on, context_uid, repo_id = self._cache_on(wire)
408 408 region = self._region(wire)
409 409
410 410 @region.conditional_cache_on_arguments(condition=cache_on)
411 411 def _ctx_children(_repo_id, _commit_id):
412 412 repo = self._factory.repo(wire)
413 413 ctx = self._get_ctx(repo, commit_id)
414 414 return [child.hex() for child in ctx.children()
415 415 if not (child.hidden() or child.obsolete())]
416 416
417 417 return _ctx_children(repo_id, commit_id)
418 418
419 419 @reraise_safe_exceptions
420 420 def ctx_phase(self, wire, commit_id):
421 421 cache_on, context_uid, repo_id = self._cache_on(wire)
422 422 region = self._region(wire)
423 423
424 424 @region.conditional_cache_on_arguments(condition=cache_on)
425 425 def _ctx_phase(_context_uid, _repo_id, _commit_id):
426 426 repo = self._factory.repo(wire)
427 427 ctx = self._get_ctx(repo, commit_id)
428 428 # public=0, draft=1, secret=3
429 429 return ctx.phase()
430 430 return _ctx_phase(context_uid, repo_id, commit_id)
431 431
432 432 @reraise_safe_exceptions
433 433 def ctx_obsolete(self, wire, commit_id):
434 434 cache_on, context_uid, repo_id = self._cache_on(wire)
435 435 region = self._region(wire)
436 436
437 437 @region.conditional_cache_on_arguments(condition=cache_on)
438 438 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
439 439 repo = self._factory.repo(wire)
440 440 ctx = self._get_ctx(repo, commit_id)
441 441 return ctx.obsolete()
442 442 return _ctx_obsolete(context_uid, repo_id, commit_id)
443 443
444 444 @reraise_safe_exceptions
445 445 def ctx_hidden(self, wire, commit_id):
446 446 cache_on, context_uid, repo_id = self._cache_on(wire)
447 447 region = self._region(wire)
448 448
449 449 @region.conditional_cache_on_arguments(condition=cache_on)
450 450 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
451 451 repo = self._factory.repo(wire)
452 452 ctx = self._get_ctx(repo, commit_id)
453 453 return ctx.hidden()
454 454 return _ctx_hidden(context_uid, repo_id, commit_id)
455 455
456 456 @reraise_safe_exceptions
457 457 def ctx_substate(self, wire, revision):
458 458 repo = self._factory.repo(wire)
459 459 ctx = self._get_ctx(repo, revision)
460 460 return ctx.substate
461 461
462 462 @reraise_safe_exceptions
463 463 def ctx_status(self, wire, revision):
464 464 repo = self._factory.repo(wire)
465 465 ctx = self._get_ctx(repo, revision)
466 466 status = repo[ctx.p1().node()].status(other=ctx.node())
467 467 # object of status (odd, custom named tuple in mercurial) is not
468 468 # correctly serializable, we make it a list, as the underling
469 469 # API expects this to be a list
470 470 return list(status)
471 471
472 472 @reraise_safe_exceptions
473 473 def ctx_user(self, wire, revision):
474 474 repo = self._factory.repo(wire)
475 475 ctx = self._get_ctx(repo, revision)
476 476 return ctx.user()
477 477
478 478 @reraise_safe_exceptions
479 479 def check_url(self, url, config):
480 480 url, _proto = normalize_url_for_hg(url)
481 481 url_obj = url_parser(safe_bytes(url))
482 482
483 483 test_uri = safe_str(url_obj.authinfo()[0])
484 484 authinfo = url_obj.authinfo()[1]
485 485 obfuscated_uri = get_obfuscated_url(url_obj)
486 486 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
487 487
488 488 handlers = []
489 489 if authinfo:
490 490 # create a password manager
491 491 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
492 492 passmgr.add_password(*convert_to_str(authinfo))
493 493
494 494 handlers.extend((httpbasicauthhandler(passmgr),
495 495 httpdigestauthhandler(passmgr)))
496 496
497 497 o = urllib.request.build_opener(*handlers)
498 498 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
499 499 ('Accept', 'application/mercurial-0.1')]
500 500
501 501 q = {"cmd": 'between'}
502 502 q.update({'pairs': "{}-{}".format('0' * 40, '0' * 40)})
503 503 qs = f'?{urllib.parse.urlencode(q)}'
504 504 cu = f"{test_uri}{qs}"
505 505
506 506 try:
507 507 req = urllib.request.Request(cu, None, {})
508 508 log.debug("Trying to open URL %s", obfuscated_uri)
509 509 resp = o.open(req)
510 510 if resp.code != 200:
511 511 raise exceptions.URLError()('Return Code is not 200')
512 512 except Exception as e:
513 513 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
514 514 # means it cannot be cloned
515 515 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
516 516
517 517 # now check if it's a proper hg repo, but don't do it for svn
518 518 try:
519 519 if _proto == 'svn':
520 520 pass
521 521 else:
522 522 # check for pure hg repos
523 523 log.debug(
524 524 "Verifying if URL is a Mercurial repository: %s", obfuscated_uri)
525 525 # Create repo path with custom mercurial path object
526 526 ui = make_ui_from_config(config)
527 527 repo_path = hg_path(ui=ui, rawloc=safe_bytes(test_uri))
528 528 peer_checker = make_peer(ui, repo_path, False)
529 529 peer_checker.lookup(b'tip')
530 530 except Exception as e:
531 531 log.warning("URL is not a valid Mercurial repository: %s",
532 532 obfuscated_uri)
533 533 raise exceptions.URLError(e)(
534 534 f"url [{obfuscated_uri}] does not look like an hg repo org_exc: {e}")
535 535
536 536 log.info("URL is a valid Mercurial repository: %s", obfuscated_uri)
537 537 return True
538 538
539 539 @reraise_safe_exceptions
540 540 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
541 541 repo = self._factory.repo(wire)
542 542
543 543 if file_filter:
544 544 # unpack the file-filter
545 545 repo_path, node_path = file_filter
546 546 match_filter = match(safe_bytes(repo_path), b'', [safe_bytes(node_path)])
547 547 else:
548 548 match_filter = file_filter
549 549 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
550 550
551 551 try:
552 552 diff_iter = patch.diff(
553 553 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts)
554 554 return BytesEnvelope(b"".join(diff_iter))
555 555 except RepoLookupError as e:
556 556 raise exceptions.LookupException(e)()
557 557
558 558 @reraise_safe_exceptions
559 559 def node_history(self, wire, revision, path, limit):
560 560 cache_on, context_uid, repo_id = self._cache_on(wire)
561 561 region = self._region(wire)
562 562
563 563 @region.conditional_cache_on_arguments(condition=cache_on)
564 564 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
565 565 repo = self._factory.repo(wire)
566 566
567 567 ctx = self._get_ctx(repo, revision)
568 568 fctx = ctx.filectx(safe_bytes(path))
569 569
570 570 def history_iter():
571 571 limit_rev = fctx.rev()
572 572
573 573 for fctx_candidate in reversed(list(fctx.filelog())):
574 574 f_obj = fctx.filectx(fctx_candidate)
575 575
576 576 # NOTE: This can be problematic...we can hide ONLY history node resulting in empty history
577 577 _ctx = f_obj.changectx()
578 578 if _ctx.hidden() or _ctx.obsolete():
579 579 continue
580 580
581 581 if limit_rev >= f_obj.rev():
582 582 yield f_obj
583 583
584 584 history = []
585 585 for cnt, obj in enumerate(history_iter()):
586 586 if limit and cnt >= limit:
587 587 break
588 588 history.append(hex(obj.node()))
589 589
590 590 return [x for x in history]
591 591 return _node_history(context_uid, repo_id, revision, path, limit)
592 592
593 593 @reraise_safe_exceptions
594 594 def node_history_until(self, wire, revision, path, limit):
595 595 cache_on, context_uid, repo_id = self._cache_on(wire)
596 596 region = self._region(wire)
597 597
598 598 @region.conditional_cache_on_arguments(condition=cache_on)
599 599 def _node_history_until(_context_uid, _repo_id):
600 600 repo = self._factory.repo(wire)
601 601 ctx = self._get_ctx(repo, revision)
602 602 fctx = ctx.filectx(safe_bytes(path))
603 603
604 604 file_log = list(fctx.filelog())
605 605 if limit:
606 606 # Limit to the last n items
607 607 file_log = file_log[-limit:]
608 608
609 609 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
610 610 return _node_history_until(context_uid, repo_id, revision, path, limit)
611 611
612 612 @reraise_safe_exceptions
613 613 def bulk_file_request(self, wire, commit_id, path, pre_load):
614 614 cache_on, context_uid, repo_id = self._cache_on(wire)
615 615 region = self._region(wire)
616 616
617 617 @region.conditional_cache_on_arguments(condition=cache_on)
618 618 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
619 619 result = {}
620 620 for attr in pre_load:
621 621 try:
622 622 method = self._bulk_file_methods[attr]
623 623 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
624 624 result[attr] = method(wire, _commit_id, _path)
625 625 except KeyError as e:
626 626 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
627 627 return result
628 628
629 629 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
630 630
631 631 @reraise_safe_exceptions
632 632 def fctx_annotate(self, wire, revision, path):
633 633 repo = self._factory.repo(wire)
634 634 ctx = self._get_ctx(repo, revision)
635 635 fctx = ctx.filectx(safe_bytes(path))
636 636
637 637 result = []
638 638 for i, annotate_obj in enumerate(fctx.annotate(), 1):
639 639 ln_no = i
640 640 sha = hex(annotate_obj.fctx.node())
641 641 content = annotate_obj.text
642 642 result.append((ln_no, ascii_str(sha), content))
643 643 return BinaryEnvelope(result)
644 644
645 645 @reraise_safe_exceptions
646 646 def fctx_node_data(self, wire, revision, path):
647 647 repo = self._factory.repo(wire)
648 648 ctx = self._get_ctx(repo, revision)
649 649 fctx = ctx.filectx(safe_bytes(path))
650 650 return BytesEnvelope(fctx.data())
651 651
652 652 @reraise_safe_exceptions
653 653 def fctx_flags(self, wire, commit_id, path):
654 654 cache_on, context_uid, repo_id = self._cache_on(wire)
655 655 region = self._region(wire)
656 656
657 657 @region.conditional_cache_on_arguments(condition=cache_on)
658 658 def _fctx_flags(_repo_id, _commit_id, _path):
659 659 repo = self._factory.repo(wire)
660 660 ctx = self._get_ctx(repo, commit_id)
661 661 fctx = ctx.filectx(safe_bytes(path))
662 662 return fctx.flags()
663 663
664 664 return _fctx_flags(repo_id, commit_id, path)
665 665
666 666 @reraise_safe_exceptions
667 667 def fctx_size(self, wire, commit_id, path):
668 668 cache_on, context_uid, repo_id = self._cache_on(wire)
669 669 region = self._region(wire)
670 670
671 671 @region.conditional_cache_on_arguments(condition=cache_on)
672 672 def _fctx_size(_repo_id, _revision, _path):
673 673 repo = self._factory.repo(wire)
674 674 ctx = self._get_ctx(repo, commit_id)
675 675 fctx = ctx.filectx(safe_bytes(path))
676 676 return fctx.size()
677 677 return _fctx_size(repo_id, commit_id, path)
678 678
679 679 @reraise_safe_exceptions
680 680 def get_all_commit_ids(self, wire, name):
681 681 cache_on, context_uid, repo_id = self._cache_on(wire)
682 682 region = self._region(wire)
683 683
684 684 @region.conditional_cache_on_arguments(condition=cache_on)
685 685 def _get_all_commit_ids(_context_uid, _repo_id, _name):
686 686 repo = self._factory.repo(wire)
687 687 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
688 688 return revs
689 689 return _get_all_commit_ids(context_uid, repo_id, name)
690 690
691 691 @reraise_safe_exceptions
692 692 def get_config_value(self, wire, section, name, untrusted=False):
693 693 repo = self._factory.repo(wire)
694 694 return repo.ui.config(ascii_bytes(section), ascii_bytes(name), untrusted=untrusted)
695 695
696 696 @reraise_safe_exceptions
697 697 def is_large_file(self, wire, commit_id, path):
698 698 cache_on, context_uid, repo_id = self._cache_on(wire)
699 699 region = self._region(wire)
700 700
701 701 @region.conditional_cache_on_arguments(condition=cache_on)
702 702 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
703 703 return largefiles.lfutil.isstandin(safe_bytes(path))
704 704
705 705 return _is_large_file(context_uid, repo_id, commit_id, path)
706 706
707 707 @reraise_safe_exceptions
708 708 def is_binary(self, wire, revision, path):
709 709 cache_on, context_uid, repo_id = self._cache_on(wire)
710 710 region = self._region(wire)
711 711
712 712 @region.conditional_cache_on_arguments(condition=cache_on)
713 713 def _is_binary(_repo_id, _sha, _path):
714 714 repo = self._factory.repo(wire)
715 715 ctx = self._get_ctx(repo, revision)
716 716 fctx = ctx.filectx(safe_bytes(path))
717 717 return fctx.isbinary()
718 718
719 719 return _is_binary(repo_id, revision, path)
720 720
721 721 @reraise_safe_exceptions
722 722 def md5_hash(self, wire, revision, path):
723 723 cache_on, context_uid, repo_id = self._cache_on(wire)
724 724 region = self._region(wire)
725 725
726 726 @region.conditional_cache_on_arguments(condition=cache_on)
727 727 def _md5_hash(_repo_id, _sha, _path):
728 728 repo = self._factory.repo(wire)
729 729 ctx = self._get_ctx(repo, revision)
730 730 fctx = ctx.filectx(safe_bytes(path))
731 731 return hashlib.md5(fctx.data()).hexdigest()
732 732
733 733 return _md5_hash(repo_id, revision, path)
734 734
735 735 @reraise_safe_exceptions
736 736 def in_largefiles_store(self, wire, sha):
737 737 repo = self._factory.repo(wire)
738 738 return largefiles.lfutil.instore(repo, sha)
739 739
740 740 @reraise_safe_exceptions
741 741 def in_user_cache(self, wire, sha):
742 742 repo = self._factory.repo(wire)
743 743 return largefiles.lfutil.inusercache(repo.ui, sha)
744 744
745 745 @reraise_safe_exceptions
746 746 def store_path(self, wire, sha):
747 747 repo = self._factory.repo(wire)
748 748 return largefiles.lfutil.storepath(repo, sha)
749 749
750 750 @reraise_safe_exceptions
751 751 def link(self, wire, sha, path):
752 752 repo = self._factory.repo(wire)
753 753 largefiles.lfutil.link(
754 754 largefiles.lfutil.usercachepath(repo.ui, sha), path)
755 755
756 756 @reraise_safe_exceptions
757 757 def localrepository(self, wire, create=False):
758 758 self._factory.repo(wire, create=create)
759 759
760 760 @reraise_safe_exceptions
761 761 def lookup(self, wire, revision, both):
762 762 cache_on, context_uid, repo_id = self._cache_on(wire)
763 763 region = self._region(wire)
764 764
765 765 @region.conditional_cache_on_arguments(condition=cache_on)
766 766 def _lookup(_context_uid, _repo_id, _revision, _both):
767 767 repo = self._factory.repo(wire)
768 768 rev = _revision
769 769 if isinstance(rev, int):
770 770 # NOTE(marcink):
771 771 # since Mercurial doesn't support negative indexes properly
772 772 # we need to shift accordingly by one to get proper index, e.g
773 773 # repo[-1] => repo[-2]
774 774 # repo[0] => repo[-1]
775 775 if rev <= 0:
776 776 rev = rev + -1
777 777 try:
778 778 ctx = self._get_ctx(repo, rev)
779 779 except AmbiguousPrefixLookupError:
780 780 e = RepoLookupError(rev)
781 781 e._org_exc_tb = format_exc(sys.exc_info())
782 782 raise exceptions.LookupException(e)(rev)
783 783 except (TypeError, RepoLookupError, binascii.Error) as e:
784 784 e._org_exc_tb = format_exc(sys.exc_info())
785 785 raise exceptions.LookupException(e)(rev)
786 786 except LookupError as e:
787 787 e._org_exc_tb = format_exc(sys.exc_info())
788 788 raise exceptions.LookupException(e)(e.name)
789 789
790 790 if not both:
791 791 return ctx.hex()
792 792
793 793 ctx = repo[ctx.hex()]
794 794 return ctx.hex(), ctx.rev()
795 795
796 796 return _lookup(context_uid, repo_id, revision, both)
797 797
798 798 @reraise_safe_exceptions
799 799 def sync_push(self, wire, url):
800 800 if not self.check_url(url, wire['config']):
801 801 return
802 802
803 803 repo = self._factory.repo(wire)
804 804
805 805 # Disable any prompts for this repo
806 806 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
807 807
808 808 bookmarks = list(dict(repo._bookmarks).keys())
809 809 remote = peer(repo, {}, safe_bytes(url))
810 810 # Disable any prompts for this remote
811 811 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
812 812
813 813 return exchange.push(
814 814 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
815 815
816 816 @reraise_safe_exceptions
817 817 def revision(self, wire, rev):
818 818 repo = self._factory.repo(wire)
819 819 ctx = self._get_ctx(repo, rev)
820 820 return ctx.rev()
821 821
822 822 @reraise_safe_exceptions
823 823 def rev_range(self, wire, commit_filter):
824 824 cache_on, context_uid, repo_id = self._cache_on(wire)
825 825 region = self._region(wire)
826 826
827 827 @region.conditional_cache_on_arguments(condition=cache_on)
828 828 def _rev_range(_context_uid, _repo_id, _filter):
829 829 repo = self._factory.repo(wire)
830 830 revisions = [
831 831 ascii_str(repo[rev].hex())
832 832 for rev in revrange(repo, list(map(ascii_bytes, commit_filter)))
833 833 ]
834 834 return revisions
835 835
836 836 return _rev_range(context_uid, repo_id, sorted(commit_filter))
837 837
838 838 @reraise_safe_exceptions
839 839 def rev_range_hash(self, wire, node):
840 840 repo = self._factory.repo(wire)
841 841
842 842 def get_revs(repo, rev_opt):
843 843 if rev_opt:
844 844 revs = revrange(repo, rev_opt)
845 845 if len(revs) == 0:
846 846 return (nullrev, nullrev)
847 847 return max(revs), min(revs)
848 848 else:
849 849 return len(repo) - 1, 0
850 850
851 851 stop, start = get_revs(repo, [node + ':'])
852 852 revs = [ascii_str(repo[r].hex()) for r in range(start, stop + 1)]
853 853 return revs
854 854
855 855 @reraise_safe_exceptions
856 856 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
857 857 org_path = safe_bytes(wire["path"])
858 858 other_path = safe_bytes(kwargs.pop('other_path', ''))
859 859
860 860 # case when we want to compare two independent repositories
861 861 if other_path and other_path != wire["path"]:
862 862 baseui = self._factory._create_config(wire["config"])
863 863 repo = unionrepo.makeunionrepository(baseui, other_path, org_path)
864 864 else:
865 865 repo = self._factory.repo(wire)
866 866 return list(repo.revs(rev_spec, *args))
867 867
868 868 @reraise_safe_exceptions
869 869 def verify(self, wire,):
870 870 repo = self._factory.repo(wire)
871 871 baseui = self._factory._create_config(wire['config'])
872 872
873 873 baseui, output = patch_ui_message_output(baseui)
874 874
875 875 repo.ui = baseui
876 876 verify.verify(repo)
877 877 return output.getvalue()
878 878
879 879 @reraise_safe_exceptions
880 880 def hg_update_cache(self, wire,):
881 881 repo = self._factory.repo(wire)
882 882 baseui = self._factory._create_config(wire['config'])
883 883 baseui, output = patch_ui_message_output(baseui)
884 884
885 885 repo.ui = baseui
886 886 with repo.wlock(), repo.lock():
887 887 repo.updatecaches(full=True)
888 888
889 889 return output.getvalue()
890 890
891 891 @reraise_safe_exceptions
892 892 def hg_rebuild_fn_cache(self, wire,):
893 893 repo = self._factory.repo(wire)
894 894 baseui = self._factory._create_config(wire['config'])
895 895 baseui, output = patch_ui_message_output(baseui)
896 896
897 897 repo.ui = baseui
898 898
899 899 repair.rebuildfncache(baseui, repo)
900 900
901 901 return output.getvalue()
902 902
903 903 @reraise_safe_exceptions
904 904 def tags(self, wire):
905 905 cache_on, context_uid, repo_id = self._cache_on(wire)
906 906 region = self._region(wire)
907 907
908 908 @region.conditional_cache_on_arguments(condition=cache_on)
909 909 def _tags(_context_uid, _repo_id):
910 910 repo = self._factory.repo(wire)
911 911 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo.tags().items()}
912 912
913 913 return _tags(context_uid, repo_id)
914 914
915 915 @reraise_safe_exceptions
916 916 def update(self, wire, node='', clean=False):
917 917 repo = self._factory.repo(wire)
918 918 baseui = self._factory._create_config(wire['config'])
919 919 node = safe_bytes(node)
920 920
921 921 commands.update(baseui, repo, node=node, clean=clean)
922 922
923 923 @reraise_safe_exceptions
924 924 def identify(self, wire):
925 925 repo = self._factory.repo(wire)
926 926 baseui = self._factory._create_config(wire['config'])
927 927 output = io.BytesIO()
928 928 baseui.write = output.write
929 929 # This is required to get a full node id
930 930 baseui.debugflag = True
931 931 commands.identify(baseui, repo, id=True)
932 932
933 933 return output.getvalue()
934 934
935 935 @reraise_safe_exceptions
936 936 def heads(self, wire, branch=None):
937 937 repo = self._factory.repo(wire)
938 938 baseui = self._factory._create_config(wire['config'])
939 939 output = io.BytesIO()
940 940
941 941 def write(data, **unused_kwargs):
942 942 output.write(data)
943 943
944 944 baseui.write = write
945 945 if branch:
946 946 args = [safe_bytes(branch)]
947 947 else:
948 948 args = []
949 949 commands.heads(baseui, repo, template=b'{node} ', *args)
950 950
951 951 return output.getvalue()
952 952
953 953 @reraise_safe_exceptions
954 954 def ancestor(self, wire, revision1, revision2):
955 955 repo = self._factory.repo(wire)
956 956 changelog = repo.changelog
957 957 lookup = repo.lookup
958 958 a = changelog.ancestor(lookup(safe_bytes(revision1)), lookup(safe_bytes(revision2)))
959 959 return hex(a)
960 960
961 961 @reraise_safe_exceptions
962 962 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
963 963 baseui = self._factory._create_config(wire["config"], hooks=hooks)
964 964 clone(baseui, safe_bytes(source), safe_bytes(dest), noupdate=not update_after_clone)
965 965
966 966 @reraise_safe_exceptions
967 967 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
968 968
969 969 repo = self._factory.repo(wire)
970 970 baseui = self._factory._create_config(wire['config'])
971 971 publishing = baseui.configbool(b'phases', b'publish')
972 972
973 973 def _filectxfn(_repo, ctx, path: bytes):
974 974 """
975 975 Marks given path as added/changed/removed in a given _repo. This is
976 976 for internal mercurial commit function.
977 977 """
978 978
979 979 # check if this path is removed
980 980 if safe_str(path) in removed:
981 981 # returning None is a way to mark node for removal
982 982 return None
983 983
984 984 # check if this path is added
985 985 for node in updated:
986 986 if safe_bytes(node['path']) == path:
987 987 return memfilectx(
988 988 _repo,
989 989 changectx=ctx,
990 990 path=safe_bytes(node['path']),
991 991 data=safe_bytes(node['content']),
992 992 islink=False,
993 993 isexec=bool(node['mode'] & stat.S_IXUSR),
994 994 copysource=False)
995 995 abort_exc = exceptions.AbortException()
996 996 raise abort_exc(f"Given path haven't been marked as added, changed or removed ({path})")
997 997
998 998 if publishing:
999 999 new_commit_phase = b'public'
1000 1000 else:
1001 1001 new_commit_phase = b'draft'
1002 1002 with repo.ui.configoverride({(b'phases', b'new-commit'): new_commit_phase}):
1003 1003 kwargs = {safe_bytes(k): safe_bytes(v) for k, v in extra.items()}
1004 1004 commit_ctx = memctx(
1005 1005 repo=repo,
1006 1006 parents=parents,
1007 1007 text=safe_bytes(message),
1008 1008 files=[safe_bytes(x) for x in files],
1009 1009 filectxfn=_filectxfn,
1010 1010 user=safe_bytes(user),
1011 1011 date=(commit_time, commit_timezone),
1012 1012 extra=kwargs)
1013 1013
1014 1014 n = repo.commitctx(commit_ctx)
1015 1015 new_id = hex(n)
1016 1016
1017 1017 return new_id
1018 1018
1019 1019 @reraise_safe_exceptions
1020 1020 def pull(self, wire, url, commit_ids=None):
1021 1021 repo = self._factory.repo(wire)
1022 1022 # Disable any prompts for this repo
1023 1023 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
1024 1024
1025 1025 remote = peer(repo, {}, safe_bytes(url))
1026 1026 # Disable any prompts for this remote
1027 1027 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
1028 1028
1029 1029 if commit_ids:
1030 1030 commit_ids = [bin(commit_id) for commit_id in commit_ids]
1031 1031
1032 1032 return exchange.pull(
1033 1033 repo, remote, heads=commit_ids, force=None).cgresult
1034 1034
1035 1035 @reraise_safe_exceptions
1036 1036 def pull_cmd(self, wire, source, bookmark='', branch='', revision='', hooks=True):
1037 1037 repo = self._factory.repo(wire)
1038 1038 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1039 1039
1040 1040 source = safe_bytes(source)
1041 1041
1042 1042 # Mercurial internally has a lot of logic that checks ONLY if
1043 1043 # option is defined, we just pass those if they are defined then
1044 1044 opts = {"remote_hidden": False}
1045 1045
1046 1046 if bookmark:
1047 1047 opts['bookmark'] = [safe_bytes(x) for x in bookmark] \
1048 1048 if isinstance(bookmark, list) else safe_bytes(bookmark)
1049 1049
1050 1050 if branch:
1051 1051 opts['branch'] = [safe_bytes(x) for x in branch] \
1052 1052 if isinstance(branch, list) else safe_bytes(branch)
1053 1053
1054 1054 if revision:
1055 1055 opts['rev'] = [safe_bytes(x) for x in revision] \
1056 1056 if isinstance(revision, list) else safe_bytes(revision)
1057 1057
1058 1058 commands.pull(baseui, repo, source, **opts)
1059 1059
1060 1060 @reraise_safe_exceptions
1061 1061 def push(self, wire, revisions, dest_path, hooks: bool = True, push_branches: bool = False):
1062 1062 repo = self._factory.repo(wire)
1063 1063 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1064 1064
1065 1065 revisions = [safe_bytes(x) for x in revisions] \
1066 1066 if isinstance(revisions, list) else safe_bytes(revisions)
1067 1067
1068 1068 commands.push(baseui, repo, safe_bytes(dest_path),
1069 1069 rev=revisions,
1070 1070 new_branch=push_branches)
1071 1071
1072 1072 @reraise_safe_exceptions
1073 1073 def strip(self, wire, revision, update, backup):
1074 1074 repo = self._factory.repo(wire)
1075 1075 ctx = self._get_ctx(repo, revision)
1076 1076 hgext_strip.strip(
1077 1077 repo.baseui, repo, ctx.node(), update=update, backup=backup)
1078 1078
1079 1079 @reraise_safe_exceptions
1080 1080 def get_unresolved_files(self, wire):
1081 1081 repo = self._factory.repo(wire)
1082 1082
1083 1083 log.debug('Calculating unresolved files for repo: %s', repo)
1084 1084 output = io.BytesIO()
1085 1085
1086 1086 def write(data, **unused_kwargs):
1087 1087 output.write(data)
1088 1088
1089 1089 baseui = self._factory._create_config(wire['config'])
1090 1090 baseui.write = write
1091 1091
1092 1092 commands.resolve(baseui, repo, list=True)
1093 1093 unresolved = output.getvalue().splitlines(0)
1094 1094 return unresolved
1095 1095
1096 1096 @reraise_safe_exceptions
1097 1097 def merge(self, wire, revision):
1098 1098 repo = self._factory.repo(wire)
1099 1099 baseui = self._factory._create_config(wire['config'])
1100 1100 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1101 1101
1102 1102 # In case of sub repositories are used mercurial prompts the user in
1103 1103 # case of merge conflicts or different sub repository sources. By
1104 1104 # setting the interactive flag to `False` mercurial doesn't prompt the
1105 1105 # used but instead uses a default value.
1106 1106 repo.ui.setconfig(b'ui', b'interactive', False)
1107 1107 commands.merge(baseui, repo, rev=safe_bytes(revision))
1108 1108
1109 1109 @reraise_safe_exceptions
1110 1110 def merge_state(self, wire):
1111 1111 repo = self._factory.repo(wire)
1112 1112 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1113 1113
1114 1114 # In case of sub repositories are used mercurial prompts the user in
1115 1115 # case of merge conflicts or different sub repository sources. By
1116 1116 # setting the interactive flag to `False` mercurial doesn't prompt the
1117 1117 # used but instead uses a default value.
1118 1118 repo.ui.setconfig(b'ui', b'interactive', False)
1119 1119 ms = hg_merge.mergestate(repo)
1120 1120 return [x for x in ms.unresolved()]
1121 1121
1122 1122 @reraise_safe_exceptions
1123 1123 def commit(self, wire, message, username, close_branch=False):
1124 1124 repo = self._factory.repo(wire)
1125 1125 baseui = self._factory._create_config(wire['config'])
1126 1126 repo.ui.setconfig(b'ui', b'username', safe_bytes(username))
1127 1127 commands.commit(baseui, repo, message=safe_bytes(message), close_branch=close_branch)
1128 1128
1129 1129 @reraise_safe_exceptions
1130 1130 def rebase(self, wire, source='', dest='', abort=False):
1131 1131
1132 1132 repo = self._factory.repo(wire)
1133 1133 baseui = self._factory._create_config(wire['config'])
1134 1134 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1135 1135 # In case of sub repositories are used mercurial prompts the user in
1136 1136 # case of merge conflicts or different sub repository sources. By
1137 1137 # setting the interactive flag to `False` mercurial doesn't prompt the
1138 1138 # used but instead uses a default value.
1139 1139 repo.ui.setconfig(b'ui', b'interactive', False)
1140 1140
1141 1141 rebase_kws = dict(
1142 1142 keep=not abort,
1143 1143 abort=abort
1144 1144 )
1145 1145
1146 1146 if source:
1147 1147 source = repo[source]
1148 1148 rebase_kws['base'] = [source.hex()]
1149 1149 if dest:
1150 1150 dest = repo[dest]
1151 1151 rebase_kws['dest'] = dest.hex()
1152 1152
1153 1153 rebase.rebase(baseui, repo, **rebase_kws)
1154 1154
1155 1155 @reraise_safe_exceptions
1156 1156 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
1157 1157 repo = self._factory.repo(wire)
1158 1158 ctx = self._get_ctx(repo, revision)
1159 1159 node = ctx.node()
1160 1160
1161 1161 date = (tag_time, tag_timezone)
1162 1162 try:
1163 1163 hg_tag.tag(repo, safe_bytes(name), node, safe_bytes(message), local, safe_bytes(user), date)
1164 1164 except Abort as e:
1165 1165 log.exception("Tag operation aborted")
1166 1166 # Exception can contain unicode which we convert
1167 1167 raise exceptions.AbortException(e)(repr(e))
1168 1168
1169 1169 @reraise_safe_exceptions
1170 1170 def bookmark(self, wire, bookmark, revision=''):
1171 1171 repo = self._factory.repo(wire)
1172 1172 baseui = self._factory._create_config(wire['config'])
1173 1173 revision = revision or ''
1174 1174 commands.bookmark(baseui, repo, safe_bytes(bookmark), rev=safe_bytes(revision), force=True)
1175 1175
1176 1176 @reraise_safe_exceptions
1177 1177 def install_hooks(self, wire, force=False):
1178 1178 # we don't need any special hooks for Mercurial
1179 1179 pass
1180 1180
1181 1181 @reraise_safe_exceptions
1182 1182 def get_hooks_info(self, wire):
1183 1183 return {
1184 1184 'pre_version': vcsserver.get_version(),
1185 1185 'post_version': vcsserver.get_version(),
1186 1186 }
1187 1187
1188 1188 @reraise_safe_exceptions
1189 1189 def set_head_ref(self, wire, head_name):
1190 1190 pass
1191 1191
1192 1192 @reraise_safe_exceptions
1193 1193 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1194 1194 archive_dir_name, commit_id, cache_config):
1195 1195
1196 1196 def file_walker(_commit_id, path):
1197 1197 repo = self._factory.repo(wire)
1198 1198 ctx = repo[_commit_id]
1199 1199 is_root = path in ['', '/']
1200 1200 if is_root:
1201 1201 matcher = alwaysmatcher(badfn=None)
1202 1202 else:
1203 1203 matcher = patternmatcher('', [(b'glob', safe_bytes(path)+b'/**', b'')], badfn=None)
1204 1204 file_iter = ctx.manifest().walk(matcher)
1205 1205
1206 1206 for fn in file_iter:
1207 1207 file_path = fn
1208 1208 flags = ctx.flags(fn)
1209 1209 mode = b'x' in flags and 0o755 or 0o644
1210 1210 is_link = b'l' in flags
1211 1211
1212 1212 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1213 1213
1214 1214 return store_archive_in_cache(
1215 1215 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
1216 1216
@@ -1,954 +1,954 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18
19 19 import os
20 20 import subprocess
21 21 from urllib.error import URLError
22 22 import urllib.parse
23 23 import logging
24 24 import posixpath as vcspath
25 25 import io
26 26 import urllib.request
27 27 import urllib.parse
28 28 import urllib.error
29 29 import traceback
30 30
31 31
32 32 import svn.client # noqa
33 33 import svn.core # noqa
34 34 import svn.delta # noqa
35 35 import svn.diff # noqa
36 36 import svn.fs # noqa
37 37 import svn.repos # noqa
38 38
39 39 import rhodecode
40 40 from vcsserver import svn_diff, exceptions, subprocessio, settings
41 41 from vcsserver.base import (
42 42 RepoFactory,
43 43 raise_from_original,
44 44 ArchiveNode,
45 45 store_archive_in_cache,
46 46 BytesEnvelope,
47 47 BinaryEnvelope,
48 48 )
49 49 from vcsserver.exceptions import NoContentException
50 from vcsserver.str_utils import safe_str, safe_bytes
51 from vcsserver.type_utils import assert_bytes
50 from vcsserver.lib.str_utils import safe_str, safe_bytes
51 from vcsserver.lib.type_utils import assert_bytes
52 52 from vcsserver.vcs_base import RemoteBase
53 53 from vcsserver.lib.svnremoterepo import svnremoterepo
54 54
55 55 log = logging.getLogger(__name__)
56 56
57 57
58 58 svn_compatible_versions_map = {
59 59 'pre-1.4-compatible': '1.3',
60 60 'pre-1.5-compatible': '1.4',
61 61 'pre-1.6-compatible': '1.5',
62 62 'pre-1.8-compatible': '1.7',
63 63 'pre-1.9-compatible': '1.8',
64 64 }
65 65
66 66 current_compatible_version = '1.14'
67 67
68 68
69 69 def reraise_safe_exceptions(func):
70 70 """Decorator for converting svn exceptions to something neutral."""
71 71 def wrapper(*args, **kwargs):
72 72 try:
73 73 return func(*args, **kwargs)
74 74 except Exception as e:
75 75 if not hasattr(e, '_vcs_kind'):
76 76 log.exception("Unhandled exception in svn remote call")
77 77 raise_from_original(exceptions.UnhandledException(e), e)
78 78 raise
79 79 return wrapper
80 80
81 81
82 82 class SubversionFactory(RepoFactory):
83 83 repo_type = 'svn'
84 84
85 85 def _create_repo(self, wire, create, compatible_version):
86 86 path = svn.core.svn_path_canonicalize(wire['path'])
87 87 if create:
88 88 fs_config = {'compatible-version': current_compatible_version}
89 89 if compatible_version:
90 90
91 91 compatible_version_string = \
92 92 svn_compatible_versions_map.get(compatible_version) \
93 93 or compatible_version
94 94 fs_config['compatible-version'] = compatible_version_string
95 95
96 96 log.debug('Create SVN repo with config `%s`', fs_config)
97 97 repo = svn.repos.create(path, "", "", None, fs_config)
98 98 else:
99 99 repo = svn.repos.open(path)
100 100
101 101 log.debug('repository created: got SVN object: %s', repo)
102 102 return repo
103 103
104 104 def repo(self, wire, create=False, compatible_version=None):
105 105 """
106 106 Get a repository instance for the given path.
107 107 """
108 108 return self._create_repo(wire, create, compatible_version)
109 109
110 110
111 111 NODE_TYPE_MAPPING = {
112 112 svn.core.svn_node_file: 'file',
113 113 svn.core.svn_node_dir: 'dir',
114 114 }
115 115
116 116
117 117 class SvnRemote(RemoteBase):
118 118
119 119 def __init__(self, factory, hg_factory=None):
120 120 self._factory = factory
121 121
122 122 self._bulk_methods = {
123 123 # NOT supported in SVN ATM...
124 124 }
125 125 self._bulk_file_methods = {
126 126 "size": self.get_file_size,
127 127 "data": self.get_file_content,
128 128 "flags": self.get_node_type,
129 129 "is_binary": self.is_binary,
130 130 "md5": self.md5_hash
131 131 }
132 132
133 133 @reraise_safe_exceptions
134 134 def bulk_file_request(self, wire, commit_id, path, pre_load):
135 135 cache_on, context_uid, repo_id = self._cache_on(wire)
136 136 region = self._region(wire)
137 137
138 138 # since we use unified API, we need to cast from str to in for SVN
139 139 commit_id = int(commit_id)
140 140
141 141 @region.conditional_cache_on_arguments(condition=cache_on)
142 142 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
143 143 result = {}
144 144 for attr in pre_load:
145 145 try:
146 146 method = self._bulk_file_methods[attr]
147 147 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
148 148 result[attr] = method(wire, _commit_id, _path)
149 149 except KeyError as e:
150 150 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
151 151 return result
152 152
153 153 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
154 154
155 155 @reraise_safe_exceptions
156 156 def discover_svn_version(self):
157 157 try:
158 158 import svn.core
159 159 svn_ver = svn.core.SVN_VERSION
160 160 except ImportError:
161 161 svn_ver = None
162 162 return safe_str(svn_ver)
163 163
164 164 @reraise_safe_exceptions
165 165 def is_empty(self, wire):
166 166 try:
167 167 return self.lookup(wire, -1) == 0
168 168 except Exception:
169 169 log.exception("failed to read object_store")
170 170 return False
171 171
172 172 def check_url(self, url, config):
173 173
174 174 # uuid function gets only valid UUID from proper repo, else
175 175 # throws exception
176 176 username, password, src_url = self.get_url_and_credentials(url)
177 177 try:
178 178 svnremoterepo(safe_bytes(username), safe_bytes(password), safe_bytes(src_url)).svn().uuid
179 179 except Exception:
180 180 tb = traceback.format_exc()
181 181 log.debug("Invalid Subversion url: `%s`, tb: %s", url, tb)
182 182 raise URLError(f'"{url}" is not a valid Subversion source url.')
183 183 return True
184 184
185 185 def is_path_valid_repository(self, wire, path):
186 186 # NOTE(marcink): short circuit the check for SVN repo
187 187 # the repos.open might be expensive to check, but we have one cheap
188 188 # pre-condition that we can use, to check for 'format' file
189 189 if not os.path.isfile(os.path.join(path, 'format')):
190 190 return False
191 191
192 192 cache_on, context_uid, repo_id = self._cache_on(wire)
193 193 region = self._region(wire)
194 194
195 195 @region.conditional_cache_on_arguments(condition=cache_on)
196 196 def _assert_correct_path(_context_uid, _repo_id, fast_check):
197 197
198 198 try:
199 199 svn.repos.open(path)
200 200 except svn.core.SubversionException:
201 201 tb = traceback.format_exc()
202 202 log.debug("Invalid Subversion path `%s`, tb: %s", path, tb)
203 203 return False
204 204 return True
205 205
206 206 return _assert_correct_path(context_uid, repo_id, True)
207 207
208 208 @reraise_safe_exceptions
209 209 def verify(self, wire,):
210 210 repo_path = wire['path']
211 211 if not self.is_path_valid_repository(wire, repo_path):
212 212 raise Exception(
213 213 f"Path {repo_path} is not a valid Subversion repository.")
214 214
215 215 cmd = ['svnadmin', 'info', repo_path]
216 216 stdout, stderr = subprocessio.run_command(cmd)
217 217 return stdout
218 218
219 219 @reraise_safe_exceptions
220 220 def lookup(self, wire, revision):
221 221 if revision not in [-1, None, 'HEAD']:
222 222 raise NotImplementedError
223 223 repo = self._factory.repo(wire)
224 224 fs_ptr = svn.repos.fs(repo)
225 225 head = svn.fs.youngest_rev(fs_ptr)
226 226 return head
227 227
228 228 @reraise_safe_exceptions
229 229 def lookup_interval(self, wire, start_ts, end_ts):
230 230 repo = self._factory.repo(wire)
231 231 fsobj = svn.repos.fs(repo)
232 232 start_rev = None
233 233 end_rev = None
234 234 if start_ts:
235 235 start_ts_svn = apr_time_t(start_ts)
236 236 start_rev = svn.repos.dated_revision(repo, start_ts_svn) + 1
237 237 else:
238 238 start_rev = 1
239 239 if end_ts:
240 240 end_ts_svn = apr_time_t(end_ts)
241 241 end_rev = svn.repos.dated_revision(repo, end_ts_svn)
242 242 else:
243 243 end_rev = svn.fs.youngest_rev(fsobj)
244 244 return start_rev, end_rev
245 245
246 246 @reraise_safe_exceptions
247 247 def revision_properties(self, wire, revision):
248 248
249 249 cache_on, context_uid, repo_id = self._cache_on(wire)
250 250 region = self._region(wire)
251 251
252 252 @region.conditional_cache_on_arguments(condition=cache_on)
253 253 def _revision_properties(_repo_id, _revision):
254 254 repo = self._factory.repo(wire)
255 255 fs_ptr = svn.repos.fs(repo)
256 256 return svn.fs.revision_proplist(fs_ptr, revision)
257 257 return _revision_properties(repo_id, revision)
258 258
259 259 def revision_changes(self, wire, revision):
260 260
261 261 repo = self._factory.repo(wire)
262 262 fsobj = svn.repos.fs(repo)
263 263 rev_root = svn.fs.revision_root(fsobj, revision)
264 264
265 265 editor = svn.repos.ChangeCollector(fsobj, rev_root)
266 266 editor_ptr, editor_baton = svn.delta.make_editor(editor)
267 267 base_dir = ""
268 268 send_deltas = False
269 269 svn.repos.replay2(
270 270 rev_root, base_dir, svn.core.SVN_INVALID_REVNUM, send_deltas,
271 271 editor_ptr, editor_baton, None)
272 272
273 273 added = []
274 274 changed = []
275 275 removed = []
276 276
277 277 # TODO: CHANGE_ACTION_REPLACE: Figure out where it belongs
278 278 for path, change in editor.changes.items():
279 279 # TODO: Decide what to do with directory nodes. Subversion can add
280 280 # empty directories.
281 281
282 282 if change.item_kind == svn.core.svn_node_dir:
283 283 continue
284 284 if change.action in [svn.repos.CHANGE_ACTION_ADD]:
285 285 added.append(path)
286 286 elif change.action in [svn.repos.CHANGE_ACTION_MODIFY,
287 287 svn.repos.CHANGE_ACTION_REPLACE]:
288 288 changed.append(path)
289 289 elif change.action in [svn.repos.CHANGE_ACTION_DELETE]:
290 290 removed.append(path)
291 291 else:
292 292 raise NotImplementedError(
293 293 "Action {} not supported on path {}".format(
294 294 change.action, path))
295 295
296 296 changes = {
297 297 'added': added,
298 298 'changed': changed,
299 299 'removed': removed,
300 300 }
301 301 return changes
302 302
303 303 @reraise_safe_exceptions
304 304 def node_history(self, wire, path, revision, limit):
305 305 cache_on, context_uid, repo_id = self._cache_on(wire)
306 306 region = self._region(wire)
307 307
308 308 @region.conditional_cache_on_arguments(condition=cache_on)
309 309 def _assert_correct_path(_context_uid, _repo_id, _path, _revision, _limit):
310 310 cross_copies = False
311 311 repo = self._factory.repo(wire)
312 312 fsobj = svn.repos.fs(repo)
313 313 rev_root = svn.fs.revision_root(fsobj, revision)
314 314
315 315 history_revisions = []
316 316 history = svn.fs.node_history(rev_root, path)
317 317 history = svn.fs.history_prev(history, cross_copies)
318 318 while history:
319 319 __, node_revision = svn.fs.history_location(history)
320 320 history_revisions.append(node_revision)
321 321 if limit and len(history_revisions) >= limit:
322 322 break
323 323 history = svn.fs.history_prev(history, cross_copies)
324 324 return history_revisions
325 325 return _assert_correct_path(context_uid, repo_id, path, revision, limit)
326 326
327 327 @reraise_safe_exceptions
328 328 def node_properties(self, wire, path, revision):
329 329 cache_on, context_uid, repo_id = self._cache_on(wire)
330 330 region = self._region(wire)
331 331
332 332 @region.conditional_cache_on_arguments(condition=cache_on)
333 333 def _node_properties(_repo_id, _path, _revision):
334 334 repo = self._factory.repo(wire)
335 335 fsobj = svn.repos.fs(repo)
336 336 rev_root = svn.fs.revision_root(fsobj, revision)
337 337 return svn.fs.node_proplist(rev_root, path)
338 338 return _node_properties(repo_id, path, revision)
339 339
340 340 def file_annotate(self, wire, path, revision):
341 341 abs_path = 'file://' + urllib.request.pathname2url(
342 342 vcspath.join(wire['path'], path))
343 343 file_uri = svn.core.svn_path_canonicalize(abs_path)
344 344
345 345 start_rev = svn_opt_revision_value_t(0)
346 346 peg_rev = svn_opt_revision_value_t(revision)
347 347 end_rev = peg_rev
348 348
349 349 annotations = []
350 350
351 351 def receiver(line_no, revision, author, date, line, pool):
352 352 annotations.append((line_no, revision, line))
353 353
354 354 # TODO: Cannot use blame5, missing typemap function in the swig code
355 355 try:
356 356 svn.client.blame2(
357 357 file_uri, peg_rev, start_rev, end_rev,
358 358 receiver, svn.client.create_context())
359 359 except svn.core.SubversionException as exc:
360 360 log.exception("Error during blame operation.")
361 361 raise Exception(
362 362 f"Blame not supported or file does not exist at path {path}. "
363 363 f"Error {exc}.")
364 364
365 365 return BinaryEnvelope(annotations)
366 366
367 367 @reraise_safe_exceptions
368 368 def get_node_type(self, wire, revision=None, path=''):
369 369
370 370 cache_on, context_uid, repo_id = self._cache_on(wire)
371 371 region = self._region(wire)
372 372
373 373 @region.conditional_cache_on_arguments(condition=cache_on)
374 374 def _get_node_type(_repo_id, _revision, _path):
375 375 repo = self._factory.repo(wire)
376 376 fs_ptr = svn.repos.fs(repo)
377 377 if _revision is None:
378 378 _revision = svn.fs.youngest_rev(fs_ptr)
379 379 root = svn.fs.revision_root(fs_ptr, _revision)
380 380 node = svn.fs.check_path(root, path)
381 381 return NODE_TYPE_MAPPING.get(node, None)
382 382 return _get_node_type(repo_id, revision, path)
383 383
384 384 @reraise_safe_exceptions
385 385 def get_nodes(self, wire, revision=None, path=''):
386 386
387 387 cache_on, context_uid, repo_id = self._cache_on(wire)
388 388 region = self._region(wire)
389 389
390 390 @region.conditional_cache_on_arguments(condition=cache_on)
391 391 def _get_nodes(_repo_id, _path, _revision):
392 392 repo = self._factory.repo(wire)
393 393 fsobj = svn.repos.fs(repo)
394 394 if _revision is None:
395 395 _revision = svn.fs.youngest_rev(fsobj)
396 396 root = svn.fs.revision_root(fsobj, _revision)
397 397 entries = svn.fs.dir_entries(root, path)
398 398 result = []
399 399 for entry_path, entry_info in entries.items():
400 400 result.append(
401 401 (entry_path, NODE_TYPE_MAPPING.get(entry_info.kind, None)))
402 402 return result
403 403 return _get_nodes(repo_id, path, revision)
404 404
405 405 @reraise_safe_exceptions
406 406 def get_file_content(self, wire, rev=None, path=''):
407 407 repo = self._factory.repo(wire)
408 408 fsobj = svn.repos.fs(repo)
409 409
410 410 if rev is None:
411 411 rev = svn.fs.youngest_rev(fsobj)
412 412
413 413 root = svn.fs.revision_root(fsobj, rev)
414 414 content = svn.core.Stream(svn.fs.file_contents(root, path))
415 415 return BytesEnvelope(content.read())
416 416
417 417 @reraise_safe_exceptions
418 418 def get_file_size(self, wire, revision=None, path=''):
419 419
420 420 cache_on, context_uid, repo_id = self._cache_on(wire)
421 421 region = self._region(wire)
422 422
423 423 @region.conditional_cache_on_arguments(condition=cache_on)
424 424 def _get_file_size(_repo_id, _revision, _path):
425 425 repo = self._factory.repo(wire)
426 426 fsobj = svn.repos.fs(repo)
427 427 if _revision is None:
428 428 _revision = svn.fs.youngest_revision(fsobj)
429 429 root = svn.fs.revision_root(fsobj, _revision)
430 430 size = svn.fs.file_length(root, path)
431 431 return size
432 432 return _get_file_size(repo_id, revision, path)
433 433
434 434 def create_repository(self, wire, compatible_version=None):
435 435 log.info('Creating Subversion repository in path "%s"', wire['path'])
436 436 self._factory.repo(wire, create=True,
437 437 compatible_version=compatible_version)
438 438
439 439 def get_url_and_credentials(self, src_url) -> tuple[str, str, str]:
440 440 obj = urllib.parse.urlparse(src_url)
441 441 username = obj.username or ''
442 442 password = obj.password or ''
443 443 return username, password, src_url
444 444
445 445 def import_remote_repository(self, wire, src_url):
446 446 repo_path = wire['path']
447 447 if not self.is_path_valid_repository(wire, repo_path):
448 448 raise Exception(
449 449 f"Path {repo_path} is not a valid Subversion repository.")
450 450
451 451 username, password, src_url = self.get_url_and_credentials(src_url)
452 452 rdump_cmd = ['svnrdump', 'dump', '--non-interactive',
453 453 '--trust-server-cert-failures=unknown-ca']
454 454 if username and password:
455 455 rdump_cmd += ['--username', username, '--password', password]
456 456 rdump_cmd += [src_url]
457 457
458 458 rdump = subprocess.Popen(
459 459 rdump_cmd,
460 460 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
461 461 load = subprocess.Popen(
462 462 ['svnadmin', 'load', repo_path], stdin=rdump.stdout)
463 463
464 464 # TODO: johbo: This can be a very long operation, might be better
465 465 # to track some kind of status and provide an api to check if the
466 466 # import is done.
467 467 rdump.wait()
468 468 load.wait()
469 469
470 470 log.debug('Return process ended with code: %s', rdump.returncode)
471 471 if rdump.returncode != 0:
472 472 errors = rdump.stderr.read()
473 473 log.error('svnrdump dump failed: statuscode %s: message: %s', rdump.returncode, errors)
474 474
475 475 reason = 'UNKNOWN'
476 476 if b'svnrdump: E230001:' in errors:
477 477 reason = 'INVALID_CERTIFICATE'
478 478
479 479 if reason == 'UNKNOWN':
480 480 reason = f'UNKNOWN:{safe_str(errors)}'
481 481
482 482 raise Exception(
483 483 'Failed to dump the remote repository from {}. Reason:{}'.format(
484 484 src_url, reason))
485 485 if load.returncode != 0:
486 486 raise Exception(
487 487 f'Failed to load the dump of remote repository from {src_url}.')
488 488
489 489 def commit(self, wire, message, author, timestamp, updated, removed):
490 490
491 491 message = safe_bytes(message)
492 492 author = safe_bytes(author)
493 493
494 494 repo = self._factory.repo(wire)
495 495 fsobj = svn.repos.fs(repo)
496 496
497 497 rev = svn.fs.youngest_rev(fsobj)
498 498 txn = svn.repos.fs_begin_txn_for_commit(repo, rev, author, message)
499 499 txn_root = svn.fs.txn_root(txn)
500 500
501 501 for node in updated:
502 502 TxnNodeProcessor(node, txn_root).update()
503 503 for node in removed:
504 504 TxnNodeProcessor(node, txn_root).remove()
505 505
506 506 commit_id = svn.repos.fs_commit_txn(repo, txn)
507 507
508 508 if timestamp:
509 509 apr_time = apr_time_t(timestamp)
510 510 ts_formatted = svn.core.svn_time_to_cstring(apr_time)
511 511 svn.fs.change_rev_prop(fsobj, commit_id, 'svn:date', ts_formatted)
512 512
513 513 log.debug('Committed revision "%s" to "%s".', commit_id, wire['path'])
514 514 return commit_id
515 515
516 516 @reraise_safe_exceptions
517 517 def diff(self, wire, rev1, rev2, path1=None, path2=None,
518 518 ignore_whitespace=False, context=3):
519 519
520 520 wire.update(cache=False)
521 521 repo = self._factory.repo(wire)
522 522 diff_creator = SvnDiffer(
523 523 repo, rev1, path1, rev2, path2, ignore_whitespace, context)
524 524 try:
525 525 return BytesEnvelope(diff_creator.generate_diff())
526 526 except svn.core.SubversionException as e:
527 527 log.exception(
528 528 "Error during diff operation operation. "
529 529 "Path might not exist %s, %s", path1, path2)
530 530 return BytesEnvelope(b'')
531 531
532 532 @reraise_safe_exceptions
533 533 def is_large_file(self, wire, path):
534 534 return False
535 535
536 536 @reraise_safe_exceptions
537 537 def is_binary(self, wire, rev, path):
538 538 cache_on, context_uid, repo_id = self._cache_on(wire)
539 539 region = self._region(wire)
540 540
541 541 @region.conditional_cache_on_arguments(condition=cache_on)
542 542 def _is_binary(_repo_id, _rev, _path):
543 543 raw_bytes = self.get_file_content(wire, rev, path)
544 544 if not raw_bytes:
545 545 return False
546 546 return b'\0' in raw_bytes
547 547
548 548 return _is_binary(repo_id, rev, path)
549 549
550 550 @reraise_safe_exceptions
551 551 def md5_hash(self, wire, rev, path):
552 552 cache_on, context_uid, repo_id = self._cache_on(wire)
553 553 region = self._region(wire)
554 554
555 555 @region.conditional_cache_on_arguments(condition=cache_on)
556 556 def _md5_hash(_repo_id, _rev, _path):
557 557 return ''
558 558
559 559 return _md5_hash(repo_id, rev, path)
560 560
561 561 @reraise_safe_exceptions
562 562 def run_svn_command(self, wire, cmd, **opts):
563 563 path = wire.get('path', None)
564 564 debug_mode = rhodecode.ConfigGet().get_bool('debug')
565 565
566 566 if path and os.path.isdir(path):
567 567 opts['cwd'] = path
568 568
569 569 safe_call = opts.pop('_safe', False)
570 570
571 571 svnenv = os.environ.copy()
572 572 svnenv.update(opts.pop('extra_env', {}))
573 573
574 574 _opts = {'env': svnenv, 'shell': False}
575 575
576 576 try:
577 577 _opts.update(opts)
578 578 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
579 579
580 580 return b''.join(proc), b''.join(proc.stderr)
581 581 except OSError as err:
582 582 if safe_call:
583 583 return '', safe_str(err).strip()
584 584 else:
585 585 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
586 586 call_opts = {}
587 587 if debug_mode:
588 588 call_opts = _opts
589 589
590 590 tb_err = ("Couldn't run svn command ({}).\n"
591 591 "Original error was:{}\n"
592 592 "Call options:{}\n"
593 593 .format(cmd, err, call_opts))
594 594 log.exception(tb_err)
595 595 raise exceptions.VcsException()(tb_err)
596 596
597 597 @reraise_safe_exceptions
598 598 def install_hooks(self, wire, force=False):
599 599 from vcsserver.hook_utils import install_svn_hooks
600 600 repo_path = wire['path']
601 601 binary_dir = settings.BINARY_DIR
602 602 executable = None
603 603 if binary_dir:
604 604 executable = os.path.join(binary_dir, 'python3')
605 605 return install_svn_hooks(repo_path, force_create=force)
606 606
607 607 @reraise_safe_exceptions
608 608 def get_hooks_info(self, wire):
609 609 from vcsserver.hook_utils import (
610 610 get_svn_pre_hook_version, get_svn_post_hook_version)
611 611 repo_path = wire['path']
612 612 return {
613 613 'pre_version': get_svn_pre_hook_version(repo_path),
614 614 'post_version': get_svn_post_hook_version(repo_path),
615 615 }
616 616
617 617 @reraise_safe_exceptions
618 618 def set_head_ref(self, wire, head_name):
619 619 pass
620 620
621 621 @reraise_safe_exceptions
622 622 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
623 623 archive_dir_name, commit_id, cache_config):
624 624
625 625 def walk_tree(root, root_dir, _commit_id):
626 626 """
627 627 Special recursive svn repo walker
628 628 """
629 629 root_dir = safe_bytes(root_dir)
630 630
631 631 filemode_default = 0o100644
632 632 filemode_executable = 0o100755
633 633
634 634 file_iter = svn.fs.dir_entries(root, root_dir)
635 635 for f_name in file_iter:
636 636 f_type = NODE_TYPE_MAPPING.get(file_iter[f_name].kind, None)
637 637
638 638 if f_type == 'dir':
639 639 # return only DIR, and then all entries in that dir
640 640 yield os.path.join(root_dir, f_name), {'mode': filemode_default}, f_type
641 641 new_root = os.path.join(root_dir, f_name)
642 642 yield from walk_tree(root, new_root, _commit_id)
643 643 else:
644 644
645 645 f_path = os.path.join(root_dir, f_name).rstrip(b'/')
646 646 prop_list = svn.fs.node_proplist(root, f_path)
647 647
648 648 f_mode = filemode_default
649 649 if prop_list.get('svn:executable'):
650 650 f_mode = filemode_executable
651 651
652 652 f_is_link = False
653 653 if prop_list.get('svn:special'):
654 654 f_is_link = True
655 655
656 656 data = {
657 657 'is_link': f_is_link,
658 658 'mode': f_mode,
659 659 'content_stream': svn.core.Stream(svn.fs.file_contents(root, f_path)).read
660 660 }
661 661
662 662 yield f_path, data, f_type
663 663
664 664 def file_walker(_commit_id, path):
665 665 repo = self._factory.repo(wire)
666 666 root = svn.fs.revision_root(svn.repos.fs(repo), int(commit_id))
667 667
668 668 def no_content():
669 669 raise NoContentException()
670 670
671 671 for f_name, f_data, f_type in walk_tree(root, path, _commit_id):
672 672 file_path = f_name
673 673
674 674 if f_type == 'dir':
675 675 mode = f_data['mode']
676 676 yield ArchiveNode(file_path, mode, False, no_content)
677 677 else:
678 678 mode = f_data['mode']
679 679 is_link = f_data['is_link']
680 680 data_stream = f_data['content_stream']
681 681 yield ArchiveNode(file_path, mode, is_link, data_stream)
682 682
683 683 return store_archive_in_cache(
684 684 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
685 685
686 686
687 687 class SvnDiffer:
688 688 """
689 689 Utility to create diffs based on difflib and the Subversion api
690 690 """
691 691
692 692 binary_content = False
693 693
694 694 def __init__(
695 695 self, repo, src_rev, src_path, tgt_rev, tgt_path,
696 696 ignore_whitespace, context):
697 697 self.repo = repo
698 698 self.ignore_whitespace = ignore_whitespace
699 699 self.context = context
700 700
701 701 fsobj = svn.repos.fs(repo)
702 702
703 703 self.tgt_rev = tgt_rev
704 704 self.tgt_path = tgt_path or ''
705 705 self.tgt_root = svn.fs.revision_root(fsobj, tgt_rev)
706 706 self.tgt_kind = svn.fs.check_path(self.tgt_root, self.tgt_path)
707 707
708 708 self.src_rev = src_rev
709 709 self.src_path = src_path or self.tgt_path
710 710 self.src_root = svn.fs.revision_root(fsobj, src_rev)
711 711 self.src_kind = svn.fs.check_path(self.src_root, self.src_path)
712 712
713 713 self._validate()
714 714
715 715 def _validate(self):
716 716 if (self.tgt_kind != svn.core.svn_node_none and
717 717 self.src_kind != svn.core.svn_node_none and
718 718 self.src_kind != self.tgt_kind):
719 719 # TODO: johbo: proper error handling
720 720 raise Exception(
721 721 "Source and target are not compatible for diff generation. "
722 722 "Source type: %s, target type: %s" %
723 723 (self.src_kind, self.tgt_kind))
724 724
725 725 def generate_diff(self) -> bytes:
726 726 buf = io.BytesIO()
727 727 if self.tgt_kind == svn.core.svn_node_dir:
728 728 self._generate_dir_diff(buf)
729 729 else:
730 730 self._generate_file_diff(buf)
731 731 return buf.getvalue()
732 732
733 733 def _generate_dir_diff(self, buf: io.BytesIO):
734 734 editor = DiffChangeEditor()
735 735 editor_ptr, editor_baton = svn.delta.make_editor(editor)
736 736 svn.repos.dir_delta2(
737 737 self.src_root,
738 738 self.src_path,
739 739 '', # src_entry
740 740 self.tgt_root,
741 741 self.tgt_path,
742 742 editor_ptr, editor_baton,
743 743 authorization_callback_allow_all,
744 744 False, # text_deltas
745 745 svn.core.svn_depth_infinity, # depth
746 746 False, # entry_props
747 747 False, # ignore_ancestry
748 748 )
749 749
750 750 for path, __, change in sorted(editor.changes):
751 751 self._generate_node_diff(
752 752 buf, change, path, self.tgt_path, path, self.src_path)
753 753
754 754 def _generate_file_diff(self, buf: io.BytesIO):
755 755 change = None
756 756 if self.src_kind == svn.core.svn_node_none:
757 757 change = "add"
758 758 elif self.tgt_kind == svn.core.svn_node_none:
759 759 change = "delete"
760 760 tgt_base, tgt_path = vcspath.split(self.tgt_path)
761 761 src_base, src_path = vcspath.split(self.src_path)
762 762 self._generate_node_diff(
763 763 buf, change, tgt_path, tgt_base, src_path, src_base)
764 764
765 765 def _generate_node_diff(
766 766 self, buf: io.BytesIO, change, tgt_path, tgt_base, src_path, src_base):
767 767
768 768 tgt_path_bytes = safe_bytes(tgt_path)
769 769 tgt_path = safe_str(tgt_path)
770 770
771 771 src_path_bytes = safe_bytes(src_path)
772 772 src_path = safe_str(src_path)
773 773
774 774 if self.src_rev == self.tgt_rev and tgt_base == src_base:
775 775 # makes consistent behaviour with git/hg to return empty diff if
776 776 # we compare same revisions
777 777 return
778 778
779 779 tgt_full_path = vcspath.join(tgt_base, tgt_path)
780 780 src_full_path = vcspath.join(src_base, src_path)
781 781
782 782 self.binary_content = False
783 783 mime_type = self._get_mime_type(tgt_full_path)
784 784
785 785 if mime_type and not mime_type.startswith(b'text'):
786 786 self.binary_content = True
787 787 buf.write(b"=" * 67 + b'\n')
788 788 buf.write(b"Cannot display: file marked as a binary type.\n")
789 789 buf.write(b"svn:mime-type = %s\n" % mime_type)
790 790 buf.write(b"Index: %b\n" % tgt_path_bytes)
791 791 buf.write(b"=" * 67 + b'\n')
792 792 buf.write(b"diff --git a/%b b/%b\n" % (tgt_path_bytes, tgt_path_bytes))
793 793
794 794 if change == 'add':
795 795 # TODO: johbo: SVN is missing a zero here compared to git
796 796 buf.write(b"new file mode 10644\n")
797 797
798 798 # TODO(marcink): intro to binary detection of svn patches
799 799 # if self.binary_content:
800 800 # buf.write(b'GIT binary patch\n')
801 801
802 802 buf.write(b"--- /dev/null\t(revision 0)\n")
803 803 src_lines = []
804 804 else:
805 805 if change == 'delete':
806 806 buf.write(b"deleted file mode 10644\n")
807 807
808 808 # TODO(marcink): intro to binary detection of svn patches
809 809 # if self.binary_content:
810 810 # buf.write('GIT binary patch\n')
811 811
812 812 buf.write(b"--- a/%b\t(revision %d)\n" % (src_path_bytes, self.src_rev))
813 813 src_lines = self._svn_readlines(self.src_root, src_full_path)
814 814
815 815 if change == 'delete':
816 816 buf.write(b"+++ /dev/null\t(revision %d)\n" % self.tgt_rev)
817 817 tgt_lines = []
818 818 else:
819 819 buf.write(b"+++ b/%b\t(revision %d)\n" % (tgt_path_bytes, self.tgt_rev))
820 820 tgt_lines = self._svn_readlines(self.tgt_root, tgt_full_path)
821 821
822 822 # we made our diff header, time to generate the diff content into our buffer
823 823
824 824 if not self.binary_content:
825 825 udiff = svn_diff.unified_diff(
826 826 src_lines, tgt_lines, context=self.context,
827 827 ignore_blank_lines=self.ignore_whitespace,
828 828 ignore_case=False,
829 829 ignore_space_changes=self.ignore_whitespace)
830 830
831 831 buf.writelines(udiff)
832 832
833 833 def _get_mime_type(self, path) -> bytes:
834 834 try:
835 835 mime_type = svn.fs.node_prop(
836 836 self.tgt_root, path, svn.core.SVN_PROP_MIME_TYPE)
837 837 except svn.core.SubversionException:
838 838 mime_type = svn.fs.node_prop(
839 839 self.src_root, path, svn.core.SVN_PROP_MIME_TYPE)
840 840 return mime_type
841 841
842 842 def _svn_readlines(self, fs_root, node_path):
843 843 if self.binary_content:
844 844 return []
845 845 node_kind = svn.fs.check_path(fs_root, node_path)
846 846 if node_kind not in (
847 847 svn.core.svn_node_file, svn.core.svn_node_symlink):
848 848 return []
849 849 content = svn.core.Stream(
850 850 svn.fs.file_contents(fs_root, node_path)).read()
851 851
852 852 return content.splitlines(True)
853 853
854 854
855 855 class DiffChangeEditor(svn.delta.Editor):
856 856 """
857 857 Records changes between two given revisions
858 858 """
859 859
860 860 def __init__(self):
861 861 self.changes = []
862 862
863 863 def delete_entry(self, path, revision, parent_baton, pool=None):
864 864 self.changes.append((path, None, 'delete'))
865 865
866 866 def add_file(
867 867 self, path, parent_baton, copyfrom_path, copyfrom_revision,
868 868 file_pool=None):
869 869 self.changes.append((path, 'file', 'add'))
870 870
871 871 def open_file(self, path, parent_baton, base_revision, file_pool=None):
872 872 self.changes.append((path, 'file', 'change'))
873 873
874 874
875 875 def authorization_callback_allow_all(root, path, pool):
876 876 return True
877 877
878 878
879 879 class TxnNodeProcessor:
880 880 """
881 881 Utility to process the change of one node within a transaction root.
882 882
883 883 It encapsulates the knowledge of how to add, update or remove
884 884 a node for a given transaction root. The purpose is to support the method
885 885 `SvnRemote.commit`.
886 886 """
887 887
888 888 def __init__(self, node, txn_root):
889 889 assert_bytes(node['path'])
890 890
891 891 self.node = node
892 892 self.txn_root = txn_root
893 893
894 894 def update(self):
895 895 self._ensure_parent_dirs()
896 896 self._add_file_if_node_does_not_exist()
897 897 self._update_file_content()
898 898 self._update_file_properties()
899 899
900 900 def remove(self):
901 901 svn.fs.delete(self.txn_root, self.node['path'])
902 902 # TODO: Clean up directory if empty
903 903
904 904 def _ensure_parent_dirs(self):
905 905 curdir = vcspath.dirname(self.node['path'])
906 906 dirs_to_create = []
907 907 while not self._svn_path_exists(curdir):
908 908 dirs_to_create.append(curdir)
909 909 curdir = vcspath.dirname(curdir)
910 910
911 911 for curdir in reversed(dirs_to_create):
912 912 log.debug('Creating missing directory "%s"', curdir)
913 913 svn.fs.make_dir(self.txn_root, curdir)
914 914
915 915 def _svn_path_exists(self, path):
916 916 path_status = svn.fs.check_path(self.txn_root, path)
917 917 return path_status != svn.core.svn_node_none
918 918
919 919 def _add_file_if_node_does_not_exist(self):
920 920 kind = svn.fs.check_path(self.txn_root, self.node['path'])
921 921 if kind == svn.core.svn_node_none:
922 922 svn.fs.make_file(self.txn_root, self.node['path'])
923 923
924 924 def _update_file_content(self):
925 925 assert_bytes(self.node['content'])
926 926
927 927 handler, baton = svn.fs.apply_textdelta(
928 928 self.txn_root, self.node['path'], None, None)
929 929 svn.delta.svn_txdelta_send_string(self.node['content'], handler, baton)
930 930
931 931 def _update_file_properties(self):
932 932 properties = self.node.get('properties', {})
933 933 for key, value in properties.items():
934 934 svn.fs.change_node_prop(
935 935 self.txn_root, self.node['path'], safe_bytes(key), safe_bytes(value))
936 936
937 937
938 938 def apr_time_t(timestamp):
939 939 """
940 940 Convert a Python timestamp into APR timestamp type apr_time_t
941 941 """
942 942 return int(timestamp * 1E6)
943 943
944 944
945 945 def svn_opt_revision_value_t(num):
946 946 """
947 947 Put `num` into a `svn_opt_revision_value_t` structure.
948 948 """
949 949 value = svn.core.svn_opt_revision_value_t()
950 950 value.number = num
951 951 revision = svn.core.svn_opt_revision_t()
952 952 revision.kind = svn.core.svn_opt_revision_number
953 953 revision.value = value
954 954 return revision
@@ -1,255 +1,255 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import logging
20 20 import itertools
21 21
22 22 import mercurial
23 23 import mercurial.error
24 24 import mercurial.wireprotoserver
25 25 import mercurial.hgweb.common
26 26 import mercurial.hgweb.hgweb_mod
27 27 import webob.exc
28 28
29 29 from vcsserver import pygrack, exceptions, settings, git_lfs
30 from vcsserver.str_utils import ascii_bytes, safe_bytes
30 from vcsserver.lib.str_utils import ascii_bytes, safe_bytes
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 # propagated from mercurial documentation
36 36 HG_UI_SECTIONS = [
37 37 'alias', 'auth', 'decode/encode', 'defaults', 'diff', 'email', 'extensions',
38 38 'format', 'merge-patterns', 'merge-tools', 'hooks', 'http_proxy', 'smtp',
39 39 'patch', 'paths', 'profiling', 'server', 'trusted', 'ui', 'web',
40 40 ]
41 41
42 42
43 43 class HgWeb(mercurial.hgweb.hgweb_mod.hgweb):
44 44 """Extension of hgweb that simplifies some functions."""
45 45
46 46 def _get_view(self, repo):
47 47 """Views are not supported."""
48 48 return repo
49 49
50 50 def loadsubweb(self):
51 51 """The result is only used in the templater method which is not used."""
52 52 return None
53 53
54 54 def run(self):
55 55 """Unused function so raise an exception if accidentally called."""
56 56 raise NotImplementedError
57 57
58 58 def templater(self, req):
59 59 """Function used in an unreachable code path.
60 60
61 61 This code is unreachable because we guarantee that the HTTP request,
62 62 corresponds to a Mercurial command. See the is_hg method. So, we are
63 63 never going to get a user-visible url.
64 64 """
65 65 raise NotImplementedError
66 66
67 67 def archivelist(self, nodeid):
68 68 """Unused function so raise an exception if accidentally called."""
69 69 raise NotImplementedError
70 70
71 71 def __call__(self, environ, start_response):
72 72 """Run the WSGI application.
73 73
74 74 This may be called by multiple threads.
75 75 """
76 76 from mercurial.hgweb import request as requestmod
77 77 req = requestmod.parserequestfromenv(environ)
78 78 res = requestmod.wsgiresponse(req, start_response)
79 79 gen = self.run_wsgi(req, res)
80 80
81 81 first_chunk = None
82 82
83 83 try:
84 84 data = next(gen)
85 85
86 86 def first_chunk():
87 87 yield data
88 88 except StopIteration:
89 89 pass
90 90
91 91 if first_chunk:
92 92 return itertools.chain(first_chunk(), gen)
93 93 return gen
94 94
95 95 def _runwsgi(self, req, res, repo):
96 96
97 97 cmd = req.qsparams.get(b'cmd', '')
98 98 if not mercurial.wireprotoserver.iscmd(cmd):
99 99 # NOTE(marcink): for unsupported commands, we return bad request
100 100 # internally from HG
101 101 log.warning('cmd: `%s` is not supported by the mercurial wireprotocol v1', cmd)
102 102 from mercurial.hgweb.common import statusmessage
103 103 res.status = statusmessage(mercurial.hgweb.common.HTTP_BAD_REQUEST)
104 104 res.setbodybytes(b'')
105 105 return res.sendresponse()
106 106
107 107 return super()._runwsgi(req, res, repo)
108 108
109 109
110 110 def sanitize_hg_ui(baseui):
111 111 # NOTE(marcink): since python3 hgsubversion is deprecated.
112 112 # From old installations we might still have this set enabled
113 113 # we explicitly remove this now here to make sure it wont propagate further
114 114
115 115 if baseui.config(b'extensions', b'hgsubversion') is not None:
116 116 for cfg in (baseui._ocfg, baseui._tcfg, baseui._ucfg):
117 117 if b'extensions' in cfg:
118 118 if b'hgsubversion' in cfg[b'extensions']:
119 119 del cfg[b'extensions'][b'hgsubversion']
120 120
121 121
122 122 def make_hg_ui_from_config(repo_config):
123 123 baseui = mercurial.ui.ui()
124 124
125 125 # clean the baseui object
126 126 baseui._ocfg = mercurial.config.config()
127 127 baseui._ucfg = mercurial.config.config()
128 128 baseui._tcfg = mercurial.config.config()
129 129
130 130 for section, option, value in repo_config:
131 131 baseui.setconfig(
132 132 ascii_bytes(section, allow_bytes=True),
133 133 ascii_bytes(option, allow_bytes=True),
134 134 ascii_bytes(value, allow_bytes=True))
135 135
136 136 # make our hgweb quiet so it doesn't print output
137 137 baseui.setconfig(b'ui', b'quiet', b'true')
138 138
139 139 return baseui
140 140
141 141
142 142 def update_hg_ui_from_hgrc(baseui, repo_path):
143 143 path = os.path.join(repo_path, '.hg', 'hgrc')
144 144
145 145 if not os.path.isfile(path):
146 146 log.debug('hgrc file is not present at %s, skipping...', path)
147 147 return
148 148 log.debug('reading hgrc from %s', path)
149 149 cfg = mercurial.config.config()
150 150 cfg.read(ascii_bytes(path))
151 151 for section in HG_UI_SECTIONS:
152 152 for k, v in cfg.items(section):
153 153 log.debug('settings ui from file: [%s] %s=%s', section, k, v)
154 154 baseui.setconfig(
155 155 ascii_bytes(section, allow_bytes=True),
156 156 ascii_bytes(k, allow_bytes=True),
157 157 ascii_bytes(v, allow_bytes=True))
158 158
159 159
160 160 def create_hg_wsgi_app(repo_path, repo_name, config):
161 161 """
162 162 Prepares a WSGI application to handle Mercurial requests.
163 163
164 164 :param config: is a list of 3-item tuples representing a ConfigObject
165 165 (it is the serialized version of the config object).
166 166 """
167 167 log.debug("Creating Mercurial WSGI application")
168 168
169 169 baseui = make_hg_ui_from_config(config)
170 170 update_hg_ui_from_hgrc(baseui, repo_path)
171 171 sanitize_hg_ui(baseui)
172 172
173 173 try:
174 174 return HgWeb(safe_bytes(repo_path), name=safe_bytes(repo_name), baseui=baseui)
175 175 except mercurial.error.RequirementError as e:
176 176 raise exceptions.RequirementException(e)(e)
177 177
178 178
179 179 class GitHandler:
180 180 """
181 181 Handler for Git operations like push/pull etc
182 182 """
183 183 def __init__(self, repo_location, repo_name, git_path, update_server_info,
184 184 extras):
185 185 if not os.path.isdir(repo_location):
186 186 raise OSError(repo_location)
187 187 self.content_path = repo_location
188 188 self.repo_name = repo_name
189 189 self.repo_location = repo_location
190 190 self.extras = extras
191 191 self.git_path = git_path
192 192 self.update_server_info = update_server_info
193 193
194 194 def __call__(self, environ, start_response):
195 195 app = webob.exc.HTTPNotFound()
196 196 candidate_paths = (
197 197 self.content_path, os.path.join(self.content_path, '.git'))
198 198
199 199 for content_path in candidate_paths:
200 200 try:
201 201 app = pygrack.GitRepository(
202 202 self.repo_name, content_path, self.git_path,
203 203 self.update_server_info, self.extras)
204 204 break
205 205 except OSError:
206 206 continue
207 207
208 208 return app(environ, start_response)
209 209
210 210
211 211 def create_git_wsgi_app(repo_path, repo_name, config):
212 212 """
213 213 Creates a WSGI application to handle Git requests.
214 214
215 215 :param config: is a dictionary holding the extras.
216 216 """
217 217 git_path = settings.GIT_EXECUTABLE()
218 218 update_server_info = config.pop('git_update_server_info')
219 219 app = GitHandler(
220 220 repo_path, repo_name, git_path, update_server_info, config)
221 221
222 222 return app
223 223
224 224
225 225 class GitLFSHandler:
226 226 """
227 227 Handler for Git LFS operations
228 228 """
229 229
230 230 def __init__(self, repo_location, repo_name, git_path, update_server_info,
231 231 extras):
232 232 if not os.path.isdir(repo_location):
233 233 raise OSError(repo_location)
234 234 self.content_path = repo_location
235 235 self.repo_name = repo_name
236 236 self.repo_location = repo_location
237 237 self.extras = extras
238 238 self.git_path = git_path
239 239 self.update_server_info = update_server_info
240 240
241 241 def get_app(self, git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
242 242 app = git_lfs.create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
243 243 return app
244 244
245 245
246 246 def create_git_lfs_wsgi_app(repo_path, repo_name, config):
247 247 git_path = settings.GIT_EXECUTABLE()
248 248 update_server_info = config.pop('git_update_server_info')
249 249 git_lfs_enabled = config.pop('git_lfs_enabled')
250 250 git_lfs_store_path = config.pop('git_lfs_store_path')
251 251 git_lfs_http_scheme = config.pop('git_lfs_http_scheme', 'http')
252 252 app = GitLFSHandler(
253 253 repo_path, repo_name, git_path, update_server_info, config)
254 254
255 255 return app.get_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
@@ -1,563 +1,563 b''
1 1 """
2 2 Module provides a class allowing to wrap communication over subprocess.Popen
3 3 input, output, error streams into a meaningfull, non-blocking, concurrent
4 4 stream processor exposing the output data as an iterator fitting to be a
5 5 return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
6 6
7 7 Copyright (c) 2011 Daniel Dotsenko <dotsa[at]hotmail.com>
8 8
9 9 This file is part of git_http_backend.py Project.
10 10
11 11 git_http_backend.py Project is free software: you can redistribute it and/or
12 12 modify it under the terms of the GNU Lesser General Public License as
13 13 published by the Free Software Foundation, either version 2.1 of the License,
14 14 or (at your option) any later version.
15 15
16 16 git_http_backend.py Project is distributed in the hope that it will be useful,
17 17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 19 GNU Lesser General Public License for more details.
20 20
21 21 You should have received a copy of the GNU Lesser General Public License
22 22 along with git_http_backend.py Project.
23 23 If not, see <http://www.gnu.org/licenses/>.
24 24 """
25 25 import os
26 26 import collections
27 27 import logging
28 28 import subprocess
29 29 import threading
30 30
31 from vcsserver.str_utils import safe_str
31 from vcsserver.lib.str_utils import safe_str
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 36 class StreamFeeder(threading.Thread):
37 37 """
38 38 Normal writing into pipe-like is blocking once the buffer is filled.
39 39 This thread allows a thread to seep data from a file-like into a pipe
40 40 without blocking the main thread.
41 41 We close inpipe once the end of the source stream is reached.
42 42 """
43 43
44 44 def __init__(self, source):
45 45 super().__init__()
46 46 self.daemon = True
47 47 filelike = False
48 48 self.bytes = b''
49 49 if type(source) in (str, bytes, bytearray): # string-like
50 50 self.bytes = bytes(source)
51 51 else: # can be either file pointer or file-like
52 52 if isinstance(source, int): # file pointer it is
53 53 # converting file descriptor (int) stdin into file-like
54 54 source = os.fdopen(source, 'rb', 16384)
55 55 # let's see if source is file-like by now
56 56 filelike = hasattr(source, 'read')
57 57 if not filelike and not self.bytes:
58 58 raise TypeError("StreamFeeder's source object must be a readable "
59 59 "file-like, a file descriptor, or a string-like.")
60 60 self.source = source
61 61 self.readiface, self.writeiface = os.pipe()
62 62
63 63 def run(self):
64 64 writer = self.writeiface
65 65 try:
66 66 if self.bytes:
67 67 os.write(writer, self.bytes)
68 68 else:
69 69 s = self.source
70 70
71 71 while 1:
72 72 _bytes = s.read(4096)
73 73 if not _bytes:
74 74 break
75 75 os.write(writer, _bytes)
76 76
77 77 finally:
78 78 os.close(writer)
79 79
80 80 @property
81 81 def output(self):
82 82 return self.readiface
83 83
84 84
85 85 class InputStreamChunker(threading.Thread):
86 86 def __init__(self, source, target, buffer_size, chunk_size):
87 87
88 88 super().__init__()
89 89
90 90 self.daemon = True # die die die.
91 91
92 92 self.source = source
93 93 self.target = target
94 94 self.chunk_count_max = int(buffer_size / chunk_size) + 1
95 95 self.chunk_size = chunk_size
96 96
97 97 self.data_added = threading.Event()
98 98 self.data_added.clear()
99 99
100 100 self.keep_reading = threading.Event()
101 101 self.keep_reading.set()
102 102
103 103 self.EOF = threading.Event()
104 104 self.EOF.clear()
105 105
106 106 self.go = threading.Event()
107 107 self.go.set()
108 108
109 109 def stop(self):
110 110 self.go.clear()
111 111 self.EOF.set()
112 112 try:
113 113 # this is not proper, but is done to force the reader thread let
114 114 # go of the input because, if successful, .close() will send EOF
115 115 # down the pipe.
116 116 self.source.close()
117 117 except Exception:
118 118 pass
119 119
120 120 def run(self):
121 121 s = self.source
122 122 t = self.target
123 123 cs = self.chunk_size
124 124 chunk_count_max = self.chunk_count_max
125 125 keep_reading = self.keep_reading
126 126 da = self.data_added
127 127 go = self.go
128 128
129 129 try:
130 130 b = s.read(cs)
131 131 except ValueError:
132 132 b = ''
133 133
134 134 timeout_input = 20
135 135 while b and go.is_set():
136 136 if len(t) > chunk_count_max:
137 137 keep_reading.clear()
138 138 keep_reading.wait(timeout_input)
139 139 if len(t) > chunk_count_max + timeout_input:
140 140 log.error("Timed out while waiting for input from subprocess.")
141 141 os._exit(-1) # this will cause the worker to recycle itself
142 142
143 143 t.append(b)
144 144 da.set()
145 145
146 146 try:
147 147 b = s.read(cs)
148 148 except ValueError: # probably "I/O operation on closed file"
149 149 b = ''
150 150
151 151 self.EOF.set()
152 152 da.set() # for cases when done but there was no input.
153 153
154 154
155 155 class BufferedGenerator:
156 156 """
157 157 Class behaves as a non-blocking, buffered pipe reader.
158 158 Reads chunks of data (through a thread)
159 159 from a blocking pipe, and attaches these to an array (Deque) of chunks.
160 160 Reading is halted in the thread when max chunks is internally buffered.
161 161 The .next() may operate in blocking or non-blocking fashion by yielding
162 162 '' if no data is ready
163 163 to be sent or by not returning until there is some data to send
164 164 When we get EOF from underlying source pipe we raise the marker to raise
165 165 StopIteration after the last chunk of data is yielded.
166 166 """
167 167
168 168 def __init__(self, name, source, buffer_size=65536, chunk_size=4096,
169 169 starting_values=None, bottomless=False):
170 170 starting_values = starting_values or []
171 171 self.name = name
172 172 self.buffer_size = buffer_size
173 173 self.chunk_size = chunk_size
174 174
175 175 if bottomless:
176 176 maxlen = int(buffer_size / chunk_size)
177 177 else:
178 178 maxlen = None
179 179
180 180 self.data_queue = collections.deque(starting_values, maxlen)
181 181 self.worker = InputStreamChunker(source, self.data_queue, buffer_size, chunk_size)
182 182 if starting_values:
183 183 self.worker.data_added.set()
184 184 self.worker.start()
185 185
186 186 ####################
187 187 # Generator's methods
188 188 ####################
189 189 def __str__(self):
190 190 return f'BufferedGenerator(name={self.name} chunk: {self.chunk_size} on buffer: {self.buffer_size})'
191 191
192 192 def __iter__(self):
193 193 return self
194 194
195 195 def __next__(self):
196 196
197 197 while not self.length and not self.worker.EOF.is_set():
198 198 self.worker.data_added.clear()
199 199 self.worker.data_added.wait(0.2)
200 200
201 201 if self.length:
202 202 self.worker.keep_reading.set()
203 203 return bytes(self.data_queue.popleft())
204 204 elif self.worker.EOF.is_set():
205 205 raise StopIteration
206 206
207 207 def throw(self, exc_type, value=None, traceback=None):
208 208 if not self.worker.EOF.is_set():
209 209 raise exc_type(value)
210 210
211 211 def start(self):
212 212 self.worker.start()
213 213
214 214 def stop(self):
215 215 self.worker.stop()
216 216
217 217 def close(self):
218 218 try:
219 219 self.worker.stop()
220 220 self.throw(GeneratorExit)
221 221 except (GeneratorExit, StopIteration):
222 222 pass
223 223
224 224 ####################
225 225 # Threaded reader's infrastructure.
226 226 ####################
227 227 @property
228 228 def input(self):
229 229 return self.worker.w
230 230
231 231 @property
232 232 def data_added_event(self):
233 233 return self.worker.data_added
234 234
235 235 @property
236 236 def data_added(self):
237 237 return self.worker.data_added.is_set()
238 238
239 239 @property
240 240 def reading_paused(self):
241 241 return not self.worker.keep_reading.is_set()
242 242
243 243 @property
244 244 def done_reading_event(self):
245 245 """
246 246 Done_reding does not mean that the iterator's buffer is empty.
247 247 Iterator might have done reading from underlying source, but the read
248 248 chunks might still be available for serving through .next() method.
249 249
250 250 :returns: An Event class instance.
251 251 """
252 252 return self.worker.EOF
253 253
254 254 @property
255 255 def done_reading(self):
256 256 """
257 257 Done_reading does not mean that the iterator's buffer is empty.
258 258 Iterator might have done reading from underlying source, but the read
259 259 chunks might still be available for serving through .next() method.
260 260
261 261 :returns: An Bool value.
262 262 """
263 263 return self.worker.EOF.is_set()
264 264
265 265 @property
266 266 def length(self):
267 267 """
268 268 returns int.
269 269
270 270 This is the length of the queue of chunks, not the length of
271 271 the combined contents in those chunks.
272 272
273 273 __len__() cannot be meaningfully implemented because this
274 274 reader is just flying through a bottomless pit content and
275 275 can only know the length of what it already saw.
276 276
277 277 If __len__() on WSGI server per PEP 3333 returns a value,
278 278 the response's length will be set to that. In order not to
279 279 confuse WSGI PEP3333 servers, we will not implement __len__
280 280 at all.
281 281 """
282 282 return len(self.data_queue)
283 283
284 284 def prepend(self, x):
285 285 self.data_queue.appendleft(x)
286 286
287 287 def append(self, x):
288 288 self.data_queue.append(x)
289 289
290 290 def extend(self, o):
291 291 self.data_queue.extend(o)
292 292
293 293 def __getitem__(self, i):
294 294 return self.data_queue[i]
295 295
296 296
297 297 class SubprocessIOChunker:
298 298 """
299 299 Processor class wrapping handling of subprocess IO.
300 300
301 301 .. important::
302 302
303 303 Watch out for the method `__del__` on this class. If this object
304 304 is deleted, it will kill the subprocess, so avoid to
305 305 return the `output` attribute or usage of it like in the following
306 306 example::
307 307
308 308 # `args` expected to run a program that produces a lot of output
309 309 output = ''.join(SubprocessIOChunker(
310 310 args, shell=False, inputstream=inputstream, env=environ).output)
311 311
312 312 # `output` will not contain all the data, because the __del__ method
313 313 # has already killed the subprocess in this case before all output
314 314 # has been consumed.
315 315
316 316
317 317
318 318 In a way, this is a "communicate()" replacement with a twist.
319 319
320 320 - We are multithreaded. Writing in and reading out, err are all sep threads.
321 321 - We support concurrent (in and out) stream processing.
322 322 - The output is not a stream. It's a queue of read string (bytes, not str)
323 323 chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
324 324 - We are non-blocking in more respects than communicate()
325 325 (reading from subprocess out pauses when internal buffer is full, but
326 326 does not block the parent calling code. On the flip side, reading from
327 327 slow-yielding subprocess may block the iteration until data shows up. This
328 328 does not block the parallel inpipe reading occurring parallel thread.)
329 329
330 330 The purpose of the object is to allow us to wrap subprocess interactions into
331 331 an iterable that can be passed to a WSGI server as the application's return
332 332 value. Because of stream-processing-ability, WSGI does not have to read ALL
333 333 of the subprocess's output and buffer it, before handing it to WSGI server for
334 334 HTTP response. Instead, the class initializer reads just a bit of the stream
335 335 to figure out if error occurred or likely to occur and if not, just hands the
336 336 further iteration over subprocess output to the server for completion of HTTP
337 337 response.
338 338
339 339 The real or perceived subprocess error is trapped and raised as one of
340 340 OSError family of exceptions
341 341
342 342 Example usage:
343 343 # try:
344 344 # answer = SubprocessIOChunker(
345 345 # cmd,
346 346 # input,
347 347 # buffer_size = 65536,
348 348 # chunk_size = 4096
349 349 # )
350 350 # except (OSError) as e:
351 351 # print str(e)
352 352 # raise e
353 353 #
354 354 # return answer
355 355
356 356
357 357 """
358 358
359 359 # TODO: johbo: This is used to make sure that the open end of the PIPE
360 360 # is closed in the end. It would be way better to wrap this into an
361 361 # object, so that it is closed automatically once it is consumed or
362 362 # something similar.
363 363 _close_input_fd = None
364 364
365 365 _closed = False
366 366 _stdout = None
367 367 _stderr = None
368 368
369 369 def __init__(self, cmd, input_stream=None, buffer_size=65536,
370 370 chunk_size=4096, starting_values=None, fail_on_stderr=True,
371 371 fail_on_return_code=True, **kwargs):
372 372 """
373 373 Initializes SubprocessIOChunker
374 374
375 375 :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
376 376 :param input_stream: (Default: None) A file-like, string, or file pointer.
377 377 :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
378 378 :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
379 379 :param starting_values: (Default: []) An array of strings to put in front of output que.
380 380 :param fail_on_stderr: (Default: True) Whether to raise an exception in
381 381 case something is written to stderr.
382 382 :param fail_on_return_code: (Default: True) Whether to raise an
383 383 exception if the return code is not 0.
384 384 """
385 385
386 386 kwargs['shell'] = kwargs.get('shell', True)
387 387
388 388 starting_values = starting_values or []
389 389 if input_stream:
390 390 input_streamer = StreamFeeder(input_stream)
391 391 input_streamer.start()
392 392 input_stream = input_streamer.output
393 393 self._close_input_fd = input_stream
394 394
395 395 self._fail_on_stderr = fail_on_stderr
396 396 self._fail_on_return_code = fail_on_return_code
397 397 self.cmd = cmd
398 398
399 399 _p = subprocess.Popen(cmd, bufsize=-1, stdin=input_stream, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
400 400 **kwargs)
401 401 self.process = _p
402 402
403 403 bg_out = BufferedGenerator('stdout', _p.stdout, buffer_size, chunk_size, starting_values)
404 404 bg_err = BufferedGenerator('stderr', _p.stderr, 10240, 1, bottomless=True)
405 405
406 406 while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
407 407 # doing this until we reach either end of file, or end of buffer.
408 408 bg_out.data_added_event.wait(0.2)
409 409 bg_out.data_added_event.clear()
410 410
411 411 # at this point it's still ambiguous if we are done reading or just full buffer.
412 412 # Either way, if error (returned by ended process, or implied based on
413 413 # presence of stuff in stderr output) we error out.
414 414 # Else, we are happy.
415 415 return_code = _p.poll()
416 416 ret_code_ok = return_code in [None, 0]
417 417 ret_code_fail = return_code is not None and return_code != 0
418 418 if (
419 419 (ret_code_fail and fail_on_return_code) or
420 420 (ret_code_ok and fail_on_stderr and bg_err.length)
421 421 ):
422 422
423 423 try:
424 424 _p.terminate()
425 425 except Exception:
426 426 pass
427 427
428 428 bg_out.stop()
429 429 out = b''.join(bg_out)
430 430 self._stdout = out
431 431
432 432 bg_err.stop()
433 433 err = b''.join(bg_err)
434 434 self._stderr = err
435 435
436 436 # code from https://github.com/schacon/grack/pull/7
437 437 if err.strip() == b'fatal: The remote end hung up unexpectedly' and out.startswith(b'0034shallow '):
438 438 bg_out = iter([out])
439 439 _p = None
440 440 elif err and fail_on_stderr:
441 441 text_err = err.decode()
442 442 raise OSError(
443 443 f"Subprocess exited due to an error:\n{text_err}")
444 444
445 445 if ret_code_fail and fail_on_return_code:
446 446 text_err = err.decode()
447 447 if not err:
448 448 # maybe get empty stderr, try stdout instead
449 449 # in many cases git reports the errors on stdout too
450 450 text_err = out.decode()
451 451 raise OSError(
452 452 f"Subprocess exited with non 0 ret code:{return_code}: stderr:{text_err}")
453 453
454 454 self.stdout = bg_out
455 455 self.stderr = bg_err
456 456 self.inputstream = input_stream
457 457
458 458 def __str__(self):
459 459 proc = getattr(self, 'process', 'NO_PROCESS')
460 460 return f'SubprocessIOChunker: {proc}'
461 461
462 462 def __iter__(self):
463 463 return self
464 464
465 465 def __next__(self):
466 466 # Note: mikhail: We need to be sure that we are checking the return
467 467 # code after the stdout stream is closed. Some processes, e.g. git
468 468 # are doing some magic in between closing stdout and terminating the
469 469 # process and, as a result, we are not getting return code on "slow"
470 470 # systems.
471 471 result = None
472 472 stop_iteration = None
473 473 try:
474 474 result = next(self.stdout)
475 475 except StopIteration as e:
476 476 stop_iteration = e
477 477
478 478 if self.process:
479 479 return_code = self.process.poll()
480 480 ret_code_fail = return_code is not None and return_code != 0
481 481 if ret_code_fail and self._fail_on_return_code:
482 482 self.stop_streams()
483 483 err = self.get_stderr()
484 484 raise OSError(
485 485 f"Subprocess exited (exit_code:{return_code}) due to an error during iteration:\n{err}")
486 486
487 487 if stop_iteration:
488 488 raise stop_iteration
489 489 return result
490 490
491 491 def throw(self, exc_type, value=None, traceback=None):
492 492 if self.stdout.length or not self.stdout.done_reading:
493 493 raise exc_type(value)
494 494
495 495 def close(self):
496 496 if self._closed:
497 497 return
498 498
499 499 try:
500 500 self.process.terminate()
501 501 except Exception:
502 502 pass
503 503 if self._close_input_fd:
504 504 os.close(self._close_input_fd)
505 505 try:
506 506 self.stdout.close()
507 507 except Exception:
508 508 pass
509 509 try:
510 510 self.stderr.close()
511 511 except Exception:
512 512 pass
513 513 try:
514 514 os.close(self.inputstream)
515 515 except Exception:
516 516 pass
517 517
518 518 self._closed = True
519 519
520 520 def stop_streams(self):
521 521 getattr(self.stdout, 'stop', lambda: None)()
522 522 getattr(self.stderr, 'stop', lambda: None)()
523 523
524 524 def get_stdout(self):
525 525 if self._stdout:
526 526 return self._stdout
527 527 else:
528 528 return b''.join(self.stdout)
529 529
530 530 def get_stderr(self):
531 531 if self._stderr:
532 532 return self._stderr
533 533 else:
534 534 return b''.join(self.stderr)
535 535
536 536
537 537 def run_command(arguments, env=None):
538 538 """
539 539 Run the specified command and return the stdout.
540 540
541 541 :param arguments: sequence of program arguments (including the program name)
542 542 :type arguments: list[str]
543 543 """
544 544
545 545 cmd = arguments
546 546 log.debug('Running subprocessio command %s', cmd)
547 547 proc = None
548 548 try:
549 549 _opts = {'shell': False, 'fail_on_stderr': False}
550 550 if env:
551 551 _opts.update({'env': env})
552 552 proc = SubprocessIOChunker(cmd, **_opts)
553 553 return b''.join(proc), b''.join(proc.stderr)
554 554 except OSError as err:
555 555 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
556 556 tb_err = ("Couldn't run subprocessio command (%s).\n"
557 557 "Original error was:%s\n" % (cmd, err))
558 558 log.exception(tb_err)
559 559 raise Exception(tb_err)
560 560 finally:
561 561 if proc:
562 562 proc.close()
563 563
@@ -1,289 +1,289 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import sys
20 20 import stat
21 21 import pytest
22 22 import vcsserver
23 23 import tempfile
24 24 from vcsserver import hook_utils
25 25 from vcsserver.hook_utils import set_permissions_if_needed, HOOKS_DIR_MODE, HOOKS_FILE_MODE
26 26 from vcsserver.tests.fixture import no_newline_id_generator
27 from vcsserver.str_utils import safe_bytes
27 from vcsserver.lib.str_utils import safe_bytes
28 28 from vcsserver.utils import AttributeDict
29 29
30 30
31 31 class TestCheckRhodecodeHook:
32 32
33 33 def test_returns_false_when_hook_file_is_wrong_found(self, tmpdir):
34 34 hook = os.path.join(str(tmpdir), 'fake_hook_file.py')
35 35 with open(hook, 'wb') as f:
36 36 f.write(b'dummy test')
37 37 result = hook_utils.check_rhodecode_hook(hook)
38 38 assert result is False
39 39
40 40 def test_returns_true_when_no_hook_file_found(self, tmpdir):
41 41 hook = os.path.join(str(tmpdir), 'fake_hook_file_not_existing.py')
42 42 result = hook_utils.check_rhodecode_hook(hook)
43 43 assert result
44 44
45 45 @pytest.mark.parametrize("file_content, expected_result", [
46 46 ("RC_HOOK_VER = '3.3.3'\n", True),
47 47 ("RC_HOOK = '3.3.3'\n", False),
48 48 ], ids=no_newline_id_generator)
49 49 def test_signatures(self, file_content, expected_result, tmpdir):
50 50 hook = os.path.join(str(tmpdir), 'fake_hook_file_1.py')
51 51 with open(hook, 'wb') as f:
52 52 f.write(safe_bytes(file_content))
53 53
54 54 result = hook_utils.check_rhodecode_hook(hook)
55 55
56 56 assert result is expected_result
57 57
58 58
59 59 class BaseInstallHooks:
60 60 HOOK_FILES = ()
61 61
62 62 def _check_hook_file_dir_mode(self, file_path):
63 63 dir_path = os.path.dirname(file_path)
64 64 assert os.path.exists(dir_path), f'dir {file_path} missing'
65 65 stat_info = os.stat(dir_path)
66 66
67 67 file_mode = stat.S_IMODE(stat_info.st_mode)
68 68 expected_mode = int(HOOKS_DIR_MODE)
69 69 assert expected_mode == file_mode, f'expected mode: {oct(expected_mode)} got: {oct(file_mode)} for {dir_path}'
70 70
71 71 def _check_hook_file_mode(self, file_path):
72 72 assert os.path.exists(file_path), f'path {file_path} missing'
73 73 stat_info = os.stat(file_path)
74 74
75 75 file_mode = stat.S_IMODE(stat_info.st_mode)
76 76 expected_mode = int(HOOKS_FILE_MODE)
77 77 assert expected_mode == file_mode, f'expected mode: {oct(expected_mode)} got: {oct(file_mode)} for {file_path}'
78 78
79 79 def _check_hook_file_content(self, file_path, executable):
80 80 executable = executable or sys.executable
81 81 with open(file_path, 'rt') as hook_file:
82 82 content = hook_file.read()
83 83
84 84 expected_env = '#!{}'.format(executable)
85 85 expected_rc_version = "\nRC_HOOK_VER = '{}'\n".format(vcsserver.get_version())
86 86 assert content.strip().startswith(expected_env)
87 87 assert expected_rc_version in content
88 88
89 89 def _create_fake_hook(self, file_path, content):
90 90 with open(file_path, 'w') as hook_file:
91 91 hook_file.write(content)
92 92
93 93 def create_dummy_repo(self, repo_type):
94 94 tmpdir = tempfile.mkdtemp()
95 95 repo = AttributeDict()
96 96 if repo_type == 'git':
97 97 repo.path = os.path.join(tmpdir, 'test_git_hooks_installation_repo')
98 98 os.makedirs(repo.path)
99 99 os.makedirs(os.path.join(repo.path, 'hooks'))
100 100 repo.bare = True
101 101
102 102 elif repo_type == 'svn':
103 103 repo.path = os.path.join(tmpdir, 'test_svn_hooks_installation_repo')
104 104 os.makedirs(repo.path)
105 105 os.makedirs(os.path.join(repo.path, 'hooks'))
106 106
107 107 return repo
108 108
109 109 def check_hooks(self, repo_path, repo_bare=True):
110 110 for file_name in self.HOOK_FILES:
111 111 if repo_bare:
112 112 file_path = os.path.join(repo_path, 'hooks', file_name)
113 113 else:
114 114 file_path = os.path.join(repo_path, '.git', 'hooks', file_name)
115 115
116 116 self._check_hook_file_dir_mode(file_path)
117 117 self._check_hook_file_mode(file_path)
118 118 self._check_hook_file_content(file_path, sys.executable)
119 119
120 120
121 121 class TestInstallGitHooks(BaseInstallHooks):
122 122 HOOK_FILES = ('pre-receive', 'post-receive')
123 123
124 124 def test_hooks_are_installed(self):
125 125 repo = self.create_dummy_repo('git')
126 126 result = hook_utils.install_git_hooks(repo.path, repo.bare)
127 127 assert result
128 128 self.check_hooks(repo.path, repo.bare)
129 129
130 130 def test_hooks_are_replaced(self):
131 131 repo = self.create_dummy_repo('git')
132 132 hooks_path = os.path.join(repo.path, 'hooks')
133 133 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
134 134 self._create_fake_hook(
135 135 file_path, content="RC_HOOK_VER = 'abcde'\n")
136 136
137 137 result = hook_utils.install_git_hooks(repo.path, repo.bare)
138 138 assert result
139 139 self.check_hooks(repo.path, repo.bare)
140 140
141 141 def test_non_rc_hooks_are_not_replaced(self):
142 142 repo = self.create_dummy_repo('git')
143 143 hooks_path = os.path.join(repo.path, 'hooks')
144 144 non_rc_content = 'echo "non rc hook"\n'
145 145 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
146 146 self._create_fake_hook(
147 147 file_path, content=non_rc_content)
148 148
149 149 result = hook_utils.install_git_hooks(repo.path, repo.bare)
150 150 assert result
151 151
152 152 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
153 153 with open(file_path, 'rt') as hook_file:
154 154 content = hook_file.read()
155 155 assert content == non_rc_content
156 156
157 157 def test_non_rc_hooks_are_replaced_with_force_flag(self):
158 158 repo = self.create_dummy_repo('git')
159 159 hooks_path = os.path.join(repo.path, 'hooks')
160 160 non_rc_content = 'echo "non rc hook"\n'
161 161 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
162 162 self._create_fake_hook(
163 163 file_path, content=non_rc_content)
164 164
165 165 result = hook_utils.install_git_hooks(
166 166 repo.path, repo.bare, force_create=True)
167 167 assert result
168 168 self.check_hooks(repo.path, repo.bare)
169 169
170 170
171 171 class TestInstallSvnHooks(BaseInstallHooks):
172 172 HOOK_FILES = ('pre-commit', 'post-commit')
173 173
174 174 def test_hooks_are_installed(self):
175 175 repo = self.create_dummy_repo('svn')
176 176 result = hook_utils.install_svn_hooks(repo.path)
177 177 assert result
178 178 self.check_hooks(repo.path)
179 179
180 180 def test_hooks_are_replaced(self):
181 181 repo = self.create_dummy_repo('svn')
182 182 hooks_path = os.path.join(repo.path, 'hooks')
183 183 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
184 184 self._create_fake_hook(
185 185 file_path, content="RC_HOOK_VER = 'abcde'\n")
186 186
187 187 result = hook_utils.install_svn_hooks(repo.path)
188 188 assert result
189 189 self.check_hooks(repo.path)
190 190
191 191 def test_non_rc_hooks_are_not_replaced(self):
192 192 repo = self.create_dummy_repo('svn')
193 193 hooks_path = os.path.join(repo.path, 'hooks')
194 194 non_rc_content = 'echo "non rc hook"\n'
195 195 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
196 196 self._create_fake_hook(
197 197 file_path, content=non_rc_content)
198 198
199 199 result = hook_utils.install_svn_hooks(repo.path)
200 200 assert result
201 201
202 202 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
203 203 with open(file_path, 'rt') as hook_file:
204 204 content = hook_file.read()
205 205 assert content == non_rc_content
206 206
207 207 def test_non_rc_hooks_are_replaced_with_force_flag(self):
208 208 repo = self.create_dummy_repo('svn')
209 209 hooks_path = os.path.join(repo.path, 'hooks')
210 210 non_rc_content = 'echo "non rc hook"\n'
211 211 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
212 212 self._create_fake_hook(
213 213 file_path, content=non_rc_content)
214 214
215 215 result = hook_utils.install_svn_hooks(
216 216 repo.path, force_create=True)
217 217 assert result
218 218 self.check_hooks(repo.path, )
219 219
220 220
221 221 def create_test_file(filename):
222 222 """Utility function to create a test file."""
223 223 with open(filename, 'w') as f:
224 224 f.write("Test file")
225 225
226 226
227 227 def remove_test_file(filename):
228 228 """Utility function to remove a test file."""
229 229 if os.path.exists(filename):
230 230 os.remove(filename)
231 231
232 232
233 233 @pytest.fixture
234 234 def test_file():
235 235 filename = 'test_file.txt'
236 236 create_test_file(filename)
237 237 yield filename
238 238 remove_test_file(filename)
239 239
240 240
241 241 def test_increase_permissions(test_file):
242 242 # Set initial lower permissions
243 243 initial_perms = 0o644
244 244 os.chmod(test_file, initial_perms)
245 245
246 246 # Set higher permissions
247 247 new_perms = 0o666
248 248 set_permissions_if_needed(test_file, new_perms)
249 249
250 250 # Check if permissions were updated
251 251 assert (os.stat(test_file).st_mode & 0o777) == new_perms
252 252
253 253
254 254 def test_no_permission_change_needed(test_file):
255 255 # Set initial permissions
256 256 initial_perms = 0o666
257 257 os.chmod(test_file, initial_perms)
258 258
259 259 # Attempt to set the same permissions
260 260 set_permissions_if_needed(test_file, initial_perms)
261 261
262 262 # Check if permissions were unchanged
263 263 assert (os.stat(test_file).st_mode & 0o777) == initial_perms
264 264
265 265
266 266 def test_no_permission_reduction(test_file):
267 267 # Set initial higher permissions
268 268 initial_perms = 0o666
269 269 os.chmod(test_file, initial_perms)
270 270
271 271 # Attempt to set lower permissions
272 272 lower_perms = 0o644
273 273 set_permissions_if_needed(test_file, lower_perms)
274 274
275 275 # Check if permissions were not reduced
276 276 assert (os.stat(test_file).st_mode & 0o777) == initial_perms
277 277
278 278
279 279 def test_no_permission_reduction_when_on_777(test_file):
280 280 # Set initial higher permissions
281 281 initial_perms = 0o777
282 282 os.chmod(test_file, initial_perms)
283 283
284 284 # Attempt to set lower permissions
285 285 lower_perms = 0o755
286 286 set_permissions_if_needed(test_file, lower_perms)
287 287
288 288 # Check if permissions were not reduced
289 289 assert (os.stat(test_file).st_mode & 0o777) == initial_perms
@@ -1,295 +1,295 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import more_itertools
20 20
21 21 import dulwich.protocol
22 22 import mock
23 23 import pytest
24 24 import webob
25 25 import webtest
26 26
27 27 from vcsserver import hooks, pygrack
28 28
29 from vcsserver.str_utils import ascii_bytes
29 from vcsserver.lib.str_utils import ascii_bytes
30 30
31 31
32 32 @pytest.fixture()
33 33 def pygrack_instance(tmpdir):
34 34 """
35 35 Creates a pygrack app instance.
36 36
37 37 Right now, it does not much helpful regarding the passed directory.
38 38 It just contains the required folders to pass the signature test.
39 39 """
40 40 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
41 41 tmpdir.mkdir(dir_name)
42 42
43 43 return pygrack.GitRepository('repo_name', str(tmpdir), 'git', False, {})
44 44
45 45
46 46 @pytest.fixture()
47 47 def pygrack_app(pygrack_instance):
48 48 """
49 49 Creates a pygrack app wrapped in webtest.TestApp.
50 50 """
51 51 return webtest.TestApp(pygrack_instance)
52 52
53 53
54 54 def test_invalid_service_info_refs_returns_403(pygrack_app):
55 55 response = pygrack_app.get('/info/refs?service=git-upload-packs',
56 56 expect_errors=True)
57 57
58 58 assert response.status_int == 403
59 59
60 60
61 61 def test_invalid_endpoint_returns_403(pygrack_app):
62 62 response = pygrack_app.post('/git-upload-packs', expect_errors=True)
63 63
64 64 assert response.status_int == 403
65 65
66 66
67 67 @pytest.mark.parametrize('sideband', [
68 68 'side-band-64k',
69 69 'side-band',
70 70 'side-band no-progress',
71 71 ])
72 72 def test_pre_pull_hook_fails_with_sideband(pygrack_app, sideband):
73 73 request = ''.join([
74 74 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ',
75 75 f'multi_ack {sideband} ofs-delta\n',
76 76 '0000',
77 77 '0009done\n',
78 78 ])
79 79 with mock.patch('vcsserver.hooks.git_pre_pull', return_value=hooks.HookResponse(1, 'foo')):
80 80 response = pygrack_app.post(
81 81 '/git-upload-pack', params=request,
82 82 content_type='application/x-git-upload-pack')
83 83
84 84 data = io.BytesIO(response.body)
85 85 proto = dulwich.protocol.Protocol(data.read, None)
86 86 packets = list(proto.read_pkt_seq())
87 87
88 88 expected_packets = [
89 89 b'NAK\n', b'\x02foo', b'\x02Pre pull hook failed: aborting\n',
90 90 b'\x01' + pygrack.GitRepository.EMPTY_PACK,
91 91 ]
92 92 assert packets == expected_packets
93 93
94 94
95 95 def test_pre_pull_hook_fails_no_sideband(pygrack_app):
96 96 request = ''.join([
97 97 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
98 98 'multi_ack ofs-delta\n'
99 99 '0000',
100 100 '0009done\n',
101 101 ])
102 102 with mock.patch('vcsserver.hooks.git_pre_pull',
103 103 return_value=hooks.HookResponse(1, 'foo')):
104 104 response = pygrack_app.post(
105 105 '/git-upload-pack', params=request,
106 106 content_type='application/x-git-upload-pack')
107 107
108 108 assert response.body == pygrack.GitRepository.EMPTY_PACK
109 109
110 110
111 111 def test_pull_has_hook_messages(pygrack_app):
112 112 request = ''.join([
113 113 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
114 114 'multi_ack side-band-64k ofs-delta\n'
115 115 '0000',
116 116 '0009done\n',
117 117 ])
118 118
119 119 pre_pull = 'pre_pull_output'
120 120 post_pull = 'post_pull_output'
121 121
122 122 with mock.patch('vcsserver.hooks.git_pre_pull',
123 123 return_value=hooks.HookResponse(0, pre_pull)):
124 124 with mock.patch('vcsserver.hooks.git_post_pull',
125 125 return_value=hooks.HookResponse(1, post_pull)):
126 126 with mock.patch('vcsserver.subprocessio.SubprocessIOChunker',
127 127 return_value=more_itertools.always_iterable([b'0008NAK\n0009subp\n0000'])):
128 128 response = pygrack_app.post(
129 129 '/git-upload-pack', params=request,
130 130 content_type='application/x-git-upload-pack')
131 131
132 132 data = io.BytesIO(response.body)
133 133 proto = dulwich.protocol.Protocol(data.read, None)
134 134 packets = list(proto.read_pkt_seq())
135 135
136 136 assert packets == [b'NAK\n',
137 137 # pre-pull only outputs if IT FAILS as in != 0 ret code
138 138 #b'\x02pre_pull_output',
139 139 b'subp\n',
140 140 b'\x02post_pull_output']
141 141
142 142
143 143 def test_get_want_capabilities(pygrack_instance):
144 144 data = io.BytesIO(
145 145 b'0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
146 146 b'multi_ack side-band-64k ofs-delta\n00000009done\n')
147 147
148 148 request = webob.Request({
149 149 'wsgi.input': data,
150 150 'REQUEST_METHOD': 'POST',
151 151 'webob.is_body_seekable': True
152 152 })
153 153
154 154 capabilities = pygrack_instance._get_want_capabilities(request)
155 155
156 156 assert capabilities == frozenset(
157 157 (b'ofs-delta', b'multi_ack', b'side-band-64k'))
158 158 assert data.tell() == 0
159 159
160 160
161 161 @pytest.mark.parametrize('data,capabilities,expected', [
162 162 ('foo', [], []),
163 163 ('', [pygrack.CAPABILITY_SIDE_BAND_64K], []),
164 164 ('', [pygrack.CAPABILITY_SIDE_BAND], []),
165 165 ('foo', [pygrack.CAPABILITY_SIDE_BAND_64K], [b'0008\x02foo']),
166 166 ('foo', [pygrack.CAPABILITY_SIDE_BAND], [b'0008\x02foo']),
167 167 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'03ed\x02' + b'f' * 1000]),
168 168 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995, b'000a\x02fffff']),
169 169 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'fff0\x02' + b'f' * 65515, b'000a\x02fffff']),
170 170 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995] * 65 + [b'0352\x02' + b'f' * 845]),
171 171 ], ids=[
172 172 'foo-empty',
173 173 'empty-64k', 'empty',
174 174 'foo-64k', 'foo',
175 175 'f-1000-64k', 'f-1000',
176 176 'f-65520-64k', 'f-65520'])
177 177 def test_get_messages(pygrack_instance, data, capabilities, expected):
178 178 messages = pygrack_instance._get_messages(data, capabilities)
179 179
180 180 assert messages == expected
181 181
182 182
183 183 @pytest.mark.parametrize('response,capabilities,pre_pull_messages,post_pull_messages', [
184 184 # Unexpected response
185 185 ([b'unexpected_response[no_initial_header]'], [pygrack.CAPABILITY_SIDE_BAND_64K], 'foo', 'bar'),
186 186 # No sideband
187 187 ([b'no-sideband'], [], 'foo', 'bar'),
188 188 # No messages
189 189 ([b'no-messages'], [pygrack.CAPABILITY_SIDE_BAND_64K], '', ''),
190 190 ])
191 191 def test_inject_messages_to_response_nothing_to_do(
192 192 pygrack_instance, response, capabilities, pre_pull_messages, post_pull_messages):
193 193
194 194 new_response = pygrack_instance._build_post_pull_response(
195 195 more_itertools.always_iterable(response), capabilities, pre_pull_messages, post_pull_messages)
196 196
197 197 assert list(new_response) == response
198 198
199 199
200 200 @pytest.mark.parametrize('capabilities', [
201 201 [pygrack.CAPABILITY_SIDE_BAND],
202 202 [pygrack.CAPABILITY_SIDE_BAND_64K],
203 203 ])
204 204 def test_inject_messages_to_response_single_element(pygrack_instance, capabilities):
205 205 response = [b'0008NAK\n0009subp\n0000']
206 206 new_response = pygrack_instance._build_post_pull_response(
207 207 more_itertools.always_iterable(response), capabilities, 'foo', 'bar')
208 208
209 209 expected_response = b''.join([
210 210 b'0008NAK\n',
211 211 b'0008\x02foo',
212 212 b'0009subp\n',
213 213 b'0008\x02bar',
214 214 b'0000'])
215 215
216 216 assert b''.join(new_response) == expected_response
217 217
218 218
219 219 @pytest.mark.parametrize('capabilities', [
220 220 [pygrack.CAPABILITY_SIDE_BAND],
221 221 [pygrack.CAPABILITY_SIDE_BAND_64K],
222 222 ])
223 223 def test_inject_messages_to_response_multi_element(pygrack_instance, capabilities):
224 224 response = more_itertools.always_iterable([
225 225 b'0008NAK\n000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n0000'
226 226 ])
227 227 new_response = pygrack_instance._build_post_pull_response(response, capabilities, 'foo', 'bar')
228 228
229 229 expected_response = b''.join([
230 230 b'0008NAK\n',
231 231 b'0008\x02foo',
232 232 b'000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n',
233 233 b'0008\x02bar',
234 234 b'0000'
235 235 ])
236 236
237 237 assert b''.join(new_response) == expected_response
238 238
239 239
240 240 def test_build_failed_pre_pull_response_no_sideband(pygrack_instance):
241 241 response = pygrack_instance._build_failed_pre_pull_response([], 'foo')
242 242
243 243 assert response == [pygrack.GitRepository.EMPTY_PACK]
244 244
245 245
246 246 @pytest.mark.parametrize('capabilities', [
247 247 [pygrack.CAPABILITY_SIDE_BAND],
248 248 [pygrack.CAPABILITY_SIDE_BAND_64K],
249 249 [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'],
250 250 ])
251 251 def test_build_failed_pre_pull_response(pygrack_instance, capabilities):
252 252 response = pygrack_instance._build_failed_pre_pull_response(capabilities, 'foo')
253 253
254 254 expected_response = [
255 255 b'0008NAK\n', b'0008\x02foo', b'0024\x02Pre pull hook failed: aborting\n',
256 256 b'%04x\x01%s' % (len(pygrack.GitRepository.EMPTY_PACK) + 5, pygrack.GitRepository.EMPTY_PACK),
257 257 pygrack.GitRepository.FLUSH_PACKET,
258 258 ]
259 259
260 260 assert response == expected_response
261 261
262 262
263 263 def test_inject_messages_to_response_generator(pygrack_instance):
264 264
265 265 def response_generator():
266 266 response = [
267 267 # protocol start
268 268 b'0008NAK\n',
269 269 ]
270 270 response += [ascii_bytes(f'000asubp{x}\n') for x in range(1000)]
271 271 response += [
272 272 # protocol end
273 273 pygrack.GitRepository.FLUSH_PACKET
274 274 ]
275 275 for elem in response:
276 276 yield elem
277 277
278 278 new_response = pygrack_instance._build_post_pull_response(
279 279 response_generator(), [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'], 'PRE_PULL_MSG\n', 'POST_PULL_MSG\n')
280 280
281 281 assert iter(new_response)
282 282
283 283 expected_response = b''.join([
284 284 # start
285 285 b'0008NAK\n0012\x02PRE_PULL_MSG\n',
286 286 ] + [
287 287 # ... rest
288 288 ascii_bytes(f'000asubp{x}\n') for x in range(1000)
289 289 ] + [
290 290 # final message,
291 291 b'0013\x02POST_PULL_MSG\n0000',
292 292
293 293 ])
294 294
295 295 assert b''.join(new_response) == expected_response
@@ -1,87 +1,87 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19
20 20 import mercurial.hg
21 21 import mercurial.ui
22 22 import mercurial.error
23 23 import mock
24 24 import pytest
25 25 import webtest
26 26
27 27 from vcsserver import scm_app
28 from vcsserver.str_utils import ascii_bytes
28 from vcsserver.lib.str_utils import ascii_bytes
29 29
30 30
31 31 def test_hg_does_not_accept_invalid_cmd(tmpdir):
32 32 repo = mercurial.hg.repository(mercurial.ui.ui(), ascii_bytes(str(tmpdir)), create=True)
33 33 app = webtest.TestApp(scm_app.HgWeb(repo))
34 34
35 35 response = app.get('/repo?cmd=invalidcmd', expect_errors=True)
36 36
37 37 assert response.status_int == 400
38 38
39 39
40 40 def test_create_hg_wsgi_app_requirement_error(tmpdir):
41 41 repo = mercurial.hg.repository(mercurial.ui.ui(), ascii_bytes(str(tmpdir)), create=True)
42 42 config = (
43 43 ('paths', 'default', ''),
44 44 )
45 45 with mock.patch('vcsserver.scm_app.HgWeb') as hgweb_mock:
46 46 hgweb_mock.side_effect = mercurial.error.RequirementError()
47 47 with pytest.raises(Exception):
48 48 scm_app.create_hg_wsgi_app(str(tmpdir), repo, config)
49 49
50 50
51 51 def test_git_returns_not_found(tmpdir):
52 52 app = webtest.TestApp(
53 53 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
54 54
55 55 response = app.get('/repo_name/inforefs?service=git-upload-pack',
56 56 expect_errors=True)
57 57
58 58 assert response.status_int == 404
59 59
60 60
61 61 def test_git(tmpdir):
62 62 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
63 63 tmpdir.mkdir(dir_name)
64 64
65 65 app = webtest.TestApp(
66 66 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
67 67
68 68 # We set service to git-upload-packs to trigger a 403
69 69 response = app.get('/repo_name/inforefs?service=git-upload-packs',
70 70 expect_errors=True)
71 71
72 72 assert response.status_int == 403
73 73
74 74
75 75 def test_git_fallbacks_to_git_folder(tmpdir):
76 76 tmpdir.mkdir('.git')
77 77 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
78 78 tmpdir.mkdir(os.path.join('.git', dir_name))
79 79
80 80 app = webtest.TestApp(
81 81 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
82 82
83 83 # We set service to git-upload-packs to trigger a 403
84 84 response = app.get('/repo_name/inforefs?service=git-upload-packs',
85 85 expect_errors=True)
86 86
87 87 assert response.status_int == 403
@@ -1,155 +1,155 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21
22 22 import pytest
23 23
24 24 from vcsserver import subprocessio
25 from vcsserver.str_utils import ascii_bytes
25 from vcsserver.lib.str_utils import ascii_bytes
26 26
27 27
28 28 class FileLikeObj: # pragma: no cover
29 29
30 30 def __init__(self, data: bytes, size):
31 31 chunks = size // len(data)
32 32
33 33 self.stream = self._get_stream(data, chunks)
34 34
35 35 def _get_stream(self, data, chunks):
36 36 for x in range(chunks):
37 37 yield data
38 38
39 39 def read(self, n):
40 40
41 41 buffer_stream = b''
42 42 for chunk in self.stream:
43 43 buffer_stream += chunk
44 44 if len(buffer_stream) >= n:
45 45 break
46 46
47 47 # self.stream = self.bytes[n:]
48 48 return buffer_stream
49 49
50 50
51 51 @pytest.fixture(scope='module')
52 52 def environ():
53 53 """Delete coverage variables, as they make the tests fail."""
54 54 env = dict(os.environ)
55 55 for key in list(env.keys()):
56 56 if key.startswith('COV_CORE_'):
57 57 del env[key]
58 58
59 59 return env
60 60
61 61
62 62 def _get_python_args(script):
63 63 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
64 64
65 65
66 66 def test_raise_exception_on_non_zero_return_code(environ):
67 67 call_args = _get_python_args('raise ValueError("fail")')
68 68 with pytest.raises(OSError):
69 69 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
70 70
71 71
72 72 def test_does_not_fail_on_non_zero_return_code(environ):
73 73 call_args = _get_python_args('sys.stdout.write("hello"); sys.exit(1)')
74 74 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
75 75 output = b''.join(proc)
76 76
77 77 assert output == b'hello'
78 78
79 79
80 80 def test_raise_exception_on_stderr(environ):
81 81 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); time.sleep(1);')
82 82
83 83 with pytest.raises(OSError) as excinfo:
84 84 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
85 85
86 86 assert 'exited due to an error:\nWRITE_TO_STDERR' in str(excinfo.value)
87 87
88 88
89 89 def test_does_not_fail_on_stderr(environ):
90 90 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); sys.stderr.flush; time.sleep(2);')
91 91 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_stderr=False, env=environ)
92 92 output = b''.join(proc)
93 93
94 94 assert output == b''
95 95
96 96
97 97 @pytest.mark.parametrize('size', [
98 98 1,
99 99 10 ** 5
100 100 ])
101 101 def test_output_with_no_input(size, environ):
102 102 call_args = _get_python_args(f'sys.stdout.write("X" * {size});')
103 103 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ)
104 104 output = b''.join(proc)
105 105
106 106 assert output == ascii_bytes("X" * size)
107 107
108 108
109 109 @pytest.mark.parametrize('size', [
110 110 1,
111 111 10 ** 5
112 112 ])
113 113 def test_output_with_no_input_does_not_fail(size, environ):
114 114
115 115 call_args = _get_python_args(f'sys.stdout.write("X" * {size}); sys.exit(1)')
116 116 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
117 117 output = b''.join(proc)
118 118
119 119 assert output == ascii_bytes("X" * size)
120 120
121 121
122 122 @pytest.mark.parametrize('size', [
123 123 1,
124 124 10 ** 5
125 125 ])
126 126 def test_output_with_input(size, environ):
127 127 data_len = size
128 128 inputstream = FileLikeObj(b'X', size)
129 129
130 130 # This acts like the cat command.
131 131 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 132 # note: in this tests we explicitly don't assign chunker to a variable and let it stream directly
133 133 output = b''.join(
134 134 subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
135 135 )
136 136
137 137 assert len(output) == data_len
138 138
139 139
140 140 @pytest.mark.parametrize('size', [
141 141 1,
142 142 10 ** 5
143 143 ])
144 144 def test_output_with_input_skipping_iterator(size, environ):
145 145 data_len = size
146 146 inputstream = FileLikeObj(b'X', size)
147 147
148 148 # This acts like the cat command.
149 149 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
150 150
151 151 # Note: assigning the chunker makes sure that it is not deleted too early
152 152 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
153 153 output = b''.join(proc.stdout)
154 154
155 155 assert len(output) == data_len
@@ -1,103 +1,103 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import mock
20 20 import pytest
21 21 import sys
22 22
23 from vcsserver.str_utils import ascii_bytes
23 from vcsserver.lib.str_utils import ascii_bytes
24 24
25 25
26 26 class MockPopen:
27 27 def __init__(self, stderr):
28 28 self.stdout = io.BytesIO(b'')
29 29 self.stderr = io.BytesIO(stderr)
30 30 self.returncode = 1
31 31
32 32 def wait(self):
33 33 pass
34 34
35 35
36 36 INVALID_CERTIFICATE_STDERR = '\n'.join([
37 37 'svnrdump: E230001: Unable to connect to a repository at URL url',
38 38 'svnrdump: E230001: Server SSL certificate verification failed: issuer is not trusted',
39 39 ])
40 40
41 41
42 42 @pytest.mark.parametrize('stderr,expected_reason', [
43 43 (INVALID_CERTIFICATE_STDERR, 'INVALID_CERTIFICATE'),
44 44 ('svnrdump: E123456', 'UNKNOWN:svnrdump: E123456'),
45 45 ], ids=['invalid-cert-stderr', 'svnrdump-err-123456'])
46 46 @pytest.mark.xfail(sys.platform == "cygwin",
47 47 reason="SVN not packaged for Cygwin")
48 48 def test_import_remote_repository_certificate_error(stderr, expected_reason):
49 49 from vcsserver.remote import svn_remote
50 50 factory = mock.Mock()
51 51 factory.repo = mock.Mock(return_value=mock.Mock())
52 52
53 53 remote = svn_remote.SvnRemote(factory)
54 54 remote.is_path_valid_repository = lambda wire, path: True
55 55
56 56 with mock.patch('subprocess.Popen',
57 57 return_value=MockPopen(ascii_bytes(stderr))):
58 58 with pytest.raises(Exception) as excinfo:
59 59 remote.import_remote_repository({'path': 'path'}, 'url')
60 60
61 61 expected_error_args = 'Failed to dump the remote repository from url. Reason:{}'.format(expected_reason)
62 62
63 63 assert excinfo.value.args[0] == expected_error_args
64 64
65 65
66 66 def test_svn_libraries_can_be_imported():
67 67 import svn.client # noqa
68 68 assert svn.client is not None
69 69
70 70
71 71 @pytest.mark.parametrize('example_url, parts', [
72 72 ('http://server.com', ('', '', 'http://server.com')),
73 73 ('http://user@server.com', ('user', '', 'http://user@server.com')),
74 74 ('http://user:pass@server.com', ('user', 'pass', 'http://user:pass@server.com')),
75 75 ('<script>', ('', '', '<script>')),
76 76 ('http://', ('', '', 'http://')),
77 77 ])
78 78 def test_username_password_extraction_from_url(example_url, parts):
79 79 from vcsserver.remote import svn_remote
80 80
81 81 factory = mock.Mock()
82 82 factory.repo = mock.Mock(return_value=mock.Mock())
83 83
84 84 remote = svn_remote.SvnRemote(factory)
85 85 remote.is_path_valid_repository = lambda wire, path: True
86 86
87 87 assert remote.get_url_and_credentials(example_url) == parts
88 88
89 89
90 90 @pytest.mark.parametrize('call_url', [
91 91 b'https://svn.code.sf.net/p/svnbook/source/trunk/',
92 92 b'https://marcink@svn.code.sf.net/p/svnbook/source/trunk/',
93 93 b'https://marcink:qweqwe@svn.code.sf.net/p/svnbook/source/trunk/',
94 94 ])
95 95 def test_check_url(call_url):
96 96 from vcsserver.remote import svn_remote
97 97 factory = mock.Mock()
98 98 factory.repo = mock.Mock(return_value=mock.Mock())
99 99
100 100 remote = svn_remote.SvnRemote(factory)
101 101 remote.is_path_valid_repository = lambda wire, path: True
102 102 assert remote.check_url(call_url, {'dummy': 'config'})
103 103
@@ -1,69 +1,69 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import pytest
19 from vcsserver.str_utils import ascii_bytes, ascii_str, convert_to_str
19 from vcsserver.lib.str_utils import ascii_bytes, ascii_str, convert_to_str
20 20
21 21
22 22 @pytest.mark.parametrize('given, expected', [
23 23 ('a', b'a'),
24 24 ('a', b'a'),
25 25 ])
26 26 def test_ascii_bytes(given, expected):
27 27 assert ascii_bytes(given) == expected
28 28
29 29
30 30 @pytest.mark.parametrize('given', [
31 31 'Γ₯',
32 32 'Γ₯'.encode('utf8')
33 33 ])
34 34 def test_ascii_bytes_raises(given):
35 35 with pytest.raises(ValueError):
36 36 ascii_bytes(given)
37 37
38 38
39 39 @pytest.mark.parametrize('given, expected', [
40 40 (b'a', 'a'),
41 41 ])
42 42 def test_ascii_str(given, expected):
43 43 assert ascii_str(given) == expected
44 44
45 45
46 46 @pytest.mark.parametrize('given', [
47 47 'a',
48 48 'Γ₯'.encode('utf8'),
49 49 'Γ₯'
50 50 ])
51 51 def test_ascii_str_raises(given):
52 52 with pytest.raises(ValueError):
53 53 ascii_str(given)
54 54
55 55
56 56 @pytest.mark.parametrize('given, expected', [
57 57 ('a', 'a'),
58 58 (b'a', 'a'),
59 59 # tuple
60 60 (('a', b'b', b'c'), ('a', 'b', 'c')),
61 61 # nested tuple
62 62 (('a', b'b', (b'd', b'e')), ('a', 'b', ('d', 'e'))),
63 63 # list
64 64 (['a', b'b', b'c'], ['a', 'b', 'c']),
65 65 # mixed
66 66 (['a', b'b', b'c', (b'b1', b'b2')], ['a', 'b', 'c', ('b1', 'b2')])
67 67 ])
68 68 def test_convert_to_str(given, expected):
69 69 assert convert_to_str(given) == expected
@@ -1,98 +1,98 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import wsgiref.simple_server
19 19 import wsgiref.validate
20 20
21 21 from vcsserver import wsgi_app_caller
22 from vcsserver.str_utils import ascii_bytes, safe_str
22 from vcsserver.lib.str_utils import ascii_bytes, safe_str
23 23
24 24
25 25 @wsgiref.validate.validator
26 26 def demo_app(environ, start_response):
27 27 """WSGI app used for testing."""
28 28
29 29 input_data = safe_str(environ['wsgi.input'].read(1024))
30 30
31 31 data = [
32 32 'Hello World!\n',
33 33 f'input_data={input_data}\n',
34 34 ]
35 35 for key, value in sorted(environ.items()):
36 36 data.append(f'{key}={value}\n')
37 37
38 38 write = start_response("200 OK", [('Content-Type', 'text/plain')])
39 39 write(b'Old school write method\n')
40 40 write(b'***********************\n')
41 41 return list(map(ascii_bytes, data))
42 42
43 43
44 44 BASE_ENVIRON = {
45 45 'REQUEST_METHOD': 'GET',
46 46 'SERVER_NAME': 'localhost',
47 47 'SERVER_PORT': '80',
48 48 'SCRIPT_NAME': '',
49 49 'PATH_INFO': '/',
50 50 'QUERY_STRING': '',
51 51 'foo.var': 'bla',
52 52 }
53 53
54 54
55 55 def test_complete_environ():
56 56 environ = dict(BASE_ENVIRON)
57 57 data = b"data"
58 58 wsgi_app_caller._complete_environ(environ, data)
59 59 wsgiref.validate.check_environ(environ)
60 60
61 61 assert data == environ['wsgi.input'].read(1024)
62 62
63 63
64 64 def test_start_response():
65 65 start_response = wsgi_app_caller._StartResponse()
66 66 status = '200 OK'
67 67 headers = [('Content-Type', 'text/plain')]
68 68 start_response(status, headers)
69 69
70 70 assert status == start_response.status
71 71 assert headers == start_response.headers
72 72
73 73
74 74 def test_start_response_with_error():
75 75 start_response = wsgi_app_caller._StartResponse()
76 76 status = '500 Internal Server Error'
77 77 headers = [('Content-Type', 'text/plain')]
78 78 start_response(status, headers, (None, None, None))
79 79
80 80 assert status == start_response.status
81 81 assert headers == start_response.headers
82 82
83 83
84 84 def test_wsgi_app_caller():
85 85 environ = dict(BASE_ENVIRON)
86 86 input_data = 'some text'
87 87
88 88 caller = wsgi_app_caller.WSGIAppCaller(demo_app)
89 89 responses, status, headers = caller.handle(environ, input_data)
90 90 response = b''.join(responses)
91 91
92 92 assert status == '200 OK'
93 93 assert headers == [('Content-Type', 'text/plain')]
94 94 assert response.startswith(b'Old school write method\n***********************\n')
95 95 assert b'Hello World!\n' in response
96 96 assert b'foo.var=bla\n' in response
97 97
98 98 assert ascii_bytes(f'input_data={input_data}\n') in response
@@ -1,123 +1,123 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import base64
18 18 import logging
19 19 import time
20 20
21 21 import msgpack
22 22
23 23 import vcsserver
24 from vcsserver.str_utils import safe_str
24 from vcsserver.lib.str_utils import safe_str
25 25
26 26 log = logging.getLogger(__name__)
27 27
28 28
29 29 def get_access_path(environ):
30 30 path = environ.get('PATH_INFO')
31 31 return path
32 32
33 33
34 34 def get_user_agent(environ):
35 35 return environ.get('HTTP_USER_AGENT')
36 36
37 37
38 38 def get_call_context(request) -> dict:
39 39 cc = {}
40 40 registry = request.registry
41 41 if hasattr(registry, 'vcs_call_context'):
42 42 cc.update({
43 43 'X-RC-Method': registry.vcs_call_context.get('method'),
44 44 'X-RC-Repo-Name': registry.vcs_call_context.get('repo_name')
45 45 })
46 46
47 47 return cc
48 48
49 49
50 50 def get_headers_call_context(environ, strict=True):
51 51 if 'HTTP_X_RC_VCS_STREAM_CALL_CONTEXT' in environ:
52 52 packed_cc = base64.b64decode(environ['HTTP_X_RC_VCS_STREAM_CALL_CONTEXT'])
53 53 return msgpack.unpackb(packed_cc)
54 54 elif strict:
55 55 raise ValueError('Expected header HTTP_X_RC_VCS_STREAM_CALL_CONTEXT not found')
56 56
57 57
58 58 class RequestWrapperTween:
59 59 def __init__(self, handler, registry):
60 60 self.handler = handler
61 61 self.registry = registry
62 62
63 63 # one-time configuration code goes here
64 64
65 65 def __call__(self, request):
66 66 start = time.time()
67 67 log.debug('Starting request time measurement')
68 68 response = None
69 69
70 70 try:
71 71 response = self.handler(request)
72 72 finally:
73 73 ua = get_user_agent(request.environ)
74 74 call_context = get_call_context(request)
75 75 vcs_method = call_context.get('X-RC-Method', '_NO_VCS_METHOD')
76 76 repo_name = call_context.get('X-RC-Repo-Name', '')
77 77
78 78 count = request.request_count()
79 79 _ver_ = vcsserver.get_version()
80 80 _path = safe_str(get_access_path(request.environ))
81 81
82 82 ip = '127.0.0.1'
83 83 match_route = request.matched_route.name if request.matched_route else "NOT_FOUND"
84 84 resp_code = getattr(response, 'status_code', 'UNDEFINED')
85 85
86 86 _view_path = f"{repo_name}@{_path}/{vcs_method}"
87 87
88 88 total = time.time() - start
89 89
90 90 log.info(
91 91 'Req[%4s] IP: %s %s Request to %s time: %.4fs [%s], VCSServer %s',
92 92 count, ip, request.environ.get('REQUEST_METHOD'),
93 93 _view_path, total, ua, _ver_,
94 94 extra={"time": total, "ver": _ver_, "code": resp_code,
95 95 "path": _path, "view_name": match_route, "user_agent": ua,
96 96 "vcs_method": vcs_method, "repo_name": repo_name}
97 97 )
98 98
99 99 statsd = request.registry.statsd
100 100 if statsd:
101 101 match_route = request.matched_route.name if request.matched_route else _path
102 102 elapsed_time_ms = round(1000.0 * total) # use ms only
103 103 statsd.timing(
104 104 "vcsserver_req_timing.histogram", elapsed_time_ms,
105 105 tags=[
106 106 f"view_name:{match_route}",
107 107 f"code:{resp_code}"
108 108 ],
109 109 use_decimals=False
110 110 )
111 111 statsd.incr(
112 112 "vcsserver_req_total", tags=[
113 113 f"view_name:{match_route}",
114 114 f"code:{resp_code}"
115 115 ])
116 116
117 117 return response
118 118
119 119
120 120 def includeme(config):
121 121 config.add_tween(
122 122 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
123 123 )
@@ -1,116 +1,116 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Extract the responses of a WSGI app."""
19 19
20 20 __all__ = ('WSGIAppCaller',)
21 21
22 22 import io
23 23 import logging
24 24 import os
25 25
26 from vcsserver.str_utils import ascii_bytes
26 from vcsserver.lib.str_utils import ascii_bytes
27 27
28 28 log = logging.getLogger(__name__)
29 29
30 30 DEV_NULL = open(os.devnull)
31 31
32 32
33 33 def _complete_environ(environ, input_data: bytes):
34 34 """Update the missing wsgi.* variables of a WSGI environment.
35 35
36 36 :param environ: WSGI environment to update
37 37 :type environ: dict
38 38 :param input_data: data to be read by the app
39 39 :type input_data: bytes
40 40 """
41 41 environ.update({
42 42 'wsgi.version': (1, 0),
43 43 'wsgi.url_scheme': 'http',
44 44 'wsgi.multithread': True,
45 45 'wsgi.multiprocess': True,
46 46 'wsgi.run_once': False,
47 47 'wsgi.input': io.BytesIO(input_data),
48 48 'wsgi.errors': DEV_NULL,
49 49 })
50 50
51 51
52 52 # pylint: disable=too-few-public-methods
53 53 class _StartResponse:
54 54 """Save the arguments of a start_response call."""
55 55
56 56 __slots__ = ['status', 'headers', 'content']
57 57
58 58 def __init__(self):
59 59 self.status = None
60 60 self.headers = None
61 61 self.content = []
62 62
63 63 def __call__(self, status, headers, exc_info=None):
64 64 # TODO(skreft): do something meaningful with the exc_info
65 65 exc_info = None # avoid dangling circular reference
66 66 self.status = status
67 67 self.headers = headers
68 68
69 69 return self.write
70 70
71 71 def write(self, content):
72 72 """Write method returning when calling this object.
73 73
74 74 All the data written is then available in content.
75 75 """
76 76 self.content.append(content)
77 77
78 78
79 79 class WSGIAppCaller:
80 80 """Calls a WSGI app."""
81 81
82 82 def __init__(self, app):
83 83 """
84 84 :param app: WSGI app to call
85 85 """
86 86 self.app = app
87 87
88 88 def handle(self, environ, input_data):
89 89 """Process a request with the WSGI app.
90 90
91 91 The returned data of the app is fully consumed into a list.
92 92
93 93 :param environ: WSGI environment to update
94 94 :type environ: dict
95 95 :param input_data: data to be read by the app
96 96 :type input_data: str/bytes
97 97
98 98 :returns: a tuple with the contents, status and headers
99 99 :rtype: (list<str>, str, list<(str, str)>)
100 100 """
101 101 _complete_environ(environ, ascii_bytes(input_data, allow_bytes=True))
102 102 start_response = _StartResponse()
103 103 log.debug("Calling wrapped WSGI application")
104 104 responses = self.app(environ, start_response)
105 105 responses_list = list(responses)
106 106 existing_responses = start_response.content
107 107 if existing_responses:
108 108 log.debug("Adding returned response to response written via write()")
109 109 existing_responses.extend(responses_list)
110 110 responses_list = existing_responses
111 111 if hasattr(responses, 'close'):
112 112 log.debug("Closing iterator from WSGI application")
113 113 responses.close()
114 114
115 115 log.debug("Handling of WSGI request done, returning response")
116 116 return responses_list, start_response.status, start_response.headers
General Comments 0
You need to be logged in to leave comments. Login now