##// END OF EJS Templates
lint: auto-fixes
super-admin -
r1152:a0c49580 default
parent child Browse files
Show More
@@ -1,193 +1,193 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import os
18 18 import sys
19 19 import tempfile
20 20 import logging
21 21 import urllib.parse
22 22
23 23 from vcsserver.lib.rc_cache.archive_cache import get_archival_cache_store
24 24
25 25 from vcsserver import exceptions
26 26 from vcsserver.exceptions import NoContentException
27 27 from vcsserver.hgcompat import archival
28 28 from vcsserver.str_utils import safe_bytes
29 29 from vcsserver.lib.exc_tracking import format_exc
30 30 log = logging.getLogger(__name__)
31 31
32 32
33 class RepoFactory(object):
33 class RepoFactory:
34 34 """
35 35 Utility to create instances of repository
36 36
37 37 It provides internal caching of the `repo` object based on
38 38 the :term:`call context`.
39 39 """
40 40 repo_type = None
41 41
42 42 def __init__(self):
43 43 pass
44 44
45 45 def _create_config(self, path, config):
46 46 config = {}
47 47 return config
48 48
49 49 def _create_repo(self, wire, create):
50 50 raise NotImplementedError()
51 51
52 52 def repo(self, wire, create=False):
53 53 raise NotImplementedError()
54 54
55 55
56 56 def obfuscate_qs(query_string):
57 57 if query_string is None:
58 58 return None
59 59
60 60 parsed = []
61 61 for k, v in urllib.parse.parse_qsl(query_string, keep_blank_values=True):
62 62 if k in ['auth_token', 'api_key']:
63 63 v = "*****"
64 64 parsed.append((k, v))
65 65
66 66 return '&'.join('{}{}'.format(
67 67 k, f'={v}' if v else '') for k, v in parsed)
68 68
69 69
70 70 def raise_from_original(new_type, org_exc: Exception):
71 71 """
72 72 Raise a new exception type with original args and traceback.
73 73 """
74 74 exc_info = sys.exc_info()
75 75 exc_type, exc_value, exc_traceback = exc_info
76 76 new_exc = new_type(*exc_value.args)
77 77
78 78 # store the original traceback into the new exc
79 79 new_exc._org_exc_tb = format_exc(exc_info)
80 80
81 81 try:
82 82 raise new_exc.with_traceback(exc_traceback)
83 83 finally:
84 84 del exc_traceback
85 85
86 86
87 class ArchiveNode(object):
87 class ArchiveNode:
88 88 def __init__(self, path, mode, is_link, raw_bytes):
89 89 self.path = path
90 90 self.mode = mode
91 91 self.is_link = is_link
92 92 self.raw_bytes = raw_bytes
93 93
94 94
95 95 def store_archive_in_cache(node_walker, archive_key, kind, mtime, archive_at_path, archive_dir_name,
96 96 commit_id, write_metadata=True, extra_metadata=None, cache_config=None):
97 97 """
98 98 Function that would store generate archive and send it to a dedicated backend store
99 99 In here we use diskcache
100 100
101 101 :param node_walker: a generator returning nodes to add to archive
102 102 :param archive_key: key used to store the path
103 103 :param kind: archive kind
104 104 :param mtime: time of creation
105 105 :param archive_at_path: default '/' the path at archive was started.
106 106 If this is not '/' it means it's a partial archive
107 107 :param archive_dir_name: inside dir name when creating an archive
108 108 :param commit_id: commit sha of revision archive was created at
109 109 :param write_metadata:
110 110 :param extra_metadata:
111 111 :param cache_config:
112 112
113 113 walker should be a file walker, for example,
114 114 def node_walker():
115 115 for file_info in files:
116 116 yield ArchiveNode(fn, mode, is_link, ctx[fn].data)
117 117 """
118 118 extra_metadata = extra_metadata or {}
119 119
120 120 d_cache = get_archival_cache_store(config=cache_config)
121 121
122 122 if archive_key in d_cache:
123 123 with d_cache as d_cache_reader:
124 124 reader, tag = d_cache_reader.get(archive_key, read=True, tag=True, retry=True)
125 125 return reader.name
126 126
127 127 archive_tmp_path = safe_bytes(tempfile.mkstemp()[1])
128 128 log.debug('Creating new temp archive in %s', archive_tmp_path)
129 129
130 130 if kind == "tgz":
131 131 archiver = archival.tarit(archive_tmp_path, mtime, b"gz")
132 132 elif kind == "tbz2":
133 133 archiver = archival.tarit(archive_tmp_path, mtime, b"bz2")
134 134 elif kind == 'zip':
135 135 archiver = archival.zipit(archive_tmp_path, mtime)
136 136 else:
137 137 raise exceptions.ArchiveException()(
138 138 f'Remote does not support: "{kind}" archive type.')
139 139
140 140 for f in node_walker(commit_id, archive_at_path):
141 141 f_path = os.path.join(safe_bytes(archive_dir_name), safe_bytes(f.path).lstrip(b'/'))
142 142 try:
143 143 archiver.addfile(f_path, f.mode, f.is_link, f.raw_bytes())
144 144 except NoContentException:
145 145 # NOTE(marcink): this is a special case for SVN so we can create "empty"
146 146 # directories which are not supported by archiver
147 147 archiver.addfile(os.path.join(f_path, b'.dir'), f.mode, f.is_link, b'')
148 148
149 149 if write_metadata:
150 150 metadata = dict([
151 151 ('commit_id', commit_id),
152 152 ('mtime', mtime),
153 153 ])
154 154 metadata.update(extra_metadata)
155 155
156 156 meta = [safe_bytes(f"{f_name}:{value}") for f_name, value in metadata.items()]
157 157 f_path = os.path.join(safe_bytes(archive_dir_name), b'.archival.txt')
158 158 archiver.addfile(f_path, 0o644, False, b'\n'.join(meta))
159 159
160 160 archiver.done()
161 161
162 162 # ensure set & get are atomic
163 163 with d_cache.transact():
164 164
165 165 with open(archive_tmp_path, 'rb') as archive_file:
166 166 add_result = d_cache.set(archive_key, archive_file, read=True, tag='db-name', retry=True)
167 167 if not add_result:
168 168 log.error('Failed to store cache for key=%s', archive_key)
169 169
170 170 os.remove(archive_tmp_path)
171 171
172 172 reader, tag = d_cache.get(archive_key, read=True, tag=True, retry=True)
173 173 if not reader:
174 174 raise AssertionError(f'empty reader on key={archive_key} added={add_result}')
175 175
176 176 return reader.name
177 177
178 178
179 class BinaryEnvelope(object):
179 class BinaryEnvelope:
180 180 def __init__(self, val):
181 181 self.val = val
182 182
183 183
184 184 class BytesEnvelope(bytes):
185 185 def __new__(cls, content):
186 186 if isinstance(content, bytes):
187 187 return super().__new__(cls, content)
188 188 else:
189 189 raise TypeError('BytesEnvelope content= param must be bytes. Use BinaryEnvelope to wrap other types')
190 190
191 191
192 192 class BinaryBytesEnvelope(BytesEnvelope):
193 193 pass
@@ -1,168 +1,168 b''
1 1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19 import os
20 20 import textwrap
21 21 import string
22 22 import functools
23 23 import logging
24 24 import tempfile
25 25 import logging.config
26 26
27 27 from vcsserver.type_utils import str2bool, aslist
28 28
29 29 log = logging.getLogger(__name__)
30 30
31 31 # skip keys, that are set here, so we don't double process those
32 32 set_keys = {
33 33 '__file__': ''
34 34 }
35 35
36 36
37 class SettingsMaker(object):
37 class SettingsMaker:
38 38
39 39 def __init__(self, app_settings):
40 40 self.settings = app_settings
41 41
42 42 @classmethod
43 43 def _bool_func(cls, input_val):
44 44 if isinstance(input_val, bytes):
45 45 # decode to str
46 46 input_val = input_val.decode('utf8')
47 47 return str2bool(input_val)
48 48
49 49 @classmethod
50 50 def _int_func(cls, input_val):
51 51 return int(input_val)
52 52
53 53 @classmethod
54 54 def _list_func(cls, input_val, sep=','):
55 55 return aslist(input_val, sep=sep)
56 56
57 57 @classmethod
58 58 def _string_func(cls, input_val, lower=True):
59 59 if lower:
60 60 input_val = input_val.lower()
61 61 return input_val
62 62
63 63 @classmethod
64 64 def _float_func(cls, input_val):
65 65 return float(input_val)
66 66
67 67 @classmethod
68 68 def _dir_func(cls, input_val, ensure_dir=False, mode=0o755):
69 69
70 70 # ensure we have our dir created
71 71 if not os.path.isdir(input_val) and ensure_dir:
72 72 os.makedirs(input_val, mode=mode, exist_ok=True)
73 73
74 74 if not os.path.isdir(input_val):
75 75 raise Exception(f'Dir at {input_val} does not exist')
76 76 return input_val
77 77
78 78 @classmethod
79 79 def _file_path_func(cls, input_val, ensure_dir=False, mode=0o755):
80 80 dirname = os.path.dirname(input_val)
81 81 cls._dir_func(dirname, ensure_dir=ensure_dir)
82 82 return input_val
83 83
84 84 @classmethod
85 85 def _key_transformator(cls, key):
86 86 return "{}_{}".format('RC'.upper(), key.upper().replace('.', '_').replace('-', '_'))
87 87
88 88 def maybe_env_key(self, key):
89 89 # now maybe we have this KEY in env, search and use the value with higher priority.
90 90 transformed_key = self._key_transformator(key)
91 91 envvar_value = os.environ.get(transformed_key)
92 92 if envvar_value:
93 93 log.debug('using `%s` key instead of `%s` key for config', transformed_key, key)
94 94
95 95 return envvar_value
96 96
97 97 def env_expand(self):
98 98 replaced = {}
99 99 for k, v in self.settings.items():
100 100 if k not in set_keys:
101 101 envvar_value = self.maybe_env_key(k)
102 102 if envvar_value:
103 103 replaced[k] = envvar_value
104 104 set_keys[k] = envvar_value
105 105
106 106 # replace ALL keys updated
107 107 self.settings.update(replaced)
108 108
109 109 def enable_logging(self, logging_conf=None, level='INFO', formatter='generic'):
110 110 """
111 111 Helper to enable debug on running instance
112 112 :return:
113 113 """
114 114
115 115 if not str2bool(self.settings.get('logging.autoconfigure')):
116 116 log.info('logging configuration based on main .ini file')
117 117 return
118 118
119 119 if logging_conf is None:
120 120 logging_conf = self.settings.get('logging.logging_conf_file') or ''
121 121
122 122 if not os.path.isfile(logging_conf):
123 123 log.error('Unable to setup logging based on %s, '
124 124 'file does not exist.... specify path using logging.logging_conf_file= config setting. ', logging_conf)
125 125 return
126 126
127 127 with open(logging_conf, 'rt') as f:
128 128 ini_template = textwrap.dedent(f.read())
129 129 ini_template = string.Template(ini_template).safe_substitute(
130 130 RC_LOGGING_LEVEL=os.environ.get('RC_LOGGING_LEVEL', '') or level,
131 131 RC_LOGGING_FORMATTER=os.environ.get('RC_LOGGING_FORMATTER', '') or formatter
132 132 )
133 133
134 134 with tempfile.NamedTemporaryFile(prefix='rc_logging_', suffix='.ini', delete=False) as f:
135 135 log.info('Saved Temporary LOGGING config at %s', f.name)
136 136 f.write(ini_template)
137 137
138 138 logging.config.fileConfig(f.name)
139 139 os.remove(f.name)
140 140
141 141 def make_setting(self, key, default, lower=False, default_when_empty=False, parser=None):
142 142 input_val = self.settings.get(key, default)
143 143
144 144 if default_when_empty and not input_val:
145 145 # use default value when value is set in the config but it is empty
146 146 input_val = default
147 147
148 148 parser_func = {
149 149 'bool': self._bool_func,
150 150 'int': self._int_func,
151 151 'list': self._list_func,
152 152 'list:newline': functools.partial(self._list_func, sep='/n'),
153 153 'list:spacesep': functools.partial(self._list_func, sep=' '),
154 154 'string': functools.partial(self._string_func, lower=lower),
155 155 'dir': self._dir_func,
156 156 'dir:ensured': functools.partial(self._dir_func, ensure_dir=True),
157 157 'file': self._file_path_func,
158 158 'file:ensured': functools.partial(self._file_path_func, ensure_dir=True),
159 159 None: lambda i: i
160 160 }[parser]
161 161
162 162 envvar_value = self.maybe_env_key(key)
163 163 if envvar_value:
164 164 input_val = envvar_value
165 165 set_keys[key] = input_val
166 166
167 167 self.settings[key] = parser_func(input_val)
168 168 return self.settings[key]
@@ -1,56 +1,56 b''
1 1 # Copyright (C) 2014-2023 RhodeCode GmbH
2 2
3 3 """
4 4 Implementation of :class:`EchoApp`.
5 5
6 6 This WSGI application will just echo back the data which it recieves.
7 7 """
8 8
9 9 import logging
10 10
11 11
12 12 log = logging.getLogger(__name__)
13 13
14 14
15 class EchoApp(object):
15 class EchoApp:
16 16
17 17 def __init__(self, repo_path, repo_name, config):
18 18 self._repo_path = repo_path
19 19 log.info("EchoApp initialized for %s", repo_path)
20 20
21 21 def __call__(self, environ, start_response):
22 22 log.debug("EchoApp called for %s", self._repo_path)
23 23 log.debug("Content-Length: %s", environ.get('CONTENT_LENGTH'))
24 24 environ['wsgi.input'].read()
25 25 status = '200 OK'
26 26 headers = [('Content-Type', 'text/plain')]
27 27 start_response(status, headers)
28 28 return [b"ECHO"]
29 29
30 30
31 class EchoAppStream(object):
31 class EchoAppStream:
32 32
33 33 def __init__(self, repo_path, repo_name, config):
34 34 self._repo_path = repo_path
35 35 log.info("EchoApp initialized for %s", repo_path)
36 36
37 37 def __call__(self, environ, start_response):
38 38 log.debug("EchoApp called for %s", self._repo_path)
39 39 log.debug("Content-Length: %s", environ.get('CONTENT_LENGTH'))
40 40 environ['wsgi.input'].read()
41 41 status = '200 OK'
42 42 headers = [('Content-Type', 'text/plain')]
43 43 start_response(status, headers)
44 44
45 45 def generator():
46 46 for _ in range(1000000):
47 47 yield b"ECHO_STREAM"
48 48 return generator()
49 49
50 50
51 51 def create_app():
52 52 """
53 53 Allows to run this app directly in a WSGI server.
54 54 """
55 55 stub_config = {}
56 56 return EchoApp('stub_path', 'stub_name', stub_config)
@@ -1,47 +1,47 b''
1 1 # Copyright (C) 2014-2023 RhodeCode GmbH
2 2
3 3 """
4 4 Provides the same API as :mod:`remote_wsgi`.
5 5
6 6 Uses the `EchoApp` instead of real implementations.
7 7 """
8 8
9 9 import logging
10 10
11 11 from .echo_app import EchoApp
12 12 from vcsserver import wsgi_app_caller
13 13
14 14
15 15 log = logging.getLogger(__name__)
16 16
17 17
18 class GitRemoteWsgi(object):
18 class GitRemoteWsgi:
19 19 def handle(self, environ, input_data, *args, **kwargs):
20 20 app = wsgi_app_caller.WSGIAppCaller(
21 21 create_echo_wsgi_app(*args, **kwargs))
22 22
23 23 return app.handle(environ, input_data)
24 24
25 25
26 class HgRemoteWsgi(object):
26 class HgRemoteWsgi:
27 27 def handle(self, environ, input_data, *args, **kwargs):
28 28 app = wsgi_app_caller.WSGIAppCaller(
29 29 create_echo_wsgi_app(*args, **kwargs))
30 30
31 31 return app.handle(environ, input_data)
32 32
33 33
34 34 def create_echo_wsgi_app(repo_path, repo_name, config):
35 35 log.debug("Creating EchoApp WSGI application")
36 36
37 37 _assert_valid_config(config)
38 38
39 39 # Remaining items are forwarded to have the extras available
40 40 return EchoApp(repo_path, repo_name, config=config)
41 41
42 42
43 43 def _assert_valid_config(config):
44 44 config = config.copy()
45 45
46 46 # This is what git needs from config at this stage
47 47 config.pop(b'git_update_server_info')
@@ -1,291 +1,291 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import re
19 19 import logging
20 20
21 21 from pyramid.config import Configurator
22 22 from pyramid.response import Response, FileIter
23 23 from pyramid.httpexceptions import (
24 24 HTTPBadRequest, HTTPNotImplemented, HTTPNotFound, HTTPForbidden,
25 25 HTTPUnprocessableEntity)
26 26
27 27 from vcsserver.lib.rc_json import json
28 28 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
29 29 from vcsserver.git_lfs.utils import safe_result, get_cython_compat_decorator
30 30 from vcsserver.str_utils import safe_int
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 GIT_LFS_CONTENT_TYPE = 'application/vnd.git-lfs' # +json ?
36 36 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
37 37
38 38
39 39 def write_response_error(http_exception, text=None):
40 40 content_type = GIT_LFS_CONTENT_TYPE + '+json'
41 41 _exception = http_exception(content_type=content_type)
42 42 _exception.content_type = content_type
43 43 if text:
44 44 _exception.body = json.dumps({'message': text})
45 45 log.debug('LFS: writing response of type %s to client with text:%s',
46 46 http_exception, text)
47 47 return _exception
48 48
49 49
50 class AuthHeaderRequired(object):
50 class AuthHeaderRequired:
51 51 """
52 52 Decorator to check if request has proper auth-header
53 53 """
54 54
55 55 def __call__(self, func):
56 56 return get_cython_compat_decorator(self.__wrapper, func)
57 57
58 58 def __wrapper(self, func, *fargs, **fkwargs):
59 59 request = fargs[1]
60 60 auth = request.authorization
61 61 if not auth:
62 62 return write_response_error(HTTPForbidden)
63 63 return func(*fargs[1:], **fkwargs)
64 64
65 65
66 66 # views
67 67
68 68 def lfs_objects(request):
69 69 # indicate not supported, V1 API
70 70 log.warning('LFS: v1 api not supported, reporting it back to client')
71 71 return write_response_error(HTTPNotImplemented, 'LFS: v1 api not supported')
72 72
73 73
74 74 @AuthHeaderRequired()
75 75 def lfs_objects_batch(request):
76 76 """
77 77 The client sends the following information to the Batch endpoint to transfer some objects:
78 78
79 79 operation - Should be download or upload.
80 80 transfers - An optional Array of String identifiers for transfer
81 81 adapters that the client has configured. If omitted, the basic
82 82 transfer adapter MUST be assumed by the server.
83 83 objects - An Array of objects to download.
84 84 oid - String OID of the LFS object.
85 85 size - Integer byte size of the LFS object. Must be at least zero.
86 86 """
87 87 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
88 88 auth = request.authorization
89 89 repo = request.matchdict.get('repo')
90 90 data = request.json
91 91 operation = data.get('operation')
92 92 http_scheme = request.registry.git_lfs_http_scheme
93 93
94 94 if operation not in ('download', 'upload'):
95 95 log.debug('LFS: unsupported operation:%s', operation)
96 96 return write_response_error(
97 HTTPBadRequest, 'unsupported operation mode: `%s`' % operation)
97 HTTPBadRequest, f'unsupported operation mode: `{operation}`')
98 98
99 99 if 'objects' not in data:
100 100 log.debug('LFS: missing objects data')
101 101 return write_response_error(
102 102 HTTPBadRequest, 'missing objects data')
103 103
104 104 log.debug('LFS: handling operation of type: %s', operation)
105 105
106 106 objects = []
107 107 for o in data['objects']:
108 108 try:
109 109 oid = o['oid']
110 110 obj_size = o['size']
111 111 except KeyError:
112 112 log.exception('LFS, failed to extract data')
113 113 return write_response_error(
114 114 HTTPBadRequest, 'unsupported data in objects')
115 115
116 116 obj_data = {'oid': oid}
117 117
118 118 obj_href = request.route_url('lfs_objects_oid', repo=repo, oid=oid,
119 119 _scheme=http_scheme)
120 120 obj_verify_href = request.route_url('lfs_objects_verify', repo=repo,
121 121 _scheme=http_scheme)
122 122 store = LFSOidStore(
123 123 oid, repo, store_location=request.registry.git_lfs_store_path)
124 124 handler = OidHandler(
125 125 store, repo, auth, oid, obj_size, obj_data,
126 126 obj_href, obj_verify_href)
127 127
128 128 # this verifies also OIDs
129 129 actions, errors = handler.exec_operation(operation)
130 130 if errors:
131 131 log.warning('LFS: got following errors: %s', errors)
132 132 obj_data['errors'] = errors
133 133
134 134 if actions:
135 135 obj_data['actions'] = actions
136 136
137 137 obj_data['size'] = obj_size
138 138 obj_data['authenticated'] = True
139 139 objects.append(obj_data)
140 140
141 141 result = {'objects': objects, 'transfer': 'basic'}
142 142 log.debug('LFS Response %s', safe_result(result))
143 143
144 144 return result
145 145
146 146
147 147 def lfs_objects_oid_upload(request):
148 148 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
149 149 repo = request.matchdict.get('repo')
150 150 oid = request.matchdict.get('oid')
151 151 store = LFSOidStore(
152 152 oid, repo, store_location=request.registry.git_lfs_store_path)
153 153 engine = store.get_engine(mode='wb')
154 154 log.debug('LFS: starting chunked write of LFS oid: %s to storage', oid)
155 155
156 156 body = request.environ['wsgi.input']
157 157
158 158 with engine as f:
159 159 blksize = 64 * 1024 # 64kb
160 160 while True:
161 161 # read in chunks as stream comes in from Gunicorn
162 162 # this is a specific Gunicorn support function.
163 163 # might work differently on waitress
164 164 chunk = body.read(blksize)
165 165 if not chunk:
166 166 break
167 167 f.write(chunk)
168 168
169 169 return {'upload': 'ok'}
170 170
171 171
172 172 def lfs_objects_oid_download(request):
173 173 repo = request.matchdict.get('repo')
174 174 oid = request.matchdict.get('oid')
175 175
176 176 store = LFSOidStore(
177 177 oid, repo, store_location=request.registry.git_lfs_store_path)
178 178 if not store.has_oid():
179 179 log.debug('LFS: oid %s does not exists in store', oid)
180 180 return write_response_error(
181 HTTPNotFound, 'requested file with oid `%s` not found in store' % oid)
181 HTTPNotFound, f'requested file with oid `{oid}` not found in store')
182 182
183 183 # TODO(marcink): support range header ?
184 184 # Range: bytes=0-, `bytes=(\d+)\-.*`
185 185
186 186 f = open(store.oid_path, 'rb')
187 187 response = Response(
188 188 content_type='application/octet-stream', app_iter=FileIter(f))
189 189 response.headers.add('X-RC-LFS-Response-Oid', str(oid))
190 190 return response
191 191
192 192
193 193 def lfs_objects_verify(request):
194 194 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
195 195 repo = request.matchdict.get('repo')
196 196
197 197 data = request.json
198 198 oid = data.get('oid')
199 199 size = safe_int(data.get('size'))
200 200
201 201 if not (oid and size):
202 202 return write_response_error(
203 203 HTTPBadRequest, 'missing oid and size in request data')
204 204
205 205 store = LFSOidStore(
206 206 oid, repo, store_location=request.registry.git_lfs_store_path)
207 207 if not store.has_oid():
208 208 log.debug('LFS: oid %s does not exists in store', oid)
209 209 return write_response_error(
210 HTTPNotFound, 'oid `%s` does not exists in store' % oid)
210 HTTPNotFound, f'oid `{oid}` does not exists in store')
211 211
212 212 store_size = store.size_oid()
213 213 if store_size != size:
214 214 msg = 'requested file size mismatch store size:{} requested:{}'.format(
215 215 store_size, size)
216 216 return write_response_error(
217 217 HTTPUnprocessableEntity, msg)
218 218
219 219 return {'message': {'size': 'ok', 'in_store': 'ok'}}
220 220
221 221
222 222 def lfs_objects_lock(request):
223 223 return write_response_error(
224 224 HTTPNotImplemented, 'GIT LFS locking api not supported')
225 225
226 226
227 227 def not_found(request):
228 228 return write_response_error(
229 229 HTTPNotFound, 'request path not found')
230 230
231 231
232 232 def lfs_disabled(request):
233 233 return write_response_error(
234 234 HTTPNotImplemented, 'GIT LFS disabled for this repo')
235 235
236 236
237 237 def git_lfs_app(config):
238 238
239 239 # v1 API deprecation endpoint
240 240 config.add_route('lfs_objects',
241 241 '/{repo:.*?[^/]}/info/lfs/objects')
242 242 config.add_view(lfs_objects, route_name='lfs_objects',
243 243 request_method='POST', renderer='json')
244 244
245 245 # locking API
246 246 config.add_route('lfs_objects_lock',
247 247 '/{repo:.*?[^/]}/info/lfs/locks')
248 248 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock',
249 249 request_method=('POST', 'GET'), renderer='json')
250 250
251 251 config.add_route('lfs_objects_lock_verify',
252 252 '/{repo:.*?[^/]}/info/lfs/locks/verify')
253 253 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock_verify',
254 254 request_method=('POST', 'GET'), renderer='json')
255 255
256 256 # batch API
257 257 config.add_route('lfs_objects_batch',
258 258 '/{repo:.*?[^/]}/info/lfs/objects/batch')
259 259 config.add_view(lfs_objects_batch, route_name='lfs_objects_batch',
260 260 request_method='POST', renderer='json')
261 261
262 262 # oid upload/download API
263 263 config.add_route('lfs_objects_oid',
264 264 '/{repo:.*?[^/]}/info/lfs/objects/{oid}')
265 265 config.add_view(lfs_objects_oid_upload, route_name='lfs_objects_oid',
266 266 request_method='PUT', renderer='json')
267 267 config.add_view(lfs_objects_oid_download, route_name='lfs_objects_oid',
268 268 request_method='GET', renderer='json')
269 269
270 270 # verification API
271 271 config.add_route('lfs_objects_verify',
272 272 '/{repo:.*?[^/]}/info/lfs/verify')
273 273 config.add_view(lfs_objects_verify, route_name='lfs_objects_verify',
274 274 request_method='POST', renderer='json')
275 275
276 276 # not found handler for API
277 277 config.add_notfound_view(not_found, renderer='json')
278 278
279 279
280 280 def create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
281 281 config = Configurator()
282 282 if git_lfs_enabled:
283 283 config.include(git_lfs_app)
284 284 config.registry.git_lfs_store_path = git_lfs_store_path
285 285 config.registry.git_lfs_http_scheme = git_lfs_http_scheme
286 286 else:
287 287 # not found handler for API, reporting disabled LFS support
288 288 config.add_notfound_view(lfs_disabled, renderer='json')
289 289
290 290 app = config.make_wsgi_app()
291 291 return app
@@ -1,175 +1,175 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import shutil
20 20 import logging
21 21 from collections import OrderedDict
22 22
23 23 log = logging.getLogger(__name__)
24 24
25 25
26 class OidHandler(object):
26 class OidHandler:
27 27
28 28 def __init__(self, store, repo_name, auth, oid, obj_size, obj_data, obj_href,
29 29 obj_verify_href=None):
30 30 self.current_store = store
31 31 self.repo_name = repo_name
32 32 self.auth = auth
33 33 self.oid = oid
34 34 self.obj_size = obj_size
35 35 self.obj_data = obj_data
36 36 self.obj_href = obj_href
37 37 self.obj_verify_href = obj_verify_href
38 38
39 39 def get_store(self, mode=None):
40 40 return self.current_store
41 41
42 42 def get_auth(self):
43 43 """returns auth header for re-use in upload/download"""
44 44 return " ".join(self.auth)
45 45
46 46 def download(self):
47 47
48 48 store = self.get_store()
49 49 response = None
50 50 has_errors = None
51 51
52 52 if not store.has_oid():
53 53 # error reply back to client that something is wrong with dl
54 54 err_msg = f'object: {store.oid} does not exist in store'
55 55 has_errors = OrderedDict(
56 56 error=OrderedDict(
57 57 code=404,
58 58 message=err_msg
59 59 )
60 60 )
61 61
62 62 download_action = OrderedDict(
63 63 href=self.obj_href,
64 64 header=OrderedDict([("Authorization", self.get_auth())])
65 65 )
66 66 if not has_errors:
67 67 response = OrderedDict(download=download_action)
68 68 return response, has_errors
69 69
70 70 def upload(self, skip_existing=True):
71 71 """
72 72 Write upload action for git-lfs server
73 73 """
74 74
75 75 store = self.get_store()
76 76 response = None
77 77 has_errors = None
78 78
79 79 # verify if we have the OID before, if we do, reply with empty
80 80 if store.has_oid():
81 81 log.debug('LFS: store already has oid %s', store.oid)
82 82
83 83 # validate size
84 84 store_size = store.size_oid()
85 85 size_match = store_size == self.obj_size
86 86 if not size_match:
87 87 log.warning(
88 88 'LFS: size mismatch for oid:%s, in store:%s expected: %s',
89 89 self.oid, store_size, self.obj_size)
90 90 elif skip_existing:
91 91 log.debug('LFS: skipping further action as oid is existing')
92 92 return response, has_errors
93 93
94 94 chunked = ("Transfer-Encoding", "chunked")
95 95 upload_action = OrderedDict(
96 96 href=self.obj_href,
97 97 header=OrderedDict([("Authorization", self.get_auth()), chunked])
98 98 )
99 99 if not has_errors:
100 100 response = OrderedDict(upload=upload_action)
101 101 # if specified in handler, return the verification endpoint
102 102 if self.obj_verify_href:
103 103 verify_action = OrderedDict(
104 104 href=self.obj_verify_href,
105 105 header=OrderedDict([("Authorization", self.get_auth())])
106 106 )
107 107 response['verify'] = verify_action
108 108 return response, has_errors
109 109
110 110 def exec_operation(self, operation, *args, **kwargs):
111 111 handler = getattr(self, operation)
112 112 log.debug('LFS: handling request using %s handler', handler)
113 113 return handler(*args, **kwargs)
114 114
115 115
116 class LFSOidStore(object):
116 class LFSOidStore:
117 117
118 118 def __init__(self, oid, repo, store_location=None):
119 119 self.oid = oid
120 120 self.repo = repo
121 121 self.store_path = store_location or self.get_default_store()
122 122 self.tmp_oid_path = os.path.join(self.store_path, oid + '.tmp')
123 123 self.oid_path = os.path.join(self.store_path, oid)
124 124 self.fd = None
125 125
126 126 def get_engine(self, mode):
127 127 """
128 128 engine = .get_engine(mode='wb')
129 129 with engine as f:
130 130 f.write('...')
131 131 """
132 132
133 133 class StoreEngine(object):
134 134 def __init__(self, mode, store_path, oid_path, tmp_oid_path):
135 135 self.mode = mode
136 136 self.store_path = store_path
137 137 self.oid_path = oid_path
138 138 self.tmp_oid_path = tmp_oid_path
139 139
140 140 def __enter__(self):
141 141 if not os.path.isdir(self.store_path):
142 142 os.makedirs(self.store_path)
143 143
144 144 # TODO(marcink): maybe write metadata here with size/oid ?
145 145 fd = open(self.tmp_oid_path, self.mode)
146 146 self.fd = fd
147 147 return fd
148 148
149 149 def __exit__(self, exc_type, exc_value, traceback):
150 150 # close tmp file, and rename to final destination
151 151 self.fd.close()
152 152 shutil.move(self.tmp_oid_path, self.oid_path)
153 153
154 154 return StoreEngine(
155 155 mode, self.store_path, self.oid_path, self.tmp_oid_path)
156 156
157 157 def get_default_store(self):
158 158 """
159 159 Default store, consistent with defaults of Mercurial large files store
160 160 which is /home/username/.cache/largefiles
161 161 """
162 162 user_home = os.path.expanduser("~")
163 163 return os.path.join(user_home, '.cache', 'lfs-store')
164 164
165 165 def has_oid(self):
166 166 return os.path.exists(os.path.join(self.store_path, self.oid))
167 167
168 168 def size_oid(self):
169 169 size = -1
170 170
171 171 if self.has_oid():
172 172 oid = os.path.join(self.store_path, self.oid)
173 173 size = os.stat(oid).st_size
174 174
175 175 return size
@@ -1,273 +1,273 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 20 from webtest.app import TestApp as WebObTestApp
21 21
22 22 from vcsserver.lib.rc_json import json
23 23 from vcsserver.str_utils import safe_bytes
24 24 from vcsserver.git_lfs.app import create_app
25 25
26 26
27 27 @pytest.fixture(scope='function')
28 28 def git_lfs_app(tmpdir):
29 29 custom_app = WebObTestApp(create_app(
30 30 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
31 31 git_lfs_http_scheme='http'))
32 32 custom_app._store = str(tmpdir)
33 33 return custom_app
34 34
35 35
36 36 @pytest.fixture(scope='function')
37 37 def git_lfs_https_app(tmpdir):
38 38 custom_app = WebObTestApp(create_app(
39 39 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
40 40 git_lfs_http_scheme='https'))
41 41 custom_app._store = str(tmpdir)
42 42 return custom_app
43 43
44 44
45 45 @pytest.fixture()
46 46 def http_auth():
47 47 return {'HTTP_AUTHORIZATION': "Basic XXXXX"}
48 48
49 49
50 class TestLFSApplication(object):
50 class TestLFSApplication:
51 51
52 52 def test_app_wrong_path(self, git_lfs_app):
53 53 git_lfs_app.get('/repo/info/lfs/xxx', status=404)
54 54
55 55 def test_app_deprecated_endpoint(self, git_lfs_app):
56 56 response = git_lfs_app.post('/repo/info/lfs/objects', status=501)
57 57 assert response.status_code == 501
58 58 assert json.loads(response.text) == {'message': 'LFS: v1 api not supported'}
59 59
60 60 def test_app_lock_verify_api_not_available(self, git_lfs_app):
61 61 response = git_lfs_app.post('/repo/info/lfs/locks/verify', status=501)
62 62 assert response.status_code == 501
63 63 assert json.loads(response.text) == {
64 64 'message': 'GIT LFS locking api not supported'}
65 65
66 66 def test_app_lock_api_not_available(self, git_lfs_app):
67 67 response = git_lfs_app.post('/repo/info/lfs/locks', status=501)
68 68 assert response.status_code == 501
69 69 assert json.loads(response.text) == {
70 70 'message': 'GIT LFS locking api not supported'}
71 71
72 72 def test_app_batch_api_missing_auth(self, git_lfs_app):
73 73 git_lfs_app.post_json(
74 74 '/repo/info/lfs/objects/batch', params={}, status=403)
75 75
76 76 def test_app_batch_api_unsupported_operation(self, git_lfs_app, http_auth):
77 77 response = git_lfs_app.post_json(
78 78 '/repo/info/lfs/objects/batch', params={}, status=400,
79 79 extra_environ=http_auth)
80 80 assert json.loads(response.text) == {
81 81 'message': 'unsupported operation mode: `None`'}
82 82
83 83 def test_app_batch_api_missing_objects(self, git_lfs_app, http_auth):
84 84 response = git_lfs_app.post_json(
85 85 '/repo/info/lfs/objects/batch', params={'operation': 'download'},
86 86 status=400, extra_environ=http_auth)
87 87 assert json.loads(response.text) == {
88 88 'message': 'missing objects data'}
89 89
90 90 def test_app_batch_api_unsupported_data_in_objects(
91 91 self, git_lfs_app, http_auth):
92 92 params = {'operation': 'download',
93 93 'objects': [{}]}
94 94 response = git_lfs_app.post_json(
95 95 '/repo/info/lfs/objects/batch', params=params, status=400,
96 96 extra_environ=http_auth)
97 97 assert json.loads(response.text) == {
98 98 'message': 'unsupported data in objects'}
99 99
100 100 def test_app_batch_api_download_missing_object(
101 101 self, git_lfs_app, http_auth):
102 102 params = {'operation': 'download',
103 103 'objects': [{'oid': '123', 'size': '1024'}]}
104 104 response = git_lfs_app.post_json(
105 105 '/repo/info/lfs/objects/batch', params=params,
106 106 extra_environ=http_auth)
107 107
108 108 expected_objects = [
109 109 {'authenticated': True,
110 110 'errors': {'error': {
111 111 'code': 404,
112 112 'message': 'object: 123 does not exist in store'}},
113 113 'oid': '123',
114 114 'size': '1024'}
115 115 ]
116 116 assert json.loads(response.text) == {
117 117 'objects': expected_objects, 'transfer': 'basic'}
118 118
119 119 def test_app_batch_api_download(self, git_lfs_app, http_auth):
120 120 oid = '456'
121 121 oid_path = os.path.join(git_lfs_app._store, oid)
122 122 if not os.path.isdir(os.path.dirname(oid_path)):
123 123 os.makedirs(os.path.dirname(oid_path))
124 124 with open(oid_path, 'wb') as f:
125 125 f.write(safe_bytes('OID_CONTENT'))
126 126
127 127 params = {'operation': 'download',
128 128 'objects': [{'oid': oid, 'size': '1024'}]}
129 129 response = git_lfs_app.post_json(
130 130 '/repo/info/lfs/objects/batch', params=params,
131 131 extra_environ=http_auth)
132 132
133 133 expected_objects = [
134 134 {'authenticated': True,
135 135 'actions': {
136 136 'download': {
137 137 'header': {'Authorization': 'Basic XXXXX'},
138 138 'href': 'http://localhost/repo/info/lfs/objects/456'},
139 139 },
140 140 'oid': '456',
141 141 'size': '1024'}
142 142 ]
143 143 assert json.loads(response.text) == {
144 144 'objects': expected_objects, 'transfer': 'basic'}
145 145
146 146 def test_app_batch_api_upload(self, git_lfs_app, http_auth):
147 147 params = {'operation': 'upload',
148 148 'objects': [{'oid': '123', 'size': '1024'}]}
149 149 response = git_lfs_app.post_json(
150 150 '/repo/info/lfs/objects/batch', params=params,
151 151 extra_environ=http_auth)
152 152 expected_objects = [
153 153 {'authenticated': True,
154 154 'actions': {
155 155 'upload': {
156 156 'header': {'Authorization': 'Basic XXXXX',
157 157 'Transfer-Encoding': 'chunked'},
158 158 'href': 'http://localhost/repo/info/lfs/objects/123'},
159 159 'verify': {
160 160 'header': {'Authorization': 'Basic XXXXX'},
161 161 'href': 'http://localhost/repo/info/lfs/verify'}
162 162 },
163 163 'oid': '123',
164 164 'size': '1024'}
165 165 ]
166 166 assert json.loads(response.text) == {
167 167 'objects': expected_objects, 'transfer': 'basic'}
168 168
169 169 def test_app_batch_api_upload_for_https(self, git_lfs_https_app, http_auth):
170 170 params = {'operation': 'upload',
171 171 'objects': [{'oid': '123', 'size': '1024'}]}
172 172 response = git_lfs_https_app.post_json(
173 173 '/repo/info/lfs/objects/batch', params=params,
174 174 extra_environ=http_auth)
175 175 expected_objects = [
176 176 {'authenticated': True,
177 177 'actions': {
178 178 'upload': {
179 179 'header': {'Authorization': 'Basic XXXXX',
180 180 'Transfer-Encoding': 'chunked'},
181 181 'href': 'https://localhost/repo/info/lfs/objects/123'},
182 182 'verify': {
183 183 'header': {'Authorization': 'Basic XXXXX'},
184 184 'href': 'https://localhost/repo/info/lfs/verify'}
185 185 },
186 186 'oid': '123',
187 187 'size': '1024'}
188 188 ]
189 189 assert json.loads(response.text) == {
190 190 'objects': expected_objects, 'transfer': 'basic'}
191 191
192 192 def test_app_verify_api_missing_data(self, git_lfs_app):
193 193 params = {'oid': 'missing'}
194 194 response = git_lfs_app.post_json(
195 195 '/repo/info/lfs/verify', params=params,
196 196 status=400)
197 197
198 198 assert json.loads(response.text) == {
199 199 'message': 'missing oid and size in request data'}
200 200
201 201 def test_app_verify_api_missing_obj(self, git_lfs_app):
202 202 params = {'oid': 'missing', 'size': '1024'}
203 203 response = git_lfs_app.post_json(
204 204 '/repo/info/lfs/verify', params=params,
205 205 status=404)
206 206
207 207 assert json.loads(response.text) == {
208 208 'message': 'oid `missing` does not exists in store'}
209 209
210 210 def test_app_verify_api_size_mismatch(self, git_lfs_app):
211 211 oid = 'existing'
212 212 oid_path = os.path.join(git_lfs_app._store, oid)
213 213 if not os.path.isdir(os.path.dirname(oid_path)):
214 214 os.makedirs(os.path.dirname(oid_path))
215 215 with open(oid_path, 'wb') as f:
216 216 f.write(safe_bytes('OID_CONTENT'))
217 217
218 218 params = {'oid': oid, 'size': '1024'}
219 219 response = git_lfs_app.post_json(
220 220 '/repo/info/lfs/verify', params=params, status=422)
221 221
222 222 assert json.loads(response.text) == {
223 223 'message': 'requested file size mismatch '
224 224 'store size:11 requested:1024'}
225 225
226 226 def test_app_verify_api(self, git_lfs_app):
227 227 oid = 'existing'
228 228 oid_path = os.path.join(git_lfs_app._store, oid)
229 229 if not os.path.isdir(os.path.dirname(oid_path)):
230 230 os.makedirs(os.path.dirname(oid_path))
231 231 with open(oid_path, 'wb') as f:
232 232 f.write(safe_bytes('OID_CONTENT'))
233 233
234 234 params = {'oid': oid, 'size': 11}
235 235 response = git_lfs_app.post_json(
236 236 '/repo/info/lfs/verify', params=params)
237 237
238 238 assert json.loads(response.text) == {
239 239 'message': {'size': 'ok', 'in_store': 'ok'}}
240 240
241 241 def test_app_download_api_oid_not_existing(self, git_lfs_app):
242 242 oid = 'missing'
243 243
244 244 response = git_lfs_app.get(
245 245 '/repo/info/lfs/objects/{oid}'.format(oid=oid), status=404)
246 246
247 247 assert json.loads(response.text) == {
248 248 'message': 'requested file with oid `missing` not found in store'}
249 249
250 250 def test_app_download_api(self, git_lfs_app):
251 251 oid = 'existing'
252 252 oid_path = os.path.join(git_lfs_app._store, oid)
253 253 if not os.path.isdir(os.path.dirname(oid_path)):
254 254 os.makedirs(os.path.dirname(oid_path))
255 255 with open(oid_path, 'wb') as f:
256 256 f.write(safe_bytes('OID_CONTENT'))
257 257
258 258 response = git_lfs_app.get(
259 259 '/repo/info/lfs/objects/{oid}'.format(oid=oid))
260 260 assert response
261 261
262 262 def test_app_upload(self, git_lfs_app):
263 263 oid = 'uploaded'
264 264
265 265 response = git_lfs_app.put(
266 266 '/repo/info/lfs/objects/{oid}'.format(oid=oid), params='CONTENT')
267 267
268 268 assert json.loads(response.text) == {'upload': 'ok'}
269 269
270 270 # verify that we actually wrote that OID
271 271 oid_path = os.path.join(git_lfs_app._store, oid)
272 272 assert os.path.isfile(oid_path)
273 273 assert 'CONTENT' == open(oid_path).read()
@@ -1,142 +1,142 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 20 from vcsserver.str_utils import safe_bytes
21 21 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
22 22
23 23
24 24 @pytest.fixture()
25 25 def lfs_store(tmpdir):
26 26 repo = 'test'
27 27 oid = '123456789'
28 28 store = LFSOidStore(oid=oid, repo=repo, store_location=str(tmpdir))
29 29 return store
30 30
31 31
32 32 @pytest.fixture()
33 33 def oid_handler(lfs_store):
34 34 store = lfs_store
35 35 repo = store.repo
36 36 oid = store.oid
37 37
38 38 oid_handler = OidHandler(
39 39 store=store, repo_name=repo, auth=('basic', 'xxxx'),
40 40 oid=oid,
41 41 obj_size='1024', obj_data={}, obj_href='http://localhost/handle_oid',
42 42 obj_verify_href='http://localhost/verify')
43 43 return oid_handler
44 44
45 45
46 class TestOidHandler(object):
46 class TestOidHandler:
47 47
48 48 @pytest.mark.parametrize('exec_action', [
49 49 'download',
50 50 'upload',
51 51 ])
52 52 def test_exec_action(self, exec_action, oid_handler):
53 53 handler = oid_handler.exec_operation(exec_action)
54 54 assert handler
55 55
56 56 def test_exec_action_undefined(self, oid_handler):
57 57 with pytest.raises(AttributeError):
58 58 oid_handler.exec_operation('wrong')
59 59
60 60 def test_download_oid_not_existing(self, oid_handler):
61 61 response, has_errors = oid_handler.exec_operation('download')
62 62
63 63 assert response is None
64 64 assert has_errors['error'] == {
65 65 'code': 404,
66 66 'message': 'object: 123456789 does not exist in store'}
67 67
68 68 def test_download_oid(self, oid_handler):
69 69 store = oid_handler.get_store()
70 70 if not os.path.isdir(os.path.dirname(store.oid_path)):
71 71 os.makedirs(os.path.dirname(store.oid_path))
72 72
73 73 with open(store.oid_path, 'wb') as f:
74 74 f.write(safe_bytes('CONTENT'))
75 75
76 76 response, has_errors = oid_handler.exec_operation('download')
77 77
78 78 assert has_errors is None
79 79 assert response['download'] == {
80 80 'header': {'Authorization': 'basic xxxx'},
81 81 'href': 'http://localhost/handle_oid'
82 82 }
83 83
84 84 def test_upload_oid_that_exists(self, oid_handler):
85 85 store = oid_handler.get_store()
86 86 if not os.path.isdir(os.path.dirname(store.oid_path)):
87 87 os.makedirs(os.path.dirname(store.oid_path))
88 88
89 89 with open(store.oid_path, 'wb') as f:
90 90 f.write(safe_bytes('CONTENT'))
91 91 oid_handler.obj_size = 7
92 92 response, has_errors = oid_handler.exec_operation('upload')
93 93 assert has_errors is None
94 94 assert response is None
95 95
96 96 def test_upload_oid_that_exists_but_has_wrong_size(self, oid_handler):
97 97 store = oid_handler.get_store()
98 98 if not os.path.isdir(os.path.dirname(store.oid_path)):
99 99 os.makedirs(os.path.dirname(store.oid_path))
100 100
101 101 with open(store.oid_path, 'wb') as f:
102 102 f.write(safe_bytes('CONTENT'))
103 103
104 104 oid_handler.obj_size = 10240
105 105 response, has_errors = oid_handler.exec_operation('upload')
106 106 assert has_errors is None
107 107 assert response['upload'] == {
108 108 'header': {'Authorization': 'basic xxxx',
109 109 'Transfer-Encoding': 'chunked'},
110 110 'href': 'http://localhost/handle_oid',
111 111 }
112 112
113 113 def test_upload_oid(self, oid_handler):
114 114 response, has_errors = oid_handler.exec_operation('upload')
115 115 assert has_errors is None
116 116 assert response['upload'] == {
117 117 'header': {'Authorization': 'basic xxxx',
118 118 'Transfer-Encoding': 'chunked'},
119 119 'href': 'http://localhost/handle_oid'
120 120 }
121 121
122 122
123 class TestLFSStore(object):
123 class TestLFSStore:
124 124 def test_write_oid(self, lfs_store):
125 125 oid_location = lfs_store.oid_path
126 126
127 127 assert not os.path.isfile(oid_location)
128 128
129 129 engine = lfs_store.get_engine(mode='wb')
130 130 with engine as f:
131 131 f.write(safe_bytes('CONTENT'))
132 132
133 133 assert os.path.isfile(oid_location)
134 134
135 135 def test_detect_has_oid(self, lfs_store):
136 136
137 137 assert lfs_store.has_oid() is False
138 138 engine = lfs_store.get_engine(mode='wb')
139 139 with engine as f:
140 140 f.write(safe_bytes('CONTENT'))
141 141
142 142 assert lfs_store.has_oid() is True
@@ -1,202 +1,202 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import re
19 19 import os
20 20 import sys
21 21 import datetime
22 22 import logging
23 23 import pkg_resources
24 24
25 25 import vcsserver
26 26 from vcsserver.str_utils import safe_bytes
27 27
28 28 log = logging.getLogger(__name__)
29 29
30 30
31 31 def get_git_hooks_path(repo_path, bare):
32 32 hooks_path = os.path.join(repo_path, 'hooks')
33 33 if not bare:
34 34 hooks_path = os.path.join(repo_path, '.git', 'hooks')
35 35
36 36 return hooks_path
37 37
38 38
39 39 def install_git_hooks(repo_path, bare, executable=None, force_create=False):
40 40 """
41 41 Creates a RhodeCode hook inside a git repository
42 42
43 43 :param repo_path: path to repository
44 44 :param executable: binary executable to put in the hooks
45 45 :param force_create: Create even if same name hook exists
46 46 """
47 47 executable = executable or sys.executable
48 48 hooks_path = get_git_hooks_path(repo_path, bare)
49 49
50 50 if not os.path.isdir(hooks_path):
51 51 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
52 52
53 53 tmpl_post = pkg_resources.resource_string(
54 54 'vcsserver', '/'.join(
55 55 ('hook_utils', 'hook_templates', 'git_post_receive.py.tmpl')))
56 56 tmpl_pre = pkg_resources.resource_string(
57 57 'vcsserver', '/'.join(
58 58 ('hook_utils', 'hook_templates', 'git_pre_receive.py.tmpl')))
59 59
60 60 path = '' # not used for now
61 61 timestamp = datetime.datetime.utcnow().isoformat()
62 62
63 63 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
64 64 log.debug('Installing git hook in repo %s', repo_path)
65 _hook_file = os.path.join(hooks_path, '%s-receive' % h_type)
65 _hook_file = os.path.join(hooks_path, f'{h_type}-receive')
66 66 _rhodecode_hook = check_rhodecode_hook(_hook_file)
67 67
68 68 if _rhodecode_hook or force_create:
69 69 log.debug('writing git %s hook file at %s !', h_type, _hook_file)
70 70 try:
71 71 with open(_hook_file, 'wb') as f:
72 72 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
73 73 template = template.replace(b'_DATE_', safe_bytes(timestamp))
74 74 template = template.replace(b'_ENV_', safe_bytes(executable))
75 75 template = template.replace(b'_PATH_', safe_bytes(path))
76 76 f.write(template)
77 77 os.chmod(_hook_file, 0o755)
78 78 except OSError:
79 79 log.exception('error writing hook file %s', _hook_file)
80 80 else:
81 81 log.debug('skipping writing hook file')
82 82
83 83 return True
84 84
85 85
86 86 def get_svn_hooks_path(repo_path):
87 87 hooks_path = os.path.join(repo_path, 'hooks')
88 88
89 89 return hooks_path
90 90
91 91
92 92 def install_svn_hooks(repo_path, executable=None, force_create=False):
93 93 """
94 94 Creates RhodeCode hooks inside a svn repository
95 95
96 96 :param repo_path: path to repository
97 97 :param executable: binary executable to put in the hooks
98 98 :param force_create: Create even if same name hook exists
99 99 """
100 100 executable = executable or sys.executable
101 101 hooks_path = get_svn_hooks_path(repo_path)
102 102 if not os.path.isdir(hooks_path):
103 103 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
104 104
105 105 tmpl_post = pkg_resources.resource_string(
106 106 'vcsserver', '/'.join(
107 107 ('hook_utils', 'hook_templates', 'svn_post_commit_hook.py.tmpl')))
108 108 tmpl_pre = pkg_resources.resource_string(
109 109 'vcsserver', '/'.join(
110 110 ('hook_utils', 'hook_templates', 'svn_pre_commit_hook.py.tmpl')))
111 111
112 112 path = '' # not used for now
113 113 timestamp = datetime.datetime.utcnow().isoformat()
114 114
115 115 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
116 116 log.debug('Installing svn hook in repo %s', repo_path)
117 _hook_file = os.path.join(hooks_path, '%s-commit' % h_type)
117 _hook_file = os.path.join(hooks_path, f'{h_type}-commit')
118 118 _rhodecode_hook = check_rhodecode_hook(_hook_file)
119 119
120 120 if _rhodecode_hook or force_create:
121 121 log.debug('writing svn %s hook file at %s !', h_type, _hook_file)
122 122
123 123 try:
124 124 with open(_hook_file, 'wb') as f:
125 125 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
126 126 template = template.replace(b'_DATE_', safe_bytes(timestamp))
127 127 template = template.replace(b'_ENV_', safe_bytes(executable))
128 128 template = template.replace(b'_PATH_', safe_bytes(path))
129 129
130 130 f.write(template)
131 131 os.chmod(_hook_file, 0o755)
132 132 except OSError:
133 133 log.exception('error writing hook file %s', _hook_file)
134 134 else:
135 135 log.debug('skipping writing hook file')
136 136
137 137 return True
138 138
139 139
140 140 def get_version_from_hook(hook_path):
141 141 version = b''
142 142 hook_content = read_hook_content(hook_path)
143 143 matches = re.search(rb'RC_HOOK_VER\s*=\s*(.*)', hook_content)
144 144 if matches:
145 145 try:
146 146 version = matches.groups()[0]
147 147 log.debug('got version %s from hooks.', version)
148 148 except Exception:
149 149 log.exception("Exception while reading the hook version.")
150 150 return version.replace(b"'", b"")
151 151
152 152
153 153 def check_rhodecode_hook(hook_path):
154 154 """
155 155 Check if the hook was created by RhodeCode
156 156 """
157 157 if not os.path.exists(hook_path):
158 158 return True
159 159
160 160 log.debug('hook exists, checking if it is from RhodeCode')
161 161
162 162 version = get_version_from_hook(hook_path)
163 163 if version:
164 164 return True
165 165
166 166 return False
167 167
168 168
169 169 def read_hook_content(hook_path) -> bytes:
170 170 content = b''
171 171 if os.path.isfile(hook_path):
172 172 with open(hook_path, 'rb') as f:
173 173 content = f.read()
174 174 return content
175 175
176 176
177 177 def get_git_pre_hook_version(repo_path, bare):
178 178 hooks_path = get_git_hooks_path(repo_path, bare)
179 179 _hook_file = os.path.join(hooks_path, 'pre-receive')
180 180 version = get_version_from_hook(_hook_file)
181 181 return version
182 182
183 183
184 184 def get_git_post_hook_version(repo_path, bare):
185 185 hooks_path = get_git_hooks_path(repo_path, bare)
186 186 _hook_file = os.path.join(hooks_path, 'post-receive')
187 187 version = get_version_from_hook(_hook_file)
188 188 return version
189 189
190 190
191 191 def get_svn_pre_hook_version(repo_path):
192 192 hooks_path = get_svn_hooks_path(repo_path)
193 193 _hook_file = os.path.join(hooks_path, 'pre-commit')
194 194 version = get_version_from_hook(_hook_file)
195 195 return version
196 196
197 197
198 198 def get_svn_post_hook_version(repo_path):
199 199 hooks_path = get_svn_hooks_path(repo_path)
200 200 _hook_file = os.path.join(hooks_path, 'post-commit')
201 201 version = get_version_from_hook(_hook_file)
202 202 return version
@@ -1,779 +1,779 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21 import logging
22 22 import collections
23 23 import importlib
24 24 import base64
25 25 import msgpack
26 26 import dataclasses
27 27 import pygit2
28 28
29 29 import http.client
30 30
31 31
32 32 import mercurial.scmutil
33 33 import mercurial.node
34 34
35 35 from vcsserver.lib.rc_json import json
36 36 from vcsserver import exceptions, subprocessio, settings
37 37 from vcsserver.str_utils import ascii_str, safe_str
38 38 from vcsserver.remote.git_remote import Repository
39 39
40 40 log = logging.getLogger(__name__)
41 41
42 42
43 class HooksHttpClient(object):
43 class HooksHttpClient:
44 44 proto = 'msgpack.v1'
45 45 connection = None
46 46
47 47 def __init__(self, hooks_uri):
48 48 self.hooks_uri = hooks_uri
49 49
50 50 def __repr__(self):
51 51 return f'{self.__class__}(hook_uri={self.hooks_uri}, proto={self.proto})'
52 52
53 53 def __call__(self, method, extras):
54 54 connection = http.client.HTTPConnection(self.hooks_uri)
55 55 # binary msgpack body
56 56 headers, body = self._serialize(method, extras)
57 57 log.debug('Doing a new hooks call using HTTPConnection to %s', self.hooks_uri)
58 58
59 59 try:
60 60 try:
61 61 connection.request('POST', '/', body, headers)
62 62 except Exception as error:
63 63 log.error('Hooks calling Connection failed on %s, org error: %s', connection.__dict__, error)
64 64 raise
65 65
66 66 response = connection.getresponse()
67 67 try:
68 68 return msgpack.load(response)
69 69 except Exception:
70 70 response_data = response.read()
71 71 log.exception('Failed to decode hook response json data. '
72 72 'response_code:%s, raw_data:%s',
73 73 response.status, response_data)
74 74 raise
75 75 finally:
76 76 connection.close()
77 77
78 78 @classmethod
79 79 def _serialize(cls, hook_name, extras):
80 80 data = {
81 81 'method': hook_name,
82 82 'extras': extras
83 83 }
84 84 headers = {
85 85 "rc-hooks-protocol": cls.proto,
86 86 "Connection": "keep-alive"
87 87 }
88 88 return headers, msgpack.packb(data)
89 89
90 90
91 class HooksDummyClient(object):
91 class HooksDummyClient:
92 92 def __init__(self, hooks_module):
93 93 self._hooks_module = importlib.import_module(hooks_module)
94 94
95 95 def __call__(self, hook_name, extras):
96 96 with self._hooks_module.Hooks() as hooks:
97 97 return getattr(hooks, hook_name)(extras)
98 98
99 99
100 class HooksShadowRepoClient(object):
100 class HooksShadowRepoClient:
101 101
102 102 def __call__(self, hook_name, extras):
103 103 return {'output': '', 'status': 0}
104 104
105 105
106 class RemoteMessageWriter(object):
106 class RemoteMessageWriter:
107 107 """Writer base class."""
108 108 def write(self, message):
109 109 raise NotImplementedError()
110 110
111 111
112 112 class HgMessageWriter(RemoteMessageWriter):
113 113 """Writer that knows how to send messages to mercurial clients."""
114 114
115 115 def __init__(self, ui):
116 116 self.ui = ui
117 117
118 118 def write(self, message: str):
119 119 # TODO: Check why the quiet flag is set by default.
120 120 old = self.ui.quiet
121 121 self.ui.quiet = False
122 122 self.ui.status(message.encode('utf-8'))
123 123 self.ui.quiet = old
124 124
125 125
126 126 class GitMessageWriter(RemoteMessageWriter):
127 127 """Writer that knows how to send messages to git clients."""
128 128
129 129 def __init__(self, stdout=None):
130 130 self.stdout = stdout or sys.stdout
131 131
132 132 def write(self, message: str):
133 133 self.stdout.write(message)
134 134
135 135
136 136 class SvnMessageWriter(RemoteMessageWriter):
137 137 """Writer that knows how to send messages to svn clients."""
138 138
139 139 def __init__(self, stderr=None):
140 140 # SVN needs data sent to stderr for back-to-client messaging
141 141 self.stderr = stderr or sys.stderr
142 142
143 143 def write(self, message):
144 144 self.stderr.write(message.encode('utf-8'))
145 145
146 146
147 147 def _handle_exception(result):
148 148 exception_class = result.get('exception')
149 149 exception_traceback = result.get('exception_traceback')
150 150 log.debug('Handling hook-call exception: %s', exception_class)
151 151
152 152 if exception_traceback:
153 153 log.error('Got traceback from remote call:%s', exception_traceback)
154 154
155 155 if exception_class == 'HTTPLockedRC':
156 156 raise exceptions.RepositoryLockedException()(*result['exception_args'])
157 157 elif exception_class == 'HTTPBranchProtected':
158 158 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
159 159 elif exception_class == 'RepositoryError':
160 160 raise exceptions.VcsException()(*result['exception_args'])
161 161 elif exception_class:
162 162 raise Exception(
163 163 f"""Got remote exception "{exception_class}" with args "{result['exception_args']}" """
164 164 )
165 165
166 166
167 167 def _get_hooks_client(extras):
168 168 hooks_uri = extras.get('hooks_uri')
169 169 is_shadow_repo = extras.get('is_shadow_repo')
170 170
171 171 if hooks_uri:
172 172 return HooksHttpClient(extras['hooks_uri'])
173 173 elif is_shadow_repo:
174 174 return HooksShadowRepoClient()
175 175 else:
176 176 return HooksDummyClient(extras['hooks_module'])
177 177
178 178
179 179 def _call_hook(hook_name, extras, writer):
180 180 hooks_client = _get_hooks_client(extras)
181 181 log.debug('Hooks, using client:%s', hooks_client)
182 182 result = hooks_client(hook_name, extras)
183 183 log.debug('Hooks got result: %s', result)
184 184 _handle_exception(result)
185 185 writer.write(result['output'])
186 186
187 187 return result['status']
188 188
189 189
190 190 def _extras_from_ui(ui):
191 191 hook_data = ui.config(b'rhodecode', b'RC_SCM_DATA')
192 192 if not hook_data:
193 193 # maybe it's inside environ ?
194 194 env_hook_data = os.environ.get('RC_SCM_DATA')
195 195 if env_hook_data:
196 196 hook_data = env_hook_data
197 197
198 198 extras = {}
199 199 if hook_data:
200 200 extras = json.loads(hook_data)
201 201 return extras
202 202
203 203
204 204 def _rev_range_hash(repo, node, check_heads=False):
205 205 from vcsserver.hgcompat import get_ctx
206 206
207 207 commits = []
208 208 revs = []
209 209 start = get_ctx(repo, node).rev()
210 210 end = len(repo)
211 211 for rev in range(start, end):
212 212 revs.append(rev)
213 213 ctx = get_ctx(repo, rev)
214 214 commit_id = ascii_str(mercurial.node.hex(ctx.node()))
215 215 branch = safe_str(ctx.branch())
216 216 commits.append((commit_id, branch))
217 217
218 218 parent_heads = []
219 219 if check_heads:
220 220 parent_heads = _check_heads(repo, start, end, revs)
221 221 return commits, parent_heads
222 222
223 223
224 224 def _check_heads(repo, start, end, commits):
225 225 from vcsserver.hgcompat import get_ctx
226 226 changelog = repo.changelog
227 227 parents = set()
228 228
229 229 for new_rev in commits:
230 230 for p in changelog.parentrevs(new_rev):
231 231 if p == mercurial.node.nullrev:
232 232 continue
233 233 if p < start:
234 234 parents.add(p)
235 235
236 236 for p in parents:
237 237 branch = get_ctx(repo, p).branch()
238 238 # The heads descending from that parent, on the same branch
239 239 parent_heads = {p}
240 240 reachable = {p}
241 241 for x in range(p + 1, end):
242 242 if get_ctx(repo, x).branch() != branch:
243 243 continue
244 244 for pp in changelog.parentrevs(x):
245 245 if pp in reachable:
246 246 reachable.add(x)
247 247 parent_heads.discard(pp)
248 248 parent_heads.add(x)
249 249 # More than one head? Suggest merging
250 250 if len(parent_heads) > 1:
251 251 return list(parent_heads)
252 252
253 253 return []
254 254
255 255
256 256 def _get_git_env():
257 257 env = {}
258 258 for k, v in os.environ.items():
259 259 if k.startswith('GIT'):
260 260 env[k] = v
261 261
262 262 # serialized version
263 263 return [(k, v) for k, v in env.items()]
264 264
265 265
266 266 def _get_hg_env(old_rev, new_rev, txnid, repo_path):
267 267 env = {}
268 268 for k, v in os.environ.items():
269 269 if k.startswith('HG'):
270 270 env[k] = v
271 271
272 272 env['HG_NODE'] = old_rev
273 273 env['HG_NODE_LAST'] = new_rev
274 274 env['HG_TXNID'] = txnid
275 275 env['HG_PENDING'] = repo_path
276 276
277 277 return [(k, v) for k, v in env.items()]
278 278
279 279
280 280 def repo_size(ui, repo, **kwargs):
281 281 extras = _extras_from_ui(ui)
282 282 return _call_hook('repo_size', extras, HgMessageWriter(ui))
283 283
284 284
285 285 def pre_pull(ui, repo, **kwargs):
286 286 extras = _extras_from_ui(ui)
287 287 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
288 288
289 289
290 290 def pre_pull_ssh(ui, repo, **kwargs):
291 291 extras = _extras_from_ui(ui)
292 292 if extras and extras.get('SSH'):
293 293 return pre_pull(ui, repo, **kwargs)
294 294 return 0
295 295
296 296
297 297 def post_pull(ui, repo, **kwargs):
298 298 extras = _extras_from_ui(ui)
299 299 return _call_hook('post_pull', extras, HgMessageWriter(ui))
300 300
301 301
302 302 def post_pull_ssh(ui, repo, **kwargs):
303 303 extras = _extras_from_ui(ui)
304 304 if extras and extras.get('SSH'):
305 305 return post_pull(ui, repo, **kwargs)
306 306 return 0
307 307
308 308
309 309 def pre_push(ui, repo, node=None, **kwargs):
310 310 """
311 311 Mercurial pre_push hook
312 312 """
313 313 extras = _extras_from_ui(ui)
314 314 detect_force_push = extras.get('detect_force_push')
315 315
316 316 rev_data = []
317 317 hook_type: str = safe_str(kwargs.get('hooktype'))
318 318
319 319 if node and hook_type == 'pretxnchangegroup':
320 320 branches = collections.defaultdict(list)
321 321 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
322 322 for commit_id, branch in commits:
323 323 branches[branch].append(commit_id)
324 324
325 325 for branch, commits in branches.items():
326 326 old_rev = ascii_str(kwargs.get('node_last')) or commits[0]
327 327 rev_data.append({
328 328 'total_commits': len(commits),
329 329 'old_rev': old_rev,
330 330 'new_rev': commits[-1],
331 331 'ref': '',
332 332 'type': 'branch',
333 333 'name': branch,
334 334 })
335 335
336 336 for push_ref in rev_data:
337 337 push_ref['multiple_heads'] = _heads
338 338
339 339 repo_path = os.path.join(
340 340 extras.get('repo_store', ''), extras.get('repository', ''))
341 341 push_ref['hg_env'] = _get_hg_env(
342 342 old_rev=push_ref['old_rev'],
343 343 new_rev=push_ref['new_rev'], txnid=ascii_str(kwargs.get('txnid')),
344 344 repo_path=repo_path)
345 345
346 346 extras['hook_type'] = hook_type or 'pre_push'
347 347 extras['commit_ids'] = rev_data
348 348
349 349 return _call_hook('pre_push', extras, HgMessageWriter(ui))
350 350
351 351
352 352 def pre_push_ssh(ui, repo, node=None, **kwargs):
353 353 extras = _extras_from_ui(ui)
354 354 if extras.get('SSH'):
355 355 return pre_push(ui, repo, node, **kwargs)
356 356
357 357 return 0
358 358
359 359
360 360 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
361 361 """
362 362 Mercurial pre_push hook for SSH
363 363 """
364 364 extras = _extras_from_ui(ui)
365 365 if extras.get('SSH'):
366 366 permission = extras['SSH_PERMISSIONS']
367 367
368 368 if 'repository.write' == permission or 'repository.admin' == permission:
369 369 return 0
370 370
371 371 # non-zero ret code
372 372 return 1
373 373
374 374 return 0
375 375
376 376
377 377 def post_push(ui, repo, node, **kwargs):
378 378 """
379 379 Mercurial post_push hook
380 380 """
381 381 extras = _extras_from_ui(ui)
382 382
383 383 commit_ids = []
384 384 branches = []
385 385 bookmarks = []
386 386 tags = []
387 387 hook_type: str = safe_str(kwargs.get('hooktype'))
388 388
389 389 commits, _heads = _rev_range_hash(repo, node)
390 390 for commit_id, branch in commits:
391 391 commit_ids.append(commit_id)
392 392 if branch not in branches:
393 393 branches.append(branch)
394 394
395 395 if hasattr(ui, '_rc_pushkey_bookmarks'):
396 396 bookmarks = ui._rc_pushkey_bookmarks
397 397
398 398 extras['hook_type'] = hook_type or 'post_push'
399 399 extras['commit_ids'] = commit_ids
400 400
401 401 extras['new_refs'] = {
402 402 'branches': branches,
403 403 'bookmarks': bookmarks,
404 404 'tags': tags
405 405 }
406 406
407 407 return _call_hook('post_push', extras, HgMessageWriter(ui))
408 408
409 409
410 410 def post_push_ssh(ui, repo, node, **kwargs):
411 411 """
412 412 Mercurial post_push hook for SSH
413 413 """
414 414 if _extras_from_ui(ui).get('SSH'):
415 415 return post_push(ui, repo, node, **kwargs)
416 416 return 0
417 417
418 418
419 419 def key_push(ui, repo, **kwargs):
420 420 from vcsserver.hgcompat import get_ctx
421 421
422 422 if kwargs['new'] != b'0' and kwargs['namespace'] == b'bookmarks':
423 423 # store new bookmarks in our UI object propagated later to post_push
424 424 ui._rc_pushkey_bookmarks = get_ctx(repo, kwargs['key']).bookmarks()
425 425 return
426 426
427 427
428 428 # backward compat
429 429 log_pull_action = post_pull
430 430
431 431 # backward compat
432 432 log_push_action = post_push
433 433
434 434
435 435 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
436 436 """
437 437 Old hook name: keep here for backward compatibility.
438 438
439 439 This is only required when the installed git hooks are not upgraded.
440 440 """
441 441 pass
442 442
443 443
444 444 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
445 445 """
446 446 Old hook name: keep here for backward compatibility.
447 447
448 448 This is only required when the installed git hooks are not upgraded.
449 449 """
450 450 pass
451 451
452 452
453 453 @dataclasses.dataclass
454 454 class HookResponse:
455 455 status: int
456 456 output: str
457 457
458 458
459 459 def git_pre_pull(extras) -> HookResponse:
460 460 """
461 461 Pre pull hook.
462 462
463 463 :param extras: dictionary containing the keys defined in simplevcs
464 464 :type extras: dict
465 465
466 466 :return: status code of the hook. 0 for success.
467 467 :rtype: int
468 468 """
469 469
470 470 if 'pull' not in extras['hooks']:
471 471 return HookResponse(0, '')
472 472
473 473 stdout = io.StringIO()
474 474 try:
475 475 status_code = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
476 476
477 477 except Exception as error:
478 478 log.exception('Failed to call pre_pull hook')
479 479 status_code = 128
480 480 stdout.write(f'ERROR: {error}\n')
481 481
482 482 return HookResponse(status_code, stdout.getvalue())
483 483
484 484
485 485 def git_post_pull(extras) -> HookResponse:
486 486 """
487 487 Post pull hook.
488 488
489 489 :param extras: dictionary containing the keys defined in simplevcs
490 490 :type extras: dict
491 491
492 492 :return: status code of the hook. 0 for success.
493 493 :rtype: int
494 494 """
495 495 if 'pull' not in extras['hooks']:
496 496 return HookResponse(0, '')
497 497
498 498 stdout = io.StringIO()
499 499 try:
500 500 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
501 501 except Exception as error:
502 502 status = 128
503 503 stdout.write(f'ERROR: {error}\n')
504 504
505 505 return HookResponse(status, stdout.getvalue())
506 506
507 507
508 508 def _parse_git_ref_lines(revision_lines):
509 509 rev_data = []
510 510 for revision_line in revision_lines or []:
511 511 old_rev, new_rev, ref = revision_line.strip().split(' ')
512 512 ref_data = ref.split('/', 2)
513 513 if ref_data[1] in ('tags', 'heads'):
514 514 rev_data.append({
515 515 # NOTE(marcink):
516 516 # we're unable to tell total_commits for git at this point
517 517 # but we set the variable for consistency with GIT
518 518 'total_commits': -1,
519 519 'old_rev': old_rev,
520 520 'new_rev': new_rev,
521 521 'ref': ref,
522 522 'type': ref_data[1],
523 523 'name': ref_data[2],
524 524 })
525 525 return rev_data
526 526
527 527
528 528 def git_pre_receive(unused_repo_path, revision_lines, env) -> int:
529 529 """
530 530 Pre push hook.
531 531
532 532 :return: status code of the hook. 0 for success.
533 533 """
534 534 extras = json.loads(env['RC_SCM_DATA'])
535 535 rev_data = _parse_git_ref_lines(revision_lines)
536 536 if 'push' not in extras['hooks']:
537 537 return 0
538 538 empty_commit_id = '0' * 40
539 539
540 540 detect_force_push = extras.get('detect_force_push')
541 541
542 542 for push_ref in rev_data:
543 543 # store our git-env which holds the temp store
544 544 push_ref['git_env'] = _get_git_env()
545 545 push_ref['pruned_sha'] = ''
546 546 if not detect_force_push:
547 547 # don't check for forced-push when we don't need to
548 548 continue
549 549
550 550 type_ = push_ref['type']
551 551 new_branch = push_ref['old_rev'] == empty_commit_id
552 552 delete_branch = push_ref['new_rev'] == empty_commit_id
553 553 if type_ == 'heads' and not (new_branch or delete_branch):
554 554 old_rev = push_ref['old_rev']
555 555 new_rev = push_ref['new_rev']
556 556 cmd = [settings.GIT_EXECUTABLE, 'rev-list', old_rev, f'^{new_rev}']
557 557 stdout, stderr = subprocessio.run_command(
558 558 cmd, env=os.environ.copy())
559 559 # means we're having some non-reachable objects, this forced push was used
560 560 if stdout:
561 561 push_ref['pruned_sha'] = stdout.splitlines()
562 562
563 563 extras['hook_type'] = 'pre_receive'
564 564 extras['commit_ids'] = rev_data
565 565
566 566 stdout = sys.stdout
567 567 status_code = _call_hook('pre_push', extras, GitMessageWriter(stdout))
568 568
569 569 return status_code
570 570
571 571
572 572 def git_post_receive(unused_repo_path, revision_lines, env) -> int:
573 573 """
574 574 Post push hook.
575 575
576 576 :return: status code of the hook. 0 for success.
577 577 """
578 578 extras = json.loads(env['RC_SCM_DATA'])
579 579 if 'push' not in extras['hooks']:
580 580 return 0
581 581
582 582 rev_data = _parse_git_ref_lines(revision_lines)
583 583
584 584 git_revs = []
585 585
586 586 # N.B.(skreft): it is ok to just call git, as git before calling a
587 587 # subcommand sets the PATH environment variable so that it point to the
588 588 # correct version of the git executable.
589 589 empty_commit_id = '0' * 40
590 590 branches = []
591 591 tags = []
592 592 for push_ref in rev_data:
593 593 type_ = push_ref['type']
594 594
595 595 if type_ == 'heads':
596 596 # starting new branch case
597 597 if push_ref['old_rev'] == empty_commit_id:
598 598 push_ref_name = push_ref['name']
599 599
600 600 if push_ref_name not in branches:
601 601 branches.append(push_ref_name)
602 602
603 603 need_head_set = ''
604 604 with Repository(os.getcwd()) as repo:
605 605 try:
606 606 repo.head
607 607 except pygit2.GitError:
608 608 need_head_set = f'refs/heads/{push_ref_name}'
609 609
610 610 if need_head_set:
611 611 repo.set_head(need_head_set)
612 612 print(f"Setting default branch to {push_ref_name}")
613 613
614 614 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref', '--format=%(refname)', 'refs/heads/*']
615 615 stdout, stderr = subprocessio.run_command(
616 616 cmd, env=os.environ.copy())
617 617 heads = safe_str(stdout)
618 618 heads = heads.replace(push_ref['ref'], '')
619 619 heads = ' '.join(head for head
620 620 in heads.splitlines() if head) or '.'
621 621 cmd = [settings.GIT_EXECUTABLE, 'log', '--reverse',
622 622 '--pretty=format:%H', '--', push_ref['new_rev'],
623 623 '--not', heads]
624 624 stdout, stderr = subprocessio.run_command(
625 625 cmd, env=os.environ.copy())
626 626 git_revs.extend(list(map(ascii_str, stdout.splitlines())))
627 627
628 628 # delete branch case
629 629 elif push_ref['new_rev'] == empty_commit_id:
630 git_revs.append('delete_branch=>%s' % push_ref['name'])
630 git_revs.append(f'delete_branch=>{push_ref["name"]}')
631 631 else:
632 632 if push_ref['name'] not in branches:
633 633 branches.append(push_ref['name'])
634 634
635 635 cmd = [settings.GIT_EXECUTABLE, 'log',
636 636 '{old_rev}..{new_rev}'.format(**push_ref),
637 637 '--reverse', '--pretty=format:%H']
638 638 stdout, stderr = subprocessio.run_command(
639 639 cmd, env=os.environ.copy())
640 640 # we get bytes from stdout, we need str to be consistent
641 641 log_revs = list(map(ascii_str, stdout.splitlines()))
642 642 git_revs.extend(log_revs)
643 643
644 644 # Pure pygit2 impl. but still 2-3x slower :/
645 645 # results = []
646 646 #
647 647 # with Repository(os.getcwd()) as repo:
648 648 # repo_new_rev = repo[push_ref['new_rev']]
649 649 # repo_old_rev = repo[push_ref['old_rev']]
650 650 # walker = repo.walk(repo_new_rev.id, pygit2.GIT_SORT_TOPOLOGICAL)
651 651 #
652 652 # for commit in walker:
653 653 # if commit.id == repo_old_rev.id:
654 654 # break
655 655 # results.append(commit.id.hex)
656 656 # # reverse the order, can't use GIT_SORT_REVERSE
657 657 # log_revs = results[::-1]
658 658
659 659 elif type_ == 'tags':
660 660 if push_ref['name'] not in tags:
661 661 tags.append(push_ref['name'])
662 git_revs.append('tag=>%s' % push_ref['name'])
662 git_revs.append(f'tag=>{push_ref["name"]}')
663 663
664 664 extras['hook_type'] = 'post_receive'
665 665 extras['commit_ids'] = git_revs
666 666 extras['new_refs'] = {
667 667 'branches': branches,
668 668 'bookmarks': [],
669 669 'tags': tags,
670 670 }
671 671
672 672 stdout = sys.stdout
673 673
674 674 if 'repo_size' in extras['hooks']:
675 675 try:
676 676 _call_hook('repo_size', extras, GitMessageWriter(stdout))
677 677 except Exception:
678 678 pass
679 679
680 680 status_code = _call_hook('post_push', extras, GitMessageWriter(stdout))
681 681 return status_code
682 682
683 683
684 684 def _get_extras_from_txn_id(path, txn_id):
685 685 extras = {}
686 686 try:
687 687 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
688 688 '-t', txn_id,
689 689 '--revprop', path, 'rc-scm-extras']
690 690 stdout, stderr = subprocessio.run_command(
691 691 cmd, env=os.environ.copy())
692 692 extras = json.loads(base64.urlsafe_b64decode(stdout))
693 693 except Exception:
694 694 log.exception('Failed to extract extras info from txn_id')
695 695
696 696 return extras
697 697
698 698
699 699 def _get_extras_from_commit_id(commit_id, path):
700 700 extras = {}
701 701 try:
702 702 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
703 703 '-r', commit_id,
704 704 '--revprop', path, 'rc-scm-extras']
705 705 stdout, stderr = subprocessio.run_command(
706 706 cmd, env=os.environ.copy())
707 707 extras = json.loads(base64.urlsafe_b64decode(stdout))
708 708 except Exception:
709 709 log.exception('Failed to extract extras info from commit_id')
710 710
711 711 return extras
712 712
713 713
714 714 def svn_pre_commit(repo_path, commit_data, env):
715 715 path, txn_id = commit_data
716 716 branches = []
717 717 tags = []
718 718
719 719 if env.get('RC_SCM_DATA'):
720 720 extras = json.loads(env['RC_SCM_DATA'])
721 721 else:
722 722 # fallback method to read from TXN-ID stored data
723 723 extras = _get_extras_from_txn_id(path, txn_id)
724 724 if not extras:
725 725 return 0
726 726
727 727 extras['hook_type'] = 'pre_commit'
728 728 extras['commit_ids'] = [txn_id]
729 729 extras['txn_id'] = txn_id
730 730 extras['new_refs'] = {
731 731 'total_commits': 1,
732 732 'branches': branches,
733 733 'bookmarks': [],
734 734 'tags': tags,
735 735 }
736 736
737 737 return _call_hook('pre_push', extras, SvnMessageWriter())
738 738
739 739
740 740 def svn_post_commit(repo_path, commit_data, env):
741 741 """
742 742 commit_data is path, rev, txn_id
743 743 """
744 744 if len(commit_data) == 3:
745 745 path, commit_id, txn_id = commit_data
746 746 elif len(commit_data) == 2:
747 747 log.error('Failed to extract txn_id from commit_data using legacy method. '
748 748 'Some functionality might be limited')
749 749 path, commit_id = commit_data
750 750 txn_id = None
751 751
752 752 branches = []
753 753 tags = []
754 754
755 755 if env.get('RC_SCM_DATA'):
756 756 extras = json.loads(env['RC_SCM_DATA'])
757 757 else:
758 758 # fallback method to read from TXN-ID stored data
759 759 extras = _get_extras_from_commit_id(commit_id, path)
760 760 if not extras:
761 761 return 0
762 762
763 763 extras['hook_type'] = 'post_commit'
764 764 extras['commit_ids'] = [commit_id]
765 765 extras['txn_id'] = txn_id
766 766 extras['new_refs'] = {
767 767 'branches': branches,
768 768 'bookmarks': [],
769 769 'tags': tags,
770 770 'total_commits': 1,
771 771 }
772 772
773 773 if 'repo_size' in extras['hooks']:
774 774 try:
775 775 _call_hook('repo_size', extras, SvnMessageWriter())
776 776 except Exception:
777 777 pass
778 778
779 779 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,775 +1,775 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import platform
21 21 import sys
22 22 import locale
23 23 import logging
24 24 import uuid
25 25 import time
26 26 import wsgiref.util
27 27 import tempfile
28 28 import psutil
29 29
30 30 from itertools import chain
31 31
32 32 import msgpack
33 33 import configparser
34 34
35 35 from pyramid.config import Configurator
36 36 from pyramid.wsgi import wsgiapp
37 37 from pyramid.response import Response
38 38
39 39 from vcsserver.base import BytesEnvelope, BinaryEnvelope
40 40 from vcsserver.lib.rc_json import json
41 41 from vcsserver.config.settings_maker import SettingsMaker
42 42 from vcsserver.str_utils import safe_int
43 43 from vcsserver.lib.statsd_client import StatsdClient
44 44 from vcsserver.tweens.request_wrapper import get_headers_call_context
45 45
46 46 import vcsserver
47 47 from vcsserver import remote_wsgi, scm_app, settings, hgpatches
48 48 from vcsserver.git_lfs.app import GIT_LFS_CONTENT_TYPE, GIT_LFS_PROTO_PAT
49 49 from vcsserver.echo_stub import remote_wsgi as remote_wsgi_stub
50 50 from vcsserver.echo_stub.echo_app import EchoApp
51 51 from vcsserver.exceptions import HTTPRepoLocked, HTTPRepoBranchProtected
52 52 from vcsserver.lib.exc_tracking import store_exception, format_exc
53 53 from vcsserver.server import VcsServer
54 54
55 55 strict_vcs = True
56 56
57 57 git_import_err = None
58 58 try:
59 59 from vcsserver.remote.git_remote import GitFactory, GitRemote
60 60 except ImportError as e:
61 61 GitFactory = None
62 62 GitRemote = None
63 63 git_import_err = e
64 64 if strict_vcs:
65 65 raise
66 66
67 67
68 68 hg_import_err = None
69 69 try:
70 70 from vcsserver.remote.hg_remote import MercurialFactory, HgRemote
71 71 except ImportError as e:
72 72 MercurialFactory = None
73 73 HgRemote = None
74 74 hg_import_err = e
75 75 if strict_vcs:
76 76 raise
77 77
78 78
79 79 svn_import_err = None
80 80 try:
81 81 from vcsserver.remote.svn_remote import SubversionFactory, SvnRemote
82 82 except ImportError as e:
83 83 SubversionFactory = None
84 84 SvnRemote = None
85 85 svn_import_err = e
86 86 if strict_vcs:
87 87 raise
88 88
89 89 log = logging.getLogger(__name__)
90 90
91 91 # due to Mercurial/glibc2.27 problems we need to detect if locale settings are
92 92 # causing problems and "fix" it in case they do and fallback to LC_ALL = C
93 93
94 94 try:
95 95 locale.setlocale(locale.LC_ALL, '')
96 96 except locale.Error as e:
97 97 log.error(
98 98 'LOCALE ERROR: failed to set LC_ALL, fallback to LC_ALL=C, org error: %s', e)
99 99 os.environ['LC_ALL'] = 'C'
100 100
101 101
102 102 def _is_request_chunked(environ):
103 103 stream = environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked'
104 104 return stream
105 105
106 106
107 107 def log_max_fd():
108 108 try:
109 109 maxfd = psutil.Process().rlimit(psutil.RLIMIT_NOFILE)[1]
110 110 log.info('Max file descriptors value: %s', maxfd)
111 111 except Exception:
112 112 pass
113 113
114 114
115 class VCS(object):
115 class VCS:
116 116 def __init__(self, locale_conf=None, cache_config=None):
117 117 self.locale = locale_conf
118 118 self.cache_config = cache_config
119 119 self._configure_locale()
120 120
121 121 log_max_fd()
122 122
123 123 if GitFactory and GitRemote:
124 124 git_factory = GitFactory()
125 125 self._git_remote = GitRemote(git_factory)
126 126 else:
127 127 log.error("Git client import failed: %s", git_import_err)
128 128
129 129 if MercurialFactory and HgRemote:
130 130 hg_factory = MercurialFactory()
131 131 self._hg_remote = HgRemote(hg_factory)
132 132 else:
133 133 log.error("Mercurial client import failed: %s", hg_import_err)
134 134
135 135 if SubversionFactory and SvnRemote:
136 136 svn_factory = SubversionFactory()
137 137
138 138 # hg factory is used for svn url validation
139 139 hg_factory = MercurialFactory()
140 140 self._svn_remote = SvnRemote(svn_factory, hg_factory=hg_factory)
141 141 else:
142 142 log.error("Subversion client import failed: %s", svn_import_err)
143 143
144 144 self._vcsserver = VcsServer()
145 145
146 146 def _configure_locale(self):
147 147 if self.locale:
148 148 log.info('Settings locale: `LC_ALL` to %s', self.locale)
149 149 else:
150 150 log.info('Configuring locale subsystem based on environment variables')
151 151 try:
152 152 # If self.locale is the empty string, then the locale
153 153 # module will use the environment variables. See the
154 154 # documentation of the package `locale`.
155 155 locale.setlocale(locale.LC_ALL, self.locale)
156 156
157 157 language_code, encoding = locale.getlocale()
158 158 log.info(
159 159 'Locale set to language code "%s" with encoding "%s".',
160 160 language_code, encoding)
161 161 except locale.Error:
162 162 log.exception('Cannot set locale, not configuring the locale system')
163 163
164 164
165 class WsgiProxy(object):
165 class WsgiProxy:
166 166 def __init__(self, wsgi):
167 167 self.wsgi = wsgi
168 168
169 169 def __call__(self, environ, start_response):
170 170 input_data = environ['wsgi.input'].read()
171 171 input_data = msgpack.unpackb(input_data)
172 172
173 173 error = None
174 174 try:
175 175 data, status, headers = self.wsgi.handle(
176 176 input_data['environment'], input_data['input_data'],
177 177 *input_data['args'], **input_data['kwargs'])
178 178 except Exception as e:
179 179 data, status, headers = [], None, None
180 180 error = {
181 181 'message': str(e),
182 182 '_vcs_kind': getattr(e, '_vcs_kind', None)
183 183 }
184 184
185 185 start_response(200, {})
186 186 return self._iterator(error, status, headers, data)
187 187
188 188 def _iterator(self, error, status, headers, data):
189 189 initial_data = [
190 190 error,
191 191 status,
192 192 headers,
193 193 ]
194 194
195 195 for d in chain(initial_data, data):
196 196 yield msgpack.packb(d)
197 197
198 198
199 199 def not_found(request):
200 200 return {'status': '404 NOT FOUND'}
201 201
202 202
203 class VCSViewPredicate(object):
203 class VCSViewPredicate:
204 204 def __init__(self, val, config):
205 205 self.remotes = val
206 206
207 207 def text(self):
208 208 return f'vcs view method = {list(self.remotes.keys())}'
209 209
210 210 phash = text
211 211
212 212 def __call__(self, context, request):
213 213 """
214 214 View predicate that returns true if given backend is supported by
215 215 defined remotes.
216 216 """
217 217 backend = request.matchdict.get('backend')
218 218 return backend in self.remotes
219 219
220 220
221 class HTTPApplication(object):
221 class HTTPApplication:
222 222 ALLOWED_EXCEPTIONS = ('KeyError', 'URLError')
223 223
224 224 remote_wsgi = remote_wsgi
225 225 _use_echo_app = False
226 226
227 227 def __init__(self, settings=None, global_config=None):
228 228
229 229 self.config = Configurator(settings=settings)
230 230 # Init our statsd at very start
231 231 self.config.registry.statsd = StatsdClient.statsd
232 232 self.config.registry.vcs_call_context = {}
233 233
234 234 self.global_config = global_config
235 235 self.config.include('vcsserver.lib.rc_cache')
236 236 self.config.include('vcsserver.lib.rc_cache.archive_cache')
237 237
238 238 settings_locale = settings.get('locale', '') or 'en_US.UTF-8'
239 239 vcs = VCS(locale_conf=settings_locale, cache_config=settings)
240 240 self._remotes = {
241 241 'hg': vcs._hg_remote,
242 242 'git': vcs._git_remote,
243 243 'svn': vcs._svn_remote,
244 244 'server': vcs._vcsserver,
245 245 }
246 246 if settings.get('dev.use_echo_app', 'false').lower() == 'true':
247 247 self._use_echo_app = True
248 248 log.warning("Using EchoApp for VCS operations.")
249 249 self.remote_wsgi = remote_wsgi_stub
250 250
251 251 self._configure_settings(global_config, settings)
252 252
253 253 self._configure()
254 254
255 255 def _configure_settings(self, global_config, app_settings):
256 256 """
257 257 Configure the settings module.
258 258 """
259 259 settings_merged = global_config.copy()
260 260 settings_merged.update(app_settings)
261 261
262 262 git_path = app_settings.get('git_path', None)
263 263 if git_path:
264 264 settings.GIT_EXECUTABLE = git_path
265 265 binary_dir = app_settings.get('core.binary_dir', None)
266 266 if binary_dir:
267 267 settings.BINARY_DIR = binary_dir
268 268
269 269 # Store the settings to make them available to other modules.
270 270 vcsserver.PYRAMID_SETTINGS = settings_merged
271 271 vcsserver.CONFIG = settings_merged
272 272
273 273 def _configure(self):
274 274 self.config.add_renderer(name='msgpack', factory=self._msgpack_renderer_factory)
275 275
276 276 self.config.add_route('service', '/_service')
277 277 self.config.add_route('status', '/status')
278 278 self.config.add_route('hg_proxy', '/proxy/hg')
279 279 self.config.add_route('git_proxy', '/proxy/git')
280 280
281 281 # rpc methods
282 282 self.config.add_route('vcs', '/{backend}')
283 283
284 284 # streaming rpc remote methods
285 285 self.config.add_route('vcs_stream', '/{backend}/stream')
286 286
287 287 # vcs operations clone/push as streaming
288 288 self.config.add_route('stream_git', '/stream/git/*repo_name')
289 289 self.config.add_route('stream_hg', '/stream/hg/*repo_name')
290 290
291 291 self.config.add_view(self.status_view, route_name='status', renderer='json')
292 292 self.config.add_view(self.service_view, route_name='service', renderer='msgpack')
293 293
294 294 self.config.add_view(self.hg_proxy(), route_name='hg_proxy')
295 295 self.config.add_view(self.git_proxy(), route_name='git_proxy')
296 296 self.config.add_view(self.vcs_view, route_name='vcs', renderer='msgpack',
297 297 vcs_view=self._remotes)
298 298 self.config.add_view(self.vcs_stream_view, route_name='vcs_stream',
299 299 vcs_view=self._remotes)
300 300
301 301 self.config.add_view(self.hg_stream(), route_name='stream_hg')
302 302 self.config.add_view(self.git_stream(), route_name='stream_git')
303 303
304 304 self.config.add_view_predicate('vcs_view', VCSViewPredicate)
305 305
306 306 self.config.add_notfound_view(not_found, renderer='json')
307 307
308 308 self.config.add_view(self.handle_vcs_exception, context=Exception)
309 309
310 310 self.config.add_tween(
311 311 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
312 312 )
313 313 self.config.add_request_method(
314 314 'vcsserver.lib.request_counter.get_request_counter',
315 315 'request_count')
316 316
317 317 def wsgi_app(self):
318 318 return self.config.make_wsgi_app()
319 319
320 320 def _vcs_view_params(self, request):
321 321 remote = self._remotes[request.matchdict['backend']]
322 322 payload = msgpack.unpackb(request.body, use_list=True)
323 323
324 324 method = payload.get('method')
325 325 params = payload['params']
326 326 wire = params.get('wire')
327 327 args = params.get('args')
328 328 kwargs = params.get('kwargs')
329 329 context_uid = None
330 330
331 331 request.registry.vcs_call_context = {
332 332 'method': method,
333 333 'repo_name': payload.get('_repo_name'),
334 334 }
335 335
336 336 if wire:
337 337 try:
338 338 wire['context'] = context_uid = uuid.UUID(wire['context'])
339 339 except KeyError:
340 340 pass
341 341 args.insert(0, wire)
342 342 repo_state_uid = wire.get('repo_state_uid') if wire else None
343 343
344 344 # NOTE(marcink): trading complexity for slight performance
345 345 if log.isEnabledFor(logging.DEBUG):
346 346 # also we SKIP printing out any of those methods args since they maybe excessive
347 347 just_args_methods = {
348 348 'commitctx': ('content', 'removed', 'updated'),
349 349 'commit': ('content', 'removed', 'updated')
350 350 }
351 351 if method in just_args_methods:
352 352 skip_args = just_args_methods[method]
353 353 call_args = ''
354 354 call_kwargs = {}
355 355 for k in kwargs:
356 356 if k in skip_args:
357 357 # replace our skip key with dummy
358 358 call_kwargs[k] = f'RemovedParam({k})'
359 359 else:
360 360 call_kwargs[k] = kwargs[k]
361 361 else:
362 362 call_args = args[1:]
363 363 call_kwargs = kwargs
364 364
365 365 log.debug('Method requested:`%s` with args:%s kwargs:%s context_uid: %s, repo_state_uid:%s',
366 366 method, call_args, call_kwargs, context_uid, repo_state_uid)
367 367
368 368 statsd = request.registry.statsd
369 369 if statsd:
370 370 statsd.incr(
371 371 'vcsserver_method_total', tags=[
372 372 f"method:{method}",
373 373 ])
374 374 return payload, remote, method, args, kwargs
375 375
376 376 def vcs_view(self, request):
377 377
378 378 payload, remote, method, args, kwargs = self._vcs_view_params(request)
379 379 payload_id = payload.get('id')
380 380
381 381 try:
382 382 resp = getattr(remote, method)(*args, **kwargs)
383 383 except Exception as e:
384 384 exc_info = list(sys.exc_info())
385 385 exc_type, exc_value, exc_traceback = exc_info
386 386
387 387 org_exc = getattr(e, '_org_exc', None)
388 388 org_exc_name = None
389 389 org_exc_tb = ''
390 390 if org_exc:
391 391 org_exc_name = org_exc.__class__.__name__
392 392 org_exc_tb = getattr(e, '_org_exc_tb', '')
393 393 # replace our "faked" exception with our org
394 394 exc_info[0] = org_exc.__class__
395 395 exc_info[1] = org_exc
396 396
397 397 should_store_exc = True
398 398 if org_exc:
399 399 def get_exc_fqn(_exc_obj):
400 400 module_name = getattr(org_exc.__class__, '__module__', 'UNKNOWN')
401 401 return module_name + '.' + org_exc_name
402 402
403 403 exc_fqn = get_exc_fqn(org_exc)
404 404
405 405 if exc_fqn in ['mercurial.error.RepoLookupError',
406 406 'vcsserver.exceptions.RefNotFoundException']:
407 407 should_store_exc = False
408 408
409 409 if should_store_exc:
410 410 store_exception(id(exc_info), exc_info, request_path=request.path)
411 411
412 412 tb_info = format_exc(exc_info)
413 413
414 414 type_ = e.__class__.__name__
415 415 if type_ not in self.ALLOWED_EXCEPTIONS:
416 416 type_ = None
417 417
418 418 resp = {
419 419 'id': payload_id,
420 420 'error': {
421 421 'message': str(e),
422 422 'traceback': tb_info,
423 423 'org_exc': org_exc_name,
424 424 'org_exc_tb': org_exc_tb,
425 425 'type': type_
426 426 }
427 427 }
428 428
429 429 try:
430 430 resp['error']['_vcs_kind'] = getattr(e, '_vcs_kind', None)
431 431 except AttributeError:
432 432 pass
433 433 else:
434 434 resp = {
435 435 'id': payload_id,
436 436 'result': resp
437 437 }
438 438 log.debug('Serving data for method %s', method)
439 439 return resp
440 440
441 441 def vcs_stream_view(self, request):
442 442 payload, remote, method, args, kwargs = self._vcs_view_params(request)
443 443 # this method has a stream: marker we remove it here
444 444 method = method.split('stream:')[-1]
445 445 chunk_size = safe_int(payload.get('chunk_size')) or 4096
446 446
447 447 resp = getattr(remote, method)(*args, **kwargs)
448 448
449 449 def get_chunked_data(method_resp):
450 450 stream = io.BytesIO(method_resp)
451 451 while 1:
452 452 chunk = stream.read(chunk_size)
453 453 if not chunk:
454 454 break
455 455 yield chunk
456 456
457 457 response = Response(app_iter=get_chunked_data(resp))
458 458 response.content_type = 'application/octet-stream'
459 459
460 460 return response
461 461
462 462 def status_view(self, request):
463 463 import vcsserver
464 464 _platform_id = platform.uname()[1] or 'instance'
465 465
466 466 return {
467 467 "status": "OK",
468 468 "vcsserver_version": vcsserver.__version__,
469 469 "platform": _platform_id,
470 470 "pid": os.getpid(),
471 471 }
472 472
473 473 def service_view(self, request):
474 474 import vcsserver
475 475
476 476 payload = msgpack.unpackb(request.body, use_list=True)
477 477 server_config, app_config = {}, {}
478 478
479 479 try:
480 480 path = self.global_config['__file__']
481 481 config = configparser.RawConfigParser()
482 482
483 483 config.read(path)
484 484
485 485 if config.has_section('server:main'):
486 486 server_config = dict(config.items('server:main'))
487 487 if config.has_section('app:main'):
488 488 app_config = dict(config.items('app:main'))
489 489
490 490 except Exception:
491 491 log.exception('Failed to read .ini file for display')
492 492
493 493 environ = list(os.environ.items())
494 494
495 495 resp = {
496 496 'id': payload.get('id'),
497 497 'result': dict(
498 498 version=vcsserver.__version__,
499 499 config=server_config,
500 500 app_config=app_config,
501 501 environ=environ,
502 502 payload=payload,
503 503 )
504 504 }
505 505 return resp
506 506
507 507 def _msgpack_renderer_factory(self, info):
508 508
509 509 def _render(value, system):
510 510 bin_type = False
511 511 res = value.get('result')
512 512 if isinstance(res, BytesEnvelope):
513 513 log.debug('Result is wrapped in BytesEnvelope type')
514 514 bin_type = True
515 515 elif isinstance(res, BinaryEnvelope):
516 516 log.debug('Result is wrapped in BinaryEnvelope type')
517 517 value['result'] = res.val
518 518 bin_type = True
519 519
520 520 request = system.get('request')
521 521 if request is not None:
522 522 response = request.response
523 523 ct = response.content_type
524 524 if ct == response.default_content_type:
525 525 response.content_type = 'application/x-msgpack'
526 526 if bin_type:
527 527 response.content_type = 'application/x-msgpack-bin'
528 528
529 529 return msgpack.packb(value, use_bin_type=bin_type)
530 530 return _render
531 531
532 532 def set_env_from_config(self, environ, config):
533 533 dict_conf = {}
534 534 try:
535 535 for elem in config:
536 536 if elem[0] == 'rhodecode':
537 537 dict_conf = json.loads(elem[2])
538 538 break
539 539 except Exception:
540 540 log.exception('Failed to fetch SCM CONFIG')
541 541 return
542 542
543 543 username = dict_conf.get('username')
544 544 if username:
545 545 environ['REMOTE_USER'] = username
546 546 # mercurial specific, some extension api rely on this
547 547 environ['HGUSER'] = username
548 548
549 549 ip = dict_conf.get('ip')
550 550 if ip:
551 551 environ['REMOTE_HOST'] = ip
552 552
553 553 if _is_request_chunked(environ):
554 554 # set the compatibility flag for webob
555 555 environ['wsgi.input_terminated'] = True
556 556
557 557 def hg_proxy(self):
558 558 @wsgiapp
559 559 def _hg_proxy(environ, start_response):
560 560 app = WsgiProxy(self.remote_wsgi.HgRemoteWsgi())
561 561 return app(environ, start_response)
562 562 return _hg_proxy
563 563
564 564 def git_proxy(self):
565 565 @wsgiapp
566 566 def _git_proxy(environ, start_response):
567 567 app = WsgiProxy(self.remote_wsgi.GitRemoteWsgi())
568 568 return app(environ, start_response)
569 569 return _git_proxy
570 570
571 571 def hg_stream(self):
572 572 if self._use_echo_app:
573 573 @wsgiapp
574 574 def _hg_stream(environ, start_response):
575 575 app = EchoApp('fake_path', 'fake_name', None)
576 576 return app(environ, start_response)
577 577 return _hg_stream
578 578 else:
579 579 @wsgiapp
580 580 def _hg_stream(environ, start_response):
581 581 log.debug('http-app: handling hg stream')
582 582 call_context = get_headers_call_context(environ)
583 583
584 584 repo_path = call_context['repo_path']
585 585 repo_name = call_context['repo_name']
586 586 config = call_context['repo_config']
587 587
588 588 app = scm_app.create_hg_wsgi_app(
589 589 repo_path, repo_name, config)
590 590
591 591 # Consistent path information for hgweb
592 592 environ['PATH_INFO'] = call_context['path_info']
593 593 environ['REPO_NAME'] = repo_name
594 594 self.set_env_from_config(environ, config)
595 595
596 596 log.debug('http-app: starting app handler '
597 597 'with %s and process request', app)
598 598 return app(environ, ResponseFilter(start_response))
599 599 return _hg_stream
600 600
601 601 def git_stream(self):
602 602 if self._use_echo_app:
603 603 @wsgiapp
604 604 def _git_stream(environ, start_response):
605 605 app = EchoApp('fake_path', 'fake_name', None)
606 606 return app(environ, start_response)
607 607 return _git_stream
608 608 else:
609 609 @wsgiapp
610 610 def _git_stream(environ, start_response):
611 611 log.debug('http-app: handling git stream')
612 612
613 613 call_context = get_headers_call_context(environ)
614 614
615 615 repo_path = call_context['repo_path']
616 616 repo_name = call_context['repo_name']
617 617 config = call_context['repo_config']
618 618
619 619 environ['PATH_INFO'] = call_context['path_info']
620 620 self.set_env_from_config(environ, config)
621 621
622 622 content_type = environ.get('CONTENT_TYPE', '')
623 623
624 624 path = environ['PATH_INFO']
625 625 is_lfs_request = GIT_LFS_CONTENT_TYPE in content_type
626 626 log.debug(
627 627 'LFS: Detecting if request `%s` is LFS server path based '
628 628 'on content type:`%s`, is_lfs:%s',
629 629 path, content_type, is_lfs_request)
630 630
631 631 if not is_lfs_request:
632 632 # fallback detection by path
633 633 if GIT_LFS_PROTO_PAT.match(path):
634 634 is_lfs_request = True
635 635 log.debug(
636 636 'LFS: fallback detection by path of: `%s`, is_lfs:%s',
637 637 path, is_lfs_request)
638 638
639 639 if is_lfs_request:
640 640 app = scm_app.create_git_lfs_wsgi_app(
641 641 repo_path, repo_name, config)
642 642 else:
643 643 app = scm_app.create_git_wsgi_app(
644 644 repo_path, repo_name, config)
645 645
646 646 log.debug('http-app: starting app handler '
647 647 'with %s and process request', app)
648 648
649 649 return app(environ, start_response)
650 650
651 651 return _git_stream
652 652
653 653 def handle_vcs_exception(self, exception, request):
654 654 _vcs_kind = getattr(exception, '_vcs_kind', '')
655 655
656 656 if _vcs_kind == 'repo_locked':
657 657 headers_call_context = get_headers_call_context(request.environ)
658 658 status_code = safe_int(headers_call_context['locked_status_code'])
659 659
660 660 return HTTPRepoLocked(
661 661 title=str(exception), status_code=status_code, headers=[('X-Rc-Locked', '1')])
662 662
663 663 elif _vcs_kind == 'repo_branch_protected':
664 664 # Get custom repo-branch-protected status code if present.
665 665 return HTTPRepoBranchProtected(
666 666 title=str(exception), headers=[('X-Rc-Branch-Protection', '1')])
667 667
668 668 exc_info = request.exc_info
669 669 store_exception(id(exc_info), exc_info)
670 670
671 671 traceback_info = 'unavailable'
672 672 if request.exc_info:
673 673 traceback_info = format_exc(request.exc_info)
674 674
675 675 log.error(
676 676 'error occurred handling this request for path: %s, \n%s',
677 677 request.path, traceback_info)
678 678
679 679 statsd = request.registry.statsd
680 680 if statsd:
681 681 exc_type = f"{exception.__class__.__module__}.{exception.__class__.__name__}"
682 682 statsd.incr('vcsserver_exception_total',
683 683 tags=[f"type:{exc_type}"])
684 684 raise exception
685 685
686 686
687 class ResponseFilter(object):
687 class ResponseFilter:
688 688
689 689 def __init__(self, start_response):
690 690 self._start_response = start_response
691 691
692 692 def __call__(self, status, response_headers, exc_info=None):
693 693 headers = tuple(
694 694 (h, v) for h, v in response_headers
695 695 if not wsgiref.util.is_hop_by_hop(h))
696 696 return self._start_response(status, headers, exc_info)
697 697
698 698
699 699 def sanitize_settings_and_apply_defaults(global_config, settings):
700 700 _global_settings_maker = SettingsMaker(global_config)
701 701 settings_maker = SettingsMaker(settings)
702 702
703 703 settings_maker.make_setting('logging.autoconfigure', False, parser='bool')
704 704
705 705 logging_conf = os.path.join(os.path.dirname(global_config.get('__file__')), 'logging.ini')
706 706 settings_maker.enable_logging(logging_conf)
707 707
708 708 # Default includes, possible to change as a user
709 709 pyramid_includes = settings_maker.make_setting('pyramid.includes', [], parser='list:newline')
710 710 log.debug("Using the following pyramid.includes: %s", pyramid_includes)
711 711
712 712 settings_maker.make_setting('__file__', global_config.get('__file__'))
713 713
714 714 settings_maker.make_setting('pyramid.default_locale_name', 'en')
715 715 settings_maker.make_setting('locale', 'en_US.UTF-8')
716 716
717 717 settings_maker.make_setting('core.binary_dir', '')
718 718
719 719 temp_store = tempfile.gettempdir()
720 720 default_cache_dir = os.path.join(temp_store, 'rc_cache')
721 721 # save default, cache dir, and use it for all backends later.
722 722 default_cache_dir = settings_maker.make_setting(
723 723 'cache_dir',
724 724 default=default_cache_dir, default_when_empty=True,
725 725 parser='dir:ensured')
726 726
727 727 # exception store cache
728 728 settings_maker.make_setting(
729 729 'exception_tracker.store_path',
730 730 default=os.path.join(default_cache_dir, 'exc_store'), default_when_empty=True,
731 731 parser='dir:ensured'
732 732 )
733 733
734 734 # repo_object cache defaults
735 735 settings_maker.make_setting(
736 736 'rc_cache.repo_object.backend',
737 737 default='dogpile.cache.rc.file_namespace',
738 738 parser='string')
739 739 settings_maker.make_setting(
740 740 'rc_cache.repo_object.expiration_time',
741 741 default=30 * 24 * 60 * 60, # 30days
742 742 parser='int')
743 743 settings_maker.make_setting(
744 744 'rc_cache.repo_object.arguments.filename',
745 745 default=os.path.join(default_cache_dir, 'vcsserver_cache_repo_object.db'),
746 746 parser='string')
747 747
748 748 # statsd
749 749 settings_maker.make_setting('statsd.enabled', False, parser='bool')
750 750 settings_maker.make_setting('statsd.statsd_host', 'statsd-exporter', parser='string')
751 751 settings_maker.make_setting('statsd.statsd_port', 9125, parser='int')
752 752 settings_maker.make_setting('statsd.statsd_prefix', '')
753 753 settings_maker.make_setting('statsd.statsd_ipv6', False, parser='bool')
754 754
755 755 settings_maker.env_expand()
756 756
757 757
758 758 def main(global_config, **settings):
759 759 start_time = time.time()
760 760 log.info('Pyramid app config starting')
761 761
762 762 if MercurialFactory:
763 763 hgpatches.patch_largefiles_capabilities()
764 764 hgpatches.patch_subrepo_type_mapping()
765 765
766 766 # Fill in and sanitize the defaults & do ENV expansion
767 767 sanitize_settings_and_apply_defaults(global_config, settings)
768 768
769 769 # init and bootstrap StatsdClient
770 770 StatsdClient.setup(settings)
771 771
772 772 pyramid_app = HTTPApplication(settings=settings, global_config=global_config).wsgi_app()
773 773 total_time = time.time() - start_time
774 774 log.info('Pyramid app created and configured in %.2fs', total_time)
775 775 return pyramid_app
@@ -1,394 +1,394 b''
1 1
2 2 import threading
3 3 import weakref
4 4 from base64 import b64encode
5 5 from logging import getLogger
6 6 from os import urandom
7 7 from typing import Union
8 8
9 9 from redis import StrictRedis
10 10
11 11 __version__ = '4.0.0'
12 12
13 13 loggers = {
14 14 k: getLogger("vcsserver." + ".".join((__name__, k)))
15 15 for k in [
16 16 "acquire",
17 17 "refresh.thread.start",
18 18 "refresh.thread.stop",
19 19 "refresh.thread.exit",
20 20 "refresh.start",
21 21 "refresh.shutdown",
22 22 "refresh.exit",
23 23 "release",
24 24 ]
25 25 }
26 26
27 27 text_type = str
28 28 binary_type = bytes
29 29
30 30
31 31 # Check if the id match. If not, return an error code.
32 32 UNLOCK_SCRIPT = b"""
33 33 if redis.call("get", KEYS[1]) ~= ARGV[1] then
34 34 return 1
35 35 else
36 36 redis.call("del", KEYS[2])
37 37 redis.call("lpush", KEYS[2], 1)
38 38 redis.call("pexpire", KEYS[2], ARGV[2])
39 39 redis.call("del", KEYS[1])
40 40 return 0
41 41 end
42 42 """
43 43
44 44 # Covers both cases when key doesn't exist and doesn't equal to lock's id
45 45 EXTEND_SCRIPT = b"""
46 46 if redis.call("get", KEYS[1]) ~= ARGV[1] then
47 47 return 1
48 48 elseif redis.call("ttl", KEYS[1]) < 0 then
49 49 return 2
50 50 else
51 51 redis.call("expire", KEYS[1], ARGV[2])
52 52 return 0
53 53 end
54 54 """
55 55
56 56 RESET_SCRIPT = b"""
57 57 redis.call('del', KEYS[2])
58 58 redis.call('lpush', KEYS[2], 1)
59 59 redis.call('pexpire', KEYS[2], ARGV[2])
60 60 return redis.call('del', KEYS[1])
61 61 """
62 62
63 63 RESET_ALL_SCRIPT = b"""
64 64 local locks = redis.call('keys', 'lock:*')
65 65 local signal
66 66 for _, lock in pairs(locks) do
67 67 signal = 'lock-signal:' .. string.sub(lock, 6)
68 68 redis.call('del', signal)
69 69 redis.call('lpush', signal, 1)
70 70 redis.call('expire', signal, 1)
71 71 redis.call('del', lock)
72 72 end
73 73 return #locks
74 74 """
75 75
76 76
77 77 class AlreadyAcquired(RuntimeError):
78 78 pass
79 79
80 80
81 81 class NotAcquired(RuntimeError):
82 82 pass
83 83
84 84
85 85 class AlreadyStarted(RuntimeError):
86 86 pass
87 87
88 88
89 89 class TimeoutNotUsable(RuntimeError):
90 90 pass
91 91
92 92
93 93 class InvalidTimeout(RuntimeError):
94 94 pass
95 95
96 96
97 97 class TimeoutTooLarge(RuntimeError):
98 98 pass
99 99
100 100
101 101 class NotExpirable(RuntimeError):
102 102 pass
103 103
104 104
105 class Lock(object):
105 class Lock:
106 106 """
107 107 A Lock context manager implemented via redis SETNX/BLPOP.
108 108 """
109 109
110 110 unlock_script = None
111 111 extend_script = None
112 112 reset_script = None
113 113 reset_all_script = None
114 114
115 115 _lock_renewal_interval: float
116 116 _lock_renewal_thread: Union[threading.Thread, None]
117 117
118 118 def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000):
119 119 """
120 120 :param redis_client:
121 121 An instance of :class:`~StrictRedis`.
122 122 :param name:
123 123 The name (redis key) the lock should have.
124 124 :param expire:
125 125 The lock expiry time in seconds. If left at the default (None)
126 126 the lock will not expire.
127 127 :param id:
128 128 The ID (redis value) the lock should have. A random value is
129 129 generated when left at the default.
130 130
131 131 Note that if you specify this then the lock is marked as "held". Acquires
132 132 won't be possible.
133 133 :param auto_renewal:
134 134 If set to ``True``, Lock will automatically renew the lock so that it
135 135 doesn't expire for as long as the lock is held (acquire() called
136 136 or running in a context manager).
137 137
138 138 Implementation note: Renewal will happen using a daemon thread with
139 139 an interval of ``expire*2/3``. If wishing to use a different renewal
140 140 time, subclass Lock, call ``super().__init__()`` then set
141 141 ``self._lock_renewal_interval`` to your desired interval.
142 142 :param strict:
143 143 If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
144 144 :param signal_expire:
145 145 Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``.
146 146 """
147 147 if strict and not isinstance(redis_client, StrictRedis):
148 148 raise ValueError("redis_client must be instance of StrictRedis. "
149 149 "Use strict=False if you know what you're doing.")
150 150 if auto_renewal and expire is None:
151 151 raise ValueError("Expire may not be None when auto_renewal is set")
152 152
153 153 self._client = redis_client
154 154
155 155 if expire:
156 156 expire = int(expire)
157 157 if expire < 0:
158 158 raise ValueError("A negative expire is not acceptable.")
159 159 else:
160 160 expire = None
161 161 self._expire = expire
162 162
163 163 self._signal_expire = signal_expire
164 164 if id is None:
165 165 self._id = b64encode(urandom(18)).decode('ascii')
166 166 elif isinstance(id, binary_type):
167 167 try:
168 168 self._id = id.decode('ascii')
169 169 except UnicodeDecodeError:
170 170 self._id = b64encode(id).decode('ascii')
171 171 elif isinstance(id, text_type):
172 172 self._id = id
173 173 else:
174 174 raise TypeError(f"Incorrect type for `id`. Must be bytes/str not {type(id)}.")
175 175 self._name = 'lock:' + name
176 176 self._signal = 'lock-signal:' + name
177 177 self._lock_renewal_interval = (float(expire) * 2 / 3
178 178 if auto_renewal
179 179 else None)
180 180 self._lock_renewal_thread = None
181 181
182 182 self.register_scripts(redis_client)
183 183
184 184 @classmethod
185 185 def register_scripts(cls, redis_client):
186 186 global reset_all_script
187 187 if reset_all_script is None:
188 188 cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT)
189 189 cls.extend_script = redis_client.register_script(EXTEND_SCRIPT)
190 190 cls.reset_script = redis_client.register_script(RESET_SCRIPT)
191 191 cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
192 192 reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
193 193
194 194 @property
195 195 def _held(self):
196 196 return self.id == self.get_owner_id()
197 197
198 198 def reset(self):
199 199 """
200 200 Forcibly deletes the lock. Use this with care.
201 201 """
202 202 self.reset_script(client=self._client, keys=(self._name, self._signal), args=(self.id, self._signal_expire))
203 203
204 204 @property
205 205 def id(self):
206 206 return self._id
207 207
208 208 def get_owner_id(self):
209 209 owner_id = self._client.get(self._name)
210 210 if isinstance(owner_id, binary_type):
211 211 owner_id = owner_id.decode('ascii', 'replace')
212 212 return owner_id
213 213
214 214 def acquire(self, blocking=True, timeout=None):
215 215 """
216 216 :param blocking:
217 217 Boolean value specifying whether lock should be blocking or not.
218 218 :param timeout:
219 219 An integer value specifying the maximum number of seconds to block.
220 220 """
221 221 logger = loggers["acquire"]
222 222
223 223 logger.debug("Getting blocking: %s acquire on %r ...", blocking, self._name)
224 224
225 225 if self._held:
226 226 owner_id = self.get_owner_id()
227 227 raise AlreadyAcquired("Already acquired from this Lock instance. Lock id: {}".format(owner_id))
228 228
229 229 if not blocking and timeout is not None:
230 230 raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
231 231
232 232 if timeout:
233 233 timeout = int(timeout)
234 234 if timeout < 0:
235 235 raise InvalidTimeout(f"Timeout ({timeout}) cannot be less than or equal to 0")
236 236
237 237 if self._expire and not self._lock_renewal_interval and timeout > self._expire:
238 238 raise TimeoutTooLarge(f"Timeout ({timeout}) cannot be greater than expire ({self._expire})")
239 239
240 240 busy = True
241 241 blpop_timeout = timeout or self._expire or 0
242 242 timed_out = False
243 243 while busy:
244 244 busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
245 245 if busy:
246 246 if timed_out:
247 247 return False
248 248 elif blocking:
249 249 timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
250 250 else:
251 251 logger.warning("Failed to acquire Lock(%r).", self._name)
252 252 return False
253 253
254 254 logger.debug("Acquired Lock(%r).", self._name)
255 255 if self._lock_renewal_interval is not None:
256 256 self._start_lock_renewer()
257 257 return True
258 258
259 259 def extend(self, expire=None):
260 260 """
261 261 Extends expiration time of the lock.
262 262
263 263 :param expire:
264 264 New expiration time. If ``None`` - `expire` provided during
265 265 lock initialization will be taken.
266 266 """
267 267 if expire:
268 268 expire = int(expire)
269 269 if expire < 0:
270 270 raise ValueError("A negative expire is not acceptable.")
271 271 elif self._expire is not None:
272 272 expire = self._expire
273 273 else:
274 274 raise TypeError(
275 275 "To extend a lock 'expire' must be provided as an "
276 276 "argument to extend() method or at initialization time."
277 277 )
278 278
279 279 error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
280 280 if error == 1:
281 281 raise NotAcquired(f"Lock {self._name} is not acquired or it already expired.")
282 282 elif error == 2:
283 283 raise NotExpirable(f"Lock {self._name} has no assigned expiration time")
284 284 elif error:
285 285 raise RuntimeError(f"Unsupported error code {error} from EXTEND script")
286 286
287 287 @staticmethod
288 288 def _lock_renewer(name, lockref, interval, stop):
289 289 """
290 290 Renew the lock key in redis every `interval` seconds for as long
291 291 as `self._lock_renewal_thread.should_exit` is False.
292 292 """
293 293 while not stop.wait(timeout=interval):
294 294 loggers["refresh.thread.start"].debug("Refreshing Lock(%r).", name)
295 295 lock: "Lock" = lockref()
296 296 if lock is None:
297 297 loggers["refresh.thread.stop"].debug(
298 298 "Stopping loop because Lock(%r) was garbage collected.", name
299 299 )
300 300 break
301 301 lock.extend(expire=lock._expire)
302 302 del lock
303 303 loggers["refresh.thread.exit"].debug("Exiting renewal thread for Lock(%r).", name)
304 304
305 305 def _start_lock_renewer(self):
306 306 """
307 307 Starts the lock refresher thread.
308 308 """
309 309 if self._lock_renewal_thread is not None:
310 310 raise AlreadyStarted("Lock refresh thread already started")
311 311
312 312 loggers["refresh.start"].debug(
313 313 "Starting renewal thread for Lock(%r). Refresh interval: %s seconds.",
314 314 self._name, self._lock_renewal_interval
315 315 )
316 316 self._lock_renewal_stop = threading.Event()
317 317 self._lock_renewal_thread = threading.Thread(
318 318 group=None,
319 319 target=self._lock_renewer,
320 320 kwargs={
321 321 'name': self._name,
322 322 'lockref': weakref.ref(self),
323 323 'interval': self._lock_renewal_interval,
324 324 'stop': self._lock_renewal_stop,
325 325 },
326 326 )
327 327 self._lock_renewal_thread.daemon = True
328 328 self._lock_renewal_thread.start()
329 329
330 330 def _stop_lock_renewer(self):
331 331 """
332 332 Stop the lock renewer.
333 333
334 334 This signals the renewal thread and waits for its exit.
335 335 """
336 336 if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
337 337 return
338 338 loggers["refresh.shutdown"].debug("Signaling renewal thread for Lock(%r) to exit.", self._name)
339 339 self._lock_renewal_stop.set()
340 340 self._lock_renewal_thread.join()
341 341 self._lock_renewal_thread = None
342 342 loggers["refresh.exit"].debug("Renewal thread for Lock(%r) exited.", self._name)
343 343
344 344 def __enter__(self):
345 345 acquired = self.acquire(blocking=True)
346 346 if not acquired:
347 347 raise AssertionError(f"Lock({self._name}) wasn't acquired, but blocking=True was used!")
348 348 return self
349 349
350 350 def __exit__(self, exc_type=None, exc_value=None, traceback=None):
351 351 self.release()
352 352
353 353 def release(self):
354 354 """Releases the lock, that was acquired with the same object.
355 355
356 356 .. note::
357 357
358 358 If you want to release a lock that you acquired in a different place you have two choices:
359 359
360 360 * Use ``Lock("name", id=id_from_other_place).release()``
361 361 * Use ``Lock("name").reset()``
362 362 """
363 363 if self._lock_renewal_thread is not None:
364 364 self._stop_lock_renewer()
365 365 loggers["release"].debug("Releasing Lock(%r).", self._name)
366 366 error = self.unlock_script(client=self._client, keys=(self._name, self._signal), args=(self._id, self._signal_expire))
367 367 if error == 1:
368 368 raise NotAcquired(f"Lock({self._name}) is not acquired or it already expired.")
369 369 elif error:
370 370 raise RuntimeError(f"Unsupported error code {error} from EXTEND script.")
371 371
372 372 def locked(self):
373 373 """
374 374 Return true if the lock is acquired.
375 375
376 376 Checks that lock with same name already exists. This method returns true, even if
377 377 lock have another id.
378 378 """
379 379 return self._client.exists(self._name) == 1
380 380
381 381
382 382 reset_all_script = None
383 383
384 384
385 385 def reset_all(redis_client):
386 386 """
387 387 Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
388 388
389 389 :param redis_client:
390 390 An instance of :class:`~StrictRedis`.
391 391 """
392 392 Lock.register_scripts(redis_client)
393 393
394 394 reset_all_script(client=redis_client) # noqa
@@ -1,154 +1,154 b''
1 1 import re
2 2 import random
3 3 from collections import deque
4 4 from datetime import timedelta
5 5 from repoze.lru import lru_cache
6 6
7 7 from .timer import Timer
8 8
9 9 TAG_INVALID_CHARS_RE = re.compile(
10 10 r"[^\w\d_\-:/\.]",
11 11 #re.UNICODE
12 12 )
13 13 TAG_INVALID_CHARS_SUBS = "_"
14 14
15 15 # we save and expose methods called by statsd for discovery
16 16 buckets_dict = {
17 17
18 18 }
19 19
20 20
21 21 @lru_cache(maxsize=500)
22 22 def _normalize_tags_with_cache(tag_list):
23 23 return [TAG_INVALID_CHARS_RE.sub(TAG_INVALID_CHARS_SUBS, tag) for tag in tag_list]
24 24
25 25
26 26 def normalize_tags(tag_list):
27 27 # We have to turn our input tag list into a non-mutable tuple for it to
28 28 # be hashable (and thus usable) by the @lru_cache decorator.
29 29 return _normalize_tags_with_cache(tuple(tag_list))
30 30
31 31
32 class StatsClientBase(object):
32 class StatsClientBase:
33 33 """A Base class for various statsd clients."""
34 34
35 35 def close(self):
36 36 """Used to close and clean up any underlying resources."""
37 37 raise NotImplementedError()
38 38
39 39 def _send(self):
40 40 raise NotImplementedError()
41 41
42 42 def pipeline(self):
43 43 raise NotImplementedError()
44 44
45 45 def timer(self, stat, rate=1, tags=None, auto_send=True):
46 46 """
47 47 statsd = StatsdClient.statsd
48 48 with statsd.timer('bucket_name', auto_send=True) as tmr:
49 49 # This block will be timed.
50 50 for i in range(0, 100000):
51 51 i ** 2
52 52 # you can access time here...
53 53 elapsed_ms = tmr.ms
54 54 """
55 55 return Timer(self, stat, rate, tags, auto_send=auto_send)
56 56
57 57 def timing(self, stat, delta, rate=1, tags=None, use_decimals=True):
58 58 """
59 59 Send new timing information.
60 60
61 61 `delta` can be either a number of milliseconds or a timedelta.
62 62 """
63 63 if isinstance(delta, timedelta):
64 64 # Convert timedelta to number of milliseconds.
65 65 delta = delta.total_seconds() * 1000.
66 66 if use_decimals:
67 67 fmt = '%0.6f|ms'
68 68 else:
69 69 fmt = '%s|ms'
70 70 self._send_stat(stat, fmt % delta, rate, tags)
71 71
72 72 def incr(self, stat, count=1, rate=1, tags=None):
73 73 """Increment a stat by `count`."""
74 74 self._send_stat(stat, '%s|c' % count, rate, tags)
75 75
76 76 def decr(self, stat, count=1, rate=1, tags=None):
77 77 """Decrement a stat by `count`."""
78 78 self.incr(stat, -count, rate, tags)
79 79
80 80 def gauge(self, stat, value, rate=1, delta=False, tags=None):
81 81 """Set a gauge value."""
82 82 if value < 0 and not delta:
83 83 if rate < 1:
84 84 if random.random() > rate:
85 85 return
86 86 with self.pipeline() as pipe:
87 87 pipe._send_stat(stat, '0|g', 1)
88 88 pipe._send_stat(stat, '%s|g' % value, 1)
89 89 else:
90 90 prefix = '+' if delta and value >= 0 else ''
91 91 self._send_stat(stat, '%s%s|g' % (prefix, value), rate, tags)
92 92
93 93 def set(self, stat, value, rate=1):
94 94 """Set a set value."""
95 95 self._send_stat(stat, '%s|s' % value, rate)
96 96
97 97 def histogram(self, stat, value, rate=1, tags=None):
98 98 """Set a histogram"""
99 99 self._send_stat(stat, '%s|h' % value, rate, tags)
100 100
101 101 def _send_stat(self, stat, value, rate, tags=None):
102 102 self._after(self._prepare(stat, value, rate, tags))
103 103
104 104 def _prepare(self, stat, value, rate, tags=None):
105 105 global buckets_dict
106 106 buckets_dict[stat] = 1
107 107
108 108 if rate < 1:
109 109 if random.random() > rate:
110 110 return
111 111 value = '%s|@%s' % (value, rate)
112 112
113 113 if self._prefix:
114 114 stat = '%s.%s' % (self._prefix, stat)
115 115
116 116 res = '%s:%s%s' % (
117 117 stat,
118 118 value,
119 119 ("|#" + ",".join(normalize_tags(tags))) if tags else "",
120 120 )
121 121 return res
122 122
123 123 def _after(self, data):
124 124 if data:
125 125 self._send(data)
126 126
127 127
128 128 class PipelineBase(StatsClientBase):
129 129
130 130 def __init__(self, client):
131 131 self._client = client
132 132 self._prefix = client._prefix
133 133 self._stats = deque()
134 134
135 135 def _send(self):
136 136 raise NotImplementedError()
137 137
138 138 def _after(self, data):
139 139 if data is not None:
140 140 self._stats.append(data)
141 141
142 142 def __enter__(self):
143 143 return self
144 144
145 145 def __exit__(self, typ, value, tb):
146 146 self.send()
147 147
148 148 def send(self):
149 149 if not self._stats:
150 150 return
151 151 self._send()
152 152
153 153 def pipeline(self):
154 154 return self.__class__(self)
@@ -1,66 +1,66 b''
1 1 import functools
2 2 from time import perf_counter as time_now
3 3
4 4
5 5 def safe_wraps(wrapper, *args, **kwargs):
6 6 """Safely wraps partial functions."""
7 7 while isinstance(wrapper, functools.partial):
8 8 wrapper = wrapper.func
9 9 return functools.wraps(wrapper, *args, **kwargs)
10 10
11 11
12 class Timer(object):
12 class Timer:
13 13 """A context manager/decorator for statsd.timing()."""
14 14
15 15 def __init__(self, client, stat, rate=1, tags=None, use_decimals=True, auto_send=True):
16 16 self.client = client
17 17 self.stat = stat
18 18 self.rate = rate
19 19 self.tags = tags
20 20 self.ms = None
21 21 self._sent = False
22 22 self._start_time = None
23 23 self.use_decimals = use_decimals
24 24 self.auto_send = auto_send
25 25
26 26 def __call__(self, f):
27 27 """Thread-safe timing function decorator."""
28 28 @safe_wraps(f)
29 29 def _wrapped(*args, **kwargs):
30 30 start_time = time_now()
31 31 try:
32 32 return f(*args, **kwargs)
33 33 finally:
34 34 elapsed_time_ms = 1000.0 * (time_now() - start_time)
35 35 self.client.timing(self.stat, elapsed_time_ms, self.rate, self.tags, self.use_decimals)
36 36 self._sent = True
37 37 return _wrapped
38 38
39 39 def __enter__(self):
40 40 return self.start()
41 41
42 42 def __exit__(self, typ, value, tb):
43 43 self.stop(send=self.auto_send)
44 44
45 45 def start(self):
46 46 self.ms = None
47 47 self._sent = False
48 48 self._start_time = time_now()
49 49 return self
50 50
51 51 def stop(self, send=True):
52 52 if self._start_time is None:
53 53 raise RuntimeError('Timer has not started.')
54 54 dt = time_now() - self._start_time
55 55 self.ms = 1000.0 * dt # Convert to milliseconds.
56 56 if send:
57 57 self.send()
58 58 return self
59 59
60 60 def send(self):
61 61 if self.ms is None:
62 62 raise RuntimeError('No data recorded.')
63 63 if self._sent:
64 64 raise RuntimeError('Already sent data.')
65 65 self._sent = True
66 66 self.client.timing(self.stat, self.ms, self.rate, self.tags, self.use_decimals)
@@ -1,267 +1,267 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 #import errno
19 19 import fcntl
20 20 import functools
21 21 import logging
22 22 import os
23 23 import pickle
24 24 #import time
25 25
26 26 #import gevent
27 27 import msgpack
28 28 import redis
29 29
30 30 flock_org = fcntl.flock
31 31 from typing import Union
32 32
33 33 from dogpile.cache.api import Deserializer, Serializer
34 34 from dogpile.cache.backends import file as file_backend
35 35 from dogpile.cache.backends import memory as memory_backend
36 36 from dogpile.cache.backends import redis as redis_backend
37 37 from dogpile.cache.backends.file import FileLock
38 38 from dogpile.cache.util import memoized_property
39 39
40 40 from vcsserver.lib.memory_lru_dict import LRUDict, LRUDictDebug
41 41 from vcsserver.str_utils import safe_bytes, safe_str
42 42 from vcsserver.type_utils import str2bool
43 43
44 44 _default_max_size = 1024
45 45
46 46 log = logging.getLogger(__name__)
47 47
48 48
49 49 class LRUMemoryBackend(memory_backend.MemoryBackend):
50 50 key_prefix = 'lru_mem_backend'
51 51 pickle_values = False
52 52
53 53 def __init__(self, arguments):
54 54 self.max_size = arguments.pop('max_size', _default_max_size)
55 55
56 56 LRUDictClass = LRUDict
57 57 if arguments.pop('log_key_count', None):
58 58 LRUDictClass = LRUDictDebug
59 59
60 60 arguments['cache_dict'] = LRUDictClass(self.max_size)
61 61 super().__init__(arguments)
62 62
63 63 def __repr__(self):
64 64 return f'{self.__class__}(maxsize=`{self.max_size}`)'
65 65
66 66 def __str__(self):
67 67 return self.__repr__()
68 68
69 69 def delete(self, key):
70 70 try:
71 71 del self._cache[key]
72 72 except KeyError:
73 73 # we don't care if key isn't there at deletion
74 74 pass
75 75
76 76 def delete_multi(self, keys):
77 77 for key in keys:
78 78 self.delete(key)
79 79
80 80
81 81 class PickleSerializer:
82 82 serializer: None | Serializer = staticmethod( # type: ignore
83 83 functools.partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL)
84 84 )
85 85 deserializer: None | Deserializer = staticmethod( # type: ignore
86 86 functools.partial(pickle.loads)
87 87 )
88 88
89 89
90 class MsgPackSerializer(object):
90 class MsgPackSerializer:
91 91 serializer: None | Serializer = staticmethod( # type: ignore
92 92 msgpack.packb
93 93 )
94 94 deserializer: None | Deserializer = staticmethod( # type: ignore
95 95 functools.partial(msgpack.unpackb, use_list=False)
96 96 )
97 97
98 98
99 99 class CustomLockFactory(FileLock):
100 100
101 101 pass
102 102
103 103
104 104 class FileNamespaceBackend(PickleSerializer, file_backend.DBMBackend):
105 105 key_prefix = 'file_backend'
106 106
107 107 def __init__(self, arguments):
108 108 arguments['lock_factory'] = CustomLockFactory
109 109 db_file = arguments.get('filename')
110 110
111 111 log.debug('initialing cache-backend=%s db in %s', self.__class__.__name__, db_file)
112 112 db_file_dir = os.path.dirname(db_file)
113 113 if not os.path.isdir(db_file_dir):
114 114 os.makedirs(db_file_dir)
115 115
116 116 try:
117 117 super().__init__(arguments)
118 118 except Exception:
119 119 log.exception('Failed to initialize db at: %s', db_file)
120 120 raise
121 121
122 122 def __repr__(self):
123 123 return f'{self.__class__}(file=`{self.filename}`)'
124 124
125 125 def __str__(self):
126 126 return self.__repr__()
127 127
128 128 def _get_keys_pattern(self, prefix: bytes = b''):
129 129 return b'%b:%b' % (safe_bytes(self.key_prefix), safe_bytes(prefix))
130 130
131 131 def list_keys(self, prefix: bytes = b''):
132 132 prefix = self._get_keys_pattern(prefix)
133 133
134 134 def cond(dbm_key: bytes):
135 135 if not prefix:
136 136 return True
137 137
138 138 if dbm_key.startswith(prefix):
139 139 return True
140 140 return False
141 141
142 142 with self._dbm_file(True) as dbm:
143 143 try:
144 144 return list(filter(cond, dbm.keys()))
145 145 except Exception:
146 146 log.error('Failed to fetch DBM keys from DB: %s', self.get_store())
147 147 raise
148 148
149 149 def get_store(self):
150 150 return self.filename
151 151
152 152
153 153 class BaseRedisBackend(redis_backend.RedisBackend):
154 154 key_prefix = ''
155 155
156 156 def __init__(self, arguments):
157 157 self.db_conn = arguments.get('host', '') or arguments.get('url', '') or 'redis-host'
158 158 super().__init__(arguments)
159 159
160 160 self._lock_timeout = self.lock_timeout
161 161 self._lock_auto_renewal = str2bool(arguments.pop("lock_auto_renewal", True))
162 162
163 163 if self._lock_auto_renewal and not self._lock_timeout:
164 164 # set default timeout for auto_renewal
165 165 self._lock_timeout = 30
166 166
167 167 def __repr__(self):
168 168 return f'{self.__class__}(conn=`{self.db_conn}`)'
169 169
170 170 def __str__(self):
171 171 return self.__repr__()
172 172
173 173 def _create_client(self):
174 174 args = {}
175 175
176 176 if self.url is not None:
177 177 args.update(url=self.url)
178 178
179 179 else:
180 180 args.update(
181 181 host=self.host, password=self.password,
182 182 port=self.port, db=self.db
183 183 )
184 184
185 185 connection_pool = redis.ConnectionPool(**args)
186 186 self.writer_client = redis.StrictRedis(
187 187 connection_pool=connection_pool
188 188 )
189 189 self.reader_client = self.writer_client
190 190
191 191 def _get_keys_pattern(self, prefix: bytes = b''):
192 192 return b'%b:%b*' % (safe_bytes(self.key_prefix), safe_bytes(prefix))
193 193
194 194 def list_keys(self, prefix: bytes = b''):
195 195 prefix = self._get_keys_pattern(prefix)
196 196 return self.reader_client.keys(prefix)
197 197
198 198 def get_store(self):
199 199 return self.reader_client.connection_pool
200 200
201 201 def get_mutex(self, key):
202 202 if self.distributed_lock:
203 203 lock_key = f'_lock_{safe_str(key)}'
204 204 return get_mutex_lock(
205 205 self.writer_client, lock_key,
206 206 self._lock_timeout,
207 207 auto_renewal=self._lock_auto_renewal
208 208 )
209 209 else:
210 210 return None
211 211
212 212
213 213 class RedisPickleBackend(PickleSerializer, BaseRedisBackend):
214 214 key_prefix = 'redis_pickle_backend'
215 215 pass
216 216
217 217
218 218 class RedisMsgPackBackend(MsgPackSerializer, BaseRedisBackend):
219 219 key_prefix = 'redis_msgpack_backend'
220 220 pass
221 221
222 222
223 223 def get_mutex_lock(client, lock_key, lock_timeout, auto_renewal=False):
224 224 from vcsserver.lib._vendor import redis_lock
225 225
226 class _RedisLockWrapper(object):
226 class _RedisLockWrapper:
227 227 """LockWrapper for redis_lock"""
228 228
229 229 @classmethod
230 230 def get_lock(cls):
231 231 return redis_lock.Lock(
232 232 redis_client=client,
233 233 name=lock_key,
234 234 expire=lock_timeout,
235 235 auto_renewal=auto_renewal,
236 236 strict=True,
237 237 )
238 238
239 239 def __repr__(self):
240 240 return f"{self.__class__.__name__}:{lock_key}"
241 241
242 242 def __str__(self):
243 243 return f"{self.__class__.__name__}:{lock_key}"
244 244
245 245 def __init__(self):
246 246 self.lock = self.get_lock()
247 247 self.lock_key = lock_key
248 248
249 249 def acquire(self, wait=True):
250 250 log.debug('Trying to acquire Redis lock for key %s', self.lock_key)
251 251 try:
252 252 acquired = self.lock.acquire(wait)
253 253 log.debug('Got lock for key %s, %s', self.lock_key, acquired)
254 254 return acquired
255 255 except redis_lock.AlreadyAcquired:
256 256 return False
257 257 except redis_lock.AlreadyStarted:
258 258 # refresh thread exists, but it also means we acquired the lock
259 259 return True
260 260
261 261 def release(self):
262 262 try:
263 263 self.lock.release()
264 264 except redis_lock.NotAcquired:
265 265 pass
266 266
267 267 return _RedisLockWrapper()
@@ -1,160 +1,160 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import tempfile
20 20
21 21 from svn import client
22 22 from svn import core
23 23 from svn import ra
24 24
25 25 from mercurial import error
26 26
27 27 from vcsserver.str_utils import safe_bytes
28 28
29 29 core.svn_config_ensure(None)
30 30 svn_config = core.svn_config_get_config(None)
31 31
32 32
33 33 class RaCallbacks(ra.Callbacks):
34 34 @staticmethod
35 35 def open_tmp_file(pool): # pragma: no cover
36 36 (fd, fn) = tempfile.mkstemp()
37 37 os.close(fd)
38 38 return fn
39 39
40 40 @staticmethod
41 41 def get_client_string(pool):
42 42 return b'RhodeCode-subversion-url-checker'
43 43
44 44
45 45 class SubversionException(Exception):
46 46 pass
47 47
48 48
49 49 class SubversionConnectionException(SubversionException):
50 50 """Exception raised when a generic error occurs when connecting to a repository."""
51 51
52 52
53 53 def normalize_url(url):
54 54 if not url:
55 55 return url
56 56 if url.startswith(b'svn+http://') or url.startswith(b'svn+https://'):
57 57 url = url[4:]
58 58 url = url.rstrip(b'/')
59 59 return url
60 60
61 61
62 62 def _create_auth_baton(pool):
63 63 """Create a Subversion authentication baton. """
64 64 # Give the client context baton a suite of authentication
65 65 # providers.h
66 66 platform_specific = [
67 67 'svn_auth_get_gnome_keyring_simple_provider',
68 68 'svn_auth_get_gnome_keyring_ssl_client_cert_pw_provider',
69 69 'svn_auth_get_keychain_simple_provider',
70 70 'svn_auth_get_keychain_ssl_client_cert_pw_provider',
71 71 'svn_auth_get_kwallet_simple_provider',
72 72 'svn_auth_get_kwallet_ssl_client_cert_pw_provider',
73 73 'svn_auth_get_ssl_client_cert_file_provider',
74 74 'svn_auth_get_windows_simple_provider',
75 75 'svn_auth_get_windows_ssl_server_trust_provider',
76 76 ]
77 77
78 78 providers = []
79 79
80 80 for p in platform_specific:
81 81 if getattr(core, p, None) is not None:
82 82 try:
83 83 providers.append(getattr(core, p)())
84 84 except RuntimeError:
85 85 pass
86 86
87 87 providers += [
88 88 client.get_simple_provider(),
89 89 client.get_username_provider(),
90 90 client.get_ssl_client_cert_file_provider(),
91 91 client.get_ssl_client_cert_pw_file_provider(),
92 92 client.get_ssl_server_trust_file_provider(),
93 93 ]
94 94
95 95 return core.svn_auth_open(providers, pool)
96 96
97 97
98 class SubversionRepo(object):
98 class SubversionRepo:
99 99 """Wrapper for a Subversion repository.
100 100
101 101 It uses the SWIG Python bindings, see above for requirements.
102 102 """
103 103 def __init__(self, svn_url: bytes = b'', username: bytes = b'', password: bytes = b''):
104 104
105 105 self.username = username
106 106 self.password = password
107 107 self.svn_url = core.svn_path_canonicalize(svn_url)
108 108
109 109 self.auth_baton_pool = core.Pool()
110 110 self.auth_baton = _create_auth_baton(self.auth_baton_pool)
111 111 # self.init_ra_and_client() assumes that a pool already exists
112 112 self.pool = core.Pool()
113 113
114 114 self.ra = self.init_ra_and_client()
115 115 self.uuid = ra.get_uuid(self.ra, self.pool)
116 116
117 117 def init_ra_and_client(self):
118 118 """Initializes the RA and client layers, because sometimes getting
119 119 unified diffs runs the remote server out of open files.
120 120 """
121 121
122 122 if self.username:
123 123 core.svn_auth_set_parameter(self.auth_baton,
124 124 core.SVN_AUTH_PARAM_DEFAULT_USERNAME,
125 125 self.username)
126 126 if self.password:
127 127 core.svn_auth_set_parameter(self.auth_baton,
128 128 core.SVN_AUTH_PARAM_DEFAULT_PASSWORD,
129 129 self.password)
130 130
131 131 callbacks = RaCallbacks()
132 132 callbacks.auth_baton = self.auth_baton
133 133
134 134 try:
135 135 return ra.open2(self.svn_url, callbacks, svn_config, self.pool)
136 136 except SubversionException as e:
137 137 # e.child contains a detailed error messages
138 138 msglist = []
139 139 svn_exc = e
140 140 while svn_exc:
141 141 if svn_exc.args[0]:
142 142 msglist.append(svn_exc.args[0])
143 143 svn_exc = svn_exc.child
144 144 msg = '\n'.join(msglist)
145 145 raise SubversionConnectionException(msg)
146 146
147 147
148 class svnremoterepo(object):
148 class svnremoterepo:
149 149 """ the dumb wrapper for actual Subversion repositories """
150 150
151 151 def __init__(self, username: bytes = b'', password: bytes = b'', svn_url: bytes = b''):
152 152 self.username = username or b''
153 153 self.password = password or b''
154 154 self.path = normalize_url(svn_url)
155 155
156 156 def svn(self):
157 157 try:
158 158 return SubversionRepo(self.path, self.username, self.password)
159 159 except SubversionConnectionException as e:
160 160 raise error.Abort(safe_bytes(e))
@@ -1,417 +1,417 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Handles the Git smart protocol."""
19 19
20 20 import os
21 21 import socket
22 22 import logging
23 23
24 24 import dulwich.protocol
25 25 from dulwich.protocol import CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K
26 26 from webob import Request, Response, exc
27 27
28 28 from vcsserver.lib.rc_json import json
29 29 from vcsserver import hooks, subprocessio
30 30 from vcsserver.str_utils import ascii_bytes
31 31
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 class FileWrapper(object):
36 class FileWrapper:
37 37 """File wrapper that ensures how much data is read from it."""
38 38
39 39 def __init__(self, fd, content_length):
40 40 self.fd = fd
41 41 self.content_length = content_length
42 42 self.remain = content_length
43 43
44 44 def read(self, size):
45 45 if size <= self.remain:
46 46 try:
47 47 data = self.fd.read(size)
48 48 except socket.error:
49 49 raise IOError(self)
50 50 self.remain -= size
51 51 elif self.remain:
52 52 data = self.fd.read(self.remain)
53 53 self.remain = 0
54 54 else:
55 55 data = None
56 56 return data
57 57
58 58 def __repr__(self):
59 59 return '<FileWrapper {} len: {}, read: {}>'.format(
60 60 self.fd, self.content_length, self.content_length - self.remain
61 61 )
62 62
63 63
64 class GitRepository(object):
64 class GitRepository:
65 65 """WSGI app for handling Git smart protocol endpoints."""
66 66
67 67 git_folder_signature = frozenset(('config', 'head', 'info', 'objects', 'refs'))
68 68 commands = frozenset(('git-upload-pack', 'git-receive-pack'))
69 69 valid_accepts = frozenset(f'application/x-{c}-result' for c in commands)
70 70
71 71 # The last bytes are the SHA1 of the first 12 bytes.
72 72 EMPTY_PACK = (
73 73 b'PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08' +
74 74 b'\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e'
75 75 )
76 76 FLUSH_PACKET = b"0000"
77 77
78 78 SIDE_BAND_CAPS = frozenset((CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K))
79 79
80 80 def __init__(self, repo_name, content_path, git_path, update_server_info, extras):
81 81 files = frozenset(f.lower() for f in os.listdir(content_path))
82 82 valid_dir_signature = self.git_folder_signature.issubset(files)
83 83
84 84 if not valid_dir_signature:
85 85 raise OSError(f'{content_path} missing git signature')
86 86
87 87 self.content_path = content_path
88 88 self.repo_name = repo_name
89 89 self.extras = extras
90 90 self.git_path = git_path
91 91 self.update_server_info = update_server_info
92 92
93 93 def _get_fixedpath(self, path):
94 94 """
95 95 Small fix for repo_path
96 96
97 97 :param path:
98 98 """
99 99 path = path.split(self.repo_name, 1)[-1]
100 100 if path.startswith('.git'):
101 101 # for bare repos we still get the .git prefix inside, we skip it
102 102 # here, and remove from the service command
103 103 path = path[4:]
104 104
105 105 return path.strip('/')
106 106
107 107 def inforefs(self, request, unused_environ):
108 108 """
109 109 WSGI Response producer for HTTP GET Git Smart
110 110 HTTP /info/refs request.
111 111 """
112 112
113 113 git_command = request.GET.get('service')
114 114 if git_command not in self.commands:
115 115 log.debug('command %s not allowed', git_command)
116 116 return exc.HTTPForbidden()
117 117
118 118 # please, resist the urge to add '\n' to git capture and increment
119 119 # line count by 1.
120 120 # by git docs: Documentation/technical/http-protocol.txt#L214 \n is
121 121 # a part of protocol.
122 122 # The code in Git client not only does NOT need '\n', but actually
123 123 # blows up if you sprinkle "flush" (0000) as "0001\n".
124 124 # It reads binary, per number of bytes specified.
125 125 # if you do add '\n' as part of data, count it.
126 126 server_advert = '# service=%s\n' % git_command
127 127 packet_len = hex(len(server_advert) + 4)[2:].rjust(4, '0').lower()
128 128 try:
129 129 gitenv = dict(os.environ)
130 130 # forget all configs
131 131 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
132 132 command = [self.git_path, git_command[4:], '--stateless-rpc',
133 133 '--advertise-refs', self.content_path]
134 134 out = subprocessio.SubprocessIOChunker(
135 135 command,
136 136 env=gitenv,
137 137 starting_values=[ascii_bytes(packet_len + server_advert) + self.FLUSH_PACKET],
138 138 shell=False
139 139 )
140 140 except OSError:
141 141 log.exception('Error processing command')
142 142 raise exc.HTTPExpectationFailed()
143 143
144 144 resp = Response()
145 145 resp.content_type = f'application/x-{git_command}-advertisement'
146 146 resp.charset = None
147 147 resp.app_iter = out
148 148
149 149 return resp
150 150
151 151 def _get_want_capabilities(self, request):
152 152 """Read the capabilities found in the first want line of the request."""
153 153 pos = request.body_file_seekable.tell()
154 154 first_line = request.body_file_seekable.readline()
155 155 request.body_file_seekable.seek(pos)
156 156
157 157 return frozenset(
158 158 dulwich.protocol.extract_want_line_capabilities(first_line)[1])
159 159
160 160 def _build_failed_pre_pull_response(self, capabilities, pre_pull_messages):
161 161 """
162 162 Construct a response with an empty PACK file.
163 163
164 164 We use an empty PACK file, as that would trigger the failure of the pull
165 165 or clone command.
166 166
167 167 We also print in the error output a message explaining why the command
168 168 was aborted.
169 169
170 170 If additionally, the user is accepting messages we send them the output
171 171 of the pre-pull hook.
172 172
173 173 Note that for clients not supporting side-band we just send them the
174 174 emtpy PACK file.
175 175 """
176 176
177 177 if self.SIDE_BAND_CAPS.intersection(capabilities):
178 178 response = []
179 179 proto = dulwich.protocol.Protocol(None, response.append)
180 180 proto.write_pkt_line(dulwich.protocol.NAK_LINE)
181 181
182 182 self._write_sideband_to_proto(proto, ascii_bytes(pre_pull_messages, allow_bytes=True), capabilities)
183 183 # N.B.(skreft): Do not change the sideband channel to 3, as that
184 184 # produces a fatal error in the client:
185 185 # fatal: error in sideband demultiplexer
186 186 proto.write_sideband(
187 187 dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS,
188 188 ascii_bytes('Pre pull hook failed: aborting\n', allow_bytes=True))
189 189 proto.write_sideband(
190 190 dulwich.protocol.SIDE_BAND_CHANNEL_DATA,
191 191 ascii_bytes(self.EMPTY_PACK, allow_bytes=True))
192 192
193 193 # writes b"0000" as default
194 194 proto.write_pkt_line(None)
195 195
196 196 return response
197 197 else:
198 198 return [ascii_bytes(self.EMPTY_PACK, allow_bytes=True)]
199 199
200 200 def _build_post_pull_response(self, response, capabilities, start_message, end_message):
201 201 """
202 202 Given a list response we inject the post-pull messages.
203 203
204 204 We only inject the messages if the client supports sideband, and the
205 205 response has the format:
206 206 0008NAK\n...0000
207 207
208 208 Note that we do not check the no-progress capability as by default, git
209 209 sends it, which effectively would block all messages.
210 210 """
211 211
212 212 if not self.SIDE_BAND_CAPS.intersection(capabilities):
213 213 return response
214 214
215 215 if not start_message and not end_message:
216 216 return response
217 217
218 218 try:
219 219 iter(response)
220 220 # iterator probably will work, we continue
221 221 except TypeError:
222 222 raise TypeError(f'response must be an iterator: got {type(response)}')
223 223 if isinstance(response, (list, tuple)):
224 224 raise TypeError(f'response must be an iterator: got {type(response)}')
225 225
226 226 def injected_response():
227 227
228 228 do_loop = 1
229 229 header_injected = 0
230 230 next_item = None
231 231 has_item = False
232 232 item = b''
233 233
234 234 while do_loop:
235 235
236 236 try:
237 237 next_item = next(response)
238 238 except StopIteration:
239 239 do_loop = 0
240 240
241 241 if has_item:
242 242 # last item ! alter it now
243 243 if do_loop == 0 and item.endswith(self.FLUSH_PACKET):
244 244 new_response = [item[:-4]]
245 245 new_response.extend(self._get_messages(end_message, capabilities))
246 246 new_response.append(self.FLUSH_PACKET)
247 247 item = b''.join(new_response)
248 248
249 249 yield item
250 250
251 251 has_item = True
252 252 item = next_item
253 253
254 254 # alter item if it's the initial chunk
255 255 if not header_injected and item.startswith(b'0008NAK\n'):
256 256 new_response = [b'0008NAK\n']
257 257 new_response.extend(self._get_messages(start_message, capabilities))
258 258 new_response.append(item[8:])
259 259 item = b''.join(new_response)
260 260 header_injected = 1
261 261
262 262 return injected_response()
263 263
264 264 def _write_sideband_to_proto(self, proto, data, capabilities):
265 265 """
266 266 Write the data to the proto's sideband number 2 == SIDE_BAND_CHANNEL_PROGRESS
267 267
268 268 We do not use dulwich's write_sideband directly as it only supports
269 269 side-band-64k.
270 270 """
271 271 if not data:
272 272 return
273 273
274 274 # N.B.(skreft): The values below are explained in the pack protocol
275 275 # documentation, section Packfile Data.
276 276 # https://github.com/git/git/blob/master/Documentation/technical/pack-protocol.txt
277 277 if CAPABILITY_SIDE_BAND_64K in capabilities:
278 278 chunk_size = 65515
279 279 elif CAPABILITY_SIDE_BAND in capabilities:
280 280 chunk_size = 995
281 281 else:
282 282 return
283 283
284 284 chunker = (data[i:i + chunk_size] for i in range(0, len(data), chunk_size))
285 285
286 286 for chunk in chunker:
287 287 proto.write_sideband(dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS, ascii_bytes(chunk, allow_bytes=True))
288 288
289 289 def _get_messages(self, data, capabilities):
290 290 """Return a list with packets for sending data in sideband number 2."""
291 291 response = []
292 292 proto = dulwich.protocol.Protocol(None, response.append)
293 293
294 294 self._write_sideband_to_proto(proto, data, capabilities)
295 295
296 296 return response
297 297
298 298 def backend(self, request, environ):
299 299 """
300 300 WSGI Response producer for HTTP POST Git Smart HTTP requests.
301 301 Reads commands and data from HTTP POST's body.
302 302 returns an iterator obj with contents of git command's
303 303 response to stdout
304 304 """
305 305 # TODO(skreft): think how we could detect an HTTPLockedException, as
306 306 # we probably want to have the same mechanism used by mercurial and
307 307 # simplevcs.
308 308 # For that we would need to parse the output of the command looking for
309 309 # some signs of the HTTPLockedError, parse the data and reraise it in
310 310 # pygrack. However, that would interfere with the streaming.
311 311 #
312 312 # Now the output of a blocked push is:
313 313 # Pushing to http://test_regular:test12@127.0.0.1:5001/vcs_test_git
314 314 # POST git-receive-pack (1047 bytes)
315 315 # remote: ERROR: Repository `vcs_test_git` locked by user `test_admin`. Reason:`lock_auto`
316 316 # To http://test_regular:test12@127.0.0.1:5001/vcs_test_git
317 317 # ! [remote rejected] master -> master (pre-receive hook declined)
318 318 # error: failed to push some refs to 'http://test_regular:test12@127.0.0.1:5001/vcs_test_git'
319 319
320 320 git_command = self._get_fixedpath(request.path_info)
321 321 if git_command not in self.commands:
322 322 log.debug('command %s not allowed', git_command)
323 323 return exc.HTTPForbidden()
324 324
325 325 capabilities = None
326 326 if git_command == 'git-upload-pack':
327 327 capabilities = self._get_want_capabilities(request)
328 328
329 329 if 'CONTENT_LENGTH' in environ:
330 330 inputstream = FileWrapper(request.body_file_seekable,
331 331 request.content_length)
332 332 else:
333 333 inputstream = request.body_file_seekable
334 334
335 335 resp = Response()
336 336 resp.content_type = f'application/x-{git_command}-result'
337 337 resp.charset = None
338 338
339 339 pre_pull_messages = ''
340 340 # Upload-pack == clone
341 341 if git_command == 'git-upload-pack':
342 342 hook_response = hooks.git_pre_pull(self.extras)
343 343 if hook_response.status != 0:
344 344 pre_pull_messages = hook_response.output
345 345 resp.app_iter = self._build_failed_pre_pull_response(
346 346 capabilities, pre_pull_messages)
347 347 return resp
348 348
349 349 gitenv = dict(os.environ)
350 350 # forget all configs
351 351 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
352 352 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
353 353 cmd = [self.git_path, git_command[4:], '--stateless-rpc',
354 354 self.content_path]
355 355 log.debug('handling cmd %s', cmd)
356 356
357 357 out = subprocessio.SubprocessIOChunker(
358 358 cmd,
359 359 input_stream=inputstream,
360 360 env=gitenv,
361 361 cwd=self.content_path,
362 362 shell=False,
363 363 fail_on_stderr=False,
364 364 fail_on_return_code=False
365 365 )
366 366
367 367 if self.update_server_info and git_command == 'git-receive-pack':
368 368 # We need to fully consume the iterator here, as the
369 369 # update-server-info command needs to be run after the push.
370 370 out = list(out)
371 371
372 372 # Updating refs manually after each push.
373 373 # This is required as some clients are exposing Git repos internally
374 374 # with the dumb protocol.
375 375 cmd = [self.git_path, 'update-server-info']
376 376 log.debug('handling cmd %s', cmd)
377 377 output = subprocessio.SubprocessIOChunker(
378 378 cmd,
379 379 input_stream=inputstream,
380 380 env=gitenv,
381 381 cwd=self.content_path,
382 382 shell=False,
383 383 fail_on_stderr=False,
384 384 fail_on_return_code=False
385 385 )
386 386 # Consume all the output so the subprocess finishes
387 387 for _ in output:
388 388 pass
389 389
390 390 # Upload-pack == clone
391 391 if git_command == 'git-upload-pack':
392 392 hook_response = hooks.git_post_pull(self.extras)
393 393 post_pull_messages = hook_response.output
394 394 resp.app_iter = self._build_post_pull_response(out, capabilities, pre_pull_messages, post_pull_messages)
395 395 else:
396 396 resp.app_iter = out
397 397
398 398 return resp
399 399
400 400 def __call__(self, environ, start_response):
401 401 request = Request(environ)
402 402 _path = self._get_fixedpath(request.path_info)
403 403 if _path.startswith('info/refs'):
404 404 app = self.inforefs
405 405 else:
406 406 app = self.backend
407 407
408 408 try:
409 409 resp = app(request, environ)
410 410 except exc.HTTPException as error:
411 411 log.exception('HTTP Error')
412 412 resp = error
413 413 except Exception:
414 414 log.exception('Unknown error')
415 415 resp = exc.HTTPInternalServerError()
416 416
417 417 return resp(environ, start_response)
@@ -1,1463 +1,1462 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import collections
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import traceback
24 24 import urllib.request
25 25 import urllib.parse
26 26 import urllib.error
27 27 from functools import wraps
28 28
29 29 import more_itertools
30 30 import pygit2
31 31 from pygit2 import Repository as LibGit2Repo
32 32 from pygit2 import index as LibGit2Index
33 33 from dulwich import index, objects
34 34 from dulwich.client import HttpGitClient, LocalGitClient, FetchPackResult
35 35 from dulwich.errors import (
36 36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 37 MissingCommitError, ObjectMissing, HangupException,
38 38 UnexpectedCommandError)
39 39 from dulwich.repo import Repo as DulwichRepo
40 40 from dulwich.server import update_server_info
41 41
42 42 from vcsserver import exceptions, settings, subprocessio
43 43 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes
44 44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, store_archive_in_cache, BytesEnvelope, BinaryEnvelope
45 45 from vcsserver.hgcompat import (
46 46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 47 from vcsserver.git_lfs.lib import LFSOidStore
48 48 from vcsserver.vcs_base import RemoteBase
49 49
50 50 DIR_STAT = stat.S_IFDIR
51 51 FILE_MODE = stat.S_IFMT
52 52 GIT_LINK = objects.S_IFGITLINK
53 53 PEELED_REF_MARKER = b'^{}'
54 54 HEAD_MARKER = b'HEAD'
55 55
56 56 log = logging.getLogger(__name__)
57 57
58 58
59 59 def reraise_safe_exceptions(func):
60 60 """Converts Dulwich exceptions to something neutral."""
61 61
62 62 @wraps(func)
63 63 def wrapper(*args, **kwargs):
64 64 try:
65 65 return func(*args, **kwargs)
66 66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
67 67 exc = exceptions.LookupException(org_exc=e)
68 68 raise exc(safe_str(e))
69 69 except (HangupException, UnexpectedCommandError) as e:
70 70 exc = exceptions.VcsException(org_exc=e)
71 71 raise exc(safe_str(e))
72 72 except Exception:
73 73 # NOTE(marcink): because of how dulwich handles some exceptions
74 74 # (KeyError on empty repos), we cannot track this and catch all
75 75 # exceptions, it's an exceptions from other handlers
76 76 #if not hasattr(e, '_vcs_kind'):
77 77 #log.exception("Unhandled exception in git remote call")
78 78 #raise_from_original(exceptions.UnhandledException)
79 79 raise
80 80 return wrapper
81 81
82 82
83 83 class Repo(DulwichRepo):
84 84 """
85 85 A wrapper for dulwich Repo class.
86 86
87 87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
88 88 "Too many open files" error. We need to close all opened file descriptors
89 89 once the repo object is destroyed.
90 90 """
91 91 def __del__(self):
92 92 if hasattr(self, 'object_store'):
93 93 self.close()
94 94
95 95
96 96 class Repository(LibGit2Repo):
97 97
98 98 def __enter__(self):
99 99 return self
100 100
101 101 def __exit__(self, exc_type, exc_val, exc_tb):
102 102 self.free()
103 103
104 104
105 105 class GitFactory(RepoFactory):
106 106 repo_type = 'git'
107 107
108 108 def _create_repo(self, wire, create, use_libgit2=False):
109 109 if use_libgit2:
110 110 repo = Repository(safe_bytes(wire['path']))
111 111 else:
112 112 # dulwich mode
113 113 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
114 114 repo = Repo(repo_path)
115 115
116 116 log.debug('repository created: got GIT object: %s', repo)
117 117 return repo
118 118
119 119 def repo(self, wire, create=False, use_libgit2=False):
120 120 """
121 121 Get a repository instance for the given path.
122 122 """
123 123 return self._create_repo(wire, create, use_libgit2)
124 124
125 125 def repo_libgit2(self, wire):
126 126 return self.repo(wire, use_libgit2=True)
127 127
128 128
129 129 def create_signature_from_string(author_str, **kwargs):
130 130 """
131 131 Creates a pygit2.Signature object from a string of the format 'Name <email>'.
132 132
133 133 :param author_str: String of the format 'Name <email>'
134 134 :return: pygit2.Signature object
135 135 """
136 136 match = re.match(r'^(.+) <(.+)>$', author_str)
137 137 if match is None:
138 138 raise ValueError(f"Invalid format: {author_str}")
139 139
140 140 name, email = match.groups()
141 141 return pygit2.Signature(name, email, **kwargs)
142 142
143 143
144 144 def get_obfuscated_url(url_obj):
145 145 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
146 146 url_obj.query = obfuscate_qs(url_obj.query)
147 147 obfuscated_uri = str(url_obj)
148 148 return obfuscated_uri
149 149
150 150
151 151 class GitRemote(RemoteBase):
152 152
153 153 def __init__(self, factory):
154 154 self._factory = factory
155 155 self._bulk_methods = {
156 156 "date": self.date,
157 157 "author": self.author,
158 158 "branch": self.branch,
159 159 "message": self.message,
160 160 "parents": self.parents,
161 161 "_commit": self.revision,
162 162 }
163 163 self._bulk_file_methods = {
164 164 "size": self.get_node_size,
165 165 "data": self.get_node_data,
166 166 "flags": self.get_node_flags,
167 167 "is_binary": self.get_node_is_binary,
168 168 "md5": self.md5_hash
169 169 }
170 170
171 171 def _wire_to_config(self, wire):
172 172 if 'config' in wire:
173 173 return {x[0] + '_' + x[1]: x[2] for x in wire['config']}
174 174 return {}
175 175
176 176 def _remote_conf(self, config):
177 177 params = [
178 178 '-c', 'core.askpass=""',
179 179 ]
180 180 ssl_cert_dir = config.get('vcs_ssl_dir')
181 181 if ssl_cert_dir:
182 182 params.extend(['-c', f'http.sslCAinfo={ssl_cert_dir}'])
183 183 return params
184 184
185 185 @reraise_safe_exceptions
186 186 def discover_git_version(self):
187 187 stdout, _ = self.run_git_command(
188 188 {}, ['--version'], _bare=True, _safe=True)
189 189 prefix = b'git version'
190 190 if stdout.startswith(prefix):
191 191 stdout = stdout[len(prefix):]
192 192 return safe_str(stdout.strip())
193 193
194 194 @reraise_safe_exceptions
195 195 def is_empty(self, wire):
196 196 repo_init = self._factory.repo_libgit2(wire)
197 197 with repo_init as repo:
198 198
199 199 try:
200 200 has_head = repo.head.name
201 201 if has_head:
202 202 return False
203 203
204 204 # NOTE(marcink): check again using more expensive method
205 205 return repo.is_empty
206 206 except Exception:
207 207 pass
208 208
209 209 return True
210 210
211 211 @reraise_safe_exceptions
212 212 def assert_correct_path(self, wire):
213 213 cache_on, context_uid, repo_id = self._cache_on(wire)
214 214 region = self._region(wire)
215 215
216 216 @region.conditional_cache_on_arguments(condition=cache_on)
217 217 def _assert_correct_path(_context_uid, _repo_id, fast_check):
218 218 if fast_check:
219 219 path = safe_str(wire['path'])
220 220 if pygit2.discover_repository(path):
221 221 return True
222 222 return False
223 223 else:
224 224 try:
225 225 repo_init = self._factory.repo_libgit2(wire)
226 226 with repo_init:
227 227 pass
228 228 except pygit2.GitError:
229 229 path = wire.get('path')
230 230 tb = traceback.format_exc()
231 231 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
232 232 return False
233 233 return True
234 234
235 235 return _assert_correct_path(context_uid, repo_id, True)
236 236
237 237 @reraise_safe_exceptions
238 238 def bare(self, wire):
239 239 repo_init = self._factory.repo_libgit2(wire)
240 240 with repo_init as repo:
241 241 return repo.is_bare
242 242
243 243 @reraise_safe_exceptions
244 244 def get_node_data(self, wire, commit_id, path):
245 245 repo_init = self._factory.repo_libgit2(wire)
246 246 with repo_init as repo:
247 247 commit = repo[commit_id]
248 248 blob_obj = commit.tree[path]
249 249
250 250 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
251 251 raise exceptions.LookupException()(
252 252 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
253 253
254 254 return BytesEnvelope(blob_obj.data)
255 255
256 256 @reraise_safe_exceptions
257 257 def get_node_size(self, wire, commit_id, path):
258 258 repo_init = self._factory.repo_libgit2(wire)
259 259 with repo_init as repo:
260 260 commit = repo[commit_id]
261 261 blob_obj = commit.tree[path]
262 262
263 263 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
264 264 raise exceptions.LookupException()(
265 265 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
266 266
267 267 return blob_obj.size
268 268
269 269 @reraise_safe_exceptions
270 270 def get_node_flags(self, wire, commit_id, path):
271 271 repo_init = self._factory.repo_libgit2(wire)
272 272 with repo_init as repo:
273 273 commit = repo[commit_id]
274 274 blob_obj = commit.tree[path]
275 275
276 276 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
277 277 raise exceptions.LookupException()(
278 278 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
279 279
280 280 return blob_obj.filemode
281 281
282 282 @reraise_safe_exceptions
283 283 def get_node_is_binary(self, wire, commit_id, path):
284 284 repo_init = self._factory.repo_libgit2(wire)
285 285 with repo_init as repo:
286 286 commit = repo[commit_id]
287 287 blob_obj = commit.tree[path]
288 288
289 289 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
290 290 raise exceptions.LookupException()(
291 291 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
292 292
293 293 return blob_obj.is_binary
294 294
295 295 @reraise_safe_exceptions
296 296 def blob_as_pretty_string(self, wire, sha):
297 297 repo_init = self._factory.repo_libgit2(wire)
298 298 with repo_init as repo:
299 299 blob_obj = repo[sha]
300 300 return BytesEnvelope(blob_obj.data)
301 301
302 302 @reraise_safe_exceptions
303 303 def blob_raw_length(self, wire, sha):
304 304 cache_on, context_uid, repo_id = self._cache_on(wire)
305 305 region = self._region(wire)
306 306
307 307 @region.conditional_cache_on_arguments(condition=cache_on)
308 308 def _blob_raw_length(_repo_id, _sha):
309 309
310 310 repo_init = self._factory.repo_libgit2(wire)
311 311 with repo_init as repo:
312 312 blob = repo[sha]
313 313 return blob.size
314 314
315 315 return _blob_raw_length(repo_id, sha)
316 316
317 317 def _parse_lfs_pointer(self, raw_content):
318 318 spec_string = b'version https://git-lfs.github.com/spec'
319 319 if raw_content and raw_content.startswith(spec_string):
320 320
321 321 pattern = re.compile(rb"""
322 322 (?:\n)?
323 323 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
324 324 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
325 325 ^size[ ](?P<oid_size>[0-9]+)\n
326 326 (?:\n)?
327 327 """, re.VERBOSE | re.MULTILINE)
328 328 match = pattern.match(raw_content)
329 329 if match:
330 330 return match.groupdict()
331 331
332 332 return {}
333 333
334 334 @reraise_safe_exceptions
335 335 def is_large_file(self, wire, commit_id):
336 336 cache_on, context_uid, repo_id = self._cache_on(wire)
337 337 region = self._region(wire)
338 338
339 339 @region.conditional_cache_on_arguments(condition=cache_on)
340 340 def _is_large_file(_repo_id, _sha):
341 341 repo_init = self._factory.repo_libgit2(wire)
342 342 with repo_init as repo:
343 343 blob = repo[commit_id]
344 344 if blob.is_binary:
345 345 return {}
346 346
347 347 return self._parse_lfs_pointer(blob.data)
348 348
349 349 return _is_large_file(repo_id, commit_id)
350 350
351 351 @reraise_safe_exceptions
352 352 def is_binary(self, wire, tree_id):
353 353 cache_on, context_uid, repo_id = self._cache_on(wire)
354 354 region = self._region(wire)
355 355
356 356 @region.conditional_cache_on_arguments(condition=cache_on)
357 357 def _is_binary(_repo_id, _tree_id):
358 358 repo_init = self._factory.repo_libgit2(wire)
359 359 with repo_init as repo:
360 360 blob_obj = repo[tree_id]
361 361 return blob_obj.is_binary
362 362
363 363 return _is_binary(repo_id, tree_id)
364 364
365 365 @reraise_safe_exceptions
366 366 def md5_hash(self, wire, commit_id, path):
367 367 cache_on, context_uid, repo_id = self._cache_on(wire)
368 368 region = self._region(wire)
369 369
370 370 @region.conditional_cache_on_arguments(condition=cache_on)
371 371 def _md5_hash(_repo_id, _commit_id, _path):
372 372 repo_init = self._factory.repo_libgit2(wire)
373 373 with repo_init as repo:
374 374 commit = repo[_commit_id]
375 375 blob_obj = commit.tree[_path]
376 376
377 377 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
378 378 raise exceptions.LookupException()(
379 379 f'Tree for commit_id:{_commit_id} is not a blob: {blob_obj.type_str}')
380 380
381 381 return ''
382 382
383 383 return _md5_hash(repo_id, commit_id, path)
384 384
385 385 @reraise_safe_exceptions
386 386 def in_largefiles_store(self, wire, oid):
387 387 conf = self._wire_to_config(wire)
388 388 repo_init = self._factory.repo_libgit2(wire)
389 389 with repo_init as repo:
390 390 repo_name = repo.path
391 391
392 392 store_location = conf.get('vcs_git_lfs_store_location')
393 393 if store_location:
394 394
395 395 store = LFSOidStore(
396 396 oid=oid, repo=repo_name, store_location=store_location)
397 397 return store.has_oid()
398 398
399 399 return False
400 400
401 401 @reraise_safe_exceptions
402 402 def store_path(self, wire, oid):
403 403 conf = self._wire_to_config(wire)
404 404 repo_init = self._factory.repo_libgit2(wire)
405 405 with repo_init as repo:
406 406 repo_name = repo.path
407 407
408 408 store_location = conf.get('vcs_git_lfs_store_location')
409 409 if store_location:
410 410 store = LFSOidStore(
411 411 oid=oid, repo=repo_name, store_location=store_location)
412 412 return store.oid_path
413 413 raise ValueError(f'Unable to fetch oid with path {oid}')
414 414
415 415 @reraise_safe_exceptions
416 416 def bulk_request(self, wire, rev, pre_load):
417 417 cache_on, context_uid, repo_id = self._cache_on(wire)
418 418 region = self._region(wire)
419 419
420 420 @region.conditional_cache_on_arguments(condition=cache_on)
421 421 def _bulk_request(_repo_id, _rev, _pre_load):
422 422 result = {}
423 423 for attr in pre_load:
424 424 try:
425 425 method = self._bulk_methods[attr]
426 426 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
427 427 args = [wire, rev]
428 428 result[attr] = method(*args)
429 429 except KeyError as e:
430 430 raise exceptions.VcsException(e)(f"Unknown bulk attribute: {attr}")
431 431 return result
432 432
433 433 return _bulk_request(repo_id, rev, sorted(pre_load))
434 434
435 435 @reraise_safe_exceptions
436 436 def bulk_file_request(self, wire, commit_id, path, pre_load):
437 437 cache_on, context_uid, repo_id = self._cache_on(wire)
438 438 region = self._region(wire)
439 439
440 440 @region.conditional_cache_on_arguments(condition=cache_on)
441 441 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
442 442 result = {}
443 443 for attr in pre_load:
444 444 try:
445 445 method = self._bulk_file_methods[attr]
446 446 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
447 447 result[attr] = method(wire, _commit_id, _path)
448 448 except KeyError as e:
449 449 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
450 450 return result
451 451
452 452 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
453 453
454 454 def _build_opener(self, url: str):
455 455 handlers = []
456 456 url_obj = url_parser(safe_bytes(url))
457 457 authinfo = url_obj.authinfo()[1]
458 458
459 459 if authinfo:
460 460 # create a password manager
461 461 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
462 462 passmgr.add_password(*authinfo)
463 463
464 464 handlers.extend((httpbasicauthhandler(passmgr),
465 465 httpdigestauthhandler(passmgr)))
466 466
467 467 return urllib.request.build_opener(*handlers)
468 468
469 469 @reraise_safe_exceptions
470 470 def check_url(self, url, config):
471 471 url_obj = url_parser(safe_bytes(url))
472 472
473 473 test_uri = safe_str(url_obj.authinfo()[0])
474 474 obfuscated_uri = get_obfuscated_url(url_obj)
475 475
476 476 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
477 477
478 478 if not test_uri.endswith('info/refs'):
479 479 test_uri = test_uri.rstrip('/') + '/info/refs'
480 480
481 481 o = self._build_opener(test_uri)
482 482 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
483 483
484 484 q = {"service": 'git-upload-pack'}
485 qs = '?%s' % urllib.parse.urlencode(q)
485 qs = f'?{urllib.parse.urlencode(q)}'
486 486 cu = f"{test_uri}{qs}"
487 487 req = urllib.request.Request(cu, None, {})
488 488
489 489 try:
490 490 log.debug("Trying to open URL %s", obfuscated_uri)
491 491 resp = o.open(req)
492 492 if resp.code != 200:
493 493 raise exceptions.URLError()('Return Code is not 200')
494 494 except Exception as e:
495 495 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
496 496 # means it cannot be cloned
497 497 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
498 498
499 499 # now detect if it's proper git repo
500 500 gitdata: bytes = resp.read()
501 501
502 502 if b'service=git-upload-pack' in gitdata:
503 503 pass
504 504 elif re.findall(br'[0-9a-fA-F]{40}\s+refs', gitdata):
505 505 # old style git can return some other format !
506 506 pass
507 507 else:
508 508 e = None
509 509 raise exceptions.URLError(e)(
510 "url [%s] does not look like an hg repo org_exc: %s"
511 % (obfuscated_uri, e))
510 f"url [{obfuscated_uri}] does not look like an hg repo org_exc: {e}")
512 511
513 512 return True
514 513
515 514 @reraise_safe_exceptions
516 515 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
517 516 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
518 517 remote_refs = self.pull(wire, url, apply_refs=False)
519 518 repo = self._factory.repo(wire)
520 519 if isinstance(valid_refs, list):
521 520 valid_refs = tuple(valid_refs)
522 521
523 522 for k in remote_refs:
524 523 # only parse heads/tags and skip so called deferred tags
525 524 if k.startswith(valid_refs) and not k.endswith(deferred):
526 525 repo[k] = remote_refs[k]
527 526
528 527 if update_after_clone:
529 528 # we want to checkout HEAD
530 529 repo["HEAD"] = remote_refs["HEAD"]
531 530 index.build_index_from_tree(repo.path, repo.index_path(),
532 531 repo.object_store, repo["HEAD"].tree)
533 532
534 533 @reraise_safe_exceptions
535 534 def branch(self, wire, commit_id):
536 535 cache_on, context_uid, repo_id = self._cache_on(wire)
537 536 region = self._region(wire)
538 537
539 538 @region.conditional_cache_on_arguments(condition=cache_on)
540 539 def _branch(_context_uid, _repo_id, _commit_id):
541 540 regex = re.compile('^refs/heads')
542 541
543 542 def filter_with(ref):
544 543 return regex.match(ref[0]) and ref[1] == _commit_id
545 544
546 545 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
547 546 return [x[0].split('refs/heads/')[-1] for x in branches]
548 547
549 548 return _branch(context_uid, repo_id, commit_id)
550 549
551 550 @reraise_safe_exceptions
552 551 def commit_branches(self, wire, commit_id):
553 552 cache_on, context_uid, repo_id = self._cache_on(wire)
554 553 region = self._region(wire)
555 554
556 555 @region.conditional_cache_on_arguments(condition=cache_on)
557 556 def _commit_branches(_context_uid, _repo_id, _commit_id):
558 557 repo_init = self._factory.repo_libgit2(wire)
559 558 with repo_init as repo:
560 559 branches = [x for x in repo.branches.with_commit(_commit_id)]
561 560 return branches
562 561
563 562 return _commit_branches(context_uid, repo_id, commit_id)
564 563
565 564 @reraise_safe_exceptions
566 565 def add_object(self, wire, content):
567 566 repo_init = self._factory.repo_libgit2(wire)
568 567 with repo_init as repo:
569 568 blob = objects.Blob()
570 569 blob.set_raw_string(content)
571 570 repo.object_store.add_object(blob)
572 571 return blob.id
573 572
574 573 @reraise_safe_exceptions
575 574 def create_commit(self, wire, author, committer, message, branch, new_tree_id, date_args: list[int, int] = None):
576 575 repo_init = self._factory.repo_libgit2(wire)
577 576 with repo_init as repo:
578 577
579 578 if date_args:
580 579 current_time, offset = date_args
581 580
582 581 kw = {
583 582 'time': current_time,
584 583 'offset': offset
585 584 }
586 585 author = create_signature_from_string(author, **kw)
587 586 committer = create_signature_from_string(committer, **kw)
588 587
589 588 tree = new_tree_id
590 589 if isinstance(tree, (bytes, str)):
591 590 # validate this tree is in the repo...
592 591 tree = repo[safe_str(tree)].id
593 592
594 593 parents = []
595 594 # ensure we COMMIT on top of given branch head
596 595 # check if this repo has ANY branches, otherwise it's a new branch case we need to make
597 596 if branch in repo.branches.local:
598 597 parents += [repo.branches[branch].target]
599 598 elif [x for x in repo.branches.local]:
600 599 parents += [repo.head.target]
601 600 #else:
602 601 # in case we want to commit on new branch we create it on top of HEAD
603 602 #repo.branches.local.create(branch, repo.revparse_single('HEAD'))
604 603
605 604 # # Create a new commit
606 605 commit_oid = repo.create_commit(
607 606 f'refs/heads/{branch}', # the name of the reference to update
608 607 author, # the author of the commit
609 608 committer, # the committer of the commit
610 609 message, # the commit message
611 610 tree, # the tree produced by the index
612 611 parents # list of parents for the new commit, usually just one,
613 612 )
614 613
615 614 new_commit_id = safe_str(commit_oid)
616 615
617 616 return new_commit_id
618 617
619 618 @reraise_safe_exceptions
620 619 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
621 620
622 621 def mode2pygit(mode):
623 622 """
624 623 git only supports two filemode 644 and 755
625 624
626 625 0o100755 -> 33261
627 626 0o100644 -> 33188
628 627 """
629 628 return {
630 629 0o100644: pygit2.GIT_FILEMODE_BLOB,
631 630 0o100755: pygit2.GIT_FILEMODE_BLOB_EXECUTABLE,
632 631 0o120000: pygit2.GIT_FILEMODE_LINK
633 632 }.get(mode) or pygit2.GIT_FILEMODE_BLOB
634 633
635 634 repo_init = self._factory.repo_libgit2(wire)
636 635 with repo_init as repo:
637 636 repo_index = repo.index
638 637
639 638 for pathspec in updated:
640 639 blob_id = repo.create_blob(pathspec['content'])
641 640 ie = pygit2.IndexEntry(pathspec['path'], blob_id, mode2pygit(pathspec['mode']))
642 641 repo_index.add(ie)
643 642
644 643 for pathspec in removed:
645 644 repo_index.remove(pathspec)
646 645
647 646 # Write changes to the index
648 647 repo_index.write()
649 648
650 649 # Create a tree from the updated index
651 650 commit_tree = repo_index.write_tree()
652 651
653 652 new_tree_id = commit_tree
654 653
655 654 author = commit_data['author']
656 655 committer = commit_data['committer']
657 656 message = commit_data['message']
658 657
659 658 date_args = [int(commit_data['commit_time']), int(commit_data['commit_timezone'])]
660 659
661 660 new_commit_id = self.create_commit(wire, author, committer, message, branch,
662 661 new_tree_id, date_args=date_args)
663 662
664 663 # libgit2, ensure the branch is there and exists
665 664 self.create_branch(wire, branch, new_commit_id)
666 665
667 666 # libgit2, set new ref to this created commit
668 667 self.set_refs(wire, f'refs/heads/{branch}', new_commit_id)
669 668
670 669 return new_commit_id
671 670
672 671 @reraise_safe_exceptions
673 672 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
674 673 if url != 'default' and '://' not in url:
675 674 client = LocalGitClient(url)
676 675 else:
677 676 url_obj = url_parser(safe_bytes(url))
678 677 o = self._build_opener(url)
679 678 url = url_obj.authinfo()[0]
680 679 client = HttpGitClient(base_url=url, opener=o)
681 680 repo = self._factory.repo(wire)
682 681
683 682 determine_wants = repo.object_store.determine_wants_all
684 683 if refs:
685 684 refs = [ascii_bytes(x) for x in refs]
686 685
687 686 def determine_wants_requested(remote_refs):
688 687 determined = []
689 688 for ref_name, ref_hash in remote_refs.items():
690 689 bytes_ref_name = safe_bytes(ref_name)
691 690
692 691 if bytes_ref_name in refs:
693 692 bytes_ref_hash = safe_bytes(ref_hash)
694 693 determined.append(bytes_ref_hash)
695 694 return determined
696 695
697 696 # swap with our custom requested wants
698 697 determine_wants = determine_wants_requested
699 698
700 699 try:
701 700 remote_refs = client.fetch(
702 701 path=url, target=repo, determine_wants=determine_wants)
703 702
704 703 except NotGitRepository as e:
705 704 log.warning(
706 705 'Trying to fetch from "%s" failed, not a Git repository.', url)
707 706 # Exception can contain unicode which we convert
708 707 raise exceptions.AbortException(e)(repr(e))
709 708
710 709 # mikhail: client.fetch() returns all the remote refs, but fetches only
711 710 # refs filtered by `determine_wants` function. We need to filter result
712 711 # as well
713 712 if refs:
714 713 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
715 714
716 715 if apply_refs:
717 716 # TODO: johbo: Needs proper test coverage with a git repository
718 717 # that contains a tag object, so that we would end up with
719 718 # a peeled ref at this point.
720 719 for k in remote_refs:
721 720 if k.endswith(PEELED_REF_MARKER):
722 721 log.debug("Skipping peeled reference %s", k)
723 722 continue
724 723 repo[k] = remote_refs[k]
725 724
726 725 if refs and not update_after:
727 726 # mikhail: explicitly set the head to the last ref.
728 727 repo[HEAD_MARKER] = remote_refs[refs[-1]]
729 728
730 729 if update_after:
731 730 # we want to check out HEAD
732 731 repo[HEAD_MARKER] = remote_refs[HEAD_MARKER]
733 732 index.build_index_from_tree(repo.path, repo.index_path(),
734 733 repo.object_store, repo[HEAD_MARKER].tree)
735 734
736 735 if isinstance(remote_refs, FetchPackResult):
737 736 return remote_refs.refs
738 737 return remote_refs
739 738
740 739 @reraise_safe_exceptions
741 740 def sync_fetch(self, wire, url, refs=None, all_refs=False):
742 741 self._factory.repo(wire)
743 742 if refs and not isinstance(refs, (list, tuple)):
744 743 refs = [refs]
745 744
746 745 config = self._wire_to_config(wire)
747 746 # get all remote refs we'll use to fetch later
748 747 cmd = ['ls-remote']
749 748 if not all_refs:
750 749 cmd += ['--heads', '--tags']
751 750 cmd += [url]
752 751 output, __ = self.run_git_command(
753 752 wire, cmd, fail_on_stderr=False,
754 753 _copts=self._remote_conf(config),
755 754 extra_env={'GIT_TERMINAL_PROMPT': '0'})
756 755
757 756 remote_refs = collections.OrderedDict()
758 757 fetch_refs = []
759 758
760 759 for ref_line in output.splitlines():
761 760 sha, ref = ref_line.split(b'\t')
762 761 sha = sha.strip()
763 762 if ref in remote_refs:
764 763 # duplicate, skip
765 764 continue
766 765 if ref.endswith(PEELED_REF_MARKER):
767 766 log.debug("Skipping peeled reference %s", ref)
768 767 continue
769 768 # don't sync HEAD
770 769 if ref in [HEAD_MARKER]:
771 770 continue
772 771
773 772 remote_refs[ref] = sha
774 773
775 774 if refs and sha in refs:
776 775 # we filter fetch using our specified refs
777 776 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
778 777 elif not refs:
779 778 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
780 779 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
781 780
782 781 if fetch_refs:
783 782 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
784 783 fetch_refs_chunks = list(chunk)
785 784 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
786 785 self.run_git_command(
787 786 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
788 787 fail_on_stderr=False,
789 788 _copts=self._remote_conf(config),
790 789 extra_env={'GIT_TERMINAL_PROMPT': '0'})
791 790
792 791 return remote_refs
793 792
794 793 @reraise_safe_exceptions
795 794 def sync_push(self, wire, url, refs=None):
796 795 if not self.check_url(url, wire):
797 796 return
798 797 config = self._wire_to_config(wire)
799 798 self._factory.repo(wire)
800 799 self.run_git_command(
801 800 wire, ['push', url, '--mirror'], fail_on_stderr=False,
802 801 _copts=self._remote_conf(config),
803 802 extra_env={'GIT_TERMINAL_PROMPT': '0'})
804 803
805 804 @reraise_safe_exceptions
806 805 def get_remote_refs(self, wire, url):
807 806 repo = Repo(url)
808 807 return repo.get_refs()
809 808
810 809 @reraise_safe_exceptions
811 810 def get_description(self, wire):
812 811 repo = self._factory.repo(wire)
813 812 return repo.get_description()
814 813
815 814 @reraise_safe_exceptions
816 815 def get_missing_revs(self, wire, rev1, rev2, path2):
817 816 repo = self._factory.repo(wire)
818 817 LocalGitClient(thin_packs=False).fetch(path2, repo)
819 818
820 819 wire_remote = wire.copy()
821 820 wire_remote['path'] = path2
822 821 repo_remote = self._factory.repo(wire_remote)
823 822 LocalGitClient(thin_packs=False).fetch(path2, repo_remote)
824 823
825 824 revs = [
826 825 x.commit.id
827 826 for x in repo_remote.get_walker(include=[safe_bytes(rev2)], exclude=[safe_bytes(rev1)])]
828 827 return revs
829 828
830 829 @reraise_safe_exceptions
831 830 def get_object(self, wire, sha, maybe_unreachable=False):
832 831 cache_on, context_uid, repo_id = self._cache_on(wire)
833 832 region = self._region(wire)
834 833
835 834 @region.conditional_cache_on_arguments(condition=cache_on)
836 835 def _get_object(_context_uid, _repo_id, _sha):
837 836 repo_init = self._factory.repo_libgit2(wire)
838 837 with repo_init as repo:
839 838
840 839 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
841 840 try:
842 841 commit = repo.revparse_single(sha)
843 842 except KeyError:
844 843 # NOTE(marcink): KeyError doesn't give us any meaningful information
845 844 # here, we instead give something more explicit
846 845 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
847 846 raise exceptions.LookupException(e)(missing_commit_err)
848 847 except ValueError as e:
849 848 raise exceptions.LookupException(e)(missing_commit_err)
850 849
851 850 is_tag = False
852 851 if isinstance(commit, pygit2.Tag):
853 852 commit = repo.get(commit.target)
854 853 is_tag = True
855 854
856 855 check_dangling = True
857 856 if is_tag:
858 857 check_dangling = False
859 858
860 859 if check_dangling and maybe_unreachable:
861 860 check_dangling = False
862 861
863 862 # we used a reference and it parsed means we're not having a dangling commit
864 863 if sha != commit.hex:
865 864 check_dangling = False
866 865
867 866 if check_dangling:
868 867 # check for dangling commit
869 868 for branch in repo.branches.with_commit(commit.hex):
870 869 if branch:
871 870 break
872 871 else:
873 872 # NOTE(marcink): Empty error doesn't give us any meaningful information
874 873 # here, we instead give something more explicit
875 874 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
876 875 raise exceptions.LookupException(e)(missing_commit_err)
877 876
878 877 commit_id = commit.hex
879 878 type_str = commit.type_str
880 879
881 880 return {
882 881 'id': commit_id,
883 882 'type': type_str,
884 883 'commit_id': commit_id,
885 884 'idx': 0
886 885 }
887 886
888 887 return _get_object(context_uid, repo_id, sha)
889 888
890 889 @reraise_safe_exceptions
891 890 def get_refs(self, wire):
892 891 cache_on, context_uid, repo_id = self._cache_on(wire)
893 892 region = self._region(wire)
894 893
895 894 @region.conditional_cache_on_arguments(condition=cache_on)
896 895 def _get_refs(_context_uid, _repo_id):
897 896
898 897 repo_init = self._factory.repo_libgit2(wire)
899 898 with repo_init as repo:
900 899 regex = re.compile('^refs/(heads|tags)/')
901 900 return {x.name: x.target.hex for x in
902 901 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
903 902
904 903 return _get_refs(context_uid, repo_id)
905 904
906 905 @reraise_safe_exceptions
907 906 def get_branch_pointers(self, wire):
908 907 cache_on, context_uid, repo_id = self._cache_on(wire)
909 908 region = self._region(wire)
910 909
911 910 @region.conditional_cache_on_arguments(condition=cache_on)
912 911 def _get_branch_pointers(_context_uid, _repo_id):
913 912
914 913 repo_init = self._factory.repo_libgit2(wire)
915 914 regex = re.compile('^refs/heads')
916 915 with repo_init as repo:
917 916 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
918 917 return {x.target.hex: x.shorthand for x in branches}
919 918
920 919 return _get_branch_pointers(context_uid, repo_id)
921 920
922 921 @reraise_safe_exceptions
923 922 def head(self, wire, show_exc=True):
924 923 cache_on, context_uid, repo_id = self._cache_on(wire)
925 924 region = self._region(wire)
926 925
927 926 @region.conditional_cache_on_arguments(condition=cache_on)
928 927 def _head(_context_uid, _repo_id, _show_exc):
929 928 repo_init = self._factory.repo_libgit2(wire)
930 929 with repo_init as repo:
931 930 try:
932 931 return repo.head.peel().hex
933 932 except Exception:
934 933 if show_exc:
935 934 raise
936 935 return _head(context_uid, repo_id, show_exc)
937 936
938 937 @reraise_safe_exceptions
939 938 def init(self, wire):
940 939 repo_path = safe_str(wire['path'])
941 940 pygit2.init_repository(repo_path, bare=False)
942 941
943 942 @reraise_safe_exceptions
944 943 def init_bare(self, wire):
945 944 repo_path = safe_str(wire['path'])
946 945 pygit2.init_repository(repo_path, bare=True)
947 946
948 947 @reraise_safe_exceptions
949 948 def revision(self, wire, rev):
950 949
951 950 cache_on, context_uid, repo_id = self._cache_on(wire)
952 951 region = self._region(wire)
953 952
954 953 @region.conditional_cache_on_arguments(condition=cache_on)
955 954 def _revision(_context_uid, _repo_id, _rev):
956 955 repo_init = self._factory.repo_libgit2(wire)
957 956 with repo_init as repo:
958 957 commit = repo[rev]
959 958 obj_data = {
960 959 'id': commit.id.hex,
961 960 }
962 961 # tree objects itself don't have tree_id attribute
963 962 if hasattr(commit, 'tree_id'):
964 963 obj_data['tree'] = commit.tree_id.hex
965 964
966 965 return obj_data
967 966 return _revision(context_uid, repo_id, rev)
968 967
969 968 @reraise_safe_exceptions
970 969 def date(self, wire, commit_id):
971 970 cache_on, context_uid, repo_id = self._cache_on(wire)
972 971 region = self._region(wire)
973 972
974 973 @region.conditional_cache_on_arguments(condition=cache_on)
975 974 def _date(_repo_id, _commit_id):
976 975 repo_init = self._factory.repo_libgit2(wire)
977 976 with repo_init as repo:
978 977 commit = repo[commit_id]
979 978
980 979 if hasattr(commit, 'commit_time'):
981 980 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
982 981 else:
983 982 commit = commit.get_object()
984 983 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
985 984
986 985 # TODO(marcink): check dulwich difference of offset vs timezone
987 986 return [commit_time, commit_time_offset]
988 987 return _date(repo_id, commit_id)
989 988
990 989 @reraise_safe_exceptions
991 990 def author(self, wire, commit_id):
992 991 cache_on, context_uid, repo_id = self._cache_on(wire)
993 992 region = self._region(wire)
994 993
995 994 @region.conditional_cache_on_arguments(condition=cache_on)
996 995 def _author(_repo_id, _commit_id):
997 996 repo_init = self._factory.repo_libgit2(wire)
998 997 with repo_init as repo:
999 998 commit = repo[commit_id]
1000 999
1001 1000 if hasattr(commit, 'author'):
1002 1001 author = commit.author
1003 1002 else:
1004 1003 author = commit.get_object().author
1005 1004
1006 1005 if author.email:
1007 1006 return f"{author.name} <{author.email}>"
1008 1007
1009 1008 try:
1010 1009 return f"{author.name}"
1011 1010 except Exception:
1012 1011 return f"{safe_str(author.raw_name)}"
1013 1012
1014 1013 return _author(repo_id, commit_id)
1015 1014
1016 1015 @reraise_safe_exceptions
1017 1016 def message(self, wire, commit_id):
1018 1017 cache_on, context_uid, repo_id = self._cache_on(wire)
1019 1018 region = self._region(wire)
1020 1019
1021 1020 @region.conditional_cache_on_arguments(condition=cache_on)
1022 1021 def _message(_repo_id, _commit_id):
1023 1022 repo_init = self._factory.repo_libgit2(wire)
1024 1023 with repo_init as repo:
1025 1024 commit = repo[commit_id]
1026 1025 return commit.message
1027 1026 return _message(repo_id, commit_id)
1028 1027
1029 1028 @reraise_safe_exceptions
1030 1029 def parents(self, wire, commit_id):
1031 1030 cache_on, context_uid, repo_id = self._cache_on(wire)
1032 1031 region = self._region(wire)
1033 1032
1034 1033 @region.conditional_cache_on_arguments(condition=cache_on)
1035 1034 def _parents(_repo_id, _commit_id):
1036 1035 repo_init = self._factory.repo_libgit2(wire)
1037 1036 with repo_init as repo:
1038 1037 commit = repo[commit_id]
1039 1038 if hasattr(commit, 'parent_ids'):
1040 1039 parent_ids = commit.parent_ids
1041 1040 else:
1042 1041 parent_ids = commit.get_object().parent_ids
1043 1042
1044 1043 return [x.hex for x in parent_ids]
1045 1044 return _parents(repo_id, commit_id)
1046 1045
1047 1046 @reraise_safe_exceptions
1048 1047 def children(self, wire, commit_id):
1049 1048 cache_on, context_uid, repo_id = self._cache_on(wire)
1050 1049 region = self._region(wire)
1051 1050
1052 1051 head = self.head(wire)
1053 1052
1054 1053 @region.conditional_cache_on_arguments(condition=cache_on)
1055 1054 def _children(_repo_id, _commit_id):
1056 1055
1057 1056 output, __ = self.run_git_command(
1058 1057 wire, ['rev-list', '--all', '--children', f'{commit_id}^..{head}'])
1059 1058
1060 1059 child_ids = []
1061 1060 pat = re.compile(fr'^{commit_id}')
1062 1061 for line in output.splitlines():
1063 1062 line = safe_str(line)
1064 1063 if pat.match(line):
1065 1064 found_ids = line.split(' ')[1:]
1066 1065 child_ids.extend(found_ids)
1067 1066 break
1068 1067
1069 1068 return child_ids
1070 1069 return _children(repo_id, commit_id)
1071 1070
1072 1071 @reraise_safe_exceptions
1073 1072 def set_refs(self, wire, key, value):
1074 1073 repo_init = self._factory.repo_libgit2(wire)
1075 1074 with repo_init as repo:
1076 1075 repo.references.create(key, value, force=True)
1077 1076
1078 1077 @reraise_safe_exceptions
1079 1078 def create_branch(self, wire, branch_name, commit_id, force=False):
1080 1079 repo_init = self._factory.repo_libgit2(wire)
1081 1080 with repo_init as repo:
1082 1081 if commit_id:
1083 1082 commit = repo[commit_id]
1084 1083 else:
1085 1084 # if commit is not given just use the HEAD
1086 1085 commit = repo.head()
1087 1086
1088 1087 if force:
1089 1088 repo.branches.local.create(branch_name, commit, force=force)
1090 1089 elif not repo.branches.get(branch_name):
1091 1090 # create only if that branch isn't existing
1092 1091 repo.branches.local.create(branch_name, commit, force=force)
1093 1092
1094 1093 @reraise_safe_exceptions
1095 1094 def remove_ref(self, wire, key):
1096 1095 repo_init = self._factory.repo_libgit2(wire)
1097 1096 with repo_init as repo:
1098 1097 repo.references.delete(key)
1099 1098
1100 1099 @reraise_safe_exceptions
1101 1100 def tag_remove(self, wire, tag_name):
1102 1101 repo_init = self._factory.repo_libgit2(wire)
1103 1102 with repo_init as repo:
1104 1103 key = f'refs/tags/{tag_name}'
1105 1104 repo.references.delete(key)
1106 1105
1107 1106 @reraise_safe_exceptions
1108 1107 def tree_changes(self, wire, source_id, target_id):
1109 1108 repo = self._factory.repo(wire)
1110 1109 # source can be empty
1111 1110 source_id = safe_bytes(source_id if source_id else b'')
1112 1111 target_id = safe_bytes(target_id)
1113 1112
1114 1113 source = repo[source_id].tree if source_id else None
1115 1114 target = repo[target_id].tree
1116 1115 result = repo.object_store.tree_changes(source, target)
1117 1116
1118 1117 added = set()
1119 1118 modified = set()
1120 1119 deleted = set()
1121 1120 for (old_path, new_path), (_, _), (_, _) in list(result):
1122 1121 if new_path and old_path:
1123 1122 modified.add(new_path)
1124 1123 elif new_path and not old_path:
1125 1124 added.add(new_path)
1126 1125 elif not new_path and old_path:
1127 1126 deleted.add(old_path)
1128 1127
1129 1128 return list(added), list(modified), list(deleted)
1130 1129
1131 1130 @reraise_safe_exceptions
1132 1131 def tree_and_type_for_path(self, wire, commit_id, path):
1133 1132
1134 1133 cache_on, context_uid, repo_id = self._cache_on(wire)
1135 1134 region = self._region(wire)
1136 1135
1137 1136 @region.conditional_cache_on_arguments(condition=cache_on)
1138 1137 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1139 1138 repo_init = self._factory.repo_libgit2(wire)
1140 1139
1141 1140 with repo_init as repo:
1142 1141 commit = repo[commit_id]
1143 1142 try:
1144 1143 tree = commit.tree[path]
1145 1144 except KeyError:
1146 1145 return None, None, None
1147 1146
1148 1147 return tree.id.hex, tree.type_str, tree.filemode
1149 1148 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1150 1149
1151 1150 @reraise_safe_exceptions
1152 1151 def tree_items(self, wire, tree_id):
1153 1152 cache_on, context_uid, repo_id = self._cache_on(wire)
1154 1153 region = self._region(wire)
1155 1154
1156 1155 @region.conditional_cache_on_arguments(condition=cache_on)
1157 1156 def _tree_items(_repo_id, _tree_id):
1158 1157
1159 1158 repo_init = self._factory.repo_libgit2(wire)
1160 1159 with repo_init as repo:
1161 1160 try:
1162 1161 tree = repo[tree_id]
1163 1162 except KeyError:
1164 1163 raise ObjectMissing(f'No tree with id: {tree_id}')
1165 1164
1166 1165 result = []
1167 1166 for item in tree:
1168 1167 item_sha = item.hex
1169 1168 item_mode = item.filemode
1170 1169 item_type = item.type_str
1171 1170
1172 1171 if item_type == 'commit':
1173 1172 # NOTE(marcink): submodules we translate to 'link' for backward compat
1174 1173 item_type = 'link'
1175 1174
1176 1175 result.append((item.name, item_mode, item_sha, item_type))
1177 1176 return result
1178 1177 return _tree_items(repo_id, tree_id)
1179 1178
1180 1179 @reraise_safe_exceptions
1181 1180 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1182 1181 """
1183 1182 Old version that uses subprocess to call diff
1184 1183 """
1185 1184
1186 1185 flags = [
1187 '-U%s' % context, '--patch',
1186 f'-U{context}', '--patch',
1188 1187 '--binary',
1189 1188 '--find-renames',
1190 1189 '--no-indent-heuristic',
1191 1190 # '--indent-heuristic',
1192 1191 #'--full-index',
1193 1192 #'--abbrev=40'
1194 1193 ]
1195 1194
1196 1195 if opt_ignorews:
1197 1196 flags.append('--ignore-all-space')
1198 1197
1199 1198 if commit_id_1 == self.EMPTY_COMMIT:
1200 1199 cmd = ['show'] + flags + [commit_id_2]
1201 1200 else:
1202 1201 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1203 1202
1204 1203 if file_filter:
1205 1204 cmd.extend(['--', file_filter])
1206 1205
1207 1206 diff, __ = self.run_git_command(wire, cmd)
1208 1207 # If we used 'show' command, strip first few lines (until actual diff
1209 1208 # starts)
1210 1209 if commit_id_1 == self.EMPTY_COMMIT:
1211 1210 lines = diff.splitlines()
1212 1211 x = 0
1213 1212 for line in lines:
1214 1213 if line.startswith(b'diff'):
1215 1214 break
1216 1215 x += 1
1217 1216 # Append new line just like 'diff' command do
1218 1217 diff = '\n'.join(lines[x:]) + '\n'
1219 1218 return diff
1220 1219
1221 1220 @reraise_safe_exceptions
1222 1221 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1223 1222 repo_init = self._factory.repo_libgit2(wire)
1224 1223
1225 1224 with repo_init as repo:
1226 1225 swap = True
1227 1226 flags = 0
1228 1227 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1229 1228
1230 1229 if opt_ignorews:
1231 1230 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1232 1231
1233 1232 if commit_id_1 == self.EMPTY_COMMIT:
1234 1233 comm1 = repo[commit_id_2]
1235 1234 diff_obj = comm1.tree.diff_to_tree(
1236 1235 flags=flags, context_lines=context, swap=swap)
1237 1236
1238 1237 else:
1239 1238 comm1 = repo[commit_id_2]
1240 1239 comm2 = repo[commit_id_1]
1241 1240 diff_obj = comm1.tree.diff_to_tree(
1242 1241 comm2.tree, flags=flags, context_lines=context, swap=swap)
1243 1242 similar_flags = 0
1244 1243 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1245 1244 diff_obj.find_similar(flags=similar_flags)
1246 1245
1247 1246 if file_filter:
1248 1247 for p in diff_obj:
1249 1248 if p.delta.old_file.path == file_filter:
1250 1249 return BytesEnvelope(p.data) or BytesEnvelope(b'')
1251 1250 # fo matching path == no diff
1252 1251 return BytesEnvelope(b'')
1253 1252
1254 1253 return BytesEnvelope(safe_bytes(diff_obj.patch)) or BytesEnvelope(b'')
1255 1254
1256 1255 @reraise_safe_exceptions
1257 1256 def node_history(self, wire, commit_id, path, limit):
1258 1257 cache_on, context_uid, repo_id = self._cache_on(wire)
1259 1258 region = self._region(wire)
1260 1259
1261 1260 @region.conditional_cache_on_arguments(condition=cache_on)
1262 1261 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1263 1262 # optimize for n==1, rev-list is much faster for that use-case
1264 1263 if limit == 1:
1265 1264 cmd = ['rev-list', '-1', commit_id, '--', path]
1266 1265 else:
1267 1266 cmd = ['log']
1268 1267 if limit:
1269 1268 cmd.extend(['-n', str(safe_int(limit, 0))])
1270 1269 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1271 1270
1272 1271 output, __ = self.run_git_command(wire, cmd)
1273 1272 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1274 1273
1275 1274 return [x for x in commit_ids]
1276 1275 return _node_history(context_uid, repo_id, commit_id, path, limit)
1277 1276
1278 1277 @reraise_safe_exceptions
1279 1278 def node_annotate_legacy(self, wire, commit_id, path):
1280 1279 # note: replaced by pygit2 implementation
1281 1280 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1282 1281 # -l ==> outputs long shas (and we need all 40 characters)
1283 1282 # --root ==> doesn't put '^' character for boundaries
1284 1283 # -r commit_id ==> blames for the given commit
1285 1284 output, __ = self.run_git_command(wire, cmd)
1286 1285
1287 1286 result = []
1288 1287 for i, blame_line in enumerate(output.splitlines()[:-1]):
1289 1288 line_no = i + 1
1290 1289 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1291 1290 result.append((line_no, blame_commit_id, line))
1292 1291
1293 1292 return result
1294 1293
1295 1294 @reraise_safe_exceptions
1296 1295 def node_annotate(self, wire, commit_id, path):
1297 1296
1298 1297 result_libgit = []
1299 1298 repo_init = self._factory.repo_libgit2(wire)
1300 1299 with repo_init as repo:
1301 1300 commit = repo[commit_id]
1302 1301 blame_obj = repo.blame(path, newest_commit=commit_id)
1303 1302 for i, line in enumerate(commit.tree[path].data.splitlines()):
1304 1303 line_no = i + 1
1305 1304 hunk = blame_obj.for_line(line_no)
1306 1305 blame_commit_id = hunk.final_commit_id.hex
1307 1306
1308 1307 result_libgit.append((line_no, blame_commit_id, line))
1309 1308
1310 1309 return BinaryEnvelope(result_libgit)
1311 1310
1312 1311 @reraise_safe_exceptions
1313 1312 def update_server_info(self, wire):
1314 1313 repo = self._factory.repo(wire)
1315 1314 update_server_info(repo)
1316 1315
1317 1316 @reraise_safe_exceptions
1318 1317 def get_all_commit_ids(self, wire):
1319 1318
1320 1319 cache_on, context_uid, repo_id = self._cache_on(wire)
1321 1320 region = self._region(wire)
1322 1321
1323 1322 @region.conditional_cache_on_arguments(condition=cache_on)
1324 1323 def _get_all_commit_ids(_context_uid, _repo_id):
1325 1324
1326 1325 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1327 1326 try:
1328 1327 output, __ = self.run_git_command(wire, cmd)
1329 1328 return output.splitlines()
1330 1329 except Exception:
1331 1330 # Can be raised for empty repositories
1332 1331 return []
1333 1332
1334 1333 @region.conditional_cache_on_arguments(condition=cache_on)
1335 1334 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1336 1335 repo_init = self._factory.repo_libgit2(wire)
1337 1336 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1338 1337 results = []
1339 1338 with repo_init as repo:
1340 1339 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1341 1340 results.append(commit.id.hex)
1342 1341
1343 1342 return _get_all_commit_ids(context_uid, repo_id)
1344 1343
1345 1344 @reraise_safe_exceptions
1346 1345 def run_git_command(self, wire, cmd, **opts):
1347 1346 path = wire.get('path', None)
1348 1347
1349 1348 if path and os.path.isdir(path):
1350 1349 opts['cwd'] = path
1351 1350
1352 1351 if '_bare' in opts:
1353 1352 _copts = []
1354 1353 del opts['_bare']
1355 1354 else:
1356 1355 _copts = ['-c', 'core.quotepath=false',]
1357 1356 safe_call = False
1358 1357 if '_safe' in opts:
1359 1358 # no exc on failure
1360 1359 del opts['_safe']
1361 1360 safe_call = True
1362 1361
1363 1362 if '_copts' in opts:
1364 1363 _copts.extend(opts['_copts'] or [])
1365 1364 del opts['_copts']
1366 1365
1367 1366 gitenv = os.environ.copy()
1368 1367 gitenv.update(opts.pop('extra_env', {}))
1369 1368 # need to clean fix GIT_DIR !
1370 1369 if 'GIT_DIR' in gitenv:
1371 1370 del gitenv['GIT_DIR']
1372 1371 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1373 1372 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1374 1373
1375 1374 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1376 1375 _opts = {'env': gitenv, 'shell': False}
1377 1376
1378 1377 proc = None
1379 1378 try:
1380 1379 _opts.update(opts)
1381 1380 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1382 1381
1383 1382 return b''.join(proc), b''.join(proc.stderr)
1384 1383 except OSError as err:
1385 1384 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
1386 1385 tb_err = ("Couldn't run git command (%s).\n"
1387 1386 "Original error was:%s\n"
1388 1387 "Call options:%s\n"
1389 1388 % (cmd, err, _opts))
1390 1389 log.exception(tb_err)
1391 1390 if safe_call:
1392 1391 return '', err
1393 1392 else:
1394 1393 raise exceptions.VcsException()(tb_err)
1395 1394 finally:
1396 1395 if proc:
1397 1396 proc.close()
1398 1397
1399 1398 @reraise_safe_exceptions
1400 1399 def install_hooks(self, wire, force=False):
1401 1400 from vcsserver.hook_utils import install_git_hooks
1402 1401 bare = self.bare(wire)
1403 1402 path = wire['path']
1404 1403 binary_dir = settings.BINARY_DIR
1405 1404 if binary_dir:
1406 1405 os.path.join(binary_dir, 'python3')
1407 1406 return install_git_hooks(path, bare, force_create=force)
1408 1407
1409 1408 @reraise_safe_exceptions
1410 1409 def get_hooks_info(self, wire):
1411 1410 from vcsserver.hook_utils import (
1412 1411 get_git_pre_hook_version, get_git_post_hook_version)
1413 1412 bare = self.bare(wire)
1414 1413 path = wire['path']
1415 1414 return {
1416 1415 'pre_version': get_git_pre_hook_version(path, bare),
1417 1416 'post_version': get_git_post_hook_version(path, bare),
1418 1417 }
1419 1418
1420 1419 @reraise_safe_exceptions
1421 1420 def set_head_ref(self, wire, head_name):
1422 1421 log.debug('Setting refs/head to `%s`', head_name)
1423 1422 repo_init = self._factory.repo_libgit2(wire)
1424 1423 with repo_init as repo:
1425 1424 repo.set_head(f'refs/heads/{head_name}')
1426 1425
1427 1426 return [head_name] + [f'set HEAD to refs/heads/{head_name}']
1428 1427
1429 1428 @reraise_safe_exceptions
1430 1429 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1431 1430 archive_dir_name, commit_id, cache_config):
1432 1431
1433 1432 def file_walker(_commit_id, path):
1434 1433 repo_init = self._factory.repo_libgit2(wire)
1435 1434
1436 1435 with repo_init as repo:
1437 1436 commit = repo[commit_id]
1438 1437
1439 1438 if path in ['', '/']:
1440 1439 tree = commit.tree
1441 1440 else:
1442 1441 tree = commit.tree[path.rstrip('/')]
1443 1442 tree_id = tree.id.hex
1444 1443 try:
1445 1444 tree = repo[tree_id]
1446 1445 except KeyError:
1447 1446 raise ObjectMissing(f'No tree with id: {tree_id}')
1448 1447
1449 1448 index = LibGit2Index.Index()
1450 1449 index.read_tree(tree)
1451 1450 file_iter = index
1452 1451
1453 1452 for file_node in file_iter:
1454 1453 file_path = file_node.path
1455 1454 mode = file_node.mode
1456 1455 is_link = stat.S_ISLNK(mode)
1457 1456 if mode == pygit2.GIT_FILEMODE_COMMIT:
1458 1457 log.debug('Skipping path %s as a commit node', file_path)
1459 1458 continue
1460 1459 yield ArchiveNode(file_path, mode, is_link, repo[file_node.hex].read_raw)
1461 1460
1462 1461 return store_archive_in_cache(
1463 1462 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
@@ -1,1201 +1,1200 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import binascii
18 18 import io
19 19 import logging
20 20 import stat
21 21 import sys
22 22 import urllib.request
23 23 import urllib.parse
24 24 import hashlib
25 25
26 26 from hgext import largefiles, rebase, purge
27 27
28 28 from mercurial import commands
29 29 from mercurial import unionrepo
30 30 from mercurial import verify
31 31 from mercurial import repair
32 32 from mercurial.error import AmbiguousPrefixLookupError
33 33
34 34 import vcsserver
35 35 from vcsserver import exceptions
36 36 from vcsserver.base import (
37 37 RepoFactory,
38 38 obfuscate_qs,
39 39 raise_from_original,
40 40 store_archive_in_cache,
41 41 ArchiveNode,
42 42 BytesEnvelope,
43 43 BinaryEnvelope,
44 44 )
45 45 from vcsserver.hgcompat import (
46 46 archival,
47 47 bin,
48 48 clone,
49 49 config as hgconfig,
50 50 diffopts,
51 51 hex,
52 52 get_ctx,
53 53 hg_url as url_parser,
54 54 httpbasicauthhandler,
55 55 httpdigestauthhandler,
56 56 makepeer,
57 57 instance,
58 58 match,
59 59 memctx,
60 60 exchange,
61 61 memfilectx,
62 62 nullrev,
63 63 hg_merge,
64 64 patch,
65 65 peer,
66 66 revrange,
67 67 ui,
68 68 hg_tag,
69 69 Abort,
70 70 LookupError,
71 71 RepoError,
72 72 RepoLookupError,
73 73 InterventionRequired,
74 74 RequirementError,
75 75 alwaysmatcher,
76 76 patternmatcher,
77 77 hgutil,
78 78 hgext_strip,
79 79 )
80 80 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes
81 81 from vcsserver.vcs_base import RemoteBase
82 82 from vcsserver.config import hooks as hooks_config
83 83 from vcsserver.lib.exc_tracking import format_exc
84 84
85 85 log = logging.getLogger(__name__)
86 86
87 87
88 88 def make_ui_from_config(repo_config):
89 89
90 90 class LoggingUI(ui.ui):
91 91
92 92 def status(self, *msg, **opts):
93 93 str_msg = map(safe_str, msg)
94 94 log.info(' '.join(str_msg).rstrip('\n'))
95 95 #super(LoggingUI, self).status(*msg, **opts)
96 96
97 97 def warn(self, *msg, **opts):
98 98 str_msg = map(safe_str, msg)
99 99 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
100 100 #super(LoggingUI, self).warn(*msg, **opts)
101 101
102 102 def error(self, *msg, **opts):
103 103 str_msg = map(safe_str, msg)
104 104 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
105 105 #super(LoggingUI, self).error(*msg, **opts)
106 106
107 107 def note(self, *msg, **opts):
108 108 str_msg = map(safe_str, msg)
109 109 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
110 110 #super(LoggingUI, self).note(*msg, **opts)
111 111
112 112 def debug(self, *msg, **opts):
113 113 str_msg = map(safe_str, msg)
114 114 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
115 115 #super(LoggingUI, self).debug(*msg, **opts)
116 116
117 117 baseui = LoggingUI()
118 118
119 119 # clean the baseui object
120 120 baseui._ocfg = hgconfig.config()
121 121 baseui._ucfg = hgconfig.config()
122 122 baseui._tcfg = hgconfig.config()
123 123
124 124 for section, option, value in repo_config:
125 125 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
126 126
127 127 # make our hgweb quiet so it doesn't print output
128 128 baseui.setconfig(b'ui', b'quiet', b'true')
129 129
130 130 baseui.setconfig(b'ui', b'paginate', b'never')
131 131 # for better Error reporting of Mercurial
132 132 baseui.setconfig(b'ui', b'message-output', b'stderr')
133 133
134 134 # force mercurial to only use 1 thread, otherwise it may try to set a
135 135 # signal in a non-main thread, thus generating a ValueError.
136 136 baseui.setconfig(b'worker', b'numcpus', 1)
137 137
138 138 # If there is no config for the largefiles extension, we explicitly disable
139 139 # it here. This overrides settings from repositories hgrc file. Recent
140 140 # mercurial versions enable largefiles in hgrc on clone from largefile
141 141 # repo.
142 142 if not baseui.hasconfig(b'extensions', b'largefiles'):
143 143 log.debug('Explicitly disable largefiles extension for repo.')
144 144 baseui.setconfig(b'extensions', b'largefiles', b'!')
145 145
146 146 return baseui
147 147
148 148
149 149 def reraise_safe_exceptions(func):
150 150 """Decorator for converting mercurial exceptions to something neutral."""
151 151
152 152 def wrapper(*args, **kwargs):
153 153 try:
154 154 return func(*args, **kwargs)
155 155 except (Abort, InterventionRequired) as e:
156 156 raise_from_original(exceptions.AbortException(e), e)
157 157 except RepoLookupError as e:
158 158 raise_from_original(exceptions.LookupException(e), e)
159 159 except RequirementError as e:
160 160 raise_from_original(exceptions.RequirementException(e), e)
161 161 except RepoError as e:
162 162 raise_from_original(exceptions.VcsException(e), e)
163 163 except LookupError as e:
164 164 raise_from_original(exceptions.LookupException(e), e)
165 165 except Exception as e:
166 166 if not hasattr(e, '_vcs_kind'):
167 167 log.exception("Unhandled exception in hg remote call")
168 168 raise_from_original(exceptions.UnhandledException(e), e)
169 169
170 170 raise
171 171 return wrapper
172 172
173 173
174 174 class MercurialFactory(RepoFactory):
175 175 repo_type = 'hg'
176 176
177 177 def _create_config(self, config, hooks=True):
178 178 if not hooks:
179 179
180 180 hooks_to_clean = {
181 181
182 182 hooks_config.HOOK_REPO_SIZE,
183 183 hooks_config.HOOK_PRE_PULL,
184 184 hooks_config.HOOK_PULL,
185 185
186 186 hooks_config.HOOK_PRE_PUSH,
187 187 # TODO: what about PRETXT, this was disabled in pre 5.0.0
188 188 hooks_config.HOOK_PRETX_PUSH,
189 189
190 190 }
191 191 new_config = []
192 192 for section, option, value in config:
193 193 if section == 'hooks' and option in hooks_to_clean:
194 194 continue
195 195 new_config.append((section, option, value))
196 196 config = new_config
197 197
198 198 baseui = make_ui_from_config(config)
199 199 return baseui
200 200
201 201 def _create_repo(self, wire, create):
202 202 baseui = self._create_config(wire["config"])
203 203 repo = instance(baseui, safe_bytes(wire["path"]), create)
204 204 log.debug('repository created: got HG object: %s', repo)
205 205 return repo
206 206
207 207 def repo(self, wire, create=False):
208 208 """
209 209 Get a repository instance for the given path.
210 210 """
211 211 return self._create_repo(wire, create)
212 212
213 213
214 214 def patch_ui_message_output(baseui):
215 215 baseui.setconfig(b'ui', b'quiet', b'false')
216 216 output = io.BytesIO()
217 217
218 218 def write(data, **unused_kwargs):
219 219 output.write(data)
220 220
221 221 baseui.status = write
222 222 baseui.write = write
223 223 baseui.warn = write
224 224 baseui.debug = write
225 225
226 226 return baseui, output
227 227
228 228
229 229 def get_obfuscated_url(url_obj):
230 230 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
231 231 url_obj.query = obfuscate_qs(url_obj.query)
232 232 obfuscated_uri = str(url_obj)
233 233 return obfuscated_uri
234 234
235 235
236 236 def normalize_url_for_hg(url: str):
237 237 _proto = None
238 238
239 239 if '+' in url[:url.find('://')]:
240 240 _proto = url[0:url.find('+')]
241 241 url = url[url.find('+') + 1:]
242 242 return url, _proto
243 243
244 244
245 245 class HgRemote(RemoteBase):
246 246
247 247 def __init__(self, factory):
248 248 self._factory = factory
249 249 self._bulk_methods = {
250 250 "affected_files": self.ctx_files,
251 251 "author": self.ctx_user,
252 252 "branch": self.ctx_branch,
253 253 "children": self.ctx_children,
254 254 "date": self.ctx_date,
255 255 "message": self.ctx_description,
256 256 "parents": self.ctx_parents,
257 257 "status": self.ctx_status,
258 258 "obsolete": self.ctx_obsolete,
259 259 "phase": self.ctx_phase,
260 260 "hidden": self.ctx_hidden,
261 261 "_file_paths": self.ctx_list,
262 262 }
263 263 self._bulk_file_methods = {
264 264 "size": self.fctx_size,
265 265 "data": self.fctx_node_data,
266 266 "flags": self.fctx_flags,
267 267 "is_binary": self.is_binary,
268 268 "md5": self.md5_hash,
269 269 }
270 270
271 271 def _get_ctx(self, repo, ref):
272 272 return get_ctx(repo, ref)
273 273
274 274 @reraise_safe_exceptions
275 275 def discover_hg_version(self):
276 276 from mercurial import util
277 277 return safe_str(util.version())
278 278
279 279 @reraise_safe_exceptions
280 280 def is_empty(self, wire):
281 281 repo = self._factory.repo(wire)
282 282
283 283 try:
284 284 return len(repo) == 0
285 285 except Exception:
286 286 log.exception("failed to read object_store")
287 287 return False
288 288
289 289 @reraise_safe_exceptions
290 290 def bookmarks(self, wire):
291 291 cache_on, context_uid, repo_id = self._cache_on(wire)
292 292 region = self._region(wire)
293 293
294 294 @region.conditional_cache_on_arguments(condition=cache_on)
295 295 def _bookmarks(_context_uid, _repo_id):
296 296 repo = self._factory.repo(wire)
297 297 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo._bookmarks.items()}
298 298
299 299 return _bookmarks(context_uid, repo_id)
300 300
301 301 @reraise_safe_exceptions
302 302 def branches(self, wire, normal, closed):
303 303 cache_on, context_uid, repo_id = self._cache_on(wire)
304 304 region = self._region(wire)
305 305
306 306 @region.conditional_cache_on_arguments(condition=cache_on)
307 307 def _branches(_context_uid, _repo_id, _normal, _closed):
308 308 repo = self._factory.repo(wire)
309 309 iter_branches = repo.branchmap().iterbranches()
310 310 bt = {}
311 311 for branch_name, _heads, tip_node, is_closed in iter_branches:
312 312 if normal and not is_closed:
313 313 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
314 314 if closed and is_closed:
315 315 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
316 316
317 317 return bt
318 318
319 319 return _branches(context_uid, repo_id, normal, closed)
320 320
321 321 @reraise_safe_exceptions
322 322 def bulk_request(self, wire, commit_id, pre_load):
323 323 cache_on, context_uid, repo_id = self._cache_on(wire)
324 324 region = self._region(wire)
325 325
326 326 @region.conditional_cache_on_arguments(condition=cache_on)
327 327 def _bulk_request(_repo_id, _commit_id, _pre_load):
328 328 result = {}
329 329 for attr in pre_load:
330 330 try:
331 331 method = self._bulk_methods[attr]
332 332 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
333 333 result[attr] = method(wire, commit_id)
334 334 except KeyError as e:
335 335 raise exceptions.VcsException(e)(
336 'Unknown bulk attribute: "%s"' % attr)
336 f'Unknown bulk attribute: "{attr}"')
337 337 return result
338 338
339 339 return _bulk_request(repo_id, commit_id, sorted(pre_load))
340 340
341 341 @reraise_safe_exceptions
342 342 def ctx_branch(self, wire, commit_id):
343 343 cache_on, context_uid, repo_id = self._cache_on(wire)
344 344 region = self._region(wire)
345 345
346 346 @region.conditional_cache_on_arguments(condition=cache_on)
347 347 def _ctx_branch(_repo_id, _commit_id):
348 348 repo = self._factory.repo(wire)
349 349 ctx = self._get_ctx(repo, commit_id)
350 350 return ctx.branch()
351 351 return _ctx_branch(repo_id, commit_id)
352 352
353 353 @reraise_safe_exceptions
354 354 def ctx_date(self, wire, commit_id):
355 355 cache_on, context_uid, repo_id = self._cache_on(wire)
356 356 region = self._region(wire)
357 357
358 358 @region.conditional_cache_on_arguments(condition=cache_on)
359 359 def _ctx_date(_repo_id, _commit_id):
360 360 repo = self._factory.repo(wire)
361 361 ctx = self._get_ctx(repo, commit_id)
362 362 return ctx.date()
363 363 return _ctx_date(repo_id, commit_id)
364 364
365 365 @reraise_safe_exceptions
366 366 def ctx_description(self, wire, revision):
367 367 repo = self._factory.repo(wire)
368 368 ctx = self._get_ctx(repo, revision)
369 369 return ctx.description()
370 370
371 371 @reraise_safe_exceptions
372 372 def ctx_files(self, wire, commit_id):
373 373 cache_on, context_uid, repo_id = self._cache_on(wire)
374 374 region = self._region(wire)
375 375
376 376 @region.conditional_cache_on_arguments(condition=cache_on)
377 377 def _ctx_files(_repo_id, _commit_id):
378 378 repo = self._factory.repo(wire)
379 379 ctx = self._get_ctx(repo, commit_id)
380 380 return ctx.files()
381 381
382 382 return _ctx_files(repo_id, commit_id)
383 383
384 384 @reraise_safe_exceptions
385 385 def ctx_list(self, path, revision):
386 386 repo = self._factory.repo(path)
387 387 ctx = self._get_ctx(repo, revision)
388 388 return list(ctx)
389 389
390 390 @reraise_safe_exceptions
391 391 def ctx_parents(self, wire, commit_id):
392 392 cache_on, context_uid, repo_id = self._cache_on(wire)
393 393 region = self._region(wire)
394 394
395 395 @region.conditional_cache_on_arguments(condition=cache_on)
396 396 def _ctx_parents(_repo_id, _commit_id):
397 397 repo = self._factory.repo(wire)
398 398 ctx = self._get_ctx(repo, commit_id)
399 399 return [parent.hex() for parent in ctx.parents()
400 400 if not (parent.hidden() or parent.obsolete())]
401 401
402 402 return _ctx_parents(repo_id, commit_id)
403 403
404 404 @reraise_safe_exceptions
405 405 def ctx_children(self, wire, commit_id):
406 406 cache_on, context_uid, repo_id = self._cache_on(wire)
407 407 region = self._region(wire)
408 408
409 409 @region.conditional_cache_on_arguments(condition=cache_on)
410 410 def _ctx_children(_repo_id, _commit_id):
411 411 repo = self._factory.repo(wire)
412 412 ctx = self._get_ctx(repo, commit_id)
413 413 return [child.hex() for child in ctx.children()
414 414 if not (child.hidden() or child.obsolete())]
415 415
416 416 return _ctx_children(repo_id, commit_id)
417 417
418 418 @reraise_safe_exceptions
419 419 def ctx_phase(self, wire, commit_id):
420 420 cache_on, context_uid, repo_id = self._cache_on(wire)
421 421 region = self._region(wire)
422 422
423 423 @region.conditional_cache_on_arguments(condition=cache_on)
424 424 def _ctx_phase(_context_uid, _repo_id, _commit_id):
425 425 repo = self._factory.repo(wire)
426 426 ctx = self._get_ctx(repo, commit_id)
427 427 # public=0, draft=1, secret=3
428 428 return ctx.phase()
429 429 return _ctx_phase(context_uid, repo_id, commit_id)
430 430
431 431 @reraise_safe_exceptions
432 432 def ctx_obsolete(self, wire, commit_id):
433 433 cache_on, context_uid, repo_id = self._cache_on(wire)
434 434 region = self._region(wire)
435 435
436 436 @region.conditional_cache_on_arguments(condition=cache_on)
437 437 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
438 438 repo = self._factory.repo(wire)
439 439 ctx = self._get_ctx(repo, commit_id)
440 440 return ctx.obsolete()
441 441 return _ctx_obsolete(context_uid, repo_id, commit_id)
442 442
443 443 @reraise_safe_exceptions
444 444 def ctx_hidden(self, wire, commit_id):
445 445 cache_on, context_uid, repo_id = self._cache_on(wire)
446 446 region = self._region(wire)
447 447
448 448 @region.conditional_cache_on_arguments(condition=cache_on)
449 449 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
450 450 repo = self._factory.repo(wire)
451 451 ctx = self._get_ctx(repo, commit_id)
452 452 return ctx.hidden()
453 453 return _ctx_hidden(context_uid, repo_id, commit_id)
454 454
455 455 @reraise_safe_exceptions
456 456 def ctx_substate(self, wire, revision):
457 457 repo = self._factory.repo(wire)
458 458 ctx = self._get_ctx(repo, revision)
459 459 return ctx.substate
460 460
461 461 @reraise_safe_exceptions
462 462 def ctx_status(self, wire, revision):
463 463 repo = self._factory.repo(wire)
464 464 ctx = self._get_ctx(repo, revision)
465 465 status = repo[ctx.p1().node()].status(other=ctx.node())
466 466 # object of status (odd, custom named tuple in mercurial) is not
467 467 # correctly serializable, we make it a list, as the underling
468 468 # API expects this to be a list
469 469 return list(status)
470 470
471 471 @reraise_safe_exceptions
472 472 def ctx_user(self, wire, revision):
473 473 repo = self._factory.repo(wire)
474 474 ctx = self._get_ctx(repo, revision)
475 475 return ctx.user()
476 476
477 477 @reraise_safe_exceptions
478 478 def check_url(self, url, config):
479 479 url, _proto = normalize_url_for_hg(url)
480 480 url_obj = url_parser(safe_bytes(url))
481 481
482 482 test_uri = safe_str(url_obj.authinfo()[0])
483 483 authinfo = url_obj.authinfo()[1]
484 484 obfuscated_uri = get_obfuscated_url(url_obj)
485 485 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
486 486
487 487 handlers = []
488 488 if authinfo:
489 489 # create a password manager
490 490 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
491 491 passmgr.add_password(*authinfo)
492 492
493 493 handlers.extend((httpbasicauthhandler(passmgr),
494 494 httpdigestauthhandler(passmgr)))
495 495
496 496 o = urllib.request.build_opener(*handlers)
497 497 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
498 498 ('Accept', 'application/mercurial-0.1')]
499 499
500 500 q = {"cmd": 'between'}
501 501 q.update({'pairs': "{}-{}".format('0' * 40, '0' * 40)})
502 qs = '?%s' % urllib.parse.urlencode(q)
502 qs = f'?{urllib.parse.urlencode(q)}'
503 503 cu = f"{test_uri}{qs}"
504 504 req = urllib.request.Request(cu, None, {})
505 505
506 506 try:
507 507 log.debug("Trying to open URL %s", obfuscated_uri)
508 508 resp = o.open(req)
509 509 if resp.code != 200:
510 510 raise exceptions.URLError()('Return Code is not 200')
511 511 except Exception as e:
512 512 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
513 513 # means it cannot be cloned
514 514 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
515 515
516 516 # now check if it's a proper hg repo, but don't do it for svn
517 517 try:
518 518 if _proto == 'svn':
519 519 pass
520 520 else:
521 521 # check for pure hg repos
522 522 log.debug(
523 523 "Verifying if URL is a Mercurial repository: %s", obfuscated_uri)
524 524 ui = make_ui_from_config(config)
525 525 peer_checker = makepeer(ui, safe_bytes(url))
526 526 peer_checker.lookup(b'tip')
527 527 except Exception as e:
528 528 log.warning("URL is not a valid Mercurial repository: %s",
529 529 obfuscated_uri)
530 530 raise exceptions.URLError(e)(
531 "url [%s] does not look like an hg repo org_exc: %s"
532 % (obfuscated_uri, e))
531 f"url [{obfuscated_uri}] does not look like an hg repo org_exc: {e}")
533 532
534 533 log.info("URL is a valid Mercurial repository: %s", obfuscated_uri)
535 534 return True
536 535
537 536 @reraise_safe_exceptions
538 537 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
539 538 repo = self._factory.repo(wire)
540 539
541 540 if file_filter:
542 541 # unpack the file-filter
543 542 repo_path, node_path = file_filter
544 543 match_filter = match(safe_bytes(repo_path), b'', [safe_bytes(node_path)])
545 544 else:
546 545 match_filter = file_filter
547 546 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
548 547
549 548 try:
550 549 diff_iter = patch.diff(
551 550 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts)
552 551 return BytesEnvelope(b"".join(diff_iter))
553 552 except RepoLookupError as e:
554 553 raise exceptions.LookupException(e)()
555 554
556 555 @reraise_safe_exceptions
557 556 def node_history(self, wire, revision, path, limit):
558 557 cache_on, context_uid, repo_id = self._cache_on(wire)
559 558 region = self._region(wire)
560 559
561 560 @region.conditional_cache_on_arguments(condition=cache_on)
562 561 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
563 562 repo = self._factory.repo(wire)
564 563
565 564 ctx = self._get_ctx(repo, revision)
566 565 fctx = ctx.filectx(safe_bytes(path))
567 566
568 567 def history_iter():
569 568 limit_rev = fctx.rev()
570 569 for fctx_candidate in reversed(list(fctx.filelog())):
571 570 f_obj = fctx.filectx(fctx_candidate)
572 571
573 572 # NOTE: This can be problematic...we can hide ONLY history node resulting in empty history
574 573 _ctx = obj.changectx()
575 574 if _ctx.hidden() or _ctx.obsolete():
576 575 continue
577 576
578 577 if limit_rev >= f_obj.rev():
579 578 yield f_obj
580 579
581 580 history = []
582 581 for cnt, obj in enumerate(history_iter()):
583 582 if limit and cnt >= limit:
584 583 break
585 584 history.append(hex(obj.node()))
586 585
587 586 return [x for x in history]
588 587 return _node_history(context_uid, repo_id, revision, path, limit)
589 588
590 589 @reraise_safe_exceptions
591 590 def node_history_until(self, wire, revision, path, limit):
592 591 cache_on, context_uid, repo_id = self._cache_on(wire)
593 592 region = self._region(wire)
594 593
595 594 @region.conditional_cache_on_arguments(condition=cache_on)
596 595 def _node_history_until(_context_uid, _repo_id):
597 596 repo = self._factory.repo(wire)
598 597 ctx = self._get_ctx(repo, revision)
599 598 fctx = ctx.filectx(safe_bytes(path))
600 599
601 600 file_log = list(fctx.filelog())
602 601 if limit:
603 602 # Limit to the last n items
604 603 file_log = file_log[-limit:]
605 604
606 605 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
607 606 return _node_history_until(context_uid, repo_id, revision, path, limit)
608 607
609 608 @reraise_safe_exceptions
610 609 def bulk_file_request(self, wire, commit_id, path, pre_load):
611 610 cache_on, context_uid, repo_id = self._cache_on(wire)
612 611 region = self._region(wire)
613 612
614 613 @region.conditional_cache_on_arguments(condition=cache_on)
615 614 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
616 615 result = {}
617 616 for attr in pre_load:
618 617 try:
619 618 method = self._bulk_file_methods[attr]
620 619 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
621 620 result[attr] = method(wire, _commit_id, _path)
622 621 except KeyError as e:
623 622 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
624 623 return result
625 624
626 625 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
627 626
628 627 @reraise_safe_exceptions
629 628 def fctx_annotate(self, wire, revision, path):
630 629 repo = self._factory.repo(wire)
631 630 ctx = self._get_ctx(repo, revision)
632 631 fctx = ctx.filectx(safe_bytes(path))
633 632
634 633 result = []
635 634 for i, annotate_obj in enumerate(fctx.annotate(), 1):
636 635 ln_no = i
637 636 sha = hex(annotate_obj.fctx.node())
638 637 content = annotate_obj.text
639 638 result.append((ln_no, ascii_str(sha), content))
640 639 return BinaryEnvelope(result)
641 640
642 641 @reraise_safe_exceptions
643 642 def fctx_node_data(self, wire, revision, path):
644 643 repo = self._factory.repo(wire)
645 644 ctx = self._get_ctx(repo, revision)
646 645 fctx = ctx.filectx(safe_bytes(path))
647 646 return BytesEnvelope(fctx.data())
648 647
649 648 @reraise_safe_exceptions
650 649 def fctx_flags(self, wire, commit_id, path):
651 650 cache_on, context_uid, repo_id = self._cache_on(wire)
652 651 region = self._region(wire)
653 652
654 653 @region.conditional_cache_on_arguments(condition=cache_on)
655 654 def _fctx_flags(_repo_id, _commit_id, _path):
656 655 repo = self._factory.repo(wire)
657 656 ctx = self._get_ctx(repo, commit_id)
658 657 fctx = ctx.filectx(safe_bytes(path))
659 658 return fctx.flags()
660 659
661 660 return _fctx_flags(repo_id, commit_id, path)
662 661
663 662 @reraise_safe_exceptions
664 663 def fctx_size(self, wire, commit_id, path):
665 664 cache_on, context_uid, repo_id = self._cache_on(wire)
666 665 region = self._region(wire)
667 666
668 667 @region.conditional_cache_on_arguments(condition=cache_on)
669 668 def _fctx_size(_repo_id, _revision, _path):
670 669 repo = self._factory.repo(wire)
671 670 ctx = self._get_ctx(repo, commit_id)
672 671 fctx = ctx.filectx(safe_bytes(path))
673 672 return fctx.size()
674 673 return _fctx_size(repo_id, commit_id, path)
675 674
676 675 @reraise_safe_exceptions
677 676 def get_all_commit_ids(self, wire, name):
678 677 cache_on, context_uid, repo_id = self._cache_on(wire)
679 678 region = self._region(wire)
680 679
681 680 @region.conditional_cache_on_arguments(condition=cache_on)
682 681 def _get_all_commit_ids(_context_uid, _repo_id, _name):
683 682 repo = self._factory.repo(wire)
684 683 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
685 684 return revs
686 685 return _get_all_commit_ids(context_uid, repo_id, name)
687 686
688 687 @reraise_safe_exceptions
689 688 def get_config_value(self, wire, section, name, untrusted=False):
690 689 repo = self._factory.repo(wire)
691 690 return repo.ui.config(ascii_bytes(section), ascii_bytes(name), untrusted=untrusted)
692 691
693 692 @reraise_safe_exceptions
694 693 def is_large_file(self, wire, commit_id, path):
695 694 cache_on, context_uid, repo_id = self._cache_on(wire)
696 695 region = self._region(wire)
697 696
698 697 @region.conditional_cache_on_arguments(condition=cache_on)
699 698 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
700 699 return largefiles.lfutil.isstandin(safe_bytes(path))
701 700
702 701 return _is_large_file(context_uid, repo_id, commit_id, path)
703 702
704 703 @reraise_safe_exceptions
705 704 def is_binary(self, wire, revision, path):
706 705 cache_on, context_uid, repo_id = self._cache_on(wire)
707 706 region = self._region(wire)
708 707
709 708 @region.conditional_cache_on_arguments(condition=cache_on)
710 709 def _is_binary(_repo_id, _sha, _path):
711 710 repo = self._factory.repo(wire)
712 711 ctx = self._get_ctx(repo, revision)
713 712 fctx = ctx.filectx(safe_bytes(path))
714 713 return fctx.isbinary()
715 714
716 715 return _is_binary(repo_id, revision, path)
717 716
718 717 @reraise_safe_exceptions
719 718 def md5_hash(self, wire, revision, path):
720 719 cache_on, context_uid, repo_id = self._cache_on(wire)
721 720 region = self._region(wire)
722 721
723 722 @region.conditional_cache_on_arguments(condition=cache_on)
724 723 def _md5_hash(_repo_id, _sha, _path):
725 724 repo = self._factory.repo(wire)
726 725 ctx = self._get_ctx(repo, revision)
727 726 fctx = ctx.filectx(safe_bytes(path))
728 727 return hashlib.md5(fctx.data()).hexdigest()
729 728
730 729 return _md5_hash(repo_id, revision, path)
731 730
732 731 @reraise_safe_exceptions
733 732 def in_largefiles_store(self, wire, sha):
734 733 repo = self._factory.repo(wire)
735 734 return largefiles.lfutil.instore(repo, sha)
736 735
737 736 @reraise_safe_exceptions
738 737 def in_user_cache(self, wire, sha):
739 738 repo = self._factory.repo(wire)
740 739 return largefiles.lfutil.inusercache(repo.ui, sha)
741 740
742 741 @reraise_safe_exceptions
743 742 def store_path(self, wire, sha):
744 743 repo = self._factory.repo(wire)
745 744 return largefiles.lfutil.storepath(repo, sha)
746 745
747 746 @reraise_safe_exceptions
748 747 def link(self, wire, sha, path):
749 748 repo = self._factory.repo(wire)
750 749 largefiles.lfutil.link(
751 750 largefiles.lfutil.usercachepath(repo.ui, sha), path)
752 751
753 752 @reraise_safe_exceptions
754 753 def localrepository(self, wire, create=False):
755 754 self._factory.repo(wire, create=create)
756 755
757 756 @reraise_safe_exceptions
758 757 def lookup(self, wire, revision, both):
759 758 cache_on, context_uid, repo_id = self._cache_on(wire)
760 759 region = self._region(wire)
761 760
762 761 @region.conditional_cache_on_arguments(condition=cache_on)
763 762 def _lookup(_context_uid, _repo_id, _revision, _both):
764 763 repo = self._factory.repo(wire)
765 764 rev = _revision
766 765 if isinstance(rev, int):
767 766 # NOTE(marcink):
768 767 # since Mercurial doesn't support negative indexes properly
769 768 # we need to shift accordingly by one to get proper index, e.g
770 769 # repo[-1] => repo[-2]
771 770 # repo[0] => repo[-1]
772 771 if rev <= 0:
773 772 rev = rev + -1
774 773 try:
775 774 ctx = self._get_ctx(repo, rev)
776 775 except AmbiguousPrefixLookupError:
777 776 e = RepoLookupError(rev)
778 777 e._org_exc_tb = format_exc(sys.exc_info())
779 778 raise exceptions.LookupException(e)(rev)
780 779 except (TypeError, RepoLookupError, binascii.Error) as e:
781 780 e._org_exc_tb = format_exc(sys.exc_info())
782 781 raise exceptions.LookupException(e)(rev)
783 782 except LookupError as e:
784 783 e._org_exc_tb = format_exc(sys.exc_info())
785 784 raise exceptions.LookupException(e)(e.name)
786 785
787 786 if not both:
788 787 return ctx.hex()
789 788
790 789 ctx = repo[ctx.hex()]
791 790 return ctx.hex(), ctx.rev()
792 791
793 792 return _lookup(context_uid, repo_id, revision, both)
794 793
795 794 @reraise_safe_exceptions
796 795 def sync_push(self, wire, url):
797 796 if not self.check_url(url, wire['config']):
798 797 return
799 798
800 799 repo = self._factory.repo(wire)
801 800
802 801 # Disable any prompts for this repo
803 802 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
804 803
805 804 bookmarks = list(dict(repo._bookmarks).keys())
806 805 remote = peer(repo, {}, safe_bytes(url))
807 806 # Disable any prompts for this remote
808 807 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
809 808
810 809 return exchange.push(
811 810 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
812 811
813 812 @reraise_safe_exceptions
814 813 def revision(self, wire, rev):
815 814 repo = self._factory.repo(wire)
816 815 ctx = self._get_ctx(repo, rev)
817 816 return ctx.rev()
818 817
819 818 @reraise_safe_exceptions
820 819 def rev_range(self, wire, commit_filter):
821 820 cache_on, context_uid, repo_id = self._cache_on(wire)
822 821 region = self._region(wire)
823 822
824 823 @region.conditional_cache_on_arguments(condition=cache_on)
825 824 def _rev_range(_context_uid, _repo_id, _filter):
826 825 repo = self._factory.repo(wire)
827 826 revisions = [
828 827 ascii_str(repo[rev].hex())
829 828 for rev in revrange(repo, list(map(ascii_bytes, commit_filter)))
830 829 ]
831 830 return revisions
832 831
833 832 return _rev_range(context_uid, repo_id, sorted(commit_filter))
834 833
835 834 @reraise_safe_exceptions
836 835 def rev_range_hash(self, wire, node):
837 836 repo = self._factory.repo(wire)
838 837
839 838 def get_revs(repo, rev_opt):
840 839 if rev_opt:
841 840 revs = revrange(repo, rev_opt)
842 841 if len(revs) == 0:
843 842 return (nullrev, nullrev)
844 843 return max(revs), min(revs)
845 844 else:
846 845 return len(repo) - 1, 0
847 846
848 847 stop, start = get_revs(repo, [node + ':'])
849 848 revs = [ascii_str(repo[r].hex()) for r in range(start, stop + 1)]
850 849 return revs
851 850
852 851 @reraise_safe_exceptions
853 852 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
854 853 org_path = safe_bytes(wire["path"])
855 854 other_path = safe_bytes(kwargs.pop('other_path', ''))
856 855
857 856 # case when we want to compare two independent repositories
858 857 if other_path and other_path != wire["path"]:
859 858 baseui = self._factory._create_config(wire["config"])
860 859 repo = unionrepo.makeunionrepository(baseui, other_path, org_path)
861 860 else:
862 861 repo = self._factory.repo(wire)
863 862 return list(repo.revs(rev_spec, *args))
864 863
865 864 @reraise_safe_exceptions
866 865 def verify(self, wire,):
867 866 repo = self._factory.repo(wire)
868 867 baseui = self._factory._create_config(wire['config'])
869 868
870 869 baseui, output = patch_ui_message_output(baseui)
871 870
872 871 repo.ui = baseui
873 872 verify.verify(repo)
874 873 return output.getvalue()
875 874
876 875 @reraise_safe_exceptions
877 876 def hg_update_cache(self, wire,):
878 877 repo = self._factory.repo(wire)
879 878 baseui = self._factory._create_config(wire['config'])
880 879 baseui, output = patch_ui_message_output(baseui)
881 880
882 881 repo.ui = baseui
883 882 with repo.wlock(), repo.lock():
884 883 repo.updatecaches(full=True)
885 884
886 885 return output.getvalue()
887 886
888 887 @reraise_safe_exceptions
889 888 def hg_rebuild_fn_cache(self, wire,):
890 889 repo = self._factory.repo(wire)
891 890 baseui = self._factory._create_config(wire['config'])
892 891 baseui, output = patch_ui_message_output(baseui)
893 892
894 893 repo.ui = baseui
895 894
896 895 repair.rebuildfncache(baseui, repo)
897 896
898 897 return output.getvalue()
899 898
900 899 @reraise_safe_exceptions
901 900 def tags(self, wire):
902 901 cache_on, context_uid, repo_id = self._cache_on(wire)
903 902 region = self._region(wire)
904 903
905 904 @region.conditional_cache_on_arguments(condition=cache_on)
906 905 def _tags(_context_uid, _repo_id):
907 906 repo = self._factory.repo(wire)
908 907 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo.tags().items()}
909 908
910 909 return _tags(context_uid, repo_id)
911 910
912 911 @reraise_safe_exceptions
913 912 def update(self, wire, node='', clean=False):
914 913 repo = self._factory.repo(wire)
915 914 baseui = self._factory._create_config(wire['config'])
916 915 node = safe_bytes(node)
917 916
918 917 commands.update(baseui, repo, node=node, clean=clean)
919 918
920 919 @reraise_safe_exceptions
921 920 def identify(self, wire):
922 921 repo = self._factory.repo(wire)
923 922 baseui = self._factory._create_config(wire['config'])
924 923 output = io.BytesIO()
925 924 baseui.write = output.write
926 925 # This is required to get a full node id
927 926 baseui.debugflag = True
928 927 commands.identify(baseui, repo, id=True)
929 928
930 929 return output.getvalue()
931 930
932 931 @reraise_safe_exceptions
933 932 def heads(self, wire, branch=None):
934 933 repo = self._factory.repo(wire)
935 934 baseui = self._factory._create_config(wire['config'])
936 935 output = io.BytesIO()
937 936
938 937 def write(data, **unused_kwargs):
939 938 output.write(data)
940 939
941 940 baseui.write = write
942 941 if branch:
943 942 args = [safe_bytes(branch)]
944 943 else:
945 944 args = []
946 945 commands.heads(baseui, repo, template=b'{node} ', *args)
947 946
948 947 return output.getvalue()
949 948
950 949 @reraise_safe_exceptions
951 950 def ancestor(self, wire, revision1, revision2):
952 951 repo = self._factory.repo(wire)
953 952 changelog = repo.changelog
954 953 lookup = repo.lookup
955 954 a = changelog.ancestor(lookup(safe_bytes(revision1)), lookup(safe_bytes(revision2)))
956 955 return hex(a)
957 956
958 957 @reraise_safe_exceptions
959 958 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
960 959 baseui = self._factory._create_config(wire["config"], hooks=hooks)
961 960 clone(baseui, safe_bytes(source), safe_bytes(dest), noupdate=not update_after_clone)
962 961
963 962 @reraise_safe_exceptions
964 963 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
965 964
966 965 repo = self._factory.repo(wire)
967 966 baseui = self._factory._create_config(wire['config'])
968 967 publishing = baseui.configbool(b'phases', b'publish')
969 968
970 969 def _filectxfn(_repo, ctx, path: bytes):
971 970 """
972 971 Marks given path as added/changed/removed in a given _repo. This is
973 972 for internal mercurial commit function.
974 973 """
975 974
976 975 # check if this path is removed
977 976 if safe_str(path) in removed:
978 977 # returning None is a way to mark node for removal
979 978 return None
980 979
981 980 # check if this path is added
982 981 for node in updated:
983 982 if safe_bytes(node['path']) == path:
984 983 return memfilectx(
985 984 _repo,
986 985 changectx=ctx,
987 986 path=safe_bytes(node['path']),
988 987 data=safe_bytes(node['content']),
989 988 islink=False,
990 989 isexec=bool(node['mode'] & stat.S_IXUSR),
991 990 copysource=False)
992 991 abort_exc = exceptions.AbortException()
993 992 raise abort_exc(f"Given path haven't been marked as added, changed or removed ({path})")
994 993
995 994 if publishing:
996 995 new_commit_phase = b'public'
997 996 else:
998 997 new_commit_phase = b'draft'
999 998 with repo.ui.configoverride({(b'phases', b'new-commit'): new_commit_phase}):
1000 999 kwargs = {safe_bytes(k): safe_bytes(v) for k, v in extra.items()}
1001 1000 commit_ctx = memctx(
1002 1001 repo=repo,
1003 1002 parents=parents,
1004 1003 text=safe_bytes(message),
1005 1004 files=[safe_bytes(x) for x in files],
1006 1005 filectxfn=_filectxfn,
1007 1006 user=safe_bytes(user),
1008 1007 date=(commit_time, commit_timezone),
1009 1008 extra=kwargs)
1010 1009
1011 1010 n = repo.commitctx(commit_ctx)
1012 1011 new_id = hex(n)
1013 1012
1014 1013 return new_id
1015 1014
1016 1015 @reraise_safe_exceptions
1017 1016 def pull(self, wire, url, commit_ids=None):
1018 1017 repo = self._factory.repo(wire)
1019 1018 # Disable any prompts for this repo
1020 1019 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
1021 1020
1022 1021 remote = peer(repo, {}, safe_bytes(url))
1023 1022 # Disable any prompts for this remote
1024 1023 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
1025 1024
1026 1025 if commit_ids:
1027 1026 commit_ids = [bin(commit_id) for commit_id in commit_ids]
1028 1027
1029 1028 return exchange.pull(
1030 1029 repo, remote, heads=commit_ids, force=None).cgresult
1031 1030
1032 1031 @reraise_safe_exceptions
1033 1032 def pull_cmd(self, wire, source, bookmark='', branch='', revision='', hooks=True):
1034 1033 repo = self._factory.repo(wire)
1035 1034 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1036 1035
1037 1036 source = safe_bytes(source)
1038 1037
1039 1038 # Mercurial internally has a lot of logic that checks ONLY if
1040 1039 # option is defined, we just pass those if they are defined then
1041 1040 opts = {}
1042 1041
1043 1042 if bookmark:
1044 1043 opts['bookmark'] = [safe_bytes(x) for x in bookmark] \
1045 1044 if isinstance(bookmark, list) else safe_bytes(bookmark)
1046 1045
1047 1046 if branch:
1048 1047 opts['branch'] = [safe_bytes(x) for x in branch] \
1049 1048 if isinstance(branch, list) else safe_bytes(branch)
1050 1049
1051 1050 if revision:
1052 1051 opts['rev'] = [safe_bytes(x) for x in revision] \
1053 1052 if isinstance(revision, list) else safe_bytes(revision)
1054 1053
1055 1054 commands.pull(baseui, repo, source, **opts)
1056 1055
1057 1056 @reraise_safe_exceptions
1058 1057 def push(self, wire, revisions, dest_path, hooks: bool = True, push_branches: bool = False):
1059 1058 repo = self._factory.repo(wire)
1060 1059 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1061 1060
1062 1061 revisions = [safe_bytes(x) for x in revisions] \
1063 1062 if isinstance(revisions, list) else safe_bytes(revisions)
1064 1063
1065 1064 commands.push(baseui, repo, safe_bytes(dest_path),
1066 1065 rev=revisions,
1067 1066 new_branch=push_branches)
1068 1067
1069 1068 @reraise_safe_exceptions
1070 1069 def strip(self, wire, revision, update, backup):
1071 1070 repo = self._factory.repo(wire)
1072 1071 ctx = self._get_ctx(repo, revision)
1073 1072 hgext_strip.strip(
1074 1073 repo.baseui, repo, ctx.node(), update=update, backup=backup)
1075 1074
1076 1075 @reraise_safe_exceptions
1077 1076 def get_unresolved_files(self, wire):
1078 1077 repo = self._factory.repo(wire)
1079 1078
1080 1079 log.debug('Calculating unresolved files for repo: %s', repo)
1081 1080 output = io.BytesIO()
1082 1081
1083 1082 def write(data, **unused_kwargs):
1084 1083 output.write(data)
1085 1084
1086 1085 baseui = self._factory._create_config(wire['config'])
1087 1086 baseui.write = write
1088 1087
1089 1088 commands.resolve(baseui, repo, list=True)
1090 1089 unresolved = output.getvalue().splitlines(0)
1091 1090 return unresolved
1092 1091
1093 1092 @reraise_safe_exceptions
1094 1093 def merge(self, wire, revision):
1095 1094 repo = self._factory.repo(wire)
1096 1095 baseui = self._factory._create_config(wire['config'])
1097 1096 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1098 1097
1099 1098 # In case of sub repositories are used mercurial prompts the user in
1100 1099 # case of merge conflicts or different sub repository sources. By
1101 1100 # setting the interactive flag to `False` mercurial doesn't prompt the
1102 1101 # used but instead uses a default value.
1103 1102 repo.ui.setconfig(b'ui', b'interactive', False)
1104 1103 commands.merge(baseui, repo, rev=safe_bytes(revision))
1105 1104
1106 1105 @reraise_safe_exceptions
1107 1106 def merge_state(self, wire):
1108 1107 repo = self._factory.repo(wire)
1109 1108 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1110 1109
1111 1110 # In case of sub repositories are used mercurial prompts the user in
1112 1111 # case of merge conflicts or different sub repository sources. By
1113 1112 # setting the interactive flag to `False` mercurial doesn't prompt the
1114 1113 # used but instead uses a default value.
1115 1114 repo.ui.setconfig(b'ui', b'interactive', False)
1116 1115 ms = hg_merge.mergestate(repo)
1117 1116 return [x for x in ms.unresolved()]
1118 1117
1119 1118 @reraise_safe_exceptions
1120 1119 def commit(self, wire, message, username, close_branch=False):
1121 1120 repo = self._factory.repo(wire)
1122 1121 baseui = self._factory._create_config(wire['config'])
1123 1122 repo.ui.setconfig(b'ui', b'username', safe_bytes(username))
1124 1123 commands.commit(baseui, repo, message=safe_bytes(message), close_branch=close_branch)
1125 1124
1126 1125 @reraise_safe_exceptions
1127 1126 def rebase(self, wire, source='', dest='', abort=False):
1128 1127 repo = self._factory.repo(wire)
1129 1128 baseui = self._factory._create_config(wire['config'])
1130 1129 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1131 1130 # In case of sub repositories are used mercurial prompts the user in
1132 1131 # case of merge conflicts or different sub repository sources. By
1133 1132 # setting the interactive flag to `False` mercurial doesn't prompt the
1134 1133 # used but instead uses a default value.
1135 1134 repo.ui.setconfig(b'ui', b'interactive', False)
1136 1135
1137 1136 rebase.rebase(baseui, repo, base=safe_bytes(source or ''), dest=safe_bytes(dest or ''),
1138 1137 abort=abort, keep=not abort)
1139 1138
1140 1139 @reraise_safe_exceptions
1141 1140 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
1142 1141 repo = self._factory.repo(wire)
1143 1142 ctx = self._get_ctx(repo, revision)
1144 1143 node = ctx.node()
1145 1144
1146 1145 date = (tag_time, tag_timezone)
1147 1146 try:
1148 1147 hg_tag.tag(repo, safe_bytes(name), node, safe_bytes(message), local, safe_bytes(user), date)
1149 1148 except Abort as e:
1150 1149 log.exception("Tag operation aborted")
1151 1150 # Exception can contain unicode which we convert
1152 1151 raise exceptions.AbortException(e)(repr(e))
1153 1152
1154 1153 @reraise_safe_exceptions
1155 1154 def bookmark(self, wire, bookmark, revision=''):
1156 1155 repo = self._factory.repo(wire)
1157 1156 baseui = self._factory._create_config(wire['config'])
1158 1157 revision = revision or ''
1159 1158 commands.bookmark(baseui, repo, safe_bytes(bookmark), rev=safe_bytes(revision), force=True)
1160 1159
1161 1160 @reraise_safe_exceptions
1162 1161 def install_hooks(self, wire, force=False):
1163 1162 # we don't need any special hooks for Mercurial
1164 1163 pass
1165 1164
1166 1165 @reraise_safe_exceptions
1167 1166 def get_hooks_info(self, wire):
1168 1167 return {
1169 1168 'pre_version': vcsserver.__version__,
1170 1169 'post_version': vcsserver.__version__,
1171 1170 }
1172 1171
1173 1172 @reraise_safe_exceptions
1174 1173 def set_head_ref(self, wire, head_name):
1175 1174 pass
1176 1175
1177 1176 @reraise_safe_exceptions
1178 1177 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1179 1178 archive_dir_name, commit_id, cache_config):
1180 1179
1181 1180 def file_walker(_commit_id, path):
1182 1181 repo = self._factory.repo(wire)
1183 1182 ctx = repo[_commit_id]
1184 1183 is_root = path in ['', '/']
1185 1184 if is_root:
1186 1185 matcher = alwaysmatcher(badfn=None)
1187 1186 else:
1188 1187 matcher = patternmatcher('', [(b'glob', safe_bytes(path)+b'/**', b'')], badfn=None)
1189 1188 file_iter = ctx.manifest().walk(matcher)
1190 1189
1191 1190 for fn in file_iter:
1192 1191 file_path = fn
1193 1192 flags = ctx.flags(fn)
1194 1193 mode = b'x' in flags and 0o755 or 0o644
1195 1194 is_link = b'l' in flags
1196 1195
1197 1196 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1198 1197
1199 1198 return store_archive_in_cache(
1200 1199 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
1201 1200
@@ -1,943 +1,942 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18
19 19 import os
20 20 import subprocess
21 21 from urllib.error import URLError
22 22 import urllib.parse
23 23 import logging
24 24 import posixpath as vcspath
25 25 import io
26 26 import urllib.request
27 27 import urllib.parse
28 28 import urllib.error
29 29 import traceback
30 30
31 31
32 32 import svn.client # noqa
33 33 import svn.core # noqa
34 34 import svn.delta # noqa
35 35 import svn.diff # noqa
36 36 import svn.fs # noqa
37 37 import svn.repos # noqa
38 38
39 39 from vcsserver import svn_diff, exceptions, subprocessio, settings
40 40 from vcsserver.base import (
41 41 RepoFactory,
42 42 raise_from_original,
43 43 ArchiveNode,
44 44 store_archive_in_cache,
45 45 BytesEnvelope,
46 46 BinaryEnvelope,
47 47 )
48 48 from vcsserver.exceptions import NoContentException
49 49 from vcsserver.str_utils import safe_str, safe_bytes
50 50 from vcsserver.type_utils import assert_bytes
51 51 from vcsserver.vcs_base import RemoteBase
52 52 from vcsserver.lib.svnremoterepo import svnremoterepo
53 53
54 54 log = logging.getLogger(__name__)
55 55
56 56
57 57 svn_compatible_versions_map = {
58 58 'pre-1.4-compatible': '1.3',
59 59 'pre-1.5-compatible': '1.4',
60 60 'pre-1.6-compatible': '1.5',
61 61 'pre-1.8-compatible': '1.7',
62 62 'pre-1.9-compatible': '1.8',
63 63 }
64 64
65 65 current_compatible_version = '1.14'
66 66
67 67
68 68 def reraise_safe_exceptions(func):
69 69 """Decorator for converting svn exceptions to something neutral."""
70 70 def wrapper(*args, **kwargs):
71 71 try:
72 72 return func(*args, **kwargs)
73 73 except Exception as e:
74 74 if not hasattr(e, '_vcs_kind'):
75 75 log.exception("Unhandled exception in svn remote call")
76 76 raise_from_original(exceptions.UnhandledException(e), e)
77 77 raise
78 78 return wrapper
79 79
80 80
81 81 class SubversionFactory(RepoFactory):
82 82 repo_type = 'svn'
83 83
84 84 def _create_repo(self, wire, create, compatible_version):
85 85 path = svn.core.svn_path_canonicalize(wire['path'])
86 86 if create:
87 87 fs_config = {'compatible-version': current_compatible_version}
88 88 if compatible_version:
89 89
90 90 compatible_version_string = \
91 91 svn_compatible_versions_map.get(compatible_version) \
92 92 or compatible_version
93 93 fs_config['compatible-version'] = compatible_version_string
94 94
95 95 log.debug('Create SVN repo with config `%s`', fs_config)
96 96 repo = svn.repos.create(path, "", "", None, fs_config)
97 97 else:
98 98 repo = svn.repos.open(path)
99 99
100 100 log.debug('repository created: got SVN object: %s', repo)
101 101 return repo
102 102
103 103 def repo(self, wire, create=False, compatible_version=None):
104 104 """
105 105 Get a repository instance for the given path.
106 106 """
107 107 return self._create_repo(wire, create, compatible_version)
108 108
109 109
110 110 NODE_TYPE_MAPPING = {
111 111 svn.core.svn_node_file: 'file',
112 112 svn.core.svn_node_dir: 'dir',
113 113 }
114 114
115 115
116 116 class SvnRemote(RemoteBase):
117 117
118 118 def __init__(self, factory, hg_factory=None):
119 119 self._factory = factory
120 120
121 121 self._bulk_methods = {
122 122 # NOT supported in SVN ATM...
123 123 }
124 124 self._bulk_file_methods = {
125 125 "size": self.get_file_size,
126 126 "data": self.get_file_content,
127 127 "flags": self.get_node_type,
128 128 "is_binary": self.is_binary,
129 129 "md5": self.md5_hash
130 130 }
131 131
132 132 @reraise_safe_exceptions
133 133 def bulk_file_request(self, wire, commit_id, path, pre_load):
134 134 cache_on, context_uid, repo_id = self._cache_on(wire)
135 135 region = self._region(wire)
136 136
137 137 # since we use unified API, we need to cast from str to in for SVN
138 138 commit_id = int(commit_id)
139 139
140 140 @region.conditional_cache_on_arguments(condition=cache_on)
141 141 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
142 142 result = {}
143 143 for attr in pre_load:
144 144 try:
145 145 method = self._bulk_file_methods[attr]
146 146 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
147 147 result[attr] = method(wire, _commit_id, _path)
148 148 except KeyError as e:
149 149 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
150 150 return result
151 151
152 152 return BinaryEnvelope(_bulk_file_request(repo_id, commit_id, path, sorted(pre_load)))
153 153
154 154 @reraise_safe_exceptions
155 155 def discover_svn_version(self):
156 156 try:
157 157 import svn.core
158 158 svn_ver = svn.core.SVN_VERSION
159 159 except ImportError:
160 160 svn_ver = None
161 161 return safe_str(svn_ver)
162 162
163 163 @reraise_safe_exceptions
164 164 def is_empty(self, wire):
165 165 try:
166 166 return self.lookup(wire, -1) == 0
167 167 except Exception:
168 168 log.exception("failed to read object_store")
169 169 return False
170 170
171 171 def check_url(self, url, config):
172 172
173 173 # uuid function gets only valid UUID from proper repo, else
174 174 # throws exception
175 175 username, password, src_url = self.get_url_and_credentials(url)
176 176 try:
177 177 svnremoterepo(safe_bytes(username), safe_bytes(password), safe_bytes(src_url)).svn().uuid
178 178 except Exception:
179 179 tb = traceback.format_exc()
180 180 log.debug("Invalid Subversion url: `%s`, tb: %s", url, tb)
181 181 raise URLError(f'"{url}" is not a valid Subversion source url.')
182 182 return True
183 183
184 184 def is_path_valid_repository(self, wire, path):
185 185
186 186 # NOTE(marcink): short circuit the check for SVN repo
187 187 # the repos.open might be expensive to check, but we have one cheap
188 188 # pre condition that we can use, to check for 'format' file
189 189
190 190 if not os.path.isfile(os.path.join(path, 'format')):
191 191 return False
192 192
193 193 try:
194 194 svn.repos.open(path)
195 195 except svn.core.SubversionException:
196 196 tb = traceback.format_exc()
197 197 log.debug("Invalid Subversion path `%s`, tb: %s", path, tb)
198 198 return False
199 199 return True
200 200
201 201 @reraise_safe_exceptions
202 202 def verify(self, wire,):
203 203 repo_path = wire['path']
204 204 if not self.is_path_valid_repository(wire, repo_path):
205 205 raise Exception(
206 "Path %s is not a valid Subversion repository." % repo_path)
206 f"Path {repo_path} is not a valid Subversion repository.")
207 207
208 208 cmd = ['svnadmin', 'info', repo_path]
209 209 stdout, stderr = subprocessio.run_command(cmd)
210 210 return stdout
211 211
212 212 @reraise_safe_exceptions
213 213 def lookup(self, wire, revision):
214 214 if revision not in [-1, None, 'HEAD']:
215 215 raise NotImplementedError
216 216 repo = self._factory.repo(wire)
217 217 fs_ptr = svn.repos.fs(repo)
218 218 head = svn.fs.youngest_rev(fs_ptr)
219 219 return head
220 220
221 221 @reraise_safe_exceptions
222 222 def lookup_interval(self, wire, start_ts, end_ts):
223 223 repo = self._factory.repo(wire)
224 224 fsobj = svn.repos.fs(repo)
225 225 start_rev = None
226 226 end_rev = None
227 227 if start_ts:
228 228 start_ts_svn = apr_time_t(start_ts)
229 229 start_rev = svn.repos.dated_revision(repo, start_ts_svn) + 1
230 230 else:
231 231 start_rev = 1
232 232 if end_ts:
233 233 end_ts_svn = apr_time_t(end_ts)
234 234 end_rev = svn.repos.dated_revision(repo, end_ts_svn)
235 235 else:
236 236 end_rev = svn.fs.youngest_rev(fsobj)
237 237 return start_rev, end_rev
238 238
239 239 @reraise_safe_exceptions
240 240 def revision_properties(self, wire, revision):
241 241
242 242 cache_on, context_uid, repo_id = self._cache_on(wire)
243 243 region = self._region(wire)
244 244
245 245 @region.conditional_cache_on_arguments(condition=cache_on)
246 246 def _revision_properties(_repo_id, _revision):
247 247 repo = self._factory.repo(wire)
248 248 fs_ptr = svn.repos.fs(repo)
249 249 return svn.fs.revision_proplist(fs_ptr, revision)
250 250 return _revision_properties(repo_id, revision)
251 251
252 252 def revision_changes(self, wire, revision):
253 253
254 254 repo = self._factory.repo(wire)
255 255 fsobj = svn.repos.fs(repo)
256 256 rev_root = svn.fs.revision_root(fsobj, revision)
257 257
258 258 editor = svn.repos.ChangeCollector(fsobj, rev_root)
259 259 editor_ptr, editor_baton = svn.delta.make_editor(editor)
260 260 base_dir = ""
261 261 send_deltas = False
262 262 svn.repos.replay2(
263 263 rev_root, base_dir, svn.core.SVN_INVALID_REVNUM, send_deltas,
264 264 editor_ptr, editor_baton, None)
265 265
266 266 added = []
267 267 changed = []
268 268 removed = []
269 269
270 270 # TODO: CHANGE_ACTION_REPLACE: Figure out where it belongs
271 271 for path, change in editor.changes.items():
272 272 # TODO: Decide what to do with directory nodes. Subversion can add
273 273 # empty directories.
274 274
275 275 if change.item_kind == svn.core.svn_node_dir:
276 276 continue
277 277 if change.action in [svn.repos.CHANGE_ACTION_ADD]:
278 278 added.append(path)
279 279 elif change.action in [svn.repos.CHANGE_ACTION_MODIFY,
280 280 svn.repos.CHANGE_ACTION_REPLACE]:
281 281 changed.append(path)
282 282 elif change.action in [svn.repos.CHANGE_ACTION_DELETE]:
283 283 removed.append(path)
284 284 else:
285 285 raise NotImplementedError(
286 286 "Action {} not supported on path {}".format(
287 287 change.action, path))
288 288
289 289 changes = {
290 290 'added': added,
291 291 'changed': changed,
292 292 'removed': removed,
293 293 }
294 294 return changes
295 295
296 296 @reraise_safe_exceptions
297 297 def node_history(self, wire, path, revision, limit):
298 298 cache_on, context_uid, repo_id = self._cache_on(wire)
299 299 region = self._region(wire)
300 300
301 301 @region.conditional_cache_on_arguments(condition=cache_on)
302 302 def _assert_correct_path(_context_uid, _repo_id, _path, _revision, _limit):
303 303 cross_copies = False
304 304 repo = self._factory.repo(wire)
305 305 fsobj = svn.repos.fs(repo)
306 306 rev_root = svn.fs.revision_root(fsobj, revision)
307 307
308 308 history_revisions = []
309 309 history = svn.fs.node_history(rev_root, path)
310 310 history = svn.fs.history_prev(history, cross_copies)
311 311 while history:
312 312 __, node_revision = svn.fs.history_location(history)
313 313 history_revisions.append(node_revision)
314 314 if limit and len(history_revisions) >= limit:
315 315 break
316 316 history = svn.fs.history_prev(history, cross_copies)
317 317 return history_revisions
318 318 return _assert_correct_path(context_uid, repo_id, path, revision, limit)
319 319
320 320 @reraise_safe_exceptions
321 321 def node_properties(self, wire, path, revision):
322 322 cache_on, context_uid, repo_id = self._cache_on(wire)
323 323 region = self._region(wire)
324 324
325 325 @region.conditional_cache_on_arguments(condition=cache_on)
326 326 def _node_properties(_repo_id, _path, _revision):
327 327 repo = self._factory.repo(wire)
328 328 fsobj = svn.repos.fs(repo)
329 329 rev_root = svn.fs.revision_root(fsobj, revision)
330 330 return svn.fs.node_proplist(rev_root, path)
331 331 return _node_properties(repo_id, path, revision)
332 332
333 333 def file_annotate(self, wire, path, revision):
334 334 abs_path = 'file://' + urllib.request.pathname2url(
335 335 vcspath.join(wire['path'], path))
336 336 file_uri = svn.core.svn_path_canonicalize(abs_path)
337 337
338 338 start_rev = svn_opt_revision_value_t(0)
339 339 peg_rev = svn_opt_revision_value_t(revision)
340 340 end_rev = peg_rev
341 341
342 342 annotations = []
343 343
344 344 def receiver(line_no, revision, author, date, line, pool):
345 345 annotations.append((line_no, revision, line))
346 346
347 347 # TODO: Cannot use blame5, missing typemap function in the swig code
348 348 try:
349 349 svn.client.blame2(
350 350 file_uri, peg_rev, start_rev, end_rev,
351 351 receiver, svn.client.create_context())
352 352 except svn.core.SubversionException as exc:
353 353 log.exception("Error during blame operation.")
354 354 raise Exception(
355 355 f"Blame not supported or file does not exist at path {path}. "
356 356 f"Error {exc}.")
357 357
358 358 return BinaryEnvelope(annotations)
359 359
360 360 @reraise_safe_exceptions
361 361 def get_node_type(self, wire, revision=None, path=''):
362 362
363 363 cache_on, context_uid, repo_id = self._cache_on(wire)
364 364 region = self._region(wire)
365 365
366 366 @region.conditional_cache_on_arguments(condition=cache_on)
367 367 def _get_node_type(_repo_id, _revision, _path):
368 368 repo = self._factory.repo(wire)
369 369 fs_ptr = svn.repos.fs(repo)
370 370 if _revision is None:
371 371 _revision = svn.fs.youngest_rev(fs_ptr)
372 372 root = svn.fs.revision_root(fs_ptr, _revision)
373 373 node = svn.fs.check_path(root, path)
374 374 return NODE_TYPE_MAPPING.get(node, None)
375 375 return _get_node_type(repo_id, revision, path)
376 376
377 377 @reraise_safe_exceptions
378 378 def get_nodes(self, wire, revision=None, path=''):
379 379
380 380 cache_on, context_uid, repo_id = self._cache_on(wire)
381 381 region = self._region(wire)
382 382
383 383 @region.conditional_cache_on_arguments(condition=cache_on)
384 384 def _get_nodes(_repo_id, _path, _revision):
385 385 repo = self._factory.repo(wire)
386 386 fsobj = svn.repos.fs(repo)
387 387 if _revision is None:
388 388 _revision = svn.fs.youngest_rev(fsobj)
389 389 root = svn.fs.revision_root(fsobj, _revision)
390 390 entries = svn.fs.dir_entries(root, path)
391 391 result = []
392 392 for entry_path, entry_info in entries.items():
393 393 result.append(
394 394 (entry_path, NODE_TYPE_MAPPING.get(entry_info.kind, None)))
395 395 return result
396 396 return _get_nodes(repo_id, path, revision)
397 397
398 398 @reraise_safe_exceptions
399 399 def get_file_content(self, wire, rev=None, path=''):
400 400 repo = self._factory.repo(wire)
401 401 fsobj = svn.repos.fs(repo)
402 402
403 403 if rev is None:
404 404 rev = svn.fs.youngest_rev(fsobj)
405 405
406 406 root = svn.fs.revision_root(fsobj, rev)
407 407 content = svn.core.Stream(svn.fs.file_contents(root, path))
408 408 return BytesEnvelope(content.read())
409 409
410 410 @reraise_safe_exceptions
411 411 def get_file_size(self, wire, revision=None, path=''):
412 412
413 413 cache_on, context_uid, repo_id = self._cache_on(wire)
414 414 region = self._region(wire)
415 415
416 416 @region.conditional_cache_on_arguments(condition=cache_on)
417 417 def _get_file_size(_repo_id, _revision, _path):
418 418 repo = self._factory.repo(wire)
419 419 fsobj = svn.repos.fs(repo)
420 420 if _revision is None:
421 421 _revision = svn.fs.youngest_revision(fsobj)
422 422 root = svn.fs.revision_root(fsobj, _revision)
423 423 size = svn.fs.file_length(root, path)
424 424 return size
425 425 return _get_file_size(repo_id, revision, path)
426 426
427 427 def create_repository(self, wire, compatible_version=None):
428 428 log.info('Creating Subversion repository in path "%s"', wire['path'])
429 429 self._factory.repo(wire, create=True,
430 430 compatible_version=compatible_version)
431 431
432 432 def get_url_and_credentials(self, src_url) -> tuple[str, str, str]:
433 433 obj = urllib.parse.urlparse(src_url)
434 434 username = obj.username or ''
435 435 password = obj.password or ''
436 436 return username, password, src_url
437 437
438 438 def import_remote_repository(self, wire, src_url):
439 439 repo_path = wire['path']
440 440 if not self.is_path_valid_repository(wire, repo_path):
441 441 raise Exception(
442 "Path %s is not a valid Subversion repository." % repo_path)
442 f"Path {repo_path} is not a valid Subversion repository.")
443 443
444 444 username, password, src_url = self.get_url_and_credentials(src_url)
445 445 rdump_cmd = ['svnrdump', 'dump', '--non-interactive',
446 446 '--trust-server-cert-failures=unknown-ca']
447 447 if username and password:
448 448 rdump_cmd += ['--username', username, '--password', password]
449 449 rdump_cmd += [src_url]
450 450
451 451 rdump = subprocess.Popen(
452 452 rdump_cmd,
453 453 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
454 454 load = subprocess.Popen(
455 455 ['svnadmin', 'load', repo_path], stdin=rdump.stdout)
456 456
457 457 # TODO: johbo: This can be a very long operation, might be better
458 458 # to track some kind of status and provide an api to check if the
459 459 # import is done.
460 460 rdump.wait()
461 461 load.wait()
462 462
463 463 log.debug('Return process ended with code: %s', rdump.returncode)
464 464 if rdump.returncode != 0:
465 465 errors = rdump.stderr.read()
466 466 log.error('svnrdump dump failed: statuscode %s: message: %s', rdump.returncode, errors)
467 467
468 468 reason = 'UNKNOWN'
469 469 if b'svnrdump: E230001:' in errors:
470 470 reason = 'INVALID_CERTIFICATE'
471 471
472 472 if reason == 'UNKNOWN':
473 473 reason = f'UNKNOWN:{safe_str(errors)}'
474 474
475 475 raise Exception(
476 476 'Failed to dump the remote repository from {}. Reason:{}'.format(
477 477 src_url, reason))
478 478 if load.returncode != 0:
479 479 raise Exception(
480 'Failed to load the dump of remote repository from %s.' %
481 (src_url, ))
480 f'Failed to load the dump of remote repository from {src_url}.')
482 481
483 482 def commit(self, wire, message, author, timestamp, updated, removed):
484 483
485 484 message = safe_bytes(message)
486 485 author = safe_bytes(author)
487 486
488 487 repo = self._factory.repo(wire)
489 488 fsobj = svn.repos.fs(repo)
490 489
491 490 rev = svn.fs.youngest_rev(fsobj)
492 491 txn = svn.repos.fs_begin_txn_for_commit(repo, rev, author, message)
493 492 txn_root = svn.fs.txn_root(txn)
494 493
495 494 for node in updated:
496 495 TxnNodeProcessor(node, txn_root).update()
497 496 for node in removed:
498 497 TxnNodeProcessor(node, txn_root).remove()
499 498
500 499 commit_id = svn.repos.fs_commit_txn(repo, txn)
501 500
502 501 if timestamp:
503 502 apr_time = apr_time_t(timestamp)
504 503 ts_formatted = svn.core.svn_time_to_cstring(apr_time)
505 504 svn.fs.change_rev_prop(fsobj, commit_id, 'svn:date', ts_formatted)
506 505
507 506 log.debug('Committed revision "%s" to "%s".', commit_id, wire['path'])
508 507 return commit_id
509 508
510 509 @reraise_safe_exceptions
511 510 def diff(self, wire, rev1, rev2, path1=None, path2=None,
512 511 ignore_whitespace=False, context=3):
513 512
514 513 wire.update(cache=False)
515 514 repo = self._factory.repo(wire)
516 515 diff_creator = SvnDiffer(
517 516 repo, rev1, path1, rev2, path2, ignore_whitespace, context)
518 517 try:
519 518 return BytesEnvelope(diff_creator.generate_diff())
520 519 except svn.core.SubversionException as e:
521 520 log.exception(
522 521 "Error during diff operation operation. "
523 522 "Path might not exist %s, %s", path1, path2)
524 523 return BytesEnvelope(b'')
525 524
526 525 @reraise_safe_exceptions
527 526 def is_large_file(self, wire, path):
528 527 return False
529 528
530 529 @reraise_safe_exceptions
531 530 def is_binary(self, wire, rev, path):
532 531 cache_on, context_uid, repo_id = self._cache_on(wire)
533 532 region = self._region(wire)
534 533
535 534 @region.conditional_cache_on_arguments(condition=cache_on)
536 535 def _is_binary(_repo_id, _rev, _path):
537 536 raw_bytes = self.get_file_content(wire, rev, path)
538 537 if not raw_bytes:
539 538 return False
540 539 return b'\0' in raw_bytes
541 540
542 541 return _is_binary(repo_id, rev, path)
543 542
544 543 @reraise_safe_exceptions
545 544 def md5_hash(self, wire, rev, path):
546 545 cache_on, context_uid, repo_id = self._cache_on(wire)
547 546 region = self._region(wire)
548 547
549 548 @region.conditional_cache_on_arguments(condition=cache_on)
550 549 def _md5_hash(_repo_id, _rev, _path):
551 550 return ''
552 551
553 552 return _md5_hash(repo_id, rev, path)
554 553
555 554 @reraise_safe_exceptions
556 555 def run_svn_command(self, wire, cmd, **opts):
557 556 path = wire.get('path', None)
558 557
559 558 if path and os.path.isdir(path):
560 559 opts['cwd'] = path
561 560
562 561 safe_call = opts.pop('_safe', False)
563 562
564 563 svnenv = os.environ.copy()
565 564 svnenv.update(opts.pop('extra_env', {}))
566 565
567 566 _opts = {'env': svnenv, 'shell': False}
568 567
569 568 try:
570 569 _opts.update(opts)
571 570 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
572 571
573 572 return b''.join(proc), b''.join(proc.stderr)
574 573 except OSError as err:
575 574 if safe_call:
576 575 return '', safe_str(err).strip()
577 576 else:
578 577 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
579 578 tb_err = ("Couldn't run svn command (%s).\n"
580 579 "Original error was:%s\n"
581 580 "Call options:%s\n"
582 581 % (cmd, err, _opts))
583 582 log.exception(tb_err)
584 583 raise exceptions.VcsException()(tb_err)
585 584
586 585 @reraise_safe_exceptions
587 586 def install_hooks(self, wire, force=False):
588 587 from vcsserver.hook_utils import install_svn_hooks
589 588 repo_path = wire['path']
590 589 binary_dir = settings.BINARY_DIR
591 590 executable = None
592 591 if binary_dir:
593 592 executable = os.path.join(binary_dir, 'python3')
594 593 return install_svn_hooks(repo_path, force_create=force)
595 594
596 595 @reraise_safe_exceptions
597 596 def get_hooks_info(self, wire):
598 597 from vcsserver.hook_utils import (
599 598 get_svn_pre_hook_version, get_svn_post_hook_version)
600 599 repo_path = wire['path']
601 600 return {
602 601 'pre_version': get_svn_pre_hook_version(repo_path),
603 602 'post_version': get_svn_post_hook_version(repo_path),
604 603 }
605 604
606 605 @reraise_safe_exceptions
607 606 def set_head_ref(self, wire, head_name):
608 607 pass
609 608
610 609 @reraise_safe_exceptions
611 610 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
612 611 archive_dir_name, commit_id, cache_config):
613 612
614 613 def walk_tree(root, root_dir, _commit_id):
615 614 """
616 615 Special recursive svn repo walker
617 616 """
618 617 root_dir = safe_bytes(root_dir)
619 618
620 619 filemode_default = 0o100644
621 620 filemode_executable = 0o100755
622 621
623 622 file_iter = svn.fs.dir_entries(root, root_dir)
624 623 for f_name in file_iter:
625 624 f_type = NODE_TYPE_MAPPING.get(file_iter[f_name].kind, None)
626 625
627 626 if f_type == 'dir':
628 627 # return only DIR, and then all entries in that dir
629 628 yield os.path.join(root_dir, f_name), {'mode': filemode_default}, f_type
630 629 new_root = os.path.join(root_dir, f_name)
631 630 yield from walk_tree(root, new_root, _commit_id)
632 631 else:
633 632
634 633 f_path = os.path.join(root_dir, f_name).rstrip(b'/')
635 634 prop_list = svn.fs.node_proplist(root, f_path)
636 635
637 636 f_mode = filemode_default
638 637 if prop_list.get('svn:executable'):
639 638 f_mode = filemode_executable
640 639
641 640 f_is_link = False
642 641 if prop_list.get('svn:special'):
643 642 f_is_link = True
644 643
645 644 data = {
646 645 'is_link': f_is_link,
647 646 'mode': f_mode,
648 647 'content_stream': svn.core.Stream(svn.fs.file_contents(root, f_path)).read
649 648 }
650 649
651 650 yield f_path, data, f_type
652 651
653 652 def file_walker(_commit_id, path):
654 653 repo = self._factory.repo(wire)
655 654 root = svn.fs.revision_root(svn.repos.fs(repo), int(commit_id))
656 655
657 656 def no_content():
658 657 raise NoContentException()
659 658
660 659 for f_name, f_data, f_type in walk_tree(root, path, _commit_id):
661 660 file_path = f_name
662 661
663 662 if f_type == 'dir':
664 663 mode = f_data['mode']
665 664 yield ArchiveNode(file_path, mode, False, no_content)
666 665 else:
667 666 mode = f_data['mode']
668 667 is_link = f_data['is_link']
669 668 data_stream = f_data['content_stream']
670 669 yield ArchiveNode(file_path, mode, is_link, data_stream)
671 670
672 671 return store_archive_in_cache(
673 672 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
674 673
675 674
676 class SvnDiffer(object):
675 class SvnDiffer:
677 676 """
678 677 Utility to create diffs based on difflib and the Subversion api
679 678 """
680 679
681 680 binary_content = False
682 681
683 682 def __init__(
684 683 self, repo, src_rev, src_path, tgt_rev, tgt_path,
685 684 ignore_whitespace, context):
686 685 self.repo = repo
687 686 self.ignore_whitespace = ignore_whitespace
688 687 self.context = context
689 688
690 689 fsobj = svn.repos.fs(repo)
691 690
692 691 self.tgt_rev = tgt_rev
693 692 self.tgt_path = tgt_path or ''
694 693 self.tgt_root = svn.fs.revision_root(fsobj, tgt_rev)
695 694 self.tgt_kind = svn.fs.check_path(self.tgt_root, self.tgt_path)
696 695
697 696 self.src_rev = src_rev
698 697 self.src_path = src_path or self.tgt_path
699 698 self.src_root = svn.fs.revision_root(fsobj, src_rev)
700 699 self.src_kind = svn.fs.check_path(self.src_root, self.src_path)
701 700
702 701 self._validate()
703 702
704 703 def _validate(self):
705 704 if (self.tgt_kind != svn.core.svn_node_none and
706 705 self.src_kind != svn.core.svn_node_none and
707 706 self.src_kind != self.tgt_kind):
708 707 # TODO: johbo: proper error handling
709 708 raise Exception(
710 709 "Source and target are not compatible for diff generation. "
711 710 "Source type: %s, target type: %s" %
712 711 (self.src_kind, self.tgt_kind))
713 712
714 713 def generate_diff(self) -> bytes:
715 714 buf = io.BytesIO()
716 715 if self.tgt_kind == svn.core.svn_node_dir:
717 716 self._generate_dir_diff(buf)
718 717 else:
719 718 self._generate_file_diff(buf)
720 719 return buf.getvalue()
721 720
722 721 def _generate_dir_diff(self, buf: io.BytesIO):
723 722 editor = DiffChangeEditor()
724 723 editor_ptr, editor_baton = svn.delta.make_editor(editor)
725 724 svn.repos.dir_delta2(
726 725 self.src_root,
727 726 self.src_path,
728 727 '', # src_entry
729 728 self.tgt_root,
730 729 self.tgt_path,
731 730 editor_ptr, editor_baton,
732 731 authorization_callback_allow_all,
733 732 False, # text_deltas
734 733 svn.core.svn_depth_infinity, # depth
735 734 False, # entry_props
736 735 False, # ignore_ancestry
737 736 )
738 737
739 738 for path, __, change in sorted(editor.changes):
740 739 self._generate_node_diff(
741 740 buf, change, path, self.tgt_path, path, self.src_path)
742 741
743 742 def _generate_file_diff(self, buf: io.BytesIO):
744 743 change = None
745 744 if self.src_kind == svn.core.svn_node_none:
746 745 change = "add"
747 746 elif self.tgt_kind == svn.core.svn_node_none:
748 747 change = "delete"
749 748 tgt_base, tgt_path = vcspath.split(self.tgt_path)
750 749 src_base, src_path = vcspath.split(self.src_path)
751 750 self._generate_node_diff(
752 751 buf, change, tgt_path, tgt_base, src_path, src_base)
753 752
754 753 def _generate_node_diff(
755 754 self, buf: io.BytesIO, change, tgt_path, tgt_base, src_path, src_base):
756 755
757 756 tgt_path_bytes = safe_bytes(tgt_path)
758 757 tgt_path = safe_str(tgt_path)
759 758
760 759 src_path_bytes = safe_bytes(src_path)
761 760 src_path = safe_str(src_path)
762 761
763 762 if self.src_rev == self.tgt_rev and tgt_base == src_base:
764 763 # makes consistent behaviour with git/hg to return empty diff if
765 764 # we compare same revisions
766 765 return
767 766
768 767 tgt_full_path = vcspath.join(tgt_base, tgt_path)
769 768 src_full_path = vcspath.join(src_base, src_path)
770 769
771 770 self.binary_content = False
772 771 mime_type = self._get_mime_type(tgt_full_path)
773 772
774 773 if mime_type and not mime_type.startswith(b'text'):
775 774 self.binary_content = True
776 775 buf.write(b"=" * 67 + b'\n')
777 776 buf.write(b"Cannot display: file marked as a binary type.\n")
778 777 buf.write(b"svn:mime-type = %s\n" % mime_type)
779 778 buf.write(b"Index: %b\n" % tgt_path_bytes)
780 779 buf.write(b"=" * 67 + b'\n')
781 780 buf.write(b"diff --git a/%b b/%b\n" % (tgt_path_bytes, tgt_path_bytes))
782 781
783 782 if change == 'add':
784 783 # TODO: johbo: SVN is missing a zero here compared to git
785 784 buf.write(b"new file mode 10644\n")
786 785
787 786 # TODO(marcink): intro to binary detection of svn patches
788 787 # if self.binary_content:
789 788 # buf.write(b'GIT binary patch\n')
790 789
791 790 buf.write(b"--- /dev/null\t(revision 0)\n")
792 791 src_lines = []
793 792 else:
794 793 if change == 'delete':
795 794 buf.write(b"deleted file mode 10644\n")
796 795
797 796 # TODO(marcink): intro to binary detection of svn patches
798 797 # if self.binary_content:
799 798 # buf.write('GIT binary patch\n')
800 799
801 800 buf.write(b"--- a/%b\t(revision %d)\n" % (src_path_bytes, self.src_rev))
802 801 src_lines = self._svn_readlines(self.src_root, src_full_path)
803 802
804 803 if change == 'delete':
805 804 buf.write(b"+++ /dev/null\t(revision %d)\n" % self.tgt_rev)
806 805 tgt_lines = []
807 806 else:
808 807 buf.write(b"+++ b/%b\t(revision %d)\n" % (tgt_path_bytes, self.tgt_rev))
809 808 tgt_lines = self._svn_readlines(self.tgt_root, tgt_full_path)
810 809
811 810 # we made our diff header, time to generate the diff content into our buffer
812 811
813 812 if not self.binary_content:
814 813 udiff = svn_diff.unified_diff(
815 814 src_lines, tgt_lines, context=self.context,
816 815 ignore_blank_lines=self.ignore_whitespace,
817 816 ignore_case=False,
818 817 ignore_space_changes=self.ignore_whitespace)
819 818
820 819 buf.writelines(udiff)
821 820
822 821 def _get_mime_type(self, path) -> bytes:
823 822 try:
824 823 mime_type = svn.fs.node_prop(
825 824 self.tgt_root, path, svn.core.SVN_PROP_MIME_TYPE)
826 825 except svn.core.SubversionException:
827 826 mime_type = svn.fs.node_prop(
828 827 self.src_root, path, svn.core.SVN_PROP_MIME_TYPE)
829 828 return mime_type
830 829
831 830 def _svn_readlines(self, fs_root, node_path):
832 831 if self.binary_content:
833 832 return []
834 833 node_kind = svn.fs.check_path(fs_root, node_path)
835 834 if node_kind not in (
836 835 svn.core.svn_node_file, svn.core.svn_node_symlink):
837 836 return []
838 837 content = svn.core.Stream(
839 838 svn.fs.file_contents(fs_root, node_path)).read()
840 839
841 840 return content.splitlines(True)
842 841
843 842
844 843 class DiffChangeEditor(svn.delta.Editor):
845 844 """
846 845 Records changes between two given revisions
847 846 """
848 847
849 848 def __init__(self):
850 849 self.changes = []
851 850
852 851 def delete_entry(self, path, revision, parent_baton, pool=None):
853 852 self.changes.append((path, None, 'delete'))
854 853
855 854 def add_file(
856 855 self, path, parent_baton, copyfrom_path, copyfrom_revision,
857 856 file_pool=None):
858 857 self.changes.append((path, 'file', 'add'))
859 858
860 859 def open_file(self, path, parent_baton, base_revision, file_pool=None):
861 860 self.changes.append((path, 'file', 'change'))
862 861
863 862
864 863 def authorization_callback_allow_all(root, path, pool):
865 864 return True
866 865
867 866
868 class TxnNodeProcessor(object):
867 class TxnNodeProcessor:
869 868 """
870 869 Utility to process the change of one node within a transaction root.
871 870
872 871 It encapsulates the knowledge of how to add, update or remove
873 872 a node for a given transaction root. The purpose is to support the method
874 873 `SvnRemote.commit`.
875 874 """
876 875
877 876 def __init__(self, node, txn_root):
878 877 assert_bytes(node['path'])
879 878
880 879 self.node = node
881 880 self.txn_root = txn_root
882 881
883 882 def update(self):
884 883 self._ensure_parent_dirs()
885 884 self._add_file_if_node_does_not_exist()
886 885 self._update_file_content()
887 886 self._update_file_properties()
888 887
889 888 def remove(self):
890 889 svn.fs.delete(self.txn_root, self.node['path'])
891 890 # TODO: Clean up directory if empty
892 891
893 892 def _ensure_parent_dirs(self):
894 893 curdir = vcspath.dirname(self.node['path'])
895 894 dirs_to_create = []
896 895 while not self._svn_path_exists(curdir):
897 896 dirs_to_create.append(curdir)
898 897 curdir = vcspath.dirname(curdir)
899 898
900 899 for curdir in reversed(dirs_to_create):
901 900 log.debug('Creating missing directory "%s"', curdir)
902 901 svn.fs.make_dir(self.txn_root, curdir)
903 902
904 903 def _svn_path_exists(self, path):
905 904 path_status = svn.fs.check_path(self.txn_root, path)
906 905 return path_status != svn.core.svn_node_none
907 906
908 907 def _add_file_if_node_does_not_exist(self):
909 908 kind = svn.fs.check_path(self.txn_root, self.node['path'])
910 909 if kind == svn.core.svn_node_none:
911 910 svn.fs.make_file(self.txn_root, self.node['path'])
912 911
913 912 def _update_file_content(self):
914 913 assert_bytes(self.node['content'])
915 914
916 915 handler, baton = svn.fs.apply_textdelta(
917 916 self.txn_root, self.node['path'], None, None)
918 917 svn.delta.svn_txdelta_send_string(self.node['content'], handler, baton)
919 918
920 919 def _update_file_properties(self):
921 920 properties = self.node.get('properties', {})
922 921 for key, value in properties.items():
923 922 svn.fs.change_node_prop(
924 923 self.txn_root, self.node['path'], safe_bytes(key), safe_bytes(value))
925 924
926 925
927 926 def apr_time_t(timestamp):
928 927 """
929 928 Convert a Python timestamp into APR timestamp type apr_time_t
930 929 """
931 930 return int(timestamp * 1E6)
932 931
933 932
934 933 def svn_opt_revision_value_t(num):
935 934 """
936 935 Put `num` into a `svn_opt_revision_value_t` structure.
937 936 """
938 937 value = svn.core.svn_opt_revision_value_t()
939 938 value.number = num
940 939 revision = svn.core.svn_opt_revision_t()
941 940 revision.kind = svn.core.svn_opt_revision_number
942 941 revision.value = value
943 942 return revision
@@ -1,34 +1,34 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 from vcsserver import scm_app, wsgi_app_caller
19 19
20 20
21 class GitRemoteWsgi(object):
21 class GitRemoteWsgi:
22 22 def handle(self, environ, input_data, *args, **kwargs):
23 23 app = wsgi_app_caller.WSGIAppCaller(
24 24 scm_app.create_git_wsgi_app(*args, **kwargs))
25 25
26 26 return app.handle(environ, input_data)
27 27
28 28
29 class HgRemoteWsgi(object):
29 class HgRemoteWsgi:
30 30 def handle(self, environ, input_data, *args, **kwargs):
31 31 app = wsgi_app_caller.WSGIAppCaller(
32 32 scm_app.create_hg_wsgi_app(*args, **kwargs))
33 33
34 34 return app.handle(environ, input_data)
@@ -1,242 +1,242 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import logging
20 20 import itertools
21 21
22 22 import mercurial
23 23 import mercurial.error
24 24 import mercurial.wireprotoserver
25 25 import mercurial.hgweb.common
26 26 import mercurial.hgweb.hgweb_mod
27 27 import webob.exc
28 28
29 29 from vcsserver import pygrack, exceptions, settings, git_lfs
30 30 from vcsserver.str_utils import ascii_bytes, safe_bytes
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 # propagated from mercurial documentation
36 36 HG_UI_SECTIONS = [
37 37 'alias', 'auth', 'decode/encode', 'defaults', 'diff', 'email', 'extensions',
38 38 'format', 'merge-patterns', 'merge-tools', 'hooks', 'http_proxy', 'smtp',
39 39 'patch', 'paths', 'profiling', 'server', 'trusted', 'ui', 'web',
40 40 ]
41 41
42 42
43 43 class HgWeb(mercurial.hgweb.hgweb_mod.hgweb):
44 44 """Extension of hgweb that simplifies some functions."""
45 45
46 46 def _get_view(self, repo):
47 47 """Views are not supported."""
48 48 return repo
49 49
50 50 def loadsubweb(self):
51 51 """The result is only used in the templater method which is not used."""
52 52 return None
53 53
54 54 def run(self):
55 55 """Unused function so raise an exception if accidentally called."""
56 56 raise NotImplementedError
57 57
58 58 def templater(self, req):
59 59 """Function used in an unreachable code path.
60 60
61 61 This code is unreachable because we guarantee that the HTTP request,
62 62 corresponds to a Mercurial command. See the is_hg method. So, we are
63 63 never going to get a user-visible url.
64 64 """
65 65 raise NotImplementedError
66 66
67 67 def archivelist(self, nodeid):
68 68 """Unused function so raise an exception if accidentally called."""
69 69 raise NotImplementedError
70 70
71 71 def __call__(self, environ, start_response):
72 72 """Run the WSGI application.
73 73
74 74 This may be called by multiple threads.
75 75 """
76 76 from mercurial.hgweb import request as requestmod
77 77 req = requestmod.parserequestfromenv(environ)
78 78 res = requestmod.wsgiresponse(req, start_response)
79 79 gen = self.run_wsgi(req, res)
80 80
81 81 first_chunk = None
82 82
83 83 try:
84 84 data = next(gen)
85 85
86 86 def first_chunk():
87 87 yield data
88 88 except StopIteration:
89 89 pass
90 90
91 91 if first_chunk:
92 92 return itertools.chain(first_chunk(), gen)
93 93 return gen
94 94
95 95 def _runwsgi(self, req, res, repo):
96 96
97 97 cmd = req.qsparams.get(b'cmd', '')
98 98 if not mercurial.wireprotoserver.iscmd(cmd):
99 99 # NOTE(marcink): for unsupported commands, we return bad request
100 100 # internally from HG
101 101 log.warning('cmd: `%s` is not supported by the mercurial wireprotocol v1', cmd)
102 102 from mercurial.hgweb.common import statusmessage
103 103 res.status = statusmessage(mercurial.hgweb.common.HTTP_BAD_REQUEST)
104 104 res.setbodybytes(b'')
105 105 return res.sendresponse()
106 106
107 107 return super()._runwsgi(req, res, repo)
108 108
109 109
110 110 def make_hg_ui_from_config(repo_config):
111 111 baseui = mercurial.ui.ui()
112 112
113 113 # clean the baseui object
114 114 baseui._ocfg = mercurial.config.config()
115 115 baseui._ucfg = mercurial.config.config()
116 116 baseui._tcfg = mercurial.config.config()
117 117
118 118 for section, option, value in repo_config:
119 119 baseui.setconfig(
120 120 ascii_bytes(section, allow_bytes=True),
121 121 ascii_bytes(option, allow_bytes=True),
122 122 ascii_bytes(value, allow_bytes=True))
123 123
124 124 # make our hgweb quiet so it doesn't print output
125 125 baseui.setconfig(b'ui', b'quiet', b'true')
126 126
127 127 return baseui
128 128
129 129
130 130 def update_hg_ui_from_hgrc(baseui, repo_path):
131 131 path = os.path.join(repo_path, '.hg', 'hgrc')
132 132
133 133 if not os.path.isfile(path):
134 134 log.debug('hgrc file is not present at %s, skipping...', path)
135 135 return
136 136 log.debug('reading hgrc from %s', path)
137 137 cfg = mercurial.config.config()
138 138 cfg.read(ascii_bytes(path))
139 139 for section in HG_UI_SECTIONS:
140 140 for k, v in cfg.items(section):
141 141 log.debug('settings ui from file: [%s] %s=%s', section, k, v)
142 142 baseui.setconfig(
143 143 ascii_bytes(section, allow_bytes=True),
144 144 ascii_bytes(k, allow_bytes=True),
145 145 ascii_bytes(v, allow_bytes=True))
146 146
147 147
148 148 def create_hg_wsgi_app(repo_path, repo_name, config):
149 149 """
150 150 Prepares a WSGI application to handle Mercurial requests.
151 151
152 152 :param config: is a list of 3-item tuples representing a ConfigObject
153 153 (it is the serialized version of the config object).
154 154 """
155 155 log.debug("Creating Mercurial WSGI application")
156 156
157 157 baseui = make_hg_ui_from_config(config)
158 158 update_hg_ui_from_hgrc(baseui, repo_path)
159 159
160 160 try:
161 161 return HgWeb(safe_bytes(repo_path), name=safe_bytes(repo_name), baseui=baseui)
162 162 except mercurial.error.RequirementError as e:
163 163 raise exceptions.RequirementException(e)(e)
164 164
165 165
166 class GitHandler(object):
166 class GitHandler:
167 167 """
168 168 Handler for Git operations like push/pull etc
169 169 """
170 170 def __init__(self, repo_location, repo_name, git_path, update_server_info,
171 171 extras):
172 172 if not os.path.isdir(repo_location):
173 173 raise OSError(repo_location)
174 174 self.content_path = repo_location
175 175 self.repo_name = repo_name
176 176 self.repo_location = repo_location
177 177 self.extras = extras
178 178 self.git_path = git_path
179 179 self.update_server_info = update_server_info
180 180
181 181 def __call__(self, environ, start_response):
182 182 app = webob.exc.HTTPNotFound()
183 183 candidate_paths = (
184 184 self.content_path, os.path.join(self.content_path, '.git'))
185 185
186 186 for content_path in candidate_paths:
187 187 try:
188 188 app = pygrack.GitRepository(
189 189 self.repo_name, content_path, self.git_path,
190 190 self.update_server_info, self.extras)
191 191 break
192 192 except OSError:
193 193 continue
194 194
195 195 return app(environ, start_response)
196 196
197 197
198 198 def create_git_wsgi_app(repo_path, repo_name, config):
199 199 """
200 200 Creates a WSGI application to handle Git requests.
201 201
202 202 :param config: is a dictionary holding the extras.
203 203 """
204 204 git_path = settings.GIT_EXECUTABLE
205 205 update_server_info = config.pop('git_update_server_info')
206 206 app = GitHandler(
207 207 repo_path, repo_name, git_path, update_server_info, config)
208 208
209 209 return app
210 210
211 211
212 class GitLFSHandler(object):
212 class GitLFSHandler:
213 213 """
214 214 Handler for Git LFS operations
215 215 """
216 216
217 217 def __init__(self, repo_location, repo_name, git_path, update_server_info,
218 218 extras):
219 219 if not os.path.isdir(repo_location):
220 220 raise OSError(repo_location)
221 221 self.content_path = repo_location
222 222 self.repo_name = repo_name
223 223 self.repo_location = repo_location
224 224 self.extras = extras
225 225 self.git_path = git_path
226 226 self.update_server_info = update_server_info
227 227
228 228 def get_app(self, git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
229 229 app = git_lfs.create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
230 230 return app
231 231
232 232
233 233 def create_git_lfs_wsgi_app(repo_path, repo_name, config):
234 234 git_path = settings.GIT_EXECUTABLE
235 235 update_server_info = config.pop(b'git_update_server_info')
236 236 git_lfs_enabled = config.pop(b'git_lfs_enabled')
237 237 git_lfs_store_path = config.pop(b'git_lfs_store_path')
238 238 git_lfs_http_scheme = config.pop(b'git_lfs_http_scheme', 'http')
239 239 app = GitLFSHandler(
240 240 repo_path, repo_name, git_path, update_server_info, config)
241 241
242 242 return app.get_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
@@ -1,78 +1,78 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import gc
19 19 import logging
20 20 import os
21 21 import time
22 22
23 23
24 24 log = logging.getLogger(__name__)
25 25
26 26
27 class VcsServer(object):
27 class VcsServer:
28 28 """
29 29 Exposed remote interface of the vcsserver itself.
30 30
31 31 This object can be used to manage the server remotely. Right now the main
32 32 use case is to allow to shut down the server.
33 33 """
34 34
35 35 _shutdown = False
36 36
37 37 def shutdown(self):
38 38 self._shutdown = True
39 39
40 40 def ping(self):
41 41 """
42 42 Utility to probe a server connection.
43 43 """
44 44 log.debug("Received server ping.")
45 45
46 46 def echo(self, data):
47 47 """
48 48 Utility for performance testing.
49 49
50 50 Allows to pass in arbitrary data and will return this data.
51 51 """
52 52 log.debug("Received server echo.")
53 53 return data
54 54
55 55 def sleep(self, seconds):
56 56 """
57 57 Utility to simulate long running server interaction.
58 58 """
59 59 log.debug("Sleeping %s seconds", seconds)
60 60 time.sleep(seconds)
61 61
62 62 def get_pid(self):
63 63 """
64 64 Allows to discover the PID based on a proxy object.
65 65 """
66 66 return os.getpid()
67 67
68 68 def run_gc(self):
69 69 """
70 70 Allows to trigger the garbage collector.
71 71
72 72 Main intention is to support statistics gathering during test runs.
73 73 """
74 74 freed_objects = gc.collect()
75 75 return {
76 76 'freed_objects': freed_objects,
77 77 'garbage': len(gc.garbage),
78 78 }
@@ -1,563 +1,563 b''
1 1 """
2 2 Module provides a class allowing to wrap communication over subprocess.Popen
3 3 input, output, error streams into a meaningfull, non-blocking, concurrent
4 4 stream processor exposing the output data as an iterator fitting to be a
5 5 return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
6 6
7 7 Copyright (c) 2011 Daniel Dotsenko <dotsa[at]hotmail.com>
8 8
9 9 This file is part of git_http_backend.py Project.
10 10
11 11 git_http_backend.py Project is free software: you can redistribute it and/or
12 12 modify it under the terms of the GNU Lesser General Public License as
13 13 published by the Free Software Foundation, either version 2.1 of the License,
14 14 or (at your option) any later version.
15 15
16 16 git_http_backend.py Project is distributed in the hope that it will be useful,
17 17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 19 GNU Lesser General Public License for more details.
20 20
21 21 You should have received a copy of the GNU Lesser General Public License
22 22 along with git_http_backend.py Project.
23 23 If not, see <http://www.gnu.org/licenses/>.
24 24 """
25 25 import os
26 26 import collections
27 27 import logging
28 28 import subprocess
29 29 import threading
30 30
31 31 from vcsserver.str_utils import safe_str
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 36 class StreamFeeder(threading.Thread):
37 37 """
38 38 Normal writing into pipe-like is blocking once the buffer is filled.
39 39 This thread allows a thread to seep data from a file-like into a pipe
40 40 without blocking the main thread.
41 41 We close inpipe once the end of the source stream is reached.
42 42 """
43 43
44 44 def __init__(self, source):
45 45 super().__init__()
46 46 self.daemon = True
47 47 filelike = False
48 48 self.bytes = b''
49 49 if type(source) in (str, bytes, bytearray): # string-like
50 50 self.bytes = bytes(source)
51 51 else: # can be either file pointer or file-like
52 52 if isinstance(source, int): # file pointer it is
53 53 # converting file descriptor (int) stdin into file-like
54 54 source = os.fdopen(source, 'rb', 16384)
55 55 # let's see if source is file-like by now
56 56 filelike = hasattr(source, 'read')
57 57 if not filelike and not self.bytes:
58 58 raise TypeError("StreamFeeder's source object must be a readable "
59 59 "file-like, a file descriptor, or a string-like.")
60 60 self.source = source
61 61 self.readiface, self.writeiface = os.pipe()
62 62
63 63 def run(self):
64 64 writer = self.writeiface
65 65 try:
66 66 if self.bytes:
67 67 os.write(writer, self.bytes)
68 68 else:
69 69 s = self.source
70 70
71 71 while 1:
72 72 _bytes = s.read(4096)
73 73 if not _bytes:
74 74 break
75 75 os.write(writer, _bytes)
76 76
77 77 finally:
78 78 os.close(writer)
79 79
80 80 @property
81 81 def output(self):
82 82 return self.readiface
83 83
84 84
85 85 class InputStreamChunker(threading.Thread):
86 86 def __init__(self, source, target, buffer_size, chunk_size):
87 87
88 88 super().__init__()
89 89
90 90 self.daemon = True # die die die.
91 91
92 92 self.source = source
93 93 self.target = target
94 94 self.chunk_count_max = int(buffer_size / chunk_size) + 1
95 95 self.chunk_size = chunk_size
96 96
97 97 self.data_added = threading.Event()
98 98 self.data_added.clear()
99 99
100 100 self.keep_reading = threading.Event()
101 101 self.keep_reading.set()
102 102
103 103 self.EOF = threading.Event()
104 104 self.EOF.clear()
105 105
106 106 self.go = threading.Event()
107 107 self.go.set()
108 108
109 109 def stop(self):
110 110 self.go.clear()
111 111 self.EOF.set()
112 112 try:
113 113 # this is not proper, but is done to force the reader thread let
114 114 # go of the input because, if successful, .close() will send EOF
115 115 # down the pipe.
116 116 self.source.close()
117 117 except Exception:
118 118 pass
119 119
120 120 def run(self):
121 121 s = self.source
122 122 t = self.target
123 123 cs = self.chunk_size
124 124 chunk_count_max = self.chunk_count_max
125 125 keep_reading = self.keep_reading
126 126 da = self.data_added
127 127 go = self.go
128 128
129 129 try:
130 130 b = s.read(cs)
131 131 except ValueError:
132 132 b = ''
133 133
134 134 timeout_input = 20
135 135 while b and go.is_set():
136 136 if len(t) > chunk_count_max:
137 137 keep_reading.clear()
138 138 keep_reading.wait(timeout_input)
139 139 if len(t) > chunk_count_max + timeout_input:
140 140 log.error("Timed out while waiting for input from subprocess.")
141 141 os._exit(-1) # this will cause the worker to recycle itself
142 142
143 143 t.append(b)
144 144 da.set()
145 145
146 146 try:
147 147 b = s.read(cs)
148 148 except ValueError: # probably "I/O operation on closed file"
149 149 b = ''
150 150
151 151 self.EOF.set()
152 152 da.set() # for cases when done but there was no input.
153 153
154 154
155 class BufferedGenerator(object):
155 class BufferedGenerator:
156 156 """
157 157 Class behaves as a non-blocking, buffered pipe reader.
158 158 Reads chunks of data (through a thread)
159 159 from a blocking pipe, and attaches these to an array (Deque) of chunks.
160 160 Reading is halted in the thread when max chunks is internally buffered.
161 161 The .next() may operate in blocking or non-blocking fashion by yielding
162 162 '' if no data is ready
163 163 to be sent or by not returning until there is some data to send
164 164 When we get EOF from underlying source pipe we raise the marker to raise
165 165 StopIteration after the last chunk of data is yielded.
166 166 """
167 167
168 168 def __init__(self, name, source, buffer_size=65536, chunk_size=4096,
169 169 starting_values=None, bottomless=False):
170 170 starting_values = starting_values or []
171 171 self.name = name
172 172 self.buffer_size = buffer_size
173 173 self.chunk_size = chunk_size
174 174
175 175 if bottomless:
176 176 maxlen = int(buffer_size / chunk_size)
177 177 else:
178 178 maxlen = None
179 179
180 180 self.data_queue = collections.deque(starting_values, maxlen)
181 181 self.worker = InputStreamChunker(source, self.data_queue, buffer_size, chunk_size)
182 182 if starting_values:
183 183 self.worker.data_added.set()
184 184 self.worker.start()
185 185
186 186 ####################
187 187 # Generator's methods
188 188 ####################
189 189 def __str__(self):
190 190 return f'BufferedGenerator(name={self.name} chunk: {self.chunk_size} on buffer: {self.buffer_size})'
191 191
192 192 def __iter__(self):
193 193 return self
194 194
195 195 def __next__(self):
196 196
197 197 while not self.length and not self.worker.EOF.is_set():
198 198 self.worker.data_added.clear()
199 199 self.worker.data_added.wait(0.2)
200 200
201 201 if self.length:
202 202 self.worker.keep_reading.set()
203 203 return bytes(self.data_queue.popleft())
204 204 elif self.worker.EOF.is_set():
205 205 raise StopIteration
206 206
207 207 def throw(self, exc_type, value=None, traceback=None):
208 208 if not self.worker.EOF.is_set():
209 209 raise exc_type(value)
210 210
211 211 def start(self):
212 212 self.worker.start()
213 213
214 214 def stop(self):
215 215 self.worker.stop()
216 216
217 217 def close(self):
218 218 try:
219 219 self.worker.stop()
220 220 self.throw(GeneratorExit)
221 221 except (GeneratorExit, StopIteration):
222 222 pass
223 223
224 224 ####################
225 225 # Threaded reader's infrastructure.
226 226 ####################
227 227 @property
228 228 def input(self):
229 229 return self.worker.w
230 230
231 231 @property
232 232 def data_added_event(self):
233 233 return self.worker.data_added
234 234
235 235 @property
236 236 def data_added(self):
237 237 return self.worker.data_added.is_set()
238 238
239 239 @property
240 240 def reading_paused(self):
241 241 return not self.worker.keep_reading.is_set()
242 242
243 243 @property
244 244 def done_reading_event(self):
245 245 """
246 246 Done_reding does not mean that the iterator's buffer is empty.
247 247 Iterator might have done reading from underlying source, but the read
248 248 chunks might still be available for serving through .next() method.
249 249
250 250 :returns: An Event class instance.
251 251 """
252 252 return self.worker.EOF
253 253
254 254 @property
255 255 def done_reading(self):
256 256 """
257 257 Done_reading does not mean that the iterator's buffer is empty.
258 258 Iterator might have done reading from underlying source, but the read
259 259 chunks might still be available for serving through .next() method.
260 260
261 261 :returns: An Bool value.
262 262 """
263 263 return self.worker.EOF.is_set()
264 264
265 265 @property
266 266 def length(self):
267 267 """
268 268 returns int.
269 269
270 270 This is the length of the queue of chunks, not the length of
271 271 the combined contents in those chunks.
272 272
273 273 __len__() cannot be meaningfully implemented because this
274 274 reader is just flying through a bottomless pit content and
275 275 can only know the length of what it already saw.
276 276
277 277 If __len__() on WSGI server per PEP 3333 returns a value,
278 278 the response's length will be set to that. In order not to
279 279 confuse WSGI PEP3333 servers, we will not implement __len__
280 280 at all.
281 281 """
282 282 return len(self.data_queue)
283 283
284 284 def prepend(self, x):
285 285 self.data_queue.appendleft(x)
286 286
287 287 def append(self, x):
288 288 self.data_queue.append(x)
289 289
290 290 def extend(self, o):
291 291 self.data_queue.extend(o)
292 292
293 293 def __getitem__(self, i):
294 294 return self.data_queue[i]
295 295
296 296
297 class SubprocessIOChunker(object):
297 class SubprocessIOChunker:
298 298 """
299 299 Processor class wrapping handling of subprocess IO.
300 300
301 301 .. important::
302 302
303 303 Watch out for the method `__del__` on this class. If this object
304 304 is deleted, it will kill the subprocess, so avoid to
305 305 return the `output` attribute or usage of it like in the following
306 306 example::
307 307
308 308 # `args` expected to run a program that produces a lot of output
309 309 output = ''.join(SubprocessIOChunker(
310 310 args, shell=False, inputstream=inputstream, env=environ).output)
311 311
312 312 # `output` will not contain all the data, because the __del__ method
313 313 # has already killed the subprocess in this case before all output
314 314 # has been consumed.
315 315
316 316
317 317
318 318 In a way, this is a "communicate()" replacement with a twist.
319 319
320 320 - We are multithreaded. Writing in and reading out, err are all sep threads.
321 321 - We support concurrent (in and out) stream processing.
322 322 - The output is not a stream. It's a queue of read string (bytes, not str)
323 323 chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
324 324 - We are non-blocking in more respects than communicate()
325 325 (reading from subprocess out pauses when internal buffer is full, but
326 326 does not block the parent calling code. On the flip side, reading from
327 327 slow-yielding subprocess may block the iteration until data shows up. This
328 328 does not block the parallel inpipe reading occurring parallel thread.)
329 329
330 330 The purpose of the object is to allow us to wrap subprocess interactions into
331 331 an iterable that can be passed to a WSGI server as the application's return
332 332 value. Because of stream-processing-ability, WSGI does not have to read ALL
333 333 of the subprocess's output and buffer it, before handing it to WSGI server for
334 334 HTTP response. Instead, the class initializer reads just a bit of the stream
335 335 to figure out if error occurred or likely to occur and if not, just hands the
336 336 further iteration over subprocess output to the server for completion of HTTP
337 337 response.
338 338
339 339 The real or perceived subprocess error is trapped and raised as one of
340 340 OSError family of exceptions
341 341
342 342 Example usage:
343 343 # try:
344 344 # answer = SubprocessIOChunker(
345 345 # cmd,
346 346 # input,
347 347 # buffer_size = 65536,
348 348 # chunk_size = 4096
349 349 # )
350 350 # except (OSError) as e:
351 351 # print str(e)
352 352 # raise e
353 353 #
354 354 # return answer
355 355
356 356
357 357 """
358 358
359 359 # TODO: johbo: This is used to make sure that the open end of the PIPE
360 360 # is closed in the end. It would be way better to wrap this into an
361 361 # object, so that it is closed automatically once it is consumed or
362 362 # something similar.
363 363 _close_input_fd = None
364 364
365 365 _closed = False
366 366 _stdout = None
367 367 _stderr = None
368 368
369 369 def __init__(self, cmd, input_stream=None, buffer_size=65536,
370 370 chunk_size=4096, starting_values=None, fail_on_stderr=True,
371 371 fail_on_return_code=True, **kwargs):
372 372 """
373 373 Initializes SubprocessIOChunker
374 374
375 375 :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
376 376 :param input_stream: (Default: None) A file-like, string, or file pointer.
377 377 :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
378 378 :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
379 379 :param starting_values: (Default: []) An array of strings to put in front of output que.
380 380 :param fail_on_stderr: (Default: True) Whether to raise an exception in
381 381 case something is written to stderr.
382 382 :param fail_on_return_code: (Default: True) Whether to raise an
383 383 exception if the return code is not 0.
384 384 """
385 385
386 386 kwargs['shell'] = kwargs.get('shell', True)
387 387
388 388 starting_values = starting_values or []
389 389 if input_stream:
390 390 input_streamer = StreamFeeder(input_stream)
391 391 input_streamer.start()
392 392 input_stream = input_streamer.output
393 393 self._close_input_fd = input_stream
394 394
395 395 self._fail_on_stderr = fail_on_stderr
396 396 self._fail_on_return_code = fail_on_return_code
397 397 self.cmd = cmd
398 398
399 399 _p = subprocess.Popen(cmd, bufsize=-1, stdin=input_stream, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
400 400 **kwargs)
401 401 self.process = _p
402 402
403 403 bg_out = BufferedGenerator('stdout', _p.stdout, buffer_size, chunk_size, starting_values)
404 404 bg_err = BufferedGenerator('stderr', _p.stderr, 10240, 1, bottomless=True)
405 405
406 406 while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
407 407 # doing this until we reach either end of file, or end of buffer.
408 408 bg_out.data_added_event.wait(0.2)
409 409 bg_out.data_added_event.clear()
410 410
411 411 # at this point it's still ambiguous if we are done reading or just full buffer.
412 412 # Either way, if error (returned by ended process, or implied based on
413 413 # presence of stuff in stderr output) we error out.
414 414 # Else, we are happy.
415 415 return_code = _p.poll()
416 416 ret_code_ok = return_code in [None, 0]
417 417 ret_code_fail = return_code is not None and return_code != 0
418 418 if (
419 419 (ret_code_fail and fail_on_return_code) or
420 420 (ret_code_ok and fail_on_stderr and bg_err.length)
421 421 ):
422 422
423 423 try:
424 424 _p.terminate()
425 425 except Exception:
426 426 pass
427 427
428 428 bg_out.stop()
429 429 out = b''.join(bg_out)
430 430 self._stdout = out
431 431
432 432 bg_err.stop()
433 433 err = b''.join(bg_err)
434 434 self._stderr = err
435 435
436 436 # code from https://github.com/schacon/grack/pull/7
437 437 if err.strip() == b'fatal: The remote end hung up unexpectedly' and out.startswith(b'0034shallow '):
438 438 bg_out = iter([out])
439 439 _p = None
440 440 elif err and fail_on_stderr:
441 441 text_err = err.decode()
442 442 raise OSError(
443 443 f"Subprocess exited due to an error:\n{text_err}")
444 444
445 445 if ret_code_fail and fail_on_return_code:
446 446 text_err = err.decode()
447 447 if not err:
448 448 # maybe get empty stderr, try stdout instead
449 449 # in many cases git reports the errors on stdout too
450 450 text_err = out.decode()
451 451 raise OSError(
452 452 f"Subprocess exited with non 0 ret code:{return_code}: stderr:{text_err}")
453 453
454 454 self.stdout = bg_out
455 455 self.stderr = bg_err
456 456 self.inputstream = input_stream
457 457
458 458 def __str__(self):
459 459 proc = getattr(self, 'process', 'NO_PROCESS')
460 460 return f'SubprocessIOChunker: {proc}'
461 461
462 462 def __iter__(self):
463 463 return self
464 464
465 465 def __next__(self):
466 466 # Note: mikhail: We need to be sure that we are checking the return
467 467 # code after the stdout stream is closed. Some processes, e.g. git
468 468 # are doing some magic in between closing stdout and terminating the
469 469 # process and, as a result, we are not getting return code on "slow"
470 470 # systems.
471 471 result = None
472 472 stop_iteration = None
473 473 try:
474 474 result = next(self.stdout)
475 475 except StopIteration as e:
476 476 stop_iteration = e
477 477
478 478 if self.process:
479 479 return_code = self.process.poll()
480 480 ret_code_fail = return_code is not None and return_code != 0
481 481 if ret_code_fail and self._fail_on_return_code:
482 482 self.stop_streams()
483 483 err = self.get_stderr()
484 484 raise OSError(
485 485 f"Subprocess exited (exit_code:{return_code}) due to an error during iteration:\n{err}")
486 486
487 487 if stop_iteration:
488 488 raise stop_iteration
489 489 return result
490 490
491 491 def throw(self, exc_type, value=None, traceback=None):
492 492 if self.stdout.length or not self.stdout.done_reading:
493 493 raise exc_type(value)
494 494
495 495 def close(self):
496 496 if self._closed:
497 497 return
498 498
499 499 try:
500 500 self.process.terminate()
501 501 except Exception:
502 502 pass
503 503 if self._close_input_fd:
504 504 os.close(self._close_input_fd)
505 505 try:
506 506 self.stdout.close()
507 507 except Exception:
508 508 pass
509 509 try:
510 510 self.stderr.close()
511 511 except Exception:
512 512 pass
513 513 try:
514 514 os.close(self.inputstream)
515 515 except Exception:
516 516 pass
517 517
518 518 self._closed = True
519 519
520 520 def stop_streams(self):
521 521 getattr(self.stdout, 'stop', lambda: None)()
522 522 getattr(self.stderr, 'stop', lambda: None)()
523 523
524 524 def get_stdout(self):
525 525 if self._stdout:
526 526 return self._stdout
527 527 else:
528 528 return b''.join(self.stdout)
529 529
530 530 def get_stderr(self):
531 531 if self._stderr:
532 532 return self._stderr
533 533 else:
534 534 return b''.join(self.stderr)
535 535
536 536
537 537 def run_command(arguments, env=None):
538 538 """
539 539 Run the specified command and return the stdout.
540 540
541 541 :param arguments: sequence of program arguments (including the program name)
542 542 :type arguments: list[str]
543 543 """
544 544
545 545 cmd = arguments
546 546 log.debug('Running subprocessio command %s', cmd)
547 547 proc = None
548 548 try:
549 549 _opts = {'shell': False, 'fail_on_stderr': False}
550 550 if env:
551 551 _opts.update({'env': env})
552 552 proc = SubprocessIOChunker(cmd, **_opts)
553 553 return b''.join(proc), b''.join(proc.stderr)
554 554 except OSError as err:
555 555 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
556 556 tb_err = ("Couldn't run subprocessio command (%s).\n"
557 557 "Original error was:%s\n" % (cmd, err))
558 558 log.exception(tb_err)
559 559 raise Exception(tb_err)
560 560 finally:
561 561 if proc:
562 562 proc.close()
563 563
@@ -1,85 +1,85 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import shutil
20 20 import tempfile
21 21 import configparser
22 22
23 23
24 class ContextINI(object):
24 class ContextINI:
25 25 """
26 26 Allows to create a new test.ini file as a copy of existing one with edited
27 27 data. If existing file is not present, it creates a new one. Example usage::
28 28
29 29 with TestINI('test.ini', [{'section': {'key': 'val'}}]) as new_test_ini_path:
30 30 print 'vcsserver --config=%s' % new_test_ini
31 31 """
32 32
33 33 def __init__(self, ini_file_path, ini_params, new_file_prefix=None,
34 34 destroy=True):
35 35 self.ini_file_path = ini_file_path
36 36 self.ini_params = ini_params
37 37 self.new_path = None
38 38 self.new_path_prefix = new_file_prefix or 'test'
39 39 self.destroy = destroy
40 40
41 41 def __enter__(self):
42 42 _, pref = tempfile.mkstemp()
43 43 loc = tempfile.gettempdir()
44 44 self.new_path = os.path.join(loc, '{}_{}_{}'.format(
45 45 pref, self.new_path_prefix, self.ini_file_path))
46 46
47 47 # copy ini file and modify according to the params, if we re-use a file
48 48 if os.path.isfile(self.ini_file_path):
49 49 shutil.copy(self.ini_file_path, self.new_path)
50 50 else:
51 51 # create new dump file for configObj to write to.
52 52 with open(self.new_path, 'wb'):
53 53 pass
54 54
55 55 parser = configparser.ConfigParser()
56 56 parser.read(self.ini_file_path)
57 57
58 58 for data in self.ini_params:
59 59 section, ini_params = list(data.items())[0]
60 60 key, val = list(ini_params.items())[0]
61 61 if section not in parser:
62 62 parser[section] = {}
63 63 parser[section][key] = val
64 64 with open(self.ini_file_path, 'w') as f:
65 65 parser.write(f)
66 66 return self.new_path
67 67
68 68 def __exit__(self, exc_type, exc_val, exc_tb):
69 69 if self.destroy:
70 70 os.remove(self.new_path)
71 71
72 72
73 73 def no_newline_id_generator(test_name):
74 74 """
75 75 Generates a test name without spaces or newlines characters. Used for
76 76 nicer output of progress of test
77 77 """
78 78 org_name = test_name
79 79 test_name = str(test_name)\
80 80 .replace('\n', '_N') \
81 81 .replace('\r', '_N') \
82 82 .replace('\t', '_T') \
83 83 .replace(' ', '_S')
84 84
85 85 return test_name or 'test-with-empty-name'
@@ -1,162 +1,162 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import inspect
19 19
20 20 import pytest
21 21 import dulwich.errors
22 22 from mock import Mock, patch
23 23
24 24 from vcsserver.remote import git_remote
25 25
26 26 SAMPLE_REFS = {
27 27 'HEAD': 'fd627b9e0dd80b47be81af07c4a98518244ed2f7',
28 28 'refs/tags/v0.1.9': '341d28f0eec5ddf0b6b77871e13c2bbd6bec685c',
29 29 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
30 30 'refs/tags/v0.1.1': 'e6ea6d16e2f26250124a1f4b4fe37a912f9d86a0',
31 31 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
32 32 }
33 33
34 34
35 35 @pytest.fixture
36 36 def git_remote_fix():
37 37 """
38 38 A GitRemote instance with a mock factory.
39 39 """
40 40 factory = Mock()
41 41 remote = git_remote.GitRemote(factory)
42 42 return remote
43 43
44 44
45 45 def test_discover_git_version(git_remote_fix):
46 46 version = git_remote_fix.discover_git_version()
47 47 assert version
48 48
49 49
50 class TestGitFetch(object):
50 class TestGitFetch:
51 51 def setup_method(self):
52 52 self.mock_repo = Mock()
53 53 factory = Mock()
54 54 factory.repo = Mock(return_value=self.mock_repo)
55 55 self.remote_git = git_remote.GitRemote(factory)
56 56
57 57 def test_fetches_all_when_no_commit_ids_specified(self):
58 58 def side_effect(determine_wants, *args, **kwargs):
59 59 determine_wants(SAMPLE_REFS)
60 60
61 61 with patch('dulwich.client.LocalGitClient.fetch') as mock_fetch:
62 62 mock_fetch.side_effect = side_effect
63 63 self.remote_git.pull(wire={}, url='/tmp/', apply_refs=False)
64 64 determine_wants = self.mock_repo.object_store.determine_wants_all
65 65 determine_wants.assert_called_once_with(SAMPLE_REFS)
66 66
67 67 def test_fetches_specified_commits(self):
68 68 selected_refs = {
69 69 'refs/tags/v0.1.8': b'74ebce002c088b8a5ecf40073db09375515ecd68',
70 70 'refs/tags/v0.1.3': b'5a3a8fb005554692b16e21dee62bf02667d8dc3e',
71 71 }
72 72
73 73 def side_effect(determine_wants, *args, **kwargs):
74 74 result = determine_wants(SAMPLE_REFS)
75 75 assert sorted(result) == sorted(selected_refs.values())
76 76 return result
77 77
78 78 with patch('dulwich.client.LocalGitClient.fetch') as mock_fetch:
79 79 mock_fetch.side_effect = side_effect
80 80 self.remote_git.pull(
81 81 wire={}, url='/tmp/', apply_refs=False,
82 82 refs=list(selected_refs.keys()))
83 83 determine_wants = self.mock_repo.object_store.determine_wants_all
84 84 assert determine_wants.call_count == 0
85 85
86 86 def test_get_remote_refs(self):
87 87 factory = Mock()
88 88 remote_git = git_remote.GitRemote(factory)
89 89 url = 'https://example.com/test/test.git'
90 90 sample_refs = {
91 91 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
92 92 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
93 93 }
94 94
95 95 with patch('vcsserver.remote.git_remote.Repo', create=False) as mock_repo:
96 96 mock_repo().get_refs.return_value = sample_refs
97 97 remote_refs = remote_git.get_remote_refs(wire={}, url=url)
98 98 mock_repo().get_refs.assert_called_once_with()
99 99 assert remote_refs == sample_refs
100 100
101 101
102 class TestReraiseSafeExceptions(object):
102 class TestReraiseSafeExceptions:
103 103
104 104 def test_method_decorated_with_reraise_safe_exceptions(self):
105 105 factory = Mock()
106 106 git_remote_instance = git_remote.GitRemote(factory)
107 107
108 108 def fake_function():
109 109 return None
110 110
111 111 decorator = git_remote.reraise_safe_exceptions(fake_function)
112 112
113 113 methods = inspect.getmembers(git_remote_instance, predicate=inspect.ismethod)
114 114 for method_name, method in methods:
115 115 if not method_name.startswith('_') and method_name not in ['vcsserver_invalidate_cache']:
116 116 assert method.__func__.__code__ == decorator.__code__
117 117
118 118 @pytest.mark.parametrize('side_effect, expected_type', [
119 119 (dulwich.errors.ChecksumMismatch('0000000', 'deadbeef'), 'lookup'),
120 120 (dulwich.errors.NotCommitError('deadbeef'), 'lookup'),
121 121 (dulwich.errors.MissingCommitError('deadbeef'), 'lookup'),
122 122 (dulwich.errors.ObjectMissing('deadbeef'), 'lookup'),
123 123 (dulwich.errors.HangupException(), 'error'),
124 124 (dulwich.errors.UnexpectedCommandError('test-cmd'), 'error'),
125 125 ])
126 126 def test_safe_exceptions_reraised(self, side_effect, expected_type):
127 127 @git_remote.reraise_safe_exceptions
128 128 def fake_method():
129 129 raise side_effect
130 130
131 131 with pytest.raises(Exception) as exc_info:
132 132 fake_method()
133 133 assert type(exc_info.value) == Exception
134 134 assert exc_info.value._vcs_kind == expected_type
135 135
136 136
137 class TestDulwichRepoWrapper(object):
137 class TestDulwichRepoWrapper:
138 138 def test_calls_close_on_delete(self):
139 139 isdir_patcher = patch('dulwich.repo.os.path.isdir', return_value=True)
140 140 with patch.object(git_remote.Repo, 'close') as close_mock:
141 141 with isdir_patcher:
142 142 repo = git_remote.Repo('/tmp/abcde')
143 143 assert repo is not None
144 144 repo.__del__()
145 145 # can't use del repo as in python3 this isn't always calling .__del__()
146 146
147 147 close_mock.assert_called_once_with()
148 148
149 149
150 class TestGitFactory(object):
150 class TestGitFactory:
151 151 def test_create_repo_returns_dulwich_wrapper(self):
152 152
153 153 with patch('vcsserver.lib.rc_cache.region_meta.dogpile_cache_regions') as mock:
154 154 mock.side_effect = {'repo_objects': ''}
155 155 factory = git_remote.GitFactory()
156 156 wire = {
157 157 'path': '/tmp/abcde'
158 158 }
159 159 isdir_patcher = patch('dulwich.repo.os.path.isdir', return_value=True)
160 160 with isdir_patcher:
161 161 result = factory._create_repo(wire, True)
162 162 assert isinstance(result, git_remote.Repo)
@@ -1,112 +1,112 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import inspect
19 19 import sys
20 20 import traceback
21 21
22 22 import pytest
23 23 from mercurial.error import LookupError
24 24 from mock import Mock, patch
25 25
26 26 from vcsserver import exceptions, hgcompat
27 27 from vcsserver.remote import hg_remote
28 28
29 29
30 class TestDiff(object):
30 class TestDiff:
31 31 def test_raising_safe_exception_when_lookup_failed(self):
32 32
33 33 factory = Mock()
34 34 hg_remote_instance = hg_remote.HgRemote(factory)
35 35 with patch('mercurial.patch.diff') as diff_mock:
36 36 diff_mock.side_effect = LookupError(b'deadbeef', b'index', b'message')
37 37
38 38 with pytest.raises(Exception) as exc_info:
39 39 hg_remote_instance.diff(
40 40 wire={}, commit_id_1='deadbeef', commit_id_2='deadbee1',
41 41 file_filter=None, opt_git=True, opt_ignorews=True,
42 42 context=3)
43 43 assert type(exc_info.value) == Exception
44 44 assert exc_info.value._vcs_kind == 'lookup'
45 45
46 46
47 class TestReraiseSafeExceptions(object):
47 class TestReraiseSafeExceptions:
48 48 original_traceback = None
49 49
50 50 def test_method_decorated_with_reraise_safe_exceptions(self):
51 51 factory = Mock()
52 52 hg_remote_instance = hg_remote.HgRemote(factory)
53 53 methods = inspect.getmembers(hg_remote_instance, predicate=inspect.ismethod)
54 54 decorator = hg_remote.reraise_safe_exceptions(None)
55 55 for method_name, method in methods:
56 56 if not method_name.startswith('_') and method_name not in ['vcsserver_invalidate_cache']:
57 57 assert method.__func__.__code__ == decorator.__code__
58 58
59 59 @pytest.mark.parametrize('side_effect, expected_type', [
60 60 (hgcompat.Abort(b'failed-abort'), 'abort'),
61 61 (hgcompat.InterventionRequired(b'intervention-required'), 'abort'),
62 62 (hgcompat.RepoLookupError(), 'lookup'),
63 63 (hgcompat.LookupError(b'deadbeef', b'index', b'message'), 'lookup'),
64 64 (hgcompat.RepoError(), 'error'),
65 65 (hgcompat.RequirementError(), 'requirement'),
66 66 ])
67 67 def test_safe_exceptions_reraised(self, side_effect, expected_type):
68 68 @hg_remote.reraise_safe_exceptions
69 69 def fake_method():
70 70 raise side_effect
71 71
72 72 with pytest.raises(Exception) as exc_info:
73 73 fake_method()
74 74 assert type(exc_info.value) == Exception
75 75 assert exc_info.value._vcs_kind == expected_type
76 76
77 77 def test_keeps_original_traceback(self):
78 78
79 79 @hg_remote.reraise_safe_exceptions
80 80 def fake_method():
81 81 try:
82 82 raise hgcompat.Abort(b'test-abort')
83 83 except:
84 84 self.original_traceback = traceback.format_tb(sys.exc_info()[2])
85 85 raise
86 86
87 87 new_traceback = None
88 88 try:
89 89 fake_method()
90 90 except Exception:
91 91 new_traceback = traceback.format_tb(sys.exc_info()[2])
92 92
93 93 new_traceback_tail = new_traceback[-len(self.original_traceback):]
94 94 assert new_traceback_tail == self.original_traceback
95 95
96 96 def test_maps_unknown_exceptions_to_unhandled(self):
97 97 @hg_remote.reraise_safe_exceptions
98 98 def stub_method():
99 99 raise ValueError('stub')
100 100
101 101 with pytest.raises(Exception) as exc_info:
102 102 stub_method()
103 103 assert exc_info.value._vcs_kind == 'unhandled'
104 104
105 105 def test_does_not_map_known_exceptions(self):
106 106 @hg_remote.reraise_safe_exceptions
107 107 def stub_method():
108 108 raise exceptions.LookupException()('stub')
109 109
110 110 with pytest.raises(Exception) as exc_info:
111 111 stub_method()
112 112 assert exc_info.value._vcs_kind == 'lookup'
@@ -1,286 +1,286 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import threading
19 19 import msgpack
20 20
21 21 from http.server import BaseHTTPRequestHandler
22 22 from socketserver import TCPServer
23 23
24 24 import mercurial.ui
25 25 import mock
26 26 import pytest
27 27
28 28 from vcsserver.hooks import HooksHttpClient
29 29 from vcsserver.lib.rc_json import json
30 30 from vcsserver import hooks
31 31
32 32
33 33 def get_hg_ui(extras=None):
34 34 """Create a Config object with a valid RC_SCM_DATA entry."""
35 35 extras = extras or {}
36 36 required_extras = {
37 37 'username': '',
38 38 'repository': '',
39 39 'locked_by': '',
40 40 'scm': '',
41 41 'make_lock': '',
42 42 'action': '',
43 43 'ip': '',
44 44 'hooks_uri': 'fake_hooks_uri',
45 45 }
46 46 required_extras.update(extras)
47 47 hg_ui = mercurial.ui.ui()
48 48 hg_ui.setconfig(b'rhodecode', b'RC_SCM_DATA', json.dumps(required_extras))
49 49
50 50 return hg_ui
51 51
52 52
53 53 def test_git_pre_receive_is_disabled():
54 54 extras = {'hooks': ['pull']}
55 55 response = hooks.git_pre_receive(None, None,
56 56 {'RC_SCM_DATA': json.dumps(extras)})
57 57
58 58 assert response == 0
59 59
60 60
61 61 def test_git_post_receive_is_disabled():
62 62 extras = {'hooks': ['pull']}
63 63 response = hooks.git_post_receive(None, '',
64 64 {'RC_SCM_DATA': json.dumps(extras)})
65 65
66 66 assert response == 0
67 67
68 68
69 69 def test_git_post_receive_calls_repo_size():
70 70 extras = {'hooks': ['push', 'repo_size']}
71 71
72 72 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
73 73 hooks.git_post_receive(
74 74 None, '', {'RC_SCM_DATA': json.dumps(extras)})
75 75 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
76 76 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
77 77 expected_calls = [
78 78 mock.call('repo_size', extras, mock.ANY),
79 79 mock.call('post_push', extras, mock.ANY),
80 80 ]
81 81 assert call_hook_mock.call_args_list == expected_calls
82 82
83 83
84 84 def test_git_post_receive_does_not_call_disabled_repo_size():
85 85 extras = {'hooks': ['push']}
86 86
87 87 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
88 88 hooks.git_post_receive(
89 89 None, '', {'RC_SCM_DATA': json.dumps(extras)})
90 90 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
91 91 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
92 92 expected_calls = [
93 93 mock.call('post_push', extras, mock.ANY)
94 94 ]
95 95 assert call_hook_mock.call_args_list == expected_calls
96 96
97 97
98 98 def test_repo_size_exception_does_not_affect_git_post_receive():
99 99 extras = {'hooks': ['push', 'repo_size']}
100 100 status = 0
101 101
102 102 def side_effect(name, *args, **kwargs):
103 103 if name == 'repo_size':
104 104 raise Exception('Fake exception')
105 105 else:
106 106 return status
107 107
108 108 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
109 109 call_hook_mock.side_effect = side_effect
110 110 result = hooks.git_post_receive(
111 111 None, '', {'RC_SCM_DATA': json.dumps(extras)})
112 112 assert result == status
113 113
114 114
115 115 def test_git_pre_pull_is_disabled():
116 116 assert hooks.git_pre_pull({'hooks': ['push']}) == hooks.HookResponse(0, '')
117 117
118 118
119 119 def test_git_post_pull_is_disabled():
120 120 assert (
121 121 hooks.git_post_pull({'hooks': ['push']}) == hooks.HookResponse(0, ''))
122 122
123 123
124 class TestGetHooksClient(object):
124 class TestGetHooksClient:
125 125
126 126 def test_returns_http_client_when_protocol_matches(self):
127 127 hooks_uri = 'localhost:8000'
128 128 result = hooks._get_hooks_client({
129 129 'hooks_uri': hooks_uri,
130 130 'hooks_protocol': 'http'
131 131 })
132 132 assert isinstance(result, hooks.HooksHttpClient)
133 133 assert result.hooks_uri == hooks_uri
134 134
135 135 def test_returns_dummy_client_when_hooks_uri_not_specified(self):
136 136 fake_module = mock.Mock()
137 137 import_patcher = mock.patch.object(
138 138 hooks.importlib, 'import_module', return_value=fake_module)
139 139 fake_module_name = 'fake.module'
140 140 with import_patcher as import_mock:
141 141 result = hooks._get_hooks_client(
142 142 {'hooks_module': fake_module_name})
143 143
144 144 import_mock.assert_called_once_with(fake_module_name)
145 145 assert isinstance(result, hooks.HooksDummyClient)
146 146 assert result._hooks_module == fake_module
147 147
148 148
149 class TestHooksHttpClient(object):
149 class TestHooksHttpClient:
150 150 def test_init_sets_hooks_uri(self):
151 151 uri = 'localhost:3000'
152 152 client = hooks.HooksHttpClient(uri)
153 153 assert client.hooks_uri == uri
154 154
155 155 def test_serialize_returns_serialized_string(self):
156 156 client = hooks.HooksHttpClient('localhost:3000')
157 157 hook_name = 'test'
158 158 extras = {
159 159 'first': 1,
160 160 'second': 'two'
161 161 }
162 162 hooks_proto, result = client._serialize(hook_name, extras)
163 163 expected_result = msgpack.packb({
164 164 'method': hook_name,
165 165 'extras': extras,
166 166 })
167 167 assert hooks_proto == {'rc-hooks-protocol': 'msgpack.v1', 'Connection': 'keep-alive'}
168 168 assert result == expected_result
169 169
170 170 def test_call_queries_http_server(self, http_mirror):
171 171 client = hooks.HooksHttpClient(http_mirror.uri)
172 172 hook_name = 'test'
173 173 extras = {
174 174 'first': 1,
175 175 'second': 'two'
176 176 }
177 177 result = client(hook_name, extras)
178 178 expected_result = msgpack.unpackb(msgpack.packb({
179 179 'method': hook_name,
180 180 'extras': extras
181 181 }), raw=False)
182 182 assert result == expected_result
183 183
184 184
185 class TestHooksDummyClient(object):
185 class TestHooksDummyClient:
186 186 def test_init_imports_hooks_module(self):
187 187 hooks_module_name = 'rhodecode.fake.module'
188 188 hooks_module = mock.MagicMock()
189 189
190 190 import_patcher = mock.patch.object(
191 191 hooks.importlib, 'import_module', return_value=hooks_module)
192 192 with import_patcher as import_mock:
193 193 client = hooks.HooksDummyClient(hooks_module_name)
194 194 import_mock.assert_called_once_with(hooks_module_name)
195 195 assert client._hooks_module == hooks_module
196 196
197 197 def test_call_returns_hook_result(self):
198 198 hooks_module_name = 'rhodecode.fake.module'
199 199 hooks_module = mock.MagicMock()
200 200 import_patcher = mock.patch.object(
201 201 hooks.importlib, 'import_module', return_value=hooks_module)
202 202 with import_patcher:
203 203 client = hooks.HooksDummyClient(hooks_module_name)
204 204
205 205 result = client('post_push', {})
206 206 hooks_module.Hooks.assert_called_once_with()
207 207 assert result == hooks_module.Hooks().__enter__().post_push()
208 208
209 209
210 210 @pytest.fixture
211 211 def http_mirror(request):
212 212 server = MirrorHttpServer()
213 213 request.addfinalizer(server.stop)
214 214 return server
215 215
216 216
217 217 class MirrorHttpHandler(BaseHTTPRequestHandler):
218 218
219 219 def do_POST(self):
220 220 length = int(self.headers['Content-Length'])
221 221 body = self.rfile.read(length)
222 222 self.send_response(200)
223 223 self.end_headers()
224 224 self.wfile.write(body)
225 225
226 226
227 class MirrorHttpServer(object):
227 class MirrorHttpServer:
228 228 ip_address = '127.0.0.1'
229 229 port = 0
230 230
231 231 def __init__(self):
232 232 self._daemon = TCPServer((self.ip_address, 0), MirrorHttpHandler)
233 233 _, self.port = self._daemon.server_address
234 234 self._thread = threading.Thread(target=self._daemon.serve_forever)
235 235 self._thread.daemon = True
236 236 self._thread.start()
237 237
238 238 def stop(self):
239 239 self._daemon.shutdown()
240 240 self._thread.join()
241 241 self._daemon = None
242 242 self._thread = None
243 243
244 244 @property
245 245 def uri(self):
246 246 return '{}:{}'.format(self.ip_address, self.port)
247 247
248 248
249 249 def test_hooks_http_client_init():
250 250 hooks_uri = 'http://localhost:8000'
251 251 client = HooksHttpClient(hooks_uri)
252 252 assert client.hooks_uri == hooks_uri
253 253
254 254
255 255 def test_hooks_http_client_call():
256 256 hooks_uri = 'http://localhost:8000'
257 257
258 258 method = 'test_method'
259 259 extras = {'key': 'value'}
260 260
261 261 with \
262 262 mock.patch('http.client.HTTPConnection') as mock_connection,\
263 263 mock.patch('msgpack.load') as mock_load:
264 264
265 265 client = HooksHttpClient(hooks_uri)
266 266
267 267 mock_load.return_value = {'result': 'success'}
268 268 response = mock.MagicMock()
269 269 response.status = 200
270 270 mock_connection.request.side_effect = None
271 271 mock_connection.getresponse.return_value = response
272 272
273 273 result = client(method, extras)
274 274
275 275 mock_connection.assert_called_with(hooks_uri)
276 276 mock_connection.return_value.request.assert_called_once()
277 277 assert result == {'result': 'success'}
278 278
279 279
280 280 def test_hooks_http_client_serialize():
281 281 method = 'test_method'
282 282 extras = {'key': 'value'}
283 283 headers, body = HooksHttpClient._serialize(method, extras)
284 284
285 285 assert headers == {'rc-hooks-protocol': HooksHttpClient.proto, 'Connection': 'keep-alive'}
286 286 assert msgpack.unpackb(body) == {'method': method, 'extras': extras}
@@ -1,206 +1,206 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import sys
20 20 import stat
21 21 import pytest
22 22 import vcsserver
23 23 import tempfile
24 24 from vcsserver import hook_utils
25 25 from vcsserver.tests.fixture import no_newline_id_generator
26 26 from vcsserver.str_utils import safe_bytes, safe_str
27 27 from vcsserver.utils import AttributeDict
28 28
29 29
30 class TestCheckRhodecodeHook(object):
30 class TestCheckRhodecodeHook:
31 31
32 32 def test_returns_false_when_hook_file_is_wrong_found(self, tmpdir):
33 33 hook = os.path.join(str(tmpdir), 'fake_hook_file.py')
34 34 with open(hook, 'wb') as f:
35 35 f.write(b'dummy test')
36 36 result = hook_utils.check_rhodecode_hook(hook)
37 37 assert result is False
38 38
39 39 def test_returns_true_when_no_hook_file_found(self, tmpdir):
40 40 hook = os.path.join(str(tmpdir), 'fake_hook_file_not_existing.py')
41 41 result = hook_utils.check_rhodecode_hook(hook)
42 42 assert result
43 43
44 44 @pytest.mark.parametrize("file_content, expected_result", [
45 45 ("RC_HOOK_VER = '3.3.3'\n", True),
46 46 ("RC_HOOK = '3.3.3'\n", False),
47 47 ], ids=no_newline_id_generator)
48 48 def test_signatures(self, file_content, expected_result, tmpdir):
49 49 hook = os.path.join(str(tmpdir), 'fake_hook_file_1.py')
50 50 with open(hook, 'wb') as f:
51 51 f.write(safe_bytes(file_content))
52 52
53 53 result = hook_utils.check_rhodecode_hook(hook)
54 54
55 55 assert result is expected_result
56 56
57 57
58 class BaseInstallHooks(object):
58 class BaseInstallHooks:
59 59 HOOK_FILES = ()
60 60
61 61 def _check_hook_file_mode(self, file_path):
62 62 assert os.path.exists(file_path), 'path %s missing' % file_path
63 63 stat_info = os.stat(file_path)
64 64
65 65 file_mode = stat.S_IMODE(stat_info.st_mode)
66 66 expected_mode = int('755', 8)
67 67 assert expected_mode == file_mode
68 68
69 69 def _check_hook_file_content(self, file_path, executable):
70 70 executable = executable or sys.executable
71 71 with open(file_path, 'rt') as hook_file:
72 72 content = hook_file.read()
73 73
74 74 expected_env = '#!{}'.format(executable)
75 75 expected_rc_version = "\nRC_HOOK_VER = '{}'\n".format(vcsserver.__version__)
76 76 assert content.strip().startswith(expected_env)
77 77 assert expected_rc_version in content
78 78
79 79 def _create_fake_hook(self, file_path, content):
80 80 with open(file_path, 'w') as hook_file:
81 81 hook_file.write(content)
82 82
83 83 def create_dummy_repo(self, repo_type):
84 84 tmpdir = tempfile.mkdtemp()
85 85 repo = AttributeDict()
86 86 if repo_type == 'git':
87 87 repo.path = os.path.join(tmpdir, 'test_git_hooks_installation_repo')
88 88 os.makedirs(repo.path)
89 89 os.makedirs(os.path.join(repo.path, 'hooks'))
90 90 repo.bare = True
91 91
92 92 elif repo_type == 'svn':
93 93 repo.path = os.path.join(tmpdir, 'test_svn_hooks_installation_repo')
94 94 os.makedirs(repo.path)
95 95 os.makedirs(os.path.join(repo.path, 'hooks'))
96 96
97 97 return repo
98 98
99 99 def check_hooks(self, repo_path, repo_bare=True):
100 100 for file_name in self.HOOK_FILES:
101 101 if repo_bare:
102 102 file_path = os.path.join(repo_path, 'hooks', file_name)
103 103 else:
104 104 file_path = os.path.join(repo_path, '.git', 'hooks', file_name)
105 105 self._check_hook_file_mode(file_path)
106 106 self._check_hook_file_content(file_path, sys.executable)
107 107
108 108
109 109 class TestInstallGitHooks(BaseInstallHooks):
110 110 HOOK_FILES = ('pre-receive', 'post-receive')
111 111
112 112 def test_hooks_are_installed(self):
113 113 repo = self.create_dummy_repo('git')
114 114 result = hook_utils.install_git_hooks(repo.path, repo.bare)
115 115 assert result
116 116 self.check_hooks(repo.path, repo.bare)
117 117
118 118 def test_hooks_are_replaced(self):
119 119 repo = self.create_dummy_repo('git')
120 120 hooks_path = os.path.join(repo.path, 'hooks')
121 121 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
122 122 self._create_fake_hook(
123 123 file_path, content="RC_HOOK_VER = 'abcde'\n")
124 124
125 125 result = hook_utils.install_git_hooks(repo.path, repo.bare)
126 126 assert result
127 127 self.check_hooks(repo.path, repo.bare)
128 128
129 129 def test_non_rc_hooks_are_not_replaced(self):
130 130 repo = self.create_dummy_repo('git')
131 131 hooks_path = os.path.join(repo.path, 'hooks')
132 132 non_rc_content = 'echo "non rc hook"\n'
133 133 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
134 134 self._create_fake_hook(
135 135 file_path, content=non_rc_content)
136 136
137 137 result = hook_utils.install_git_hooks(repo.path, repo.bare)
138 138 assert result
139 139
140 140 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
141 141 with open(file_path, 'rt') as hook_file:
142 142 content = hook_file.read()
143 143 assert content == non_rc_content
144 144
145 145 def test_non_rc_hooks_are_replaced_with_force_flag(self):
146 146 repo = self.create_dummy_repo('git')
147 147 hooks_path = os.path.join(repo.path, 'hooks')
148 148 non_rc_content = 'echo "non rc hook"\n'
149 149 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
150 150 self._create_fake_hook(
151 151 file_path, content=non_rc_content)
152 152
153 153 result = hook_utils.install_git_hooks(
154 154 repo.path, repo.bare, force_create=True)
155 155 assert result
156 156 self.check_hooks(repo.path, repo.bare)
157 157
158 158
159 159 class TestInstallSvnHooks(BaseInstallHooks):
160 160 HOOK_FILES = ('pre-commit', 'post-commit')
161 161
162 162 def test_hooks_are_installed(self):
163 163 repo = self.create_dummy_repo('svn')
164 164 result = hook_utils.install_svn_hooks(repo.path)
165 165 assert result
166 166 self.check_hooks(repo.path)
167 167
168 168 def test_hooks_are_replaced(self):
169 169 repo = self.create_dummy_repo('svn')
170 170 hooks_path = os.path.join(repo.path, 'hooks')
171 171 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
172 172 self._create_fake_hook(
173 173 file_path, content="RC_HOOK_VER = 'abcde'\n")
174 174
175 175 result = hook_utils.install_svn_hooks(repo.path)
176 176 assert result
177 177 self.check_hooks(repo.path)
178 178
179 179 def test_non_rc_hooks_are_not_replaced(self):
180 180 repo = self.create_dummy_repo('svn')
181 181 hooks_path = os.path.join(repo.path, 'hooks')
182 182 non_rc_content = 'echo "non rc hook"\n'
183 183 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
184 184 self._create_fake_hook(
185 185 file_path, content=non_rc_content)
186 186
187 187 result = hook_utils.install_svn_hooks(repo.path)
188 188 assert result
189 189
190 190 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
191 191 with open(file_path, 'rt') as hook_file:
192 192 content = hook_file.read()
193 193 assert content == non_rc_content
194 194
195 195 def test_non_rc_hooks_are_replaced_with_force_flag(self):
196 196 repo = self.create_dummy_repo('svn')
197 197 hooks_path = os.path.join(repo.path, 'hooks')
198 198 non_rc_content = 'echo "non rc hook"\n'
199 199 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
200 200 self._create_fake_hook(
201 201 file_path, content=non_rc_content)
202 202
203 203 result = hook_utils.install_svn_hooks(
204 204 repo.path, force_create=True)
205 205 assert result
206 206 self.check_hooks(repo.path, )
@@ -1,295 +1,295 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import more_itertools
20 20
21 21 import dulwich.protocol
22 22 import mock
23 23 import pytest
24 24 import webob
25 25 import webtest
26 26
27 27 from vcsserver import hooks, pygrack
28 28
29 29 from vcsserver.str_utils import ascii_bytes
30 30
31 31
32 32 @pytest.fixture()
33 33 def pygrack_instance(tmpdir):
34 34 """
35 35 Creates a pygrack app instance.
36 36
37 37 Right now, it does not much helpful regarding the passed directory.
38 38 It just contains the required folders to pass the signature test.
39 39 """
40 40 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
41 41 tmpdir.mkdir(dir_name)
42 42
43 43 return pygrack.GitRepository('repo_name', str(tmpdir), 'git', False, {})
44 44
45 45
46 46 @pytest.fixture()
47 47 def pygrack_app(pygrack_instance):
48 48 """
49 49 Creates a pygrack app wrapped in webtest.TestApp.
50 50 """
51 51 return webtest.TestApp(pygrack_instance)
52 52
53 53
54 54 def test_invalid_service_info_refs_returns_403(pygrack_app):
55 55 response = pygrack_app.get('/info/refs?service=git-upload-packs',
56 56 expect_errors=True)
57 57
58 58 assert response.status_int == 403
59 59
60 60
61 61 def test_invalid_endpoint_returns_403(pygrack_app):
62 62 response = pygrack_app.post('/git-upload-packs', expect_errors=True)
63 63
64 64 assert response.status_int == 403
65 65
66 66
67 67 @pytest.mark.parametrize('sideband', [
68 68 'side-band-64k',
69 69 'side-band',
70 70 'side-band no-progress',
71 71 ])
72 72 def test_pre_pull_hook_fails_with_sideband(pygrack_app, sideband):
73 73 request = ''.join([
74 74 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ',
75 'multi_ack %s ofs-delta\n' % sideband,
75 f'multi_ack {sideband} ofs-delta\n',
76 76 '0000',
77 77 '0009done\n',
78 78 ])
79 79 with mock.patch('vcsserver.hooks.git_pre_pull', return_value=hooks.HookResponse(1, 'foo')):
80 80 response = pygrack_app.post(
81 81 '/git-upload-pack', params=request,
82 82 content_type='application/x-git-upload-pack')
83 83
84 84 data = io.BytesIO(response.body)
85 85 proto = dulwich.protocol.Protocol(data.read, None)
86 86 packets = list(proto.read_pkt_seq())
87 87
88 88 expected_packets = [
89 89 b'NAK\n', b'\x02foo', b'\x02Pre pull hook failed: aborting\n',
90 90 b'\x01' + pygrack.GitRepository.EMPTY_PACK,
91 91 ]
92 92 assert packets == expected_packets
93 93
94 94
95 95 def test_pre_pull_hook_fails_no_sideband(pygrack_app):
96 96 request = ''.join([
97 97 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
98 98 'multi_ack ofs-delta\n'
99 99 '0000',
100 100 '0009done\n',
101 101 ])
102 102 with mock.patch('vcsserver.hooks.git_pre_pull',
103 103 return_value=hooks.HookResponse(1, 'foo')):
104 104 response = pygrack_app.post(
105 105 '/git-upload-pack', params=request,
106 106 content_type='application/x-git-upload-pack')
107 107
108 108 assert response.body == pygrack.GitRepository.EMPTY_PACK
109 109
110 110
111 111 def test_pull_has_hook_messages(pygrack_app):
112 112 request = ''.join([
113 113 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
114 114 'multi_ack side-band-64k ofs-delta\n'
115 115 '0000',
116 116 '0009done\n',
117 117 ])
118 118
119 119 pre_pull = 'pre_pull_output'
120 120 post_pull = 'post_pull_output'
121 121
122 122 with mock.patch('vcsserver.hooks.git_pre_pull',
123 123 return_value=hooks.HookResponse(0, pre_pull)):
124 124 with mock.patch('vcsserver.hooks.git_post_pull',
125 125 return_value=hooks.HookResponse(1, post_pull)):
126 126 with mock.patch('vcsserver.subprocessio.SubprocessIOChunker',
127 127 return_value=more_itertools.always_iterable([b'0008NAK\n0009subp\n0000'])):
128 128 response = pygrack_app.post(
129 129 '/git-upload-pack', params=request,
130 130 content_type='application/x-git-upload-pack')
131 131
132 132 data = io.BytesIO(response.body)
133 133 proto = dulwich.protocol.Protocol(data.read, None)
134 134 packets = list(proto.read_pkt_seq())
135 135
136 136 assert packets == [b'NAK\n',
137 137 # pre-pull only outputs if IT FAILS as in != 0 ret code
138 138 #b'\x02pre_pull_output',
139 139 b'subp\n',
140 140 b'\x02post_pull_output']
141 141
142 142
143 143 def test_get_want_capabilities(pygrack_instance):
144 144 data = io.BytesIO(
145 145 b'0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
146 146 b'multi_ack side-band-64k ofs-delta\n00000009done\n')
147 147
148 148 request = webob.Request({
149 149 'wsgi.input': data,
150 150 'REQUEST_METHOD': 'POST',
151 151 'webob.is_body_seekable': True
152 152 })
153 153
154 154 capabilities = pygrack_instance._get_want_capabilities(request)
155 155
156 156 assert capabilities == frozenset(
157 157 (b'ofs-delta', b'multi_ack', b'side-band-64k'))
158 158 assert data.tell() == 0
159 159
160 160
161 161 @pytest.mark.parametrize('data,capabilities,expected', [
162 162 ('foo', [], []),
163 163 ('', [pygrack.CAPABILITY_SIDE_BAND_64K], []),
164 164 ('', [pygrack.CAPABILITY_SIDE_BAND], []),
165 165 ('foo', [pygrack.CAPABILITY_SIDE_BAND_64K], [b'0008\x02foo']),
166 166 ('foo', [pygrack.CAPABILITY_SIDE_BAND], [b'0008\x02foo']),
167 167 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'03ed\x02' + b'f' * 1000]),
168 168 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995, b'000a\x02fffff']),
169 169 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'fff0\x02' + b'f' * 65515, b'000a\x02fffff']),
170 170 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995] * 65 + [b'0352\x02' + b'f' * 845]),
171 171 ], ids=[
172 172 'foo-empty',
173 173 'empty-64k', 'empty',
174 174 'foo-64k', 'foo',
175 175 'f-1000-64k', 'f-1000',
176 176 'f-65520-64k', 'f-65520'])
177 177 def test_get_messages(pygrack_instance, data, capabilities, expected):
178 178 messages = pygrack_instance._get_messages(data, capabilities)
179 179
180 180 assert messages == expected
181 181
182 182
183 183 @pytest.mark.parametrize('response,capabilities,pre_pull_messages,post_pull_messages', [
184 184 # Unexpected response
185 185 ([b'unexpected_response[no_initial_header]'], [pygrack.CAPABILITY_SIDE_BAND_64K], 'foo', 'bar'),
186 186 # No sideband
187 187 ([b'no-sideband'], [], 'foo', 'bar'),
188 188 # No messages
189 189 ([b'no-messages'], [pygrack.CAPABILITY_SIDE_BAND_64K], '', ''),
190 190 ])
191 191 def test_inject_messages_to_response_nothing_to_do(
192 192 pygrack_instance, response, capabilities, pre_pull_messages, post_pull_messages):
193 193
194 194 new_response = pygrack_instance._build_post_pull_response(
195 195 more_itertools.always_iterable(response), capabilities, pre_pull_messages, post_pull_messages)
196 196
197 197 assert list(new_response) == response
198 198
199 199
200 200 @pytest.mark.parametrize('capabilities', [
201 201 [pygrack.CAPABILITY_SIDE_BAND],
202 202 [pygrack.CAPABILITY_SIDE_BAND_64K],
203 203 ])
204 204 def test_inject_messages_to_response_single_element(pygrack_instance, capabilities):
205 205 response = [b'0008NAK\n0009subp\n0000']
206 206 new_response = pygrack_instance._build_post_pull_response(
207 207 more_itertools.always_iterable(response), capabilities, 'foo', 'bar')
208 208
209 209 expected_response = b''.join([
210 210 b'0008NAK\n',
211 211 b'0008\x02foo',
212 212 b'0009subp\n',
213 213 b'0008\x02bar',
214 214 b'0000'])
215 215
216 216 assert b''.join(new_response) == expected_response
217 217
218 218
219 219 @pytest.mark.parametrize('capabilities', [
220 220 [pygrack.CAPABILITY_SIDE_BAND],
221 221 [pygrack.CAPABILITY_SIDE_BAND_64K],
222 222 ])
223 223 def test_inject_messages_to_response_multi_element(pygrack_instance, capabilities):
224 224 response = more_itertools.always_iterable([
225 225 b'0008NAK\n000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n0000'
226 226 ])
227 227 new_response = pygrack_instance._build_post_pull_response(response, capabilities, 'foo', 'bar')
228 228
229 229 expected_response = b''.join([
230 230 b'0008NAK\n',
231 231 b'0008\x02foo',
232 232 b'000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n',
233 233 b'0008\x02bar',
234 234 b'0000'
235 235 ])
236 236
237 237 assert b''.join(new_response) == expected_response
238 238
239 239
240 240 def test_build_failed_pre_pull_response_no_sideband(pygrack_instance):
241 241 response = pygrack_instance._build_failed_pre_pull_response([], 'foo')
242 242
243 243 assert response == [pygrack.GitRepository.EMPTY_PACK]
244 244
245 245
246 246 @pytest.mark.parametrize('capabilities', [
247 247 [pygrack.CAPABILITY_SIDE_BAND],
248 248 [pygrack.CAPABILITY_SIDE_BAND_64K],
249 249 [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'],
250 250 ])
251 251 def test_build_failed_pre_pull_response(pygrack_instance, capabilities):
252 252 response = pygrack_instance._build_failed_pre_pull_response(capabilities, 'foo')
253 253
254 254 expected_response = [
255 255 b'0008NAK\n', b'0008\x02foo', b'0024\x02Pre pull hook failed: aborting\n',
256 256 b'%04x\x01%s' % (len(pygrack.GitRepository.EMPTY_PACK) + 5, pygrack.GitRepository.EMPTY_PACK),
257 257 pygrack.GitRepository.FLUSH_PACKET,
258 258 ]
259 259
260 260 assert response == expected_response
261 261
262 262
263 263 def test_inject_messages_to_response_generator(pygrack_instance):
264 264
265 265 def response_generator():
266 266 response = [
267 267 # protocol start
268 268 b'0008NAK\n',
269 269 ]
270 270 response += [ascii_bytes(f'000asubp{x}\n') for x in range(1000)]
271 271 response += [
272 272 # protocol end
273 273 pygrack.GitRepository.FLUSH_PACKET
274 274 ]
275 275 for elem in response:
276 276 yield elem
277 277
278 278 new_response = pygrack_instance._build_post_pull_response(
279 279 response_generator(), [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'], 'PRE_PULL_MSG\n', 'POST_PULL_MSG\n')
280 280
281 281 assert iter(new_response)
282 282
283 283 expected_response = b''.join([
284 284 # start
285 285 b'0008NAK\n0012\x02PRE_PULL_MSG\n',
286 286 ] + [
287 287 # ... rest
288 288 ascii_bytes(f'000asubp{x}\n') for x in range(1000)
289 289 ] + [
290 290 # final message,
291 291 b'0013\x02POST_PULL_MSG\n0000',
292 292
293 293 ])
294 294
295 295 assert b''.join(new_response) == expected_response
@@ -1,155 +1,155 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21
22 22 import pytest
23 23
24 24 from vcsserver import subprocessio
25 25 from vcsserver.str_utils import ascii_bytes
26 26
27 27
28 class FileLikeObj(object): # pragma: no cover
28 class FileLikeObj: # pragma: no cover
29 29
30 30 def __init__(self, data: bytes, size):
31 31 chunks = size // len(data)
32 32
33 33 self.stream = self._get_stream(data, chunks)
34 34
35 35 def _get_stream(self, data, chunks):
36 36 for x in range(chunks):
37 37 yield data
38 38
39 39 def read(self, n):
40 40
41 41 buffer_stream = b''
42 42 for chunk in self.stream:
43 43 buffer_stream += chunk
44 44 if len(buffer_stream) >= n:
45 45 break
46 46
47 47 # self.stream = self.bytes[n:]
48 48 return buffer_stream
49 49
50 50
51 51 @pytest.fixture(scope='module')
52 52 def environ():
53 53 """Delete coverage variables, as they make the tests fail."""
54 54 env = dict(os.environ)
55 55 for key in list(env.keys()):
56 56 if key.startswith('COV_CORE_'):
57 57 del env[key]
58 58
59 59 return env
60 60
61 61
62 62 def _get_python_args(script):
63 63 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
64 64
65 65
66 66 def test_raise_exception_on_non_zero_return_code(environ):
67 67 call_args = _get_python_args('raise ValueError("fail")')
68 68 with pytest.raises(OSError):
69 69 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
70 70
71 71
72 72 def test_does_not_fail_on_non_zero_return_code(environ):
73 73 call_args = _get_python_args('sys.stdout.write("hello"); sys.exit(1)')
74 74 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
75 75 output = b''.join(proc)
76 76
77 77 assert output == b'hello'
78 78
79 79
80 80 def test_raise_exception_on_stderr(environ):
81 81 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); time.sleep(1);')
82 82
83 83 with pytest.raises(OSError) as excinfo:
84 84 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
85 85
86 86 assert 'exited due to an error:\nWRITE_TO_STDERR' in str(excinfo.value)
87 87
88 88
89 89 def test_does_not_fail_on_stderr(environ):
90 90 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); sys.stderr.flush; time.sleep(2);')
91 91 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_stderr=False, env=environ)
92 92 output = b''.join(proc)
93 93
94 94 assert output == b''
95 95
96 96
97 97 @pytest.mark.parametrize('size', [
98 98 1,
99 99 10 ** 5
100 100 ])
101 101 def test_output_with_no_input(size, environ):
102 102 call_args = _get_python_args(f'sys.stdout.write("X" * {size});')
103 103 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ)
104 104 output = b''.join(proc)
105 105
106 106 assert output == ascii_bytes("X" * size)
107 107
108 108
109 109 @pytest.mark.parametrize('size', [
110 110 1,
111 111 10 ** 5
112 112 ])
113 113 def test_output_with_no_input_does_not_fail(size, environ):
114 114
115 115 call_args = _get_python_args(f'sys.stdout.write("X" * {size}); sys.exit(1)')
116 116 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
117 117 output = b''.join(proc)
118 118
119 119 assert output == ascii_bytes("X" * size)
120 120
121 121
122 122 @pytest.mark.parametrize('size', [
123 123 1,
124 124 10 ** 5
125 125 ])
126 126 def test_output_with_input(size, environ):
127 127 data_len = size
128 128 inputstream = FileLikeObj(b'X', size)
129 129
130 130 # This acts like the cat command.
131 131 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 132 # note: in this tests we explicitly don't assign chunker to a variable and let it stream directly
133 133 output = b''.join(
134 134 subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
135 135 )
136 136
137 137 assert len(output) == data_len
138 138
139 139
140 140 @pytest.mark.parametrize('size', [
141 141 1,
142 142 10 ** 5
143 143 ])
144 144 def test_output_with_input_skipping_iterator(size, environ):
145 145 data_len = size
146 146 inputstream = FileLikeObj(b'X', size)
147 147
148 148 # This acts like the cat command.
149 149 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
150 150
151 151 # Note: assigning the chunker makes sure that it is not deleted too early
152 152 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
153 153 output = b''.join(proc.stdout)
154 154
155 155 assert len(output) == data_len
@@ -1,103 +1,103 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import mock
20 20 import pytest
21 21 import sys
22 22
23 23 from vcsserver.str_utils import ascii_bytes
24 24
25 25
26 class MockPopen(object):
26 class MockPopen:
27 27 def __init__(self, stderr):
28 28 self.stdout = io.BytesIO(b'')
29 29 self.stderr = io.BytesIO(stderr)
30 30 self.returncode = 1
31 31
32 32 def wait(self):
33 33 pass
34 34
35 35
36 36 INVALID_CERTIFICATE_STDERR = '\n'.join([
37 37 'svnrdump: E230001: Unable to connect to a repository at URL url',
38 38 'svnrdump: E230001: Server SSL certificate verification failed: issuer is not trusted',
39 39 ])
40 40
41 41
42 42 @pytest.mark.parametrize('stderr,expected_reason', [
43 43 (INVALID_CERTIFICATE_STDERR, 'INVALID_CERTIFICATE'),
44 44 ('svnrdump: E123456', 'UNKNOWN:svnrdump: E123456'),
45 45 ], ids=['invalid-cert-stderr', 'svnrdump-err-123456'])
46 46 @pytest.mark.xfail(sys.platform == "cygwin",
47 47 reason="SVN not packaged for Cygwin")
48 48 def test_import_remote_repository_certificate_error(stderr, expected_reason):
49 49 from vcsserver.remote import svn_remote
50 50 factory = mock.Mock()
51 51 factory.repo = mock.Mock(return_value=mock.Mock())
52 52
53 53 remote = svn_remote.SvnRemote(factory)
54 54 remote.is_path_valid_repository = lambda wire, path: True
55 55
56 56 with mock.patch('subprocess.Popen',
57 57 return_value=MockPopen(ascii_bytes(stderr))):
58 58 with pytest.raises(Exception) as excinfo:
59 59 remote.import_remote_repository({'path': 'path'}, 'url')
60 60
61 61 expected_error_args = 'Failed to dump the remote repository from url. Reason:{}'.format(expected_reason)
62 62
63 63 assert excinfo.value.args[0] == expected_error_args
64 64
65 65
66 66 def test_svn_libraries_can_be_imported():
67 67 import svn.client # noqa
68 68 assert svn.client is not None
69 69
70 70
71 71 @pytest.mark.parametrize('example_url, parts', [
72 72 ('http://server.com', ('', '', 'http://server.com')),
73 73 ('http://user@server.com', ('user', '', 'http://user@server.com')),
74 74 ('http://user:pass@server.com', ('user', 'pass', 'http://user:pass@server.com')),
75 75 ('<script>', ('', '', '<script>')),
76 76 ('http://', ('', '', 'http://')),
77 77 ])
78 78 def test_username_password_extraction_from_url(example_url, parts):
79 79 from vcsserver.remote import svn_remote
80 80
81 81 factory = mock.Mock()
82 82 factory.repo = mock.Mock(return_value=mock.Mock())
83 83
84 84 remote = svn_remote.SvnRemote(factory)
85 85 remote.is_path_valid_repository = lambda wire, path: True
86 86
87 87 assert remote.get_url_and_credentials(example_url) == parts
88 88
89 89
90 90 @pytest.mark.parametrize('call_url', [
91 91 b'https://svn.code.sf.net/p/svnbook/source/trunk/',
92 92 b'https://marcink@svn.code.sf.net/p/svnbook/source/trunk/',
93 93 b'https://marcink:qweqwe@svn.code.sf.net/p/svnbook/source/trunk/',
94 94 ])
95 95 def test_check_url(call_url):
96 96 from vcsserver.remote import svn_remote
97 97 factory = mock.Mock()
98 98 factory.repo = mock.Mock(return_value=mock.Mock())
99 99
100 100 remote = svn_remote.SvnRemote(factory)
101 101 remote.is_path_valid_repository = lambda wire, path: True
102 102 assert remote.check_url(call_url, {'dummy': 'config'})
103 103
@@ -1,123 +1,123 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import base64
18 18 import logging
19 19 import time
20 20
21 21 import msgpack
22 22
23 23 import vcsserver
24 24 from vcsserver.str_utils import safe_str
25 25
26 26 log = logging.getLogger(__name__)
27 27
28 28
29 29 def get_access_path(environ):
30 30 path = environ.get('PATH_INFO')
31 31 return path
32 32
33 33
34 34 def get_user_agent(environ):
35 35 return environ.get('HTTP_USER_AGENT')
36 36
37 37
38 38 def get_call_context(request) -> dict:
39 39 cc = {}
40 40 registry = request.registry
41 41 if hasattr(registry, 'vcs_call_context'):
42 42 cc.update({
43 43 'X-RC-Method': registry.vcs_call_context.get('method'),
44 44 'X-RC-Repo-Name': registry.vcs_call_context.get('repo_name')
45 45 })
46 46
47 47 return cc
48 48
49 49
50 50 def get_headers_call_context(environ, strict=True):
51 51 if 'HTTP_X_RC_VCS_STREAM_CALL_CONTEXT' in environ:
52 52 packed_cc = base64.b64decode(environ['HTTP_X_RC_VCS_STREAM_CALL_CONTEXT'])
53 53 return msgpack.unpackb(packed_cc)
54 54 elif strict:
55 55 raise ValueError('Expected header HTTP_X_RC_VCS_STREAM_CALL_CONTEXT not found')
56 56
57 57
58 class RequestWrapperTween(object):
58 class RequestWrapperTween:
59 59 def __init__(self, handler, registry):
60 60 self.handler = handler
61 61 self.registry = registry
62 62
63 63 # one-time configuration code goes here
64 64
65 65 def __call__(self, request):
66 66 start = time.time()
67 67 log.debug('Starting request time measurement')
68 68 response = None
69 69
70 70 try:
71 71 response = self.handler(request)
72 72 finally:
73 73 ua = get_user_agent(request.environ)
74 74 call_context = get_call_context(request)
75 75 vcs_method = call_context.get('X-RC-Method', '_NO_VCS_METHOD')
76 76 repo_name = call_context.get('X-RC-Repo-Name', '')
77 77
78 78 count = request.request_count()
79 79 _ver_ = vcsserver.__version__
80 80 _path = safe_str(get_access_path(request.environ))
81 81
82 82 ip = '127.0.0.1'
83 83 match_route = request.matched_route.name if request.matched_route else "NOT_FOUND"
84 84 resp_code = getattr(response, 'status_code', 'UNDEFINED')
85 85
86 86 _view_path = f"{repo_name}@{_path}/{vcs_method}"
87 87
88 88 total = time.time() - start
89 89
90 90 log.info(
91 91 'Req[%4s] IP: %s %s Request to %s time: %.4fs [%s], VCSServer %s',
92 92 count, ip, request.environ.get('REQUEST_METHOD'),
93 93 _view_path, total, ua, _ver_,
94 94 extra={"time": total, "ver": _ver_, "code": resp_code,
95 95 "path": _path, "view_name": match_route, "user_agent": ua,
96 96 "vcs_method": vcs_method, "repo_name": repo_name}
97 97 )
98 98
99 99 statsd = request.registry.statsd
100 100 if statsd:
101 101 match_route = request.matched_route.name if request.matched_route else _path
102 102 elapsed_time_ms = round(1000.0 * total) # use ms only
103 103 statsd.timing(
104 104 "vcsserver_req_timing.histogram", elapsed_time_ms,
105 105 tags=[
106 106 f"view_name:{match_route}",
107 107 f"code:{resp_code}"
108 108 ],
109 109 use_decimals=False
110 110 )
111 111 statsd.incr(
112 112 "vcsserver_req_total", tags=[
113 113 f"view_name:{match_route}",
114 114 f"code:{resp_code}"
115 115 ])
116 116
117 117 return response
118 118
119 119
120 120 def includeme(config):
121 121 config.add_tween(
122 122 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
123 123 )
@@ -1,46 +1,46 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 from vcsserver.lib import rc_cache
19 19
20 20
21 class RemoteBase(object):
21 class RemoteBase:
22 22 EMPTY_COMMIT = '0' * 40
23 23
24 24 def _region(self, wire):
25 25 cache_repo_id = wire.get('cache_repo_id', '')
26 26 cache_namespace_uid = f'cache_repo.{rc_cache.CACHE_OBJ_CACHE_VER}.{cache_repo_id}'
27 27 return rc_cache.get_or_create_region('repo_object', cache_namespace_uid)
28 28
29 29 def _cache_on(self, wire):
30 30 context = wire.get('context', '')
31 31 context_uid = f'{context}'
32 32 repo_id = wire.get('repo_id', '')
33 33 cache = wire.get('cache', True)
34 34 cache_on = context and cache
35 35 return cache_on, context_uid, repo_id
36 36
37 37 def vcsserver_invalidate_cache(self, wire, delete):
38 38 cache_repo_id = wire.get('cache_repo_id', '')
39 39 cache_namespace_uid = f'cache_repo.{rc_cache.CACHE_OBJ_CACHE_VER}.{cache_repo_id}'
40 40
41 41 if delete:
42 42 rc_cache.clear_cache_namespace(
43 43 'repo_object', cache_namespace_uid, method=rc_cache.CLEAR_DELETE)
44 44
45 45 repo_id = wire.get('repo_id', '')
46 46 return {'invalidated': {'repo_id': repo_id, 'delete': delete}}
@@ -1,116 +1,116 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Extract the responses of a WSGI app."""
19 19
20 20 __all__ = ('WSGIAppCaller',)
21 21
22 22 import io
23 23 import logging
24 24 import os
25 25
26 26 from vcsserver.str_utils import ascii_bytes
27 27
28 28 log = logging.getLogger(__name__)
29 29
30 30 DEV_NULL = open(os.devnull)
31 31
32 32
33 33 def _complete_environ(environ, input_data: bytes):
34 34 """Update the missing wsgi.* variables of a WSGI environment.
35 35
36 36 :param environ: WSGI environment to update
37 37 :type environ: dict
38 38 :param input_data: data to be read by the app
39 39 :type input_data: bytes
40 40 """
41 41 environ.update({
42 42 'wsgi.version': (1, 0),
43 43 'wsgi.url_scheme': 'http',
44 44 'wsgi.multithread': True,
45 45 'wsgi.multiprocess': True,
46 46 'wsgi.run_once': False,
47 47 'wsgi.input': io.BytesIO(input_data),
48 48 'wsgi.errors': DEV_NULL,
49 49 })
50 50
51 51
52 52 # pylint: disable=too-few-public-methods
53 class _StartResponse(object):
53 class _StartResponse:
54 54 """Save the arguments of a start_response call."""
55 55
56 56 __slots__ = ['status', 'headers', 'content']
57 57
58 58 def __init__(self):
59 59 self.status = None
60 60 self.headers = None
61 61 self.content = []
62 62
63 63 def __call__(self, status, headers, exc_info=None):
64 64 # TODO(skreft): do something meaningful with the exc_info
65 65 exc_info = None # avoid dangling circular reference
66 66 self.status = status
67 67 self.headers = headers
68 68
69 69 return self.write
70 70
71 71 def write(self, content):
72 72 """Write method returning when calling this object.
73 73
74 74 All the data written is then available in content.
75 75 """
76 76 self.content.append(content)
77 77
78 78
79 class WSGIAppCaller(object):
79 class WSGIAppCaller:
80 80 """Calls a WSGI app."""
81 81
82 82 def __init__(self, app):
83 83 """
84 84 :param app: WSGI app to call
85 85 """
86 86 self.app = app
87 87
88 88 def handle(self, environ, input_data):
89 89 """Process a request with the WSGI app.
90 90
91 91 The returned data of the app is fully consumed into a list.
92 92
93 93 :param environ: WSGI environment to update
94 94 :type environ: dict
95 95 :param input_data: data to be read by the app
96 96 :type input_data: str/bytes
97 97
98 98 :returns: a tuple with the contents, status and headers
99 99 :rtype: (list<str>, str, list<(str, str)>)
100 100 """
101 101 _complete_environ(environ, ascii_bytes(input_data, allow_bytes=True))
102 102 start_response = _StartResponse()
103 103 log.debug("Calling wrapped WSGI application")
104 104 responses = self.app(environ, start_response)
105 105 responses_list = list(responses)
106 106 existing_responses = start_response.content
107 107 if existing_responses:
108 108 log.debug("Adding returned response to response written via write()")
109 109 existing_responses.extend(responses_list)
110 110 responses_list = existing_responses
111 111 if hasattr(responses, 'close'):
112 112 log.debug("Closing iterator from WSGI application")
113 113 responses.close()
114 114
115 115 log.debug("Handling of WSGI request done, returning response")
116 116 return responses_list, start_response.status, start_response.headers
General Comments 0
You need to be logged in to leave comments. Login now