##// END OF EJS Templates
python3: code change for py3 support...
super-admin -
r1048:742e21ae python3
parent child Browse files
Show More
@@ -0,0 +1,1 b''
1 import simplejson as json
@@ -0,0 +1,53 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 #
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
18 import pytest
19 from vcsserver.utils import ascii_bytes, ascii_str
20
21
22 @pytest.mark.parametrize('given, expected', [
23 ('a', b'a'),
24 (u'a', b'a'),
25 ])
26 def test_ascii_bytes(given, expected):
27 assert ascii_bytes(given) == expected
28
29
30 @pytest.mark.parametrize('given', [
31 'Ã¥',
32 'Ã¥'.encode('utf8')
33 ])
34 def test_ascii_bytes_raises(given):
35 with pytest.raises(ValueError):
36 ascii_bytes(given)
37
38
39 @pytest.mark.parametrize('given, expected', [
40 (b'a', 'a'),
41 ])
42 def test_ascii_str(given, expected):
43 assert ascii_str(given) == expected
44
45
46 @pytest.mark.parametrize('given', [
47 u'a',
48 'Ã¥'.encode('utf8'),
49 u'Ã¥'
50 ])
51 def test_ascii_str_raises(given):
52 with pytest.raises(ValueError):
53 ascii_str(given)
@@ -1,132 +1,134 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import os
18 18 import sys
19 19 import traceback
20 20 import logging
21 21 import urllib.parse
22 22
23 23 from vcsserver.lib.rc_cache import region_meta
24 24
25 25 from vcsserver import exceptions
26 26 from vcsserver.exceptions import NoContentException
27 27 from vcsserver.hgcompat import (archival)
28 28
29 29 log = logging.getLogger(__name__)
30 30
31 31
32 32 class RepoFactory(object):
33 33 """
34 34 Utility to create instances of repository
35 35
36 36 It provides internal caching of the `repo` object based on
37 37 the :term:`call context`.
38 38 """
39 39 repo_type = None
40 40
41 41 def __init__(self):
42 42 self._cache_region = region_meta.dogpile_cache_regions['repo_object']
43 43
44 44 def _create_config(self, path, config):
45 45 config = {}
46 46 return config
47 47
48 48 def _create_repo(self, wire, create):
49 49 raise NotImplementedError()
50 50
51 51 def repo(self, wire, create=False):
52 52 raise NotImplementedError()
53 53
54 54
55 55 def obfuscate_qs(query_string):
56 56 if query_string is None:
57 57 return None
58 58
59 59 parsed = []
60 60 for k, v in urllib.parse.parse_qsl(query_string, keep_blank_values=True):
61 61 if k in ['auth_token', 'api_key']:
62 62 v = "*****"
63 63 parsed.append((k, v))
64 64
65 65 return '&'.join('{}{}'.format(
66 66 k, '={}'.format(v) if v else '') for k, v in parsed)
67 67
68 68
69 def raise_from_original(new_type):
69 def raise_from_original(new_type, org_exc: Exception):
70 70 """
71 71 Raise a new exception type with original args and traceback.
72 72 """
73
73 74 exc_type, exc_value, exc_traceback = sys.exc_info()
74 75 new_exc = new_type(*exc_value.args)
76
75 77 # store the original traceback into the new exc
76 new_exc._org_exc_tb = traceback.format_exc(exc_traceback)
78 new_exc._org_exc_tb = traceback.format_tb(exc_traceback)
77 79
78 80 try:
79 81 raise new_exc.with_traceback(exc_traceback)
80 82 finally:
81 83 del exc_traceback
82 84
83 85
84 86 class ArchiveNode(object):
85 87 def __init__(self, path, mode, is_link, raw_bytes):
86 88 self.path = path
87 89 self.mode = mode
88 90 self.is_link = is_link
89 91 self.raw_bytes = raw_bytes
90 92
91 93
92 94 def archive_repo(walker, archive_dest_path, kind, mtime, archive_at_path,
93 95 archive_dir_name, commit_id, write_metadata=True, extra_metadata=None):
94 96 """
95 97 walker should be a file walker, for example:
96 98 def walker():
97 99 for file_info in files:
98 100 yield ArchiveNode(fn, mode, is_link, ctx[fn].data)
99 101 """
100 102 extra_metadata = extra_metadata or {}
101 103
102 104 if kind == "tgz":
103 105 archiver = archival.tarit(archive_dest_path, mtime, "gz")
104 106 elif kind == "tbz2":
105 107 archiver = archival.tarit(archive_dest_path, mtime, "bz2")
106 108 elif kind == 'zip':
107 109 archiver = archival.zipit(archive_dest_path, mtime)
108 110 else:
109 111 raise exceptions.ArchiveException()(
110 112 'Remote does not support: "%s" archive type.' % kind)
111 113
112 114 for f in walker(commit_id, archive_at_path):
113 115 f_path = os.path.join(archive_dir_name, f.path.lstrip('/'))
114 116 try:
115 117 archiver.addfile(f_path, f.mode, f.is_link, f.raw_bytes())
116 118 except NoContentException:
117 119 # NOTE(marcink): this is a special case for SVN so we can create "empty"
118 120 # directories which arent supported by archiver
119 121 archiver.addfile(os.path.join(f_path, '.dir'), f.mode, f.is_link, '')
120 122
121 123 if write_metadata:
122 124 metadata = dict([
123 125 ('commit_id', commit_id),
124 126 ('mtime', mtime),
125 127 ])
126 128 metadata.update(extra_metadata)
127 129
128 130 meta = ["%s:%s" % (f_name, value) for f_name, value in metadata.items()]
129 131 f_path = os.path.join(archive_dir_name, '.archival.txt')
130 132 archiver.addfile(f_path, 0o644, False, '\n'.join(meta))
131 133
132 134 return archiver.done()
@@ -1,207 +1,208 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import textwrap
23 23 import string
24 24 import functools
25 25 import logging
26 26 import tempfile
27 27 import logging.config
28 28 log = logging.getLogger(__name__)
29 29
30 30 # skip keys, that are set here, so we don't double process those
31 31 set_keys = {
32 32 '__file__': ''
33 33 }
34 34
35 35
36 36 def str2bool(_str):
37 37 """
38 38 returns True/False value from given string, it tries to translate the
39 39 string into boolean
40 40
41 41 :param _str: string value to translate into boolean
42 42 :rtype: boolean
43 43 :returns: boolean from given string
44 44 """
45 45 if _str is None:
46 46 return False
47 47 if _str in (True, False):
48 48 return _str
49 49 _str = str(_str).strip().lower()
50 50 return _str in ('t', 'true', 'y', 'yes', 'on', '1')
51 51
52 52
53 53 def aslist(obj, sep=None, strip=True):
54 54 """
55 55 Returns given string separated by sep as list
56 56
57 57 :param obj:
58 58 :param sep:
59 59 :param strip:
60 60 """
61 61 if isinstance(obj, str):
62 62 if obj in ['', ""]:
63 63 return []
64 64
65 65 lst = obj.split(sep)
66 66 if strip:
67 67 lst = [v.strip() for v in lst]
68 68 return lst
69 69 elif isinstance(obj, (list, tuple)):
70 70 return obj
71 71 elif obj is None:
72 72 return []
73 73 else:
74 74 return [obj]
75 75
76 76
77 77 class SettingsMaker(object):
78 78
79 79 def __init__(self, app_settings):
80 80 self.settings = app_settings
81 81
82 82 @classmethod
83 83 def _bool_func(cls, input_val):
84 if isinstance(input_val, unicode):
85 input_val = input_val.encode('utf8')
84 if isinstance(input_val, bytes):
85 # decode to str
86 input_val = input_val.decode('utf8')
86 87 return str2bool(input_val)
87 88
88 89 @classmethod
89 90 def _int_func(cls, input_val):
90 91 return int(input_val)
91 92
92 93 @classmethod
93 94 def _list_func(cls, input_val, sep=','):
94 95 return aslist(input_val, sep=sep)
95 96
96 97 @classmethod
97 98 def _string_func(cls, input_val, lower=True):
98 99 if lower:
99 100 input_val = input_val.lower()
100 101 return input_val
101 102
102 103 @classmethod
103 104 def _float_func(cls, input_val):
104 105 return float(input_val)
105 106
106 107 @classmethod
107 108 def _dir_func(cls, input_val, ensure_dir=False, mode=0o755):
108 109
109 110 # ensure we have our dir created
110 111 if not os.path.isdir(input_val) and ensure_dir:
111 112 os.makedirs(input_val, mode=mode)
112 113
113 114 if not os.path.isdir(input_val):
114 115 raise Exception('Dir at {} does not exist'.format(input_val))
115 116 return input_val
116 117
117 118 @classmethod
118 119 def _file_path_func(cls, input_val, ensure_dir=False, mode=0o755):
119 120 dirname = os.path.dirname(input_val)
120 121 cls._dir_func(dirname, ensure_dir=ensure_dir)
121 122 return input_val
122 123
123 124 @classmethod
124 125 def _key_transformator(cls, key):
125 126 return "{}_{}".format('RC'.upper(), key.upper().replace('.', '_').replace('-', '_'))
126 127
127 128 def maybe_env_key(self, key):
128 129 # now maybe we have this KEY in env, search and use the value with higher priority.
129 130 transformed_key = self._key_transformator(key)
130 131 envvar_value = os.environ.get(transformed_key)
131 132 if envvar_value:
132 133 log.debug('using `%s` key instead of `%s` key for config', transformed_key, key)
133 134
134 135 return envvar_value
135 136
136 137 def env_expand(self):
137 138 replaced = {}
138 139 for k, v in self.settings.items():
139 140 if k not in set_keys:
140 141 envvar_value = self.maybe_env_key(k)
141 142 if envvar_value:
142 143 replaced[k] = envvar_value
143 144 set_keys[k] = envvar_value
144 145
145 146 # replace ALL keys updated
146 147 self.settings.update(replaced)
147 148
148 149 def enable_logging(self, logging_conf=None, level='INFO', formatter='generic'):
149 150 """
150 151 Helper to enable debug on running instance
151 152 :return:
152 153 """
153 154
154 155 if not str2bool(self.settings.get('logging.autoconfigure')):
155 156 log.info('logging configuration based on main .ini file')
156 157 return
157 158
158 159 if logging_conf is None:
159 160 logging_conf = self.settings.get('logging.logging_conf_file') or ''
160 161
161 162 if not os.path.isfile(logging_conf):
162 163 log.error('Unable to setup logging based on %s, '
163 164 'file does not exist.... specify path using logging.logging_conf_file= config setting. ', logging_conf)
164 165 return
165 166
166 167 with open(logging_conf, 'rb') as f:
167 168 ini_template = textwrap.dedent(f.read())
168 169 ini_template = string.Template(ini_template).safe_substitute(
169 170 RC_LOGGING_LEVEL=os.environ.get('RC_LOGGING_LEVEL', '') or level,
170 171 RC_LOGGING_FORMATTER=os.environ.get('RC_LOGGING_FORMATTER', '') or formatter
171 172 )
172 173
173 174 with tempfile.NamedTemporaryFile(prefix='rc_logging_', suffix='.ini', delete=False) as f:
174 175 log.info('Saved Temporary LOGGING config at %s', f.name)
175 176 f.write(ini_template)
176 177
177 178 logging.config.fileConfig(f.name)
178 179 os.remove(f.name)
179 180
180 181 def make_setting(self, key, default, lower=False, default_when_empty=False, parser=None):
181 182 input_val = self.settings.get(key, default)
182 183
183 184 if default_when_empty and not input_val:
184 185 # use default value when value is set in the config but it is empty
185 186 input_val = default
186 187
187 188 parser_func = {
188 189 'bool': self._bool_func,
189 190 'int': self._int_func,
190 191 'list': self._list_func,
191 192 'list:newline': functools.partial(self._list_func, sep='/n'),
192 193 'list:spacesep': functools.partial(self._list_func, sep=' '),
193 194 'string': functools.partial(self._string_func, lower=lower),
194 195 'dir': self._dir_func,
195 196 'dir:ensured': functools.partial(self._dir_func, ensure_dir=True),
196 197 'file': self._file_path_func,
197 198 'file:ensured': functools.partial(self._file_path_func, ensure_dir=True),
198 199 None: lambda i: i
199 200 }[parser]
200 201
201 202 envvar_value = self.maybe_env_key(key)
202 203 if envvar_value:
203 204 input_val = envvar_value
204 205 set_keys[key] = input_val
205 206
206 207 self.settings[key] = parser_func(input_val)
207 208 return self.settings[key]
@@ -1,54 +1,54 b''
1 1 """
2 2 Implementation of :class:`EchoApp`.
3 3
4 4 This WSGI application will just echo back the data which it recieves.
5 5 """
6 6
7 7 import logging
8 8
9 9
10 10 log = logging.getLogger(__name__)
11 11
12 12
13 13 class EchoApp(object):
14 14
15 15 def __init__(self, repo_path, repo_name, config):
16 16 self._repo_path = repo_path
17 17 log.info("EchoApp initialized for %s", repo_path)
18 18
19 19 def __call__(self, environ, start_response):
20 20 log.debug("EchoApp called for %s", self._repo_path)
21 21 log.debug("Content-Length: %s", environ.get('CONTENT_LENGTH'))
22 22 environ['wsgi.input'].read()
23 23 status = '200 OK'
24 24 headers = [('Content-Type', 'text/plain')]
25 25 start_response(status, headers)
26 return ["ECHO"]
26 return [b"ECHO"]
27 27
28 28
29 29 class EchoAppStream(object):
30 30
31 31 def __init__(self, repo_path, repo_name, config):
32 32 self._repo_path = repo_path
33 33 log.info("EchoApp initialized for %s", repo_path)
34 34
35 35 def __call__(self, environ, start_response):
36 36 log.debug("EchoApp called for %s", self._repo_path)
37 37 log.debug("Content-Length: %s", environ.get('CONTENT_LENGTH'))
38 38 environ['wsgi.input'].read()
39 39 status = '200 OK'
40 40 headers = [('Content-Type', 'text/plain')]
41 41 start_response(status, headers)
42 42
43 43 def generator():
44 44 for _ in range(1000000):
45 yield "ECHO"
45 yield b"ECHO_STREAM"
46 46 return generator()
47 47
48 48
49 49 def create_app():
50 50 """
51 51 Allows to run this app directly in a WSGI server.
52 52 """
53 53 stub_config = {}
54 54 return EchoApp('stub_path', 'stub_name', stub_config)
@@ -1,45 +1,45 b''
1 1 """
2 2 Provides the same API as :mod:`remote_wsgi`.
3 3
4 4 Uses the `EchoApp` instead of real implementations.
5 5 """
6 6
7 7 import logging
8 8
9 9 from .echo_app import EchoApp
10 10 from vcsserver import wsgi_app_caller
11 11
12 12
13 13 log = logging.getLogger(__name__)
14 14
15 15
16 16 class GitRemoteWsgi(object):
17 17 def handle(self, environ, input_data, *args, **kwargs):
18 18 app = wsgi_app_caller.WSGIAppCaller(
19 19 create_echo_wsgi_app(*args, **kwargs))
20 20
21 21 return app.handle(environ, input_data)
22 22
23 23
24 24 class HgRemoteWsgi(object):
25 25 def handle(self, environ, input_data, *args, **kwargs):
26 26 app = wsgi_app_caller.WSGIAppCaller(
27 27 create_echo_wsgi_app(*args, **kwargs))
28 28
29 29 return app.handle(environ, input_data)
30 30
31 31
32 32 def create_echo_wsgi_app(repo_path, repo_name, config):
33 33 log.debug("Creating EchoApp WSGI application")
34 34
35 35 _assert_valid_config(config)
36 36
37 37 # Remaining items are forwarded to have the extras available
38 38 return EchoApp(repo_path, repo_name, config=config)
39 39
40 40
41 41 def _assert_valid_config(config):
42 42 config = config.copy()
43 43
44 44 # This is what git needs from config at this stage
45 config.pop('git_update_server_info')
45 config.pop(b'git_update_server_info')
@@ -1,292 +1,292 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import re
19 19 import logging
20 20 from wsgiref.util import FileWrapper
21 21
22 import simplejson as json
23 22 from pyramid.config import Configurator
24 23 from pyramid.response import Response, FileIter
25 24 from pyramid.httpexceptions import (
26 25 HTTPBadRequest, HTTPNotImplemented, HTTPNotFound, HTTPForbidden,
27 26 HTTPUnprocessableEntity)
28 27
28 from vcsserver.lib.rc_json import json
29 29 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
30 30 from vcsserver.git_lfs.utils import safe_result, get_cython_compat_decorator
31 31 from vcsserver.utils import safe_int
32 32
33 33 log = logging.getLogger(__name__)
34 34
35 35
36 36 GIT_LFS_CONTENT_TYPE = 'application/vnd.git-lfs' #+json ?
37 37 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
38 38
39 39
40 40 def write_response_error(http_exception, text=None):
41 41 content_type = GIT_LFS_CONTENT_TYPE + '+json'
42 42 _exception = http_exception(content_type=content_type)
43 43 _exception.content_type = content_type
44 44 if text:
45 _exception.body = json.dumps({'message': text})
45 _exception.text = json.dumps({'message': text})
46 46 log.debug('LFS: writing response of type %s to client with text:%s',
47 47 http_exception, text)
48 48 return _exception
49 49
50 50
51 51 class AuthHeaderRequired(object):
52 52 """
53 53 Decorator to check if request has proper auth-header
54 54 """
55 55
56 56 def __call__(self, func):
57 57 return get_cython_compat_decorator(self.__wrapper, func)
58 58
59 59 def __wrapper(self, func, *fargs, **fkwargs):
60 60 request = fargs[1]
61 61 auth = request.authorization
62 62 if not auth:
63 63 return write_response_error(HTTPForbidden)
64 64 return func(*fargs[1:], **fkwargs)
65 65
66 66
67 67 # views
68 68
69 69 def lfs_objects(request):
70 70 # indicate not supported, V1 API
71 71 log.warning('LFS: v1 api not supported, reporting it back to client')
72 72 return write_response_error(HTTPNotImplemented, 'LFS: v1 api not supported')
73 73
74 74
75 75 @AuthHeaderRequired()
76 76 def lfs_objects_batch(request):
77 77 """
78 78 The client sends the following information to the Batch endpoint to transfer some objects:
79 79
80 80 operation - Should be download or upload.
81 81 transfers - An optional Array of String identifiers for transfer
82 82 adapters that the client has configured. If omitted, the basic
83 83 transfer adapter MUST be assumed by the server.
84 84 objects - An Array of objects to download.
85 85 oid - String OID of the LFS object.
86 86 size - Integer byte size of the LFS object. Must be at least zero.
87 87 """
88 88 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
89 89 auth = request.authorization
90 90 repo = request.matchdict.get('repo')
91 91 data = request.json
92 92 operation = data.get('operation')
93 93 http_scheme = request.registry.git_lfs_http_scheme
94 94
95 95 if operation not in ('download', 'upload'):
96 96 log.debug('LFS: unsupported operation:%s', operation)
97 97 return write_response_error(
98 98 HTTPBadRequest, 'unsupported operation mode: `%s`' % operation)
99 99
100 100 if 'objects' not in data:
101 101 log.debug('LFS: missing objects data')
102 102 return write_response_error(
103 103 HTTPBadRequest, 'missing objects data')
104 104
105 105 log.debug('LFS: handling operation of type: %s', operation)
106 106
107 107 objects = []
108 108 for o in data['objects']:
109 109 try:
110 110 oid = o['oid']
111 111 obj_size = o['size']
112 112 except KeyError:
113 113 log.exception('LFS, failed to extract data')
114 114 return write_response_error(
115 115 HTTPBadRequest, 'unsupported data in objects')
116 116
117 117 obj_data = {'oid': oid}
118 118
119 119 obj_href = request.route_url('lfs_objects_oid', repo=repo, oid=oid,
120 120 _scheme=http_scheme)
121 121 obj_verify_href = request.route_url('lfs_objects_verify', repo=repo,
122 122 _scheme=http_scheme)
123 123 store = LFSOidStore(
124 124 oid, repo, store_location=request.registry.git_lfs_store_path)
125 125 handler = OidHandler(
126 126 store, repo, auth, oid, obj_size, obj_data,
127 127 obj_href, obj_verify_href)
128 128
129 129 # this verifies also OIDs
130 130 actions, errors = handler.exec_operation(operation)
131 131 if errors:
132 132 log.warning('LFS: got following errors: %s', errors)
133 133 obj_data['errors'] = errors
134 134
135 135 if actions:
136 136 obj_data['actions'] = actions
137 137
138 138 obj_data['size'] = obj_size
139 139 obj_data['authenticated'] = True
140 140 objects.append(obj_data)
141 141
142 142 result = {'objects': objects, 'transfer': 'basic'}
143 143 log.debug('LFS Response %s', safe_result(result))
144 144
145 145 return result
146 146
147 147
148 148 def lfs_objects_oid_upload(request):
149 149 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
150 150 repo = request.matchdict.get('repo')
151 151 oid = request.matchdict.get('oid')
152 152 store = LFSOidStore(
153 153 oid, repo, store_location=request.registry.git_lfs_store_path)
154 154 engine = store.get_engine(mode='wb')
155 155 log.debug('LFS: starting chunked write of LFS oid: %s to storage', oid)
156 156
157 157 body = request.environ['wsgi.input']
158 158
159 159 with engine as f:
160 160 blksize = 64 * 1024 # 64kb
161 161 while True:
162 162 # read in chunks as stream comes in from Gunicorn
163 163 # this is a specific Gunicorn support function.
164 164 # might work differently on waitress
165 165 chunk = body.read(blksize)
166 166 if not chunk:
167 167 break
168 168 f.write(chunk)
169 169
170 170 return {'upload': 'ok'}
171 171
172 172
173 173 def lfs_objects_oid_download(request):
174 174 repo = request.matchdict.get('repo')
175 175 oid = request.matchdict.get('oid')
176 176
177 177 store = LFSOidStore(
178 178 oid, repo, store_location=request.registry.git_lfs_store_path)
179 179 if not store.has_oid():
180 180 log.debug('LFS: oid %s does not exists in store', oid)
181 181 return write_response_error(
182 182 HTTPNotFound, 'requested file with oid `%s` not found in store' % oid)
183 183
184 184 # TODO(marcink): support range header ?
185 185 # Range: bytes=0-, `bytes=(\d+)\-.*`
186 186
187 187 f = open(store.oid_path, 'rb')
188 188 response = Response(
189 189 content_type='application/octet-stream', app_iter=FileIter(f))
190 190 response.headers.add('X-RC-LFS-Response-Oid', str(oid))
191 191 return response
192 192
193 193
194 194 def lfs_objects_verify(request):
195 195 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
196 196 repo = request.matchdict.get('repo')
197 197
198 198 data = request.json
199 199 oid = data.get('oid')
200 200 size = safe_int(data.get('size'))
201 201
202 202 if not (oid and size):
203 203 return write_response_error(
204 204 HTTPBadRequest, 'missing oid and size in request data')
205 205
206 206 store = LFSOidStore(
207 207 oid, repo, store_location=request.registry.git_lfs_store_path)
208 208 if not store.has_oid():
209 209 log.debug('LFS: oid %s does not exists in store', oid)
210 210 return write_response_error(
211 211 HTTPNotFound, 'oid `%s` does not exists in store' % oid)
212 212
213 213 store_size = store.size_oid()
214 214 if store_size != size:
215 215 msg = 'requested file size mismatch store size:%s requested:%s' % (
216 216 store_size, size)
217 217 return write_response_error(
218 218 HTTPUnprocessableEntity, msg)
219 219
220 220 return {'message': {'size': 'ok', 'in_store': 'ok'}}
221 221
222 222
223 223 def lfs_objects_lock(request):
224 224 return write_response_error(
225 225 HTTPNotImplemented, 'GIT LFS locking api not supported')
226 226
227 227
228 228 def not_found(request):
229 229 return write_response_error(
230 230 HTTPNotFound, 'request path not found')
231 231
232 232
233 233 def lfs_disabled(request):
234 234 return write_response_error(
235 235 HTTPNotImplemented, 'GIT LFS disabled for this repo')
236 236
237 237
238 238 def git_lfs_app(config):
239 239
240 240 # v1 API deprecation endpoint
241 241 config.add_route('lfs_objects',
242 242 '/{repo:.*?[^/]}/info/lfs/objects')
243 243 config.add_view(lfs_objects, route_name='lfs_objects',
244 244 request_method='POST', renderer='json')
245 245
246 246 # locking API
247 247 config.add_route('lfs_objects_lock',
248 248 '/{repo:.*?[^/]}/info/lfs/locks')
249 249 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock',
250 250 request_method=('POST', 'GET'), renderer='json')
251 251
252 252 config.add_route('lfs_objects_lock_verify',
253 253 '/{repo:.*?[^/]}/info/lfs/locks/verify')
254 254 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock_verify',
255 255 request_method=('POST', 'GET'), renderer='json')
256 256
257 257 # batch API
258 258 config.add_route('lfs_objects_batch',
259 259 '/{repo:.*?[^/]}/info/lfs/objects/batch')
260 260 config.add_view(lfs_objects_batch, route_name='lfs_objects_batch',
261 261 request_method='POST', renderer='json')
262 262
263 263 # oid upload/download API
264 264 config.add_route('lfs_objects_oid',
265 265 '/{repo:.*?[^/]}/info/lfs/objects/{oid}')
266 266 config.add_view(lfs_objects_oid_upload, route_name='lfs_objects_oid',
267 267 request_method='PUT', renderer='json')
268 268 config.add_view(lfs_objects_oid_download, route_name='lfs_objects_oid',
269 269 request_method='GET', renderer='json')
270 270
271 271 # verification API
272 272 config.add_route('lfs_objects_verify',
273 273 '/{repo:.*?[^/]}/info/lfs/verify')
274 274 config.add_view(lfs_objects_verify, route_name='lfs_objects_verify',
275 275 request_method='POST', renderer='json')
276 276
277 277 # not found handler for API
278 278 config.add_notfound_view(not_found, renderer='json')
279 279
280 280
281 281 def create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
282 282 config = Configurator()
283 283 if git_lfs_enabled:
284 284 config.include(git_lfs_app)
285 285 config.registry.git_lfs_store_path = git_lfs_store_path
286 286 config.registry.git_lfs_http_scheme = git_lfs_http_scheme
287 287 else:
288 288 # not found handler for API, reporting disabled LFS support
289 289 config.add_notfound_view(lfs_disabled, renderer='json')
290 290
291 291 app = config.make_wsgi_app()
292 292 return app
@@ -1,272 +1,273 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 20 from webtest.app import TestApp as WebObTestApp
21 import simplejson as json
22 21
22 from vcsserver.lib.rc_json import json
23 from vcsserver.utils import safe_bytes
23 24 from vcsserver.git_lfs.app import create_app
24 25
25 26
26 27 @pytest.fixture(scope='function')
27 28 def git_lfs_app(tmpdir):
28 29 custom_app = WebObTestApp(create_app(
29 30 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
30 31 git_lfs_http_scheme='http'))
31 32 custom_app._store = str(tmpdir)
32 33 return custom_app
33 34
34 35
35 36 @pytest.fixture(scope='function')
36 37 def git_lfs_https_app(tmpdir):
37 38 custom_app = WebObTestApp(create_app(
38 39 git_lfs_enabled=True, git_lfs_store_path=str(tmpdir),
39 40 git_lfs_http_scheme='https'))
40 41 custom_app._store = str(tmpdir)
41 42 return custom_app
42 43
43 44
44 45 @pytest.fixture()
45 46 def http_auth():
46 47 return {'HTTP_AUTHORIZATION': "Basic XXXXX"}
47 48
48 49
49 50 class TestLFSApplication(object):
50 51
51 52 def test_app_wrong_path(self, git_lfs_app):
52 53 git_lfs_app.get('/repo/info/lfs/xxx', status=404)
53 54
54 55 def test_app_deprecated_endpoint(self, git_lfs_app):
55 56 response = git_lfs_app.post('/repo/info/lfs/objects', status=501)
56 57 assert response.status_code == 501
57 58 assert json.loads(response.text) == {'message': 'LFS: v1 api not supported'}
58 59
59 60 def test_app_lock_verify_api_not_available(self, git_lfs_app):
60 61 response = git_lfs_app.post('/repo/info/lfs/locks/verify', status=501)
61 62 assert response.status_code == 501
62 63 assert json.loads(response.text) == {
63 64 'message': 'GIT LFS locking api not supported'}
64 65
65 66 def test_app_lock_api_not_available(self, git_lfs_app):
66 67 response = git_lfs_app.post('/repo/info/lfs/locks', status=501)
67 68 assert response.status_code == 501
68 69 assert json.loads(response.text) == {
69 70 'message': 'GIT LFS locking api not supported'}
70 71
71 72 def test_app_batch_api_missing_auth(self, git_lfs_app):
72 73 git_lfs_app.post_json(
73 74 '/repo/info/lfs/objects/batch', params={}, status=403)
74 75
75 76 def test_app_batch_api_unsupported_operation(self, git_lfs_app, http_auth):
76 77 response = git_lfs_app.post_json(
77 78 '/repo/info/lfs/objects/batch', params={}, status=400,
78 79 extra_environ=http_auth)
79 80 assert json.loads(response.text) == {
80 81 'message': 'unsupported operation mode: `None`'}
81 82
82 83 def test_app_batch_api_missing_objects(self, git_lfs_app, http_auth):
83 84 response = git_lfs_app.post_json(
84 85 '/repo/info/lfs/objects/batch', params={'operation': 'download'},
85 86 status=400, extra_environ=http_auth)
86 87 assert json.loads(response.text) == {
87 88 'message': 'missing objects data'}
88 89
89 90 def test_app_batch_api_unsupported_data_in_objects(
90 91 self, git_lfs_app, http_auth):
91 92 params = {'operation': 'download',
92 93 'objects': [{}]}
93 94 response = git_lfs_app.post_json(
94 95 '/repo/info/lfs/objects/batch', params=params, status=400,
95 96 extra_environ=http_auth)
96 97 assert json.loads(response.text) == {
97 98 'message': 'unsupported data in objects'}
98 99
99 100 def test_app_batch_api_download_missing_object(
100 101 self, git_lfs_app, http_auth):
101 102 params = {'operation': 'download',
102 103 'objects': [{'oid': '123', 'size': '1024'}]}
103 104 response = git_lfs_app.post_json(
104 105 '/repo/info/lfs/objects/batch', params=params,
105 106 extra_environ=http_auth)
106 107
107 108 expected_objects = [
108 109 {'authenticated': True,
109 110 'errors': {'error': {
110 111 'code': 404,
111 112 'message': 'object: 123 does not exist in store'}},
112 113 'oid': '123',
113 114 'size': '1024'}
114 115 ]
115 116 assert json.loads(response.text) == {
116 117 'objects': expected_objects, 'transfer': 'basic'}
117 118
118 119 def test_app_batch_api_download(self, git_lfs_app, http_auth):
119 120 oid = '456'
120 121 oid_path = os.path.join(git_lfs_app._store, oid)
121 122 if not os.path.isdir(os.path.dirname(oid_path)):
122 123 os.makedirs(os.path.dirname(oid_path))
123 124 with open(oid_path, 'wb') as f:
124 f.write('OID_CONTENT')
125 f.write(safe_bytes('OID_CONTENT'))
125 126
126 127 params = {'operation': 'download',
127 128 'objects': [{'oid': oid, 'size': '1024'}]}
128 129 response = git_lfs_app.post_json(
129 130 '/repo/info/lfs/objects/batch', params=params,
130 131 extra_environ=http_auth)
131 132
132 133 expected_objects = [
133 134 {'authenticated': True,
134 135 'actions': {
135 136 'download': {
136 137 'header': {'Authorization': 'Basic XXXXX'},
137 138 'href': 'http://localhost/repo/info/lfs/objects/456'},
138 139 },
139 140 'oid': '456',
140 141 'size': '1024'}
141 142 ]
142 143 assert json.loads(response.text) == {
143 144 'objects': expected_objects, 'transfer': 'basic'}
144 145
145 146 def test_app_batch_api_upload(self, git_lfs_app, http_auth):
146 147 params = {'operation': 'upload',
147 148 'objects': [{'oid': '123', 'size': '1024'}]}
148 149 response = git_lfs_app.post_json(
149 150 '/repo/info/lfs/objects/batch', params=params,
150 151 extra_environ=http_auth)
151 152 expected_objects = [
152 153 {'authenticated': True,
153 154 'actions': {
154 155 'upload': {
155 156 'header': {'Authorization': 'Basic XXXXX',
156 157 'Transfer-Encoding': 'chunked'},
157 158 'href': 'http://localhost/repo/info/lfs/objects/123'},
158 159 'verify': {
159 160 'header': {'Authorization': 'Basic XXXXX'},
160 161 'href': 'http://localhost/repo/info/lfs/verify'}
161 162 },
162 163 'oid': '123',
163 164 'size': '1024'}
164 165 ]
165 166 assert json.loads(response.text) == {
166 167 'objects': expected_objects, 'transfer': 'basic'}
167 168
168 169 def test_app_batch_api_upload_for_https(self, git_lfs_https_app, http_auth):
169 170 params = {'operation': 'upload',
170 171 'objects': [{'oid': '123', 'size': '1024'}]}
171 172 response = git_lfs_https_app.post_json(
172 173 '/repo/info/lfs/objects/batch', params=params,
173 174 extra_environ=http_auth)
174 175 expected_objects = [
175 176 {'authenticated': True,
176 177 'actions': {
177 178 'upload': {
178 179 'header': {'Authorization': 'Basic XXXXX',
179 180 'Transfer-Encoding': 'chunked'},
180 181 'href': 'https://localhost/repo/info/lfs/objects/123'},
181 182 'verify': {
182 183 'header': {'Authorization': 'Basic XXXXX'},
183 184 'href': 'https://localhost/repo/info/lfs/verify'}
184 185 },
185 186 'oid': '123',
186 187 'size': '1024'}
187 188 ]
188 189 assert json.loads(response.text) == {
189 190 'objects': expected_objects, 'transfer': 'basic'}
190 191
191 192 def test_app_verify_api_missing_data(self, git_lfs_app):
192 193 params = {'oid': 'missing'}
193 194 response = git_lfs_app.post_json(
194 195 '/repo/info/lfs/verify', params=params,
195 196 status=400)
196 197
197 198 assert json.loads(response.text) == {
198 199 'message': 'missing oid and size in request data'}
199 200
200 201 def test_app_verify_api_missing_obj(self, git_lfs_app):
201 202 params = {'oid': 'missing', 'size': '1024'}
202 203 response = git_lfs_app.post_json(
203 204 '/repo/info/lfs/verify', params=params,
204 205 status=404)
205 206
206 207 assert json.loads(response.text) == {
207 208 'message': 'oid `missing` does not exists in store'}
208 209
209 210 def test_app_verify_api_size_mismatch(self, git_lfs_app):
210 211 oid = 'existing'
211 212 oid_path = os.path.join(git_lfs_app._store, oid)
212 213 if not os.path.isdir(os.path.dirname(oid_path)):
213 214 os.makedirs(os.path.dirname(oid_path))
214 215 with open(oid_path, 'wb') as f:
215 f.write('OID_CONTENT')
216 f.write(safe_bytes('OID_CONTENT'))
216 217
217 218 params = {'oid': oid, 'size': '1024'}
218 219 response = git_lfs_app.post_json(
219 220 '/repo/info/lfs/verify', params=params, status=422)
220 221
221 222 assert json.loads(response.text) == {
222 223 'message': 'requested file size mismatch '
223 224 'store size:11 requested:1024'}
224 225
225 226 def test_app_verify_api(self, git_lfs_app):
226 227 oid = 'existing'
227 228 oid_path = os.path.join(git_lfs_app._store, oid)
228 229 if not os.path.isdir(os.path.dirname(oid_path)):
229 230 os.makedirs(os.path.dirname(oid_path))
230 231 with open(oid_path, 'wb') as f:
231 f.write('OID_CONTENT')
232 f.write(safe_bytes('OID_CONTENT'))
232 233
233 234 params = {'oid': oid, 'size': 11}
234 235 response = git_lfs_app.post_json(
235 236 '/repo/info/lfs/verify', params=params)
236 237
237 238 assert json.loads(response.text) == {
238 239 'message': {'size': 'ok', 'in_store': 'ok'}}
239 240
240 241 def test_app_download_api_oid_not_existing(self, git_lfs_app):
241 242 oid = 'missing'
242 243
243 244 response = git_lfs_app.get(
244 245 '/repo/info/lfs/objects/{oid}'.format(oid=oid), status=404)
245 246
246 247 assert json.loads(response.text) == {
247 248 'message': 'requested file with oid `missing` not found in store'}
248 249
249 250 def test_app_download_api(self, git_lfs_app):
250 251 oid = 'existing'
251 252 oid_path = os.path.join(git_lfs_app._store, oid)
252 253 if not os.path.isdir(os.path.dirname(oid_path)):
253 254 os.makedirs(os.path.dirname(oid_path))
254 255 with open(oid_path, 'wb') as f:
255 f.write('OID_CONTENT')
256 f.write(safe_bytes('OID_CONTENT'))
256 257
257 258 response = git_lfs_app.get(
258 259 '/repo/info/lfs/objects/{oid}'.format(oid=oid))
259 260 assert response
260 261
261 262 def test_app_upload(self, git_lfs_app):
262 263 oid = 'uploaded'
263 264
264 265 response = git_lfs_app.put(
265 266 '/repo/info/lfs/objects/{oid}'.format(oid=oid), params='CONTENT')
266 267
267 268 assert json.loads(response.text) == {'upload': 'ok'}
268 269
269 270 # verify that we actually wrote that OID
270 271 oid_path = os.path.join(git_lfs_app._store, oid)
271 272 assert os.path.isfile(oid_path)
272 273 assert 'CONTENT' == open(oid_path).read()
@@ -1,141 +1,142 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import pytest
20 from vcsserver.utils import safe_bytes
20 21 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
21 22
22 23
23 24 @pytest.fixture()
24 25 def lfs_store(tmpdir):
25 26 repo = 'test'
26 27 oid = '123456789'
27 28 store = LFSOidStore(oid=oid, repo=repo, store_location=str(tmpdir))
28 29 return store
29 30
30 31
31 32 @pytest.fixture()
32 33 def oid_handler(lfs_store):
33 34 store = lfs_store
34 35 repo = store.repo
35 36 oid = store.oid
36 37
37 38 oid_handler = OidHandler(
38 39 store=store, repo_name=repo, auth=('basic', 'xxxx'),
39 40 oid=oid,
40 41 obj_size='1024', obj_data={}, obj_href='http://localhost/handle_oid',
41 42 obj_verify_href='http://localhost/verify')
42 43 return oid_handler
43 44
44 45
45 46 class TestOidHandler(object):
46 47
47 48 @pytest.mark.parametrize('exec_action', [
48 49 'download',
49 50 'upload',
50 51 ])
51 52 def test_exec_action(self, exec_action, oid_handler):
52 53 handler = oid_handler.exec_operation(exec_action)
53 54 assert handler
54 55
55 56 def test_exec_action_undefined(self, oid_handler):
56 57 with pytest.raises(AttributeError):
57 58 oid_handler.exec_operation('wrong')
58 59
59 60 def test_download_oid_not_existing(self, oid_handler):
60 61 response, has_errors = oid_handler.exec_operation('download')
61 62
62 63 assert response is None
63 64 assert has_errors['error'] == {
64 65 'code': 404,
65 66 'message': 'object: 123456789 does not exist in store'}
66 67
67 68 def test_download_oid(self, oid_handler):
68 69 store = oid_handler.get_store()
69 70 if not os.path.isdir(os.path.dirname(store.oid_path)):
70 71 os.makedirs(os.path.dirname(store.oid_path))
71 72
72 73 with open(store.oid_path, 'wb') as f:
73 f.write('CONTENT')
74 f.write(safe_bytes('CONTENT'))
74 75
75 76 response, has_errors = oid_handler.exec_operation('download')
76 77
77 78 assert has_errors is None
78 79 assert response['download'] == {
79 80 'header': {'Authorization': 'basic xxxx'},
80 81 'href': 'http://localhost/handle_oid'
81 82 }
82 83
83 84 def test_upload_oid_that_exists(self, oid_handler):
84 85 store = oid_handler.get_store()
85 86 if not os.path.isdir(os.path.dirname(store.oid_path)):
86 87 os.makedirs(os.path.dirname(store.oid_path))
87 88
88 89 with open(store.oid_path, 'wb') as f:
89 f.write('CONTENT')
90 f.write(safe_bytes('CONTENT'))
90 91 oid_handler.obj_size = 7
91 92 response, has_errors = oid_handler.exec_operation('upload')
92 93 assert has_errors is None
93 94 assert response is None
94 95
95 96 def test_upload_oid_that_exists_but_has_wrong_size(self, oid_handler):
96 97 store = oid_handler.get_store()
97 98 if not os.path.isdir(os.path.dirname(store.oid_path)):
98 99 os.makedirs(os.path.dirname(store.oid_path))
99 100
100 101 with open(store.oid_path, 'wb') as f:
101 f.write('CONTENT')
102 f.write(safe_bytes('CONTENT'))
102 103
103 104 oid_handler.obj_size = 10240
104 105 response, has_errors = oid_handler.exec_operation('upload')
105 106 assert has_errors is None
106 107 assert response['upload'] == {
107 108 'header': {'Authorization': 'basic xxxx',
108 109 'Transfer-Encoding': 'chunked'},
109 110 'href': 'http://localhost/handle_oid',
110 111 }
111 112
112 113 def test_upload_oid(self, oid_handler):
113 114 response, has_errors = oid_handler.exec_operation('upload')
114 115 assert has_errors is None
115 116 assert response['upload'] == {
116 117 'header': {'Authorization': 'basic xxxx',
117 118 'Transfer-Encoding': 'chunked'},
118 119 'href': 'http://localhost/handle_oid'
119 120 }
120 121
121 122
122 123 class TestLFSStore(object):
123 124 def test_write_oid(self, lfs_store):
124 125 oid_location = lfs_store.oid_path
125 126
126 127 assert not os.path.isfile(oid_location)
127 128
128 129 engine = lfs_store.get_engine(mode='wb')
129 130 with engine as f:
130 f.write('CONTENT')
131 f.write(safe_bytes('CONTENT'))
131 132
132 133 assert os.path.isfile(oid_location)
133 134
134 135 def test_detect_has_oid(self, lfs_store):
135 136
136 137 assert lfs_store.has_oid() is False
137 138 engine = lfs_store.get_engine(mode='wb')
138 139 with engine as f:
139 f.write('CONTENT')
140 f.write(safe_bytes('CONTENT'))
140 141
141 assert lfs_store.has_oid() is True No newline at end of file
142 assert lfs_store.has_oid() is True
@@ -1,205 +1,204 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2020 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20 import re
21 21 import os
22 22 import sys
23 23 import datetime
24 24 import logging
25 25 import pkg_resources
26 26
27 27 import vcsserver
28 from vcsserver.utils import safe_bytes
28 29
29 30 log = logging.getLogger(__name__)
30 31
31 32
32 33 def get_git_hooks_path(repo_path, bare):
33 34 hooks_path = os.path.join(repo_path, 'hooks')
34 35 if not bare:
35 36 hooks_path = os.path.join(repo_path, '.git', 'hooks')
36 37
37 38 return hooks_path
38 39
39 40
40 41 def install_git_hooks(repo_path, bare, executable=None, force_create=False):
41 42 """
42 43 Creates a RhodeCode hook inside a git repository
43 44
44 45 :param repo_path: path to repository
45 46 :param executable: binary executable to put in the hooks
46 47 :param force_create: Create even if same name hook exists
47 48 """
48 49 executable = executable or sys.executable
49 50 hooks_path = get_git_hooks_path(repo_path, bare)
50 51
51 52 if not os.path.isdir(hooks_path):
52 53 os.makedirs(hooks_path, mode=0o777)
53 54
54 55 tmpl_post = pkg_resources.resource_string(
55 56 'vcsserver', '/'.join(
56 57 ('hook_utils', 'hook_templates', 'git_post_receive.py.tmpl')))
57 58 tmpl_pre = pkg_resources.resource_string(
58 59 'vcsserver', '/'.join(
59 60 ('hook_utils', 'hook_templates', 'git_pre_receive.py.tmpl')))
60 61
61 62 path = '' # not used for now
62 63 timestamp = datetime.datetime.utcnow().isoformat()
63 64
64 65 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
65 66 log.debug('Installing git hook in repo %s', repo_path)
66 67 _hook_file = os.path.join(hooks_path, '%s-receive' % h_type)
67 68 _rhodecode_hook = check_rhodecode_hook(_hook_file)
68 69
69 70 if _rhodecode_hook or force_create:
70 71 log.debug('writing git %s hook file at %s !', h_type, _hook_file)
71 72 try:
72 73 with open(_hook_file, 'wb') as f:
73 template = template.replace(
74 '_TMPL_', vcsserver.__version__)
75 template = template.replace('_DATE_', timestamp)
76 template = template.replace('_ENV_', executable)
77 template = template.replace('_PATH_', path)
74 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
75 template = template.replace(b'_DATE_', safe_bytes(timestamp))
76 template = template.replace(b'_ENV_', safe_bytes(executable))
77 template = template.replace(b'_PATH_', safe_bytes(path))
78 78 f.write(template)
79 79 os.chmod(_hook_file, 0o755)
80 80 except IOError:
81 81 log.exception('error writing hook file %s', _hook_file)
82 82 else:
83 83 log.debug('skipping writing hook file')
84 84
85 85 return True
86 86
87 87
88 88 def get_svn_hooks_path(repo_path):
89 89 hooks_path = os.path.join(repo_path, 'hooks')
90 90
91 91 return hooks_path
92 92
93 93
94 94 def install_svn_hooks(repo_path, executable=None, force_create=False):
95 95 """
96 96 Creates RhodeCode hooks inside a svn repository
97 97
98 98 :param repo_path: path to repository
99 99 :param executable: binary executable to put in the hooks
100 100 :param force_create: Create even if same name hook exists
101 101 """
102 102 executable = executable or sys.executable
103 103 hooks_path = get_svn_hooks_path(repo_path)
104 104 if not os.path.isdir(hooks_path):
105 105 os.makedirs(hooks_path, mode=0o777)
106 106
107 107 tmpl_post = pkg_resources.resource_string(
108 108 'vcsserver', '/'.join(
109 109 ('hook_utils', 'hook_templates', 'svn_post_commit_hook.py.tmpl')))
110 110 tmpl_pre = pkg_resources.resource_string(
111 111 'vcsserver', '/'.join(
112 112 ('hook_utils', 'hook_templates', 'svn_pre_commit_hook.py.tmpl')))
113 113
114 114 path = '' # not used for now
115 115 timestamp = datetime.datetime.utcnow().isoformat()
116 116
117 117 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
118 118 log.debug('Installing svn hook in repo %s', repo_path)
119 119 _hook_file = os.path.join(hooks_path, '%s-commit' % h_type)
120 120 _rhodecode_hook = check_rhodecode_hook(_hook_file)
121 121
122 122 if _rhodecode_hook or force_create:
123 123 log.debug('writing svn %s hook file at %s !', h_type, _hook_file)
124 124
125 125 try:
126 126 with open(_hook_file, 'wb') as f:
127 template = template.replace(
128 '_TMPL_', vcsserver.__version__)
129 template = template.replace('_DATE_', timestamp)
130 template = template.replace('_ENV_', executable)
131 template = template.replace('_PATH_', path)
127 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
128 template = template.replace(b'_DATE_', safe_bytes(timestamp))
129 template = template.replace(b'_ENV_', safe_bytes(executable))
130 template = template.replace(b'_PATH_', safe_bytes(path))
132 131
133 132 f.write(template)
134 133 os.chmod(_hook_file, 0o755)
135 134 except IOError:
136 135 log.exception('error writing hook file %s', _hook_file)
137 136 else:
138 137 log.debug('skipping writing hook file')
139 138
140 139 return True
141 140
142 141
143 142 def get_version_from_hook(hook_path):
144 version = ''
143 version = b''
145 144 hook_content = read_hook_content(hook_path)
146 matches = re.search(r'(?:RC_HOOK_VER)\s*=\s*(.*)', hook_content)
145 matches = re.search(rb'(?:RC_HOOK_VER)\s*=\s*(.*)', hook_content)
147 146 if matches:
148 147 try:
149 148 version = matches.groups()[0]
150 149 log.debug('got version %s from hooks.', version)
151 150 except Exception:
152 151 log.exception("Exception while reading the hook version.")
153 return version.replace("'", "")
152 return version.replace(b"'", b"")
154 153
155 154
156 155 def check_rhodecode_hook(hook_path):
157 156 """
158 157 Check if the hook was created by RhodeCode
159 158 """
160 159 if not os.path.exists(hook_path):
161 160 return True
162 161
163 162 log.debug('hook exists, checking if it is from RhodeCode')
164 163
165 164 version = get_version_from_hook(hook_path)
166 165 if version:
167 166 return True
168 167
169 168 return False
170 169
171 170
172 171 def read_hook_content(hook_path):
173 172 content = ''
174 173 if os.path.isfile(hook_path):
175 174 with open(hook_path, 'rb') as f:
176 175 content = f.read()
177 176 return content
178 177
179 178
180 179 def get_git_pre_hook_version(repo_path, bare):
181 180 hooks_path = get_git_hooks_path(repo_path, bare)
182 181 _hook_file = os.path.join(hooks_path, 'pre-receive')
183 182 version = get_version_from_hook(_hook_file)
184 183 return version
185 184
186 185
187 186 def get_git_post_hook_version(repo_path, bare):
188 187 hooks_path = get_git_hooks_path(repo_path, bare)
189 188 _hook_file = os.path.join(hooks_path, 'post-receive')
190 189 version = get_version_from_hook(_hook_file)
191 190 return version
192 191
193 192
194 193 def get_svn_pre_hook_version(repo_path):
195 194 hooks_path = get_svn_hooks_path(repo_path)
196 195 _hook_file = os.path.join(hooks_path, 'pre-commit')
197 196 version = get_version_from_hook(_hook_file)
198 197 return version
199 198
200 199
201 200 def get_svn_post_hook_version(repo_path):
202 201 hooks_path = get_svn_hooks_path(repo_path)
203 202 _hook_file = os.path.join(hooks_path, 'post-commit')
204 203 version = get_version_from_hook(_hook_file)
205 204 return version
@@ -1,729 +1,738 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2020 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20 import io
21 21 import os
22 22 import sys
23 23 import logging
24 24 import collections
25 25 import importlib
26 26 import base64
27 import msgpack
27 28
28 29 from http.client import HTTPConnection
29 30
30 31
31 32 import mercurial.scmutil
32 33 import mercurial.node
33 import simplejson as json
34 34
35 from vcsserver.lib.rc_json import json
35 36 from vcsserver import exceptions, subprocessio, settings
37 from vcsserver.utils import safe_bytes
36 38
37 39 log = logging.getLogger(__name__)
38 40
39 41
40 42 class HooksHttpClient(object):
43 proto = 'msgpack.v1'
41 44 connection = None
42 45
43 46 def __init__(self, hooks_uri):
44 47 self.hooks_uri = hooks_uri
45 48
46 49 def __call__(self, method, extras):
47 50 connection = HTTPConnection(self.hooks_uri)
48 body = self._serialize(method, extras)
51 # binary msgpack body
52 headers, body = self._serialize(method, extras)
49 53 try:
50 connection.request('POST', '/', body)
51 except Exception:
52 log.error('Hooks calling Connection failed on %s', connection.__dict__)
54 connection.request('POST', '/', body, headers)
55 except Exception as error:
56 log.error('Hooks calling Connection failed on %s, org error: %s', connection.__dict__, error)
53 57 raise
54 58 response = connection.getresponse()
55
56 response_data = response.read()
57
58 59 try:
59 return json.loads(response_data)
60 return msgpack.load(response, raw=False)
60 61 except Exception:
62 response_data = response.read()
61 63 log.exception('Failed to decode hook response json data. '
62 64 'response_code:%s, raw_data:%s',
63 65 response.status, response_data)
64 66 raise
65 67
66 def _serialize(self, hook_name, extras):
68 @classmethod
69 def _serialize(cls, hook_name, extras):
67 70 data = {
68 71 'method': hook_name,
69 72 'extras': extras
70 73 }
71 return json.dumps(data)
74 headers = {
75 'rc-hooks-protocol': cls.proto
76 }
77 return headers, msgpack.packb(data)
72 78
73 79
74 80 class HooksDummyClient(object):
75 81 def __init__(self, hooks_module):
76 82 self._hooks_module = importlib.import_module(hooks_module)
77 83
78 84 def __call__(self, hook_name, extras):
79 85 with self._hooks_module.Hooks() as hooks:
80 86 return getattr(hooks, hook_name)(extras)
81 87
82 88
83 89 class HooksShadowRepoClient(object):
84 90
85 91 def __call__(self, hook_name, extras):
86 92 return {'output': '', 'status': 0}
87 93
88 94
89 95 class RemoteMessageWriter(object):
90 96 """Writer base class."""
91 97 def write(self, message):
92 98 raise NotImplementedError()
93 99
94 100
95 101 class HgMessageWriter(RemoteMessageWriter):
96 102 """Writer that knows how to send messages to mercurial clients."""
97 103
98 104 def __init__(self, ui):
99 105 self.ui = ui
100 106
101 107 def write(self, message):
102 108 # TODO: Check why the quiet flag is set by default.
103 109 old = self.ui.quiet
104 110 self.ui.quiet = False
105 111 self.ui.status(message.encode('utf-8'))
106 112 self.ui.quiet = old
107 113
108 114
109 115 class GitMessageWriter(RemoteMessageWriter):
110 116 """Writer that knows how to send messages to git clients."""
111 117
112 118 def __init__(self, stdout=None):
113 119 self.stdout = stdout or sys.stdout
114 120
115 121 def write(self, message):
116 self.stdout.write(message.encode('utf-8'))
122 self.stdout.write(safe_bytes(message))
117 123
118 124
119 125 class SvnMessageWriter(RemoteMessageWriter):
120 126 """Writer that knows how to send messages to svn clients."""
121 127
122 128 def __init__(self, stderr=None):
123 129 # SVN needs data sent to stderr for back-to-client messaging
124 130 self.stderr = stderr or sys.stderr
125 131
126 132 def write(self, message):
127 133 self.stderr.write(message.encode('utf-8'))
128 134
129 135
130 136 def _handle_exception(result):
131 137 exception_class = result.get('exception')
132 138 exception_traceback = result.get('exception_traceback')
133 139
134 140 if exception_traceback:
135 141 log.error('Got traceback from remote call:%s', exception_traceback)
136 142
137 143 if exception_class == 'HTTPLockedRC':
138 144 raise exceptions.RepositoryLockedException()(*result['exception_args'])
139 145 elif exception_class == 'HTTPBranchProtected':
140 146 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
141 147 elif exception_class == 'RepositoryError':
142 148 raise exceptions.VcsException()(*result['exception_args'])
143 149 elif exception_class:
144 150 raise Exception('Got remote exception "%s" with args "%s"' %
145 151 (exception_class, result['exception_args']))
146 152
147 153
148 154 def _get_hooks_client(extras):
149 155 hooks_uri = extras.get('hooks_uri')
150 156 is_shadow_repo = extras.get('is_shadow_repo')
151 157 if hooks_uri:
152 158 return HooksHttpClient(extras['hooks_uri'])
153 159 elif is_shadow_repo:
154 160 return HooksShadowRepoClient()
155 161 else:
156 162 return HooksDummyClient(extras['hooks_module'])
157 163
158 164
159 165 def _call_hook(hook_name, extras, writer):
160 166 hooks_client = _get_hooks_client(extras)
161 167 log.debug('Hooks, using client:%s', hooks_client)
162 168 result = hooks_client(hook_name, extras)
163 169 log.debug('Hooks got result: %s', result)
164 170
165 171 _handle_exception(result)
166 172 writer.write(result['output'])
167 173
168 174 return result['status']
169 175
170 176
171 177 def _extras_from_ui(ui):
172 178 hook_data = ui.config('rhodecode', 'RC_SCM_DATA')
173 179 if not hook_data:
174 180 # maybe it's inside environ ?
175 181 env_hook_data = os.environ.get('RC_SCM_DATA')
176 182 if env_hook_data:
177 183 hook_data = env_hook_data
178 184
179 185 extras = {}
180 186 if hook_data:
181 187 extras = json.loads(hook_data)
182 188 return extras
183 189
184 190
185 191 def _rev_range_hash(repo, node, check_heads=False):
186 192 from vcsserver.hgcompat import get_ctx
187 193
188 194 commits = []
189 195 revs = []
190 196 start = get_ctx(repo, node).rev()
191 197 end = len(repo)
192 198 for rev in range(start, end):
193 199 revs.append(rev)
194 200 ctx = get_ctx(repo, rev)
195 201 commit_id = mercurial.node.hex(ctx.node())
196 202 branch = ctx.branch()
197 203 commits.append((commit_id, branch))
198 204
199 205 parent_heads = []
200 206 if check_heads:
201 207 parent_heads = _check_heads(repo, start, end, revs)
202 208 return commits, parent_heads
203 209
204 210
205 211 def _check_heads(repo, start, end, commits):
206 212 from vcsserver.hgcompat import get_ctx
207 213 changelog = repo.changelog
208 214 parents = set()
209 215
210 216 for new_rev in commits:
211 217 for p in changelog.parentrevs(new_rev):
212 218 if p == mercurial.node.nullrev:
213 219 continue
214 220 if p < start:
215 221 parents.add(p)
216 222
217 223 for p in parents:
218 224 branch = get_ctx(repo, p).branch()
219 225 # The heads descending from that parent, on the same branch
220 226 parent_heads = set([p])
221 227 reachable = set([p])
222 228 for x in range(p + 1, end):
223 229 if get_ctx(repo, x).branch() != branch:
224 230 continue
225 231 for pp in changelog.parentrevs(x):
226 232 if pp in reachable:
227 233 reachable.add(x)
228 234 parent_heads.discard(pp)
229 235 parent_heads.add(x)
230 236 # More than one head? Suggest merging
231 237 if len(parent_heads) > 1:
232 238 return list(parent_heads)
233 239
234 240 return []
235 241
236 242
237 243 def _get_git_env():
238 244 env = {}
239 245 for k, v in os.environ.items():
240 246 if k.startswith('GIT'):
241 247 env[k] = v
242 248
243 249 # serialized version
244 250 return [(k, v) for k, v in env.items()]
245 251
246 252
247 253 def _get_hg_env(old_rev, new_rev, txnid, repo_path):
248 254 env = {}
249 255 for k, v in os.environ.items():
250 256 if k.startswith('HG'):
251 257 env[k] = v
252 258
253 259 env['HG_NODE'] = old_rev
254 260 env['HG_NODE_LAST'] = new_rev
255 261 env['HG_TXNID'] = txnid
256 262 env['HG_PENDING'] = repo_path
257 263
258 264 return [(k, v) for k, v in env.items()]
259 265
260 266
261 267 def repo_size(ui, repo, **kwargs):
262 268 extras = _extras_from_ui(ui)
263 269 return _call_hook('repo_size', extras, HgMessageWriter(ui))
264 270
265 271
266 272 def pre_pull(ui, repo, **kwargs):
267 273 extras = _extras_from_ui(ui)
268 274 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
269 275
270 276
271 277 def pre_pull_ssh(ui, repo, **kwargs):
272 278 extras = _extras_from_ui(ui)
273 279 if extras and extras.get('SSH'):
274 280 return pre_pull(ui, repo, **kwargs)
275 281 return 0
276 282
277 283
278 284 def post_pull(ui, repo, **kwargs):
279 285 extras = _extras_from_ui(ui)
280 286 return _call_hook('post_pull', extras, HgMessageWriter(ui))
281 287
282 288
283 289 def post_pull_ssh(ui, repo, **kwargs):
284 290 extras = _extras_from_ui(ui)
285 291 if extras and extras.get('SSH'):
286 292 return post_pull(ui, repo, **kwargs)
287 293 return 0
288 294
289 295
290 296 def pre_push(ui, repo, node=None, **kwargs):
291 297 """
292 298 Mercurial pre_push hook
293 299 """
294 300 extras = _extras_from_ui(ui)
295 301 detect_force_push = extras.get('detect_force_push')
296 302
297 303 rev_data = []
298 304 if node and kwargs.get('hooktype') == 'pretxnchangegroup':
299 305 branches = collections.defaultdict(list)
300 306 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
301 307 for commit_id, branch in commits:
302 308 branches[branch].append(commit_id)
303 309
304 310 for branch, commits in branches.items():
305 311 old_rev = kwargs.get('node_last') or commits[0]
306 312 rev_data.append({
307 313 'total_commits': len(commits),
308 314 'old_rev': old_rev,
309 315 'new_rev': commits[-1],
310 316 'ref': '',
311 317 'type': 'branch',
312 318 'name': branch,
313 319 })
314 320
315 321 for push_ref in rev_data:
316 322 push_ref['multiple_heads'] = _heads
317 323
318 324 repo_path = os.path.join(
319 325 extras.get('repo_store', ''), extras.get('repository', ''))
320 326 push_ref['hg_env'] = _get_hg_env(
321 327 old_rev=push_ref['old_rev'],
322 328 new_rev=push_ref['new_rev'], txnid=kwargs.get('txnid'),
323 329 repo_path=repo_path)
324 330
325 331 extras['hook_type'] = kwargs.get('hooktype', 'pre_push')
326 332 extras['commit_ids'] = rev_data
327 333
328 334 return _call_hook('pre_push', extras, HgMessageWriter(ui))
329 335
330 336
331 337 def pre_push_ssh(ui, repo, node=None, **kwargs):
332 338 extras = _extras_from_ui(ui)
333 339 if extras.get('SSH'):
334 340 return pre_push(ui, repo, node, **kwargs)
335 341
336 342 return 0
337 343
338 344
339 345 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
340 346 """
341 347 Mercurial pre_push hook for SSH
342 348 """
343 349 extras = _extras_from_ui(ui)
344 350 if extras.get('SSH'):
345 351 permission = extras['SSH_PERMISSIONS']
346 352
347 353 if 'repository.write' == permission or 'repository.admin' == permission:
348 354 return 0
349 355
350 356 # non-zero ret code
351 357 return 1
352 358
353 359 return 0
354 360
355 361
356 362 def post_push(ui, repo, node, **kwargs):
357 363 """
358 364 Mercurial post_push hook
359 365 """
360 366 extras = _extras_from_ui(ui)
361 367
362 368 commit_ids = []
363 369 branches = []
364 370 bookmarks = []
365 371 tags = []
366 372
367 373 commits, _heads = _rev_range_hash(repo, node)
368 374 for commit_id, branch in commits:
369 375 commit_ids.append(commit_id)
370 376 if branch not in branches:
371 377 branches.append(branch)
372 378
373 379 if hasattr(ui, '_rc_pushkey_branches'):
374 380 bookmarks = ui._rc_pushkey_branches
375 381
376 382 extras['hook_type'] = kwargs.get('hooktype', 'post_push')
377 383 extras['commit_ids'] = commit_ids
378 384 extras['new_refs'] = {
379 385 'branches': branches,
380 386 'bookmarks': bookmarks,
381 387 'tags': tags
382 388 }
383 389
384 390 return _call_hook('post_push', extras, HgMessageWriter(ui))
385 391
386 392
387 393 def post_push_ssh(ui, repo, node, **kwargs):
388 394 """
389 395 Mercurial post_push hook for SSH
390 396 """
391 397 if _extras_from_ui(ui).get('SSH'):
392 398 return post_push(ui, repo, node, **kwargs)
393 399 return 0
394 400
395 401
396 402 def key_push(ui, repo, **kwargs):
397 403 from vcsserver.hgcompat import get_ctx
398 404 if kwargs['new'] != '0' and kwargs['namespace'] == 'bookmarks':
399 405 # store new bookmarks in our UI object propagated later to post_push
400 406 ui._rc_pushkey_branches = get_ctx(repo, kwargs['key']).bookmarks()
401 407 return
402 408
403 409
404 410 # backward compat
405 411 log_pull_action = post_pull
406 412
407 413 # backward compat
408 414 log_push_action = post_push
409 415
410 416
411 417 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
412 418 """
413 419 Old hook name: keep here for backward compatibility.
414 420
415 421 This is only required when the installed git hooks are not upgraded.
416 422 """
417 423 pass
418 424
419 425
420 426 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
421 427 """
422 428 Old hook name: keep here for backward compatibility.
423 429
424 430 This is only required when the installed git hooks are not upgraded.
425 431 """
426 432 pass
427 433
428 434
429 435 HookResponse = collections.namedtuple('HookResponse', ('status', 'output'))
430 436
431 437
432 438 def git_pre_pull(extras):
433 439 """
434 440 Pre pull hook.
435 441
436 442 :param extras: dictionary containing the keys defined in simplevcs
437 443 :type extras: dict
438 444
439 445 :return: status code of the hook. 0 for success.
440 446 :rtype: int
441 447 """
448
442 449 if 'pull' not in extras['hooks']:
443 450 return HookResponse(0, '')
444 451
445 452 stdout = io.BytesIO()
446 453 try:
447 454 status = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
455
448 456 except Exception as error:
457 log.exception('Failed to call pre_pull hook')
449 458 status = 128
450 stdout.write('ERROR: %s\n' % str(error))
459 stdout.write(safe_bytes(f'ERROR: {error}\n'))
451 460
452 461 return HookResponse(status, stdout.getvalue())
453 462
454 463
455 464 def git_post_pull(extras):
456 465 """
457 466 Post pull hook.
458 467
459 468 :param extras: dictionary containing the keys defined in simplevcs
460 469 :type extras: dict
461 470
462 471 :return: status code of the hook. 0 for success.
463 472 :rtype: int
464 473 """
465 474 if 'pull' not in extras['hooks']:
466 475 return HookResponse(0, '')
467 476
468 477 stdout = io.BytesIO()
469 478 try:
470 479 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
471 480 except Exception as error:
472 481 status = 128
473 stdout.write('ERROR: %s\n' % error)
482 stdout.write(safe_bytes(f'ERROR: {error}\n'))
474 483
475 484 return HookResponse(status, stdout.getvalue())
476 485
477 486
478 487 def _parse_git_ref_lines(revision_lines):
479 488 rev_data = []
480 489 for revision_line in revision_lines or []:
481 490 old_rev, new_rev, ref = revision_line.strip().split(' ')
482 491 ref_data = ref.split('/', 2)
483 492 if ref_data[1] in ('tags', 'heads'):
484 493 rev_data.append({
485 494 # NOTE(marcink):
486 495 # we're unable to tell total_commits for git at this point
487 496 # but we set the variable for consistency with GIT
488 497 'total_commits': -1,
489 498 'old_rev': old_rev,
490 499 'new_rev': new_rev,
491 500 'ref': ref,
492 501 'type': ref_data[1],
493 502 'name': ref_data[2],
494 503 })
495 504 return rev_data
496 505
497 506
498 507 def git_pre_receive(unused_repo_path, revision_lines, env):
499 508 """
500 509 Pre push hook.
501 510
502 511 :param extras: dictionary containing the keys defined in simplevcs
503 512 :type extras: dict
504 513
505 514 :return: status code of the hook. 0 for success.
506 515 :rtype: int
507 516 """
508 517 extras = json.loads(env['RC_SCM_DATA'])
509 518 rev_data = _parse_git_ref_lines(revision_lines)
510 519 if 'push' not in extras['hooks']:
511 520 return 0
512 521 empty_commit_id = '0' * 40
513 522
514 523 detect_force_push = extras.get('detect_force_push')
515 524
516 525 for push_ref in rev_data:
517 526 # store our git-env which holds the temp store
518 527 push_ref['git_env'] = _get_git_env()
519 528 push_ref['pruned_sha'] = ''
520 529 if not detect_force_push:
521 530 # don't check for forced-push when we don't need to
522 531 continue
523 532
524 533 type_ = push_ref['type']
525 534 new_branch = push_ref['old_rev'] == empty_commit_id
526 535 delete_branch = push_ref['new_rev'] == empty_commit_id
527 536 if type_ == 'heads' and not (new_branch or delete_branch):
528 537 old_rev = push_ref['old_rev']
529 538 new_rev = push_ref['new_rev']
530 539 cmd = [settings.GIT_EXECUTABLE, 'rev-list', old_rev, '^{}'.format(new_rev)]
531 540 stdout, stderr = subprocessio.run_command(
532 541 cmd, env=os.environ.copy())
533 542 # means we're having some non-reachable objects, this forced push was used
534 543 if stdout:
535 544 push_ref['pruned_sha'] = stdout.splitlines()
536 545
537 546 extras['hook_type'] = 'pre_receive'
538 547 extras['commit_ids'] = rev_data
539 548 return _call_hook('pre_push', extras, GitMessageWriter())
540 549
541 550
542 551 def git_post_receive(unused_repo_path, revision_lines, env):
543 552 """
544 553 Post push hook.
545 554
546 555 :param extras: dictionary containing the keys defined in simplevcs
547 556 :type extras: dict
548 557
549 558 :return: status code of the hook. 0 for success.
550 559 :rtype: int
551 560 """
552 561 extras = json.loads(env['RC_SCM_DATA'])
553 562 if 'push' not in extras['hooks']:
554 563 return 0
555 564
556 565 rev_data = _parse_git_ref_lines(revision_lines)
557 566
558 567 git_revs = []
559 568
560 569 # N.B.(skreft): it is ok to just call git, as git before calling a
561 570 # subcommand sets the PATH environment variable so that it point to the
562 571 # correct version of the git executable.
563 572 empty_commit_id = '0' * 40
564 573 branches = []
565 574 tags = []
566 575 for push_ref in rev_data:
567 576 type_ = push_ref['type']
568 577
569 578 if type_ == 'heads':
570 579 if push_ref['old_rev'] == empty_commit_id:
571 580 # starting new branch case
572 581 if push_ref['name'] not in branches:
573 582 branches.append(push_ref['name'])
574 583
575 584 # Fix up head revision if needed
576 585 cmd = [settings.GIT_EXECUTABLE, 'show', 'HEAD']
577 586 try:
578 587 subprocessio.run_command(cmd, env=os.environ.copy())
579 588 except Exception:
580 589 cmd = [settings.GIT_EXECUTABLE, 'symbolic-ref', '"HEAD"',
581 590 '"refs/heads/%s"' % push_ref['name']]
582 591 print(("Setting default branch to %s" % push_ref['name']))
583 592 subprocessio.run_command(cmd, env=os.environ.copy())
584 593
585 594 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref',
586 595 '--format=%(refname)', 'refs/heads/*']
587 596 stdout, stderr = subprocessio.run_command(
588 597 cmd, env=os.environ.copy())
589 598 heads = stdout
590 599 heads = heads.replace(push_ref['ref'], '')
591 600 heads = ' '.join(head for head
592 601 in heads.splitlines() if head) or '.'
593 602 cmd = [settings.GIT_EXECUTABLE, 'log', '--reverse',
594 603 '--pretty=format:%H', '--', push_ref['new_rev'],
595 604 '--not', heads]
596 605 stdout, stderr = subprocessio.run_command(
597 606 cmd, env=os.environ.copy())
598 607 git_revs.extend(stdout.splitlines())
599 608 elif push_ref['new_rev'] == empty_commit_id:
600 609 # delete branch case
601 610 git_revs.append('delete_branch=>%s' % push_ref['name'])
602 611 else:
603 612 if push_ref['name'] not in branches:
604 613 branches.append(push_ref['name'])
605 614
606 615 cmd = [settings.GIT_EXECUTABLE, 'log',
607 616 '{old_rev}..{new_rev}'.format(**push_ref),
608 617 '--reverse', '--pretty=format:%H']
609 618 stdout, stderr = subprocessio.run_command(
610 619 cmd, env=os.environ.copy())
611 620 git_revs.extend(stdout.splitlines())
612 621 elif type_ == 'tags':
613 622 if push_ref['name'] not in tags:
614 623 tags.append(push_ref['name'])
615 624 git_revs.append('tag=>%s' % push_ref['name'])
616 625
617 626 extras['hook_type'] = 'post_receive'
618 627 extras['commit_ids'] = git_revs
619 628 extras['new_refs'] = {
620 629 'branches': branches,
621 630 'bookmarks': [],
622 631 'tags': tags,
623 632 }
624 633
625 634 if 'repo_size' in extras['hooks']:
626 635 try:
627 636 _call_hook('repo_size', extras, GitMessageWriter())
628 637 except:
629 638 pass
630 639
631 640 return _call_hook('post_push', extras, GitMessageWriter())
632 641
633 642
634 643 def _get_extras_from_txn_id(path, txn_id):
635 644 extras = {}
636 645 try:
637 646 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
638 647 '-t', txn_id,
639 648 '--revprop', path, 'rc-scm-extras']
640 649 stdout, stderr = subprocessio.run_command(
641 650 cmd, env=os.environ.copy())
642 651 extras = json.loads(base64.urlsafe_b64decode(stdout))
643 652 except Exception:
644 653 log.exception('Failed to extract extras info from txn_id')
645 654
646 655 return extras
647 656
648 657
649 658 def _get_extras_from_commit_id(commit_id, path):
650 659 extras = {}
651 660 try:
652 661 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
653 662 '-r', commit_id,
654 663 '--revprop', path, 'rc-scm-extras']
655 664 stdout, stderr = subprocessio.run_command(
656 665 cmd, env=os.environ.copy())
657 666 extras = json.loads(base64.urlsafe_b64decode(stdout))
658 667 except Exception:
659 668 log.exception('Failed to extract extras info from commit_id')
660 669
661 670 return extras
662 671
663 672
664 673 def svn_pre_commit(repo_path, commit_data, env):
665 674 path, txn_id = commit_data
666 675 branches = []
667 676 tags = []
668 677
669 678 if env.get('RC_SCM_DATA'):
670 679 extras = json.loads(env['RC_SCM_DATA'])
671 680 else:
672 681 # fallback method to read from TXN-ID stored data
673 682 extras = _get_extras_from_txn_id(path, txn_id)
674 683 if not extras:
675 684 return 0
676 685
677 686 extras['hook_type'] = 'pre_commit'
678 687 extras['commit_ids'] = [txn_id]
679 688 extras['txn_id'] = txn_id
680 689 extras['new_refs'] = {
681 690 'total_commits': 1,
682 691 'branches': branches,
683 692 'bookmarks': [],
684 693 'tags': tags,
685 694 }
686 695
687 696 return _call_hook('pre_push', extras, SvnMessageWriter())
688 697
689 698
690 699 def svn_post_commit(repo_path, commit_data, env):
691 700 """
692 701 commit_data is path, rev, txn_id
693 702 """
694 703 if len(commit_data) == 3:
695 704 path, commit_id, txn_id = commit_data
696 705 elif len(commit_data) == 2:
697 706 log.error('Failed to extract txn_id from commit_data using legacy method. '
698 707 'Some functionality might be limited')
699 708 path, commit_id = commit_data
700 709 txn_id = None
701 710
702 711 branches = []
703 712 tags = []
704 713
705 714 if env.get('RC_SCM_DATA'):
706 715 extras = json.loads(env['RC_SCM_DATA'])
707 716 else:
708 717 # fallback method to read from TXN-ID stored data
709 718 extras = _get_extras_from_commit_id(commit_id, path)
710 719 if not extras:
711 720 return 0
712 721
713 722 extras['hook_type'] = 'post_commit'
714 723 extras['commit_ids'] = [commit_id]
715 724 extras['txn_id'] = txn_id
716 725 extras['new_refs'] = {
717 726 'branches': branches,
718 727 'bookmarks': [],
719 728 'tags': tags,
720 729 'total_commits': 1,
721 730 }
722 731
723 732 if 'repo_size' in extras['hooks']:
724 733 try:
725 734 _call_hook('repo_size', extras, SvnMessageWriter())
726 735 except Exception:
727 736 pass
728 737
729 738 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,739 +1,740 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 import io
18 19 import os
19 20 import sys
20 21 import base64
21 22 import locale
22 23 import logging
23 24 import uuid
24 25 import time
25 26 import wsgiref.util
26 27 import traceback
27 28 import tempfile
28 29 import psutil
29 30
30 31 from itertools import chain
31 from io import StringIO
32 32
33 import simplejson as json
34 33 import msgpack
35 34 import configparser
36 35
37 36 from pyramid.config import Configurator
38 37 from pyramid.wsgi import wsgiapp
39 38 from pyramid.response import Response
40 39
40 from vcsserver.lib.rc_json import json
41 41 from vcsserver.config.settings_maker import SettingsMaker
42 42 from vcsserver.utils import safe_int
43 43 from vcsserver.lib.statsd_client import StatsdClient
44 44
45 45 log = logging.getLogger(__name__)
46 46
47 47 # due to Mercurial/glibc2.27 problems we need to detect if locale settings are
48 48 # causing problems and "fix" it in case they do and fallback to LC_ALL = C
49 49
50 50 try:
51 51 locale.setlocale(locale.LC_ALL, '')
52 52 except locale.Error as e:
53 53 log.error(
54 54 'LOCALE ERROR: failed to set LC_ALL, fallback to LC_ALL=C, org error: %s', e)
55 55 os.environ['LC_ALL'] = 'C'
56 56
57 57
58 58 import vcsserver
59 59 from vcsserver import remote_wsgi, scm_app, settings, hgpatches
60 60 from vcsserver.git_lfs.app import GIT_LFS_CONTENT_TYPE, GIT_LFS_PROTO_PAT
61 61 from vcsserver.echo_stub import remote_wsgi as remote_wsgi_stub
62 62 from vcsserver.echo_stub.echo_app import EchoApp
63 63 from vcsserver.exceptions import HTTPRepoLocked, HTTPRepoBranchProtected
64 64 from vcsserver.lib.exc_tracking import store_exception
65 65 from vcsserver.server import VcsServer
66 66
67 67 strict_vcs = True
68 68
69 69 git_import_err = None
70 70 try:
71 71 from vcsserver.remote.git import GitFactory, GitRemote
72 72 except ImportError as e:
73 73 GitFactory = None
74 74 GitRemote = None
75 75 git_import_err = e
76 76 if strict_vcs:
77 77 raise
78 78
79 79
80 80 hg_import_err = None
81 81 try:
82 82 from vcsserver.remote.hg import MercurialFactory, HgRemote
83 83 except ImportError as e:
84 84 MercurialFactory = None
85 85 HgRemote = None
86 86 hg_import_err = e
87 87 if strict_vcs:
88 88 raise
89 89
90 90
91 91 svn_import_err = None
92 92 try:
93 93 from vcsserver.remote.svn import SubversionFactory, SvnRemote
94 94 except ImportError as e:
95 95 SubversionFactory = None
96 96 SvnRemote = None
97 97 svn_import_err = e
98 98 if strict_vcs:
99 99 raise
100 100
101 101
102 102 def _is_request_chunked(environ):
103 103 stream = environ.get('HTTP_TRANSFER_ENCODING', '') == 'chunked'
104 104 return stream
105 105
106 106
107 107 def log_max_fd():
108 108 try:
109 109 maxfd = psutil.Process().rlimit(psutil.RLIMIT_NOFILE)[1]
110 110 log.info('Max file descriptors value: %s', maxfd)
111 111 except Exception:
112 112 pass
113 113
114 114
115 115 class VCS(object):
116 116 def __init__(self, locale_conf=None, cache_config=None):
117 117 self.locale = locale_conf
118 118 self.cache_config = cache_config
119 119 self._configure_locale()
120 120
121 121 log_max_fd()
122 122
123 123 if GitFactory and GitRemote:
124 124 git_factory = GitFactory()
125 125 self._git_remote = GitRemote(git_factory)
126 126 else:
127 127 log.error("Git client import failed: %s", git_import_err)
128 128
129 129 if MercurialFactory and HgRemote:
130 130 hg_factory = MercurialFactory()
131 131 self._hg_remote = HgRemote(hg_factory)
132 132 else:
133 133 log.error("Mercurial client import failed: %s", hg_import_err)
134 134
135 135 if SubversionFactory and SvnRemote:
136 136 svn_factory = SubversionFactory()
137 137
138 138 # hg factory is used for svn url validation
139 139 hg_factory = MercurialFactory()
140 140 self._svn_remote = SvnRemote(svn_factory, hg_factory=hg_factory)
141 141 else:
142 142 log.error("Subversion client import failed: %s", svn_import_err)
143 143
144 144 self._vcsserver = VcsServer()
145 145
146 146 def _configure_locale(self):
147 147 if self.locale:
148 148 log.info('Settings locale: `LC_ALL` to %s', self.locale)
149 149 else:
150 150 log.info('Configuring locale subsystem based on environment variables')
151 151 try:
152 152 # If self.locale is the empty string, then the locale
153 153 # module will use the environment variables. See the
154 154 # documentation of the package `locale`.
155 155 locale.setlocale(locale.LC_ALL, self.locale)
156 156
157 157 language_code, encoding = locale.getlocale()
158 158 log.info(
159 159 'Locale set to language code "%s" with encoding "%s".',
160 160 language_code, encoding)
161 161 except locale.Error:
162 162 log.exception('Cannot set locale, not configuring the locale system')
163 163
164 164
165 165 class WsgiProxy(object):
166 166 def __init__(self, wsgi):
167 167 self.wsgi = wsgi
168 168
169 169 def __call__(self, environ, start_response):
170 170 input_data = environ['wsgi.input'].read()
171 171 input_data = msgpack.unpackb(input_data)
172 172
173 173 error = None
174 174 try:
175 175 data, status, headers = self.wsgi.handle(
176 176 input_data['environment'], input_data['input_data'],
177 177 *input_data['args'], **input_data['kwargs'])
178 178 except Exception as e:
179 179 data, status, headers = [], None, None
180 180 error = {
181 181 'message': str(e),
182 182 '_vcs_kind': getattr(e, '_vcs_kind', None)
183 183 }
184 184
185 185 start_response(200, {})
186 186 return self._iterator(error, status, headers, data)
187 187
188 188 def _iterator(self, error, status, headers, data):
189 189 initial_data = [
190 190 error,
191 191 status,
192 192 headers,
193 193 ]
194 194
195 195 for d in chain(initial_data, data):
196 196 yield msgpack.packb(d)
197 197
198 198
199 199 def not_found(request):
200 200 return {'status': '404 NOT FOUND'}
201 201
202 202
203 203 class VCSViewPredicate(object):
204 204 def __init__(self, val, config):
205 205 self.remotes = val
206 206
207 207 def text(self):
208 208 return 'vcs view method = %s' % (list(self.remotes.keys()),)
209 209
210 210 phash = text
211 211
212 212 def __call__(self, context, request):
213 213 """
214 214 View predicate that returns true if given backend is supported by
215 215 defined remotes.
216 216 """
217 217 backend = request.matchdict.get('backend')
218 218 return backend in self.remotes
219 219
220 220
221 221 class HTTPApplication(object):
222 222 ALLOWED_EXCEPTIONS = ('KeyError', 'URLError')
223 223
224 224 remote_wsgi = remote_wsgi
225 225 _use_echo_app = False
226 226
227 227 def __init__(self, settings=None, global_config=None):
228 228
229 229 self.config = Configurator(settings=settings)
230 230 # Init our statsd at very start
231 231 self.config.registry.statsd = StatsdClient.statsd
232 232
233 233 self.global_config = global_config
234 234 self.config.include('vcsserver.lib.rc_cache')
235 235
236 236 settings_locale = settings.get('locale', '') or 'en_US.UTF-8'
237 237 vcs = VCS(locale_conf=settings_locale, cache_config=settings)
238 238 self._remotes = {
239 239 'hg': vcs._hg_remote,
240 240 'git': vcs._git_remote,
241 241 'svn': vcs._svn_remote,
242 242 'server': vcs._vcsserver,
243 243 }
244 244 if settings.get('dev.use_echo_app', 'false').lower() == 'true':
245 245 self._use_echo_app = True
246 246 log.warning("Using EchoApp for VCS operations.")
247 247 self.remote_wsgi = remote_wsgi_stub
248 248
249 249 self._configure_settings(global_config, settings)
250 250
251 251 self._configure()
252 252
253 253 def _configure_settings(self, global_config, app_settings):
254 254 """
255 255 Configure the settings module.
256 256 """
257 257 settings_merged = global_config.copy()
258 258 settings_merged.update(app_settings)
259 259
260 260 git_path = app_settings.get('git_path', None)
261 261 if git_path:
262 262 settings.GIT_EXECUTABLE = git_path
263 263 binary_dir = app_settings.get('core.binary_dir', None)
264 264 if binary_dir:
265 265 settings.BINARY_DIR = binary_dir
266 266
267 267 # Store the settings to make them available to other modules.
268 268 vcsserver.PYRAMID_SETTINGS = settings_merged
269 269 vcsserver.CONFIG = settings_merged
270 270
271 271 def _configure(self):
272 272 self.config.add_renderer(name='msgpack', factory=self._msgpack_renderer_factory)
273 273
274 274 self.config.add_route('service', '/_service')
275 275 self.config.add_route('status', '/status')
276 276 self.config.add_route('hg_proxy', '/proxy/hg')
277 277 self.config.add_route('git_proxy', '/proxy/git')
278 278
279 279 # rpc methods
280 280 self.config.add_route('vcs', '/{backend}')
281 281
282 282 # streaming rpc remote methods
283 283 self.config.add_route('vcs_stream', '/{backend}/stream')
284 284
285 285 # vcs operations clone/push as streaming
286 286 self.config.add_route('stream_git', '/stream/git/*repo_name')
287 287 self.config.add_route('stream_hg', '/stream/hg/*repo_name')
288 288
289 289 self.config.add_view(self.status_view, route_name='status', renderer='json')
290 290 self.config.add_view(self.service_view, route_name='service', renderer='msgpack')
291 291
292 292 self.config.add_view(self.hg_proxy(), route_name='hg_proxy')
293 293 self.config.add_view(self.git_proxy(), route_name='git_proxy')
294 294 self.config.add_view(self.vcs_view, route_name='vcs', renderer='msgpack',
295 295 vcs_view=self._remotes)
296 296 self.config.add_view(self.vcs_stream_view, route_name='vcs_stream',
297 297 vcs_view=self._remotes)
298 298
299 299 self.config.add_view(self.hg_stream(), route_name='stream_hg')
300 300 self.config.add_view(self.git_stream(), route_name='stream_git')
301 301
302 302 self.config.add_view_predicate('vcs_view', VCSViewPredicate)
303 303
304 304 self.config.add_notfound_view(not_found, renderer='json')
305 305
306 306 self.config.add_view(self.handle_vcs_exception, context=Exception)
307 307
308 308 self.config.add_tween(
309 309 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
310 310 )
311 311 self.config.add_request_method(
312 312 'vcsserver.lib.request_counter.get_request_counter',
313 313 'request_count')
314 314
315 315 def wsgi_app(self):
316 316 return self.config.make_wsgi_app()
317 317
318 318 def _vcs_view_params(self, request):
319 319 remote = self._remotes[request.matchdict['backend']]
320 payload = msgpack.unpackb(request.body, use_list=True)
320 payload = msgpack.unpackb(request.body, use_list=True, raw=False)
321
321 322 method = payload.get('method')
322 323 params = payload['params']
323 324 wire = params.get('wire')
324 325 args = params.get('args')
325 326 kwargs = params.get('kwargs')
326 327 context_uid = None
327 328
328 329 if wire:
329 330 try:
330 331 wire['context'] = context_uid = uuid.UUID(wire['context'])
331 332 except KeyError:
332 333 pass
333 334 args.insert(0, wire)
334 335 repo_state_uid = wire.get('repo_state_uid') if wire else None
335 336
336 337 # NOTE(marcink): trading complexity for slight performance
337 338 if log.isEnabledFor(logging.DEBUG):
338 339 no_args_methods = [
339 340
340 341 ]
341 342 if method in no_args_methods:
342 343 call_args = ''
343 344 else:
344 345 call_args = args[1:]
345 346
346 347 log.debug('Method requested:`%s` with args:%s kwargs:%s context_uid: %s, repo_state_uid:%s',
347 348 method, call_args, kwargs, context_uid, repo_state_uid)
348 349
349 350 statsd = request.registry.statsd
350 351 if statsd:
351 352 statsd.incr(
352 353 'vcsserver_method_total', tags=[
353 354 "method:{}".format(method),
354 355 ])
355 356 return payload, remote, method, args, kwargs
356 357
357 358 def vcs_view(self, request):
358 359
359 360 payload, remote, method, args, kwargs = self._vcs_view_params(request)
360 361 payload_id = payload.get('id')
361 362
362 363 try:
363 364 resp = getattr(remote, method)(*args, **kwargs)
364 365 except Exception as e:
365 366 exc_info = list(sys.exc_info())
366 367 exc_type, exc_value, exc_traceback = exc_info
367 368
368 369 org_exc = getattr(e, '_org_exc', None)
369 370 org_exc_name = None
370 371 org_exc_tb = ''
371 372 if org_exc:
372 373 org_exc_name = org_exc.__class__.__name__
373 374 org_exc_tb = getattr(e, '_org_exc_tb', '')
374 375 # replace our "faked" exception with our org
375 376 exc_info[0] = org_exc.__class__
376 377 exc_info[1] = org_exc
377 378
378 379 should_store_exc = True
379 380 if org_exc:
380 381 def get_exc_fqn(_exc_obj):
381 382 module_name = getattr(org_exc.__class__, '__module__', 'UNKNOWN')
382 383 return module_name + '.' + org_exc_name
383 384
384 385 exc_fqn = get_exc_fqn(org_exc)
385 386
386 387 if exc_fqn in ['mercurial.error.RepoLookupError',
387 388 'vcsserver.exceptions.RefNotFoundException']:
388 389 should_store_exc = False
389 390
390 391 if should_store_exc:
391 392 store_exception(id(exc_info), exc_info, request_path=request.path)
392 393
393 394 tb_info = ''.join(
394 395 traceback.format_exception(exc_type, exc_value, exc_traceback))
395 396
396 397 type_ = e.__class__.__name__
397 398 if type_ not in self.ALLOWED_EXCEPTIONS:
398 399 type_ = None
399 400
400 401 resp = {
401 402 'id': payload_id,
402 403 'error': {
403 'message': e.message,
404 'message': str(e),
404 405 'traceback': tb_info,
405 406 'org_exc': org_exc_name,
406 407 'org_exc_tb': org_exc_tb,
407 408 'type': type_
408 409 }
409 410 }
410 411
411 412 try:
412 413 resp['error']['_vcs_kind'] = getattr(e, '_vcs_kind', None)
413 414 except AttributeError:
414 415 pass
415 416 else:
416 417 resp = {
417 418 'id': payload_id,
418 419 'result': resp
419 420 }
420 421
421 422 return resp
422 423
423 424 def vcs_stream_view(self, request):
424 425 payload, remote, method, args, kwargs = self._vcs_view_params(request)
425 426 # this method has a stream: marker we remove it here
426 427 method = method.split('stream:')[-1]
427 428 chunk_size = safe_int(payload.get('chunk_size')) or 4096
428 429
429 430 try:
430 431 resp = getattr(remote, method)(*args, **kwargs)
431 432 except Exception as e:
432 433 raise
433 434
434 435 def get_chunked_data(method_resp):
435 stream = StringIO(method_resp)
436 stream = io.BytesIO(method_resp)
436 437 while 1:
437 438 chunk = stream.read(chunk_size)
438 439 if not chunk:
439 440 break
440 441 yield chunk
441 442
442 443 response = Response(app_iter=get_chunked_data(resp))
443 444 response.content_type = 'application/octet-stream'
444 445
445 446 return response
446 447
447 448 def status_view(self, request):
448 449 import vcsserver
449 450 return {'status': 'OK', 'vcsserver_version': vcsserver.__version__,
450 451 'pid': os.getpid()}
451 452
452 453 def service_view(self, request):
453 454 import vcsserver
454 455
455 456 payload = msgpack.unpackb(request.body, use_list=True)
456 457 server_config, app_config = {}, {}
457 458
458 459 try:
459 460 path = self.global_config['__file__']
460 461 config = configparser.RawConfigParser()
461 462
462 463 config.read(path)
463 464
464 465 if config.has_section('server:main'):
465 466 server_config = dict(config.items('server:main'))
466 467 if config.has_section('app:main'):
467 468 app_config = dict(config.items('app:main'))
468 469
469 470 except Exception:
470 471 log.exception('Failed to read .ini file for display')
471 472
472 473 environ = list(os.environ.items())
473 474
474 475 resp = {
475 476 'id': payload.get('id'),
476 477 'result': dict(
477 478 version=vcsserver.__version__,
478 479 config=server_config,
479 480 app_config=app_config,
480 481 environ=environ,
481 482 payload=payload,
482 483 )
483 484 }
484 485 return resp
485 486
486 487 def _msgpack_renderer_factory(self, info):
487 488 def _render(value, system):
488 489 request = system.get('request')
489 490 if request is not None:
490 491 response = request.response
491 492 ct = response.content_type
492 493 if ct == response.default_content_type:
493 494 response.content_type = 'application/x-msgpack'
494 495 return msgpack.packb(value)
495 496 return _render
496 497
497 498 def set_env_from_config(self, environ, config):
498 499 dict_conf = {}
499 500 try:
500 501 for elem in config:
501 502 if elem[0] == 'rhodecode':
502 503 dict_conf = json.loads(elem[2])
503 504 break
504 505 except Exception:
505 506 log.exception('Failed to fetch SCM CONFIG')
506 507 return
507 508
508 509 username = dict_conf.get('username')
509 510 if username:
510 511 environ['REMOTE_USER'] = username
511 512 # mercurial specific, some extension api rely on this
512 513 environ['HGUSER'] = username
513 514
514 515 ip = dict_conf.get('ip')
515 516 if ip:
516 517 environ['REMOTE_HOST'] = ip
517 518
518 519 if _is_request_chunked(environ):
519 520 # set the compatibility flag for webob
520 521 environ['wsgi.input_terminated'] = True
521 522
522 523 def hg_proxy(self):
523 524 @wsgiapp
524 525 def _hg_proxy(environ, start_response):
525 526 app = WsgiProxy(self.remote_wsgi.HgRemoteWsgi())
526 527 return app(environ, start_response)
527 528 return _hg_proxy
528 529
529 530 def git_proxy(self):
530 531 @wsgiapp
531 532 def _git_proxy(environ, start_response):
532 533 app = WsgiProxy(self.remote_wsgi.GitRemoteWsgi())
533 534 return app(environ, start_response)
534 535 return _git_proxy
535 536
536 537 def hg_stream(self):
537 538 if self._use_echo_app:
538 539 @wsgiapp
539 540 def _hg_stream(environ, start_response):
540 541 app = EchoApp('fake_path', 'fake_name', None)
541 542 return app(environ, start_response)
542 543 return _hg_stream
543 544 else:
544 545 @wsgiapp
545 546 def _hg_stream(environ, start_response):
546 547 log.debug('http-app: handling hg stream')
547 548 repo_path = environ['HTTP_X_RC_REPO_PATH']
548 549 repo_name = environ['HTTP_X_RC_REPO_NAME']
549 550 packed_config = base64.b64decode(
550 551 environ['HTTP_X_RC_REPO_CONFIG'])
551 552 config = msgpack.unpackb(packed_config)
552 553 app = scm_app.create_hg_wsgi_app(
553 554 repo_path, repo_name, config)
554 555
555 556 # Consistent path information for hgweb
556 557 environ['PATH_INFO'] = environ['HTTP_X_RC_PATH_INFO']
557 558 environ['REPO_NAME'] = repo_name
558 559 self.set_env_from_config(environ, config)
559 560
560 561 log.debug('http-app: starting app handler '
561 562 'with %s and process request', app)
562 563 return app(environ, ResponseFilter(start_response))
563 564 return _hg_stream
564 565
565 566 def git_stream(self):
566 567 if self._use_echo_app:
567 568 @wsgiapp
568 569 def _git_stream(environ, start_response):
569 570 app = EchoApp('fake_path', 'fake_name', None)
570 571 return app(environ, start_response)
571 572 return _git_stream
572 573 else:
573 574 @wsgiapp
574 575 def _git_stream(environ, start_response):
575 576 log.debug('http-app: handling git stream')
576 577 repo_path = environ['HTTP_X_RC_REPO_PATH']
577 578 repo_name = environ['HTTP_X_RC_REPO_NAME']
578 579 packed_config = base64.b64decode(
579 580 environ['HTTP_X_RC_REPO_CONFIG'])
580 config = msgpack.unpackb(packed_config)
581 config = msgpack.unpackb(packed_config, raw=False)
581 582
582 583 environ['PATH_INFO'] = environ['HTTP_X_RC_PATH_INFO']
583 584 self.set_env_from_config(environ, config)
584 585
585 586 content_type = environ.get('CONTENT_TYPE', '')
586 587
587 588 path = environ['PATH_INFO']
588 589 is_lfs_request = GIT_LFS_CONTENT_TYPE in content_type
589 590 log.debug(
590 591 'LFS: Detecting if request `%s` is LFS server path based '
591 592 'on content type:`%s`, is_lfs:%s',
592 593 path, content_type, is_lfs_request)
593 594
594 595 if not is_lfs_request:
595 596 # fallback detection by path
596 597 if GIT_LFS_PROTO_PAT.match(path):
597 598 is_lfs_request = True
598 599 log.debug(
599 600 'LFS: fallback detection by path of: `%s`, is_lfs:%s',
600 601 path, is_lfs_request)
601 602
602 603 if is_lfs_request:
603 604 app = scm_app.create_git_lfs_wsgi_app(
604 605 repo_path, repo_name, config)
605 606 else:
606 607 app = scm_app.create_git_wsgi_app(
607 608 repo_path, repo_name, config)
608 609
609 610 log.debug('http-app: starting app handler '
610 611 'with %s and process request', app)
611 612
612 613 return app(environ, start_response)
613 614
614 615 return _git_stream
615 616
616 617 def handle_vcs_exception(self, exception, request):
617 618 _vcs_kind = getattr(exception, '_vcs_kind', '')
618 619 if _vcs_kind == 'repo_locked':
619 620 # Get custom repo-locked status code if present.
620 621 status_code = request.headers.get('X-RC-Locked-Status-Code')
621 622 return HTTPRepoLocked(
622 623 title=exception.message, status_code=status_code)
623 624
624 625 elif _vcs_kind == 'repo_branch_protected':
625 626 # Get custom repo-branch-protected status code if present.
626 627 return HTTPRepoBranchProtected(title=exception.message)
627 628
628 629 exc_info = request.exc_info
629 630 store_exception(id(exc_info), exc_info)
630 631
631 632 traceback_info = 'unavailable'
632 633 if request.exc_info:
633 634 exc_type, exc_value, exc_tb = request.exc_info
634 635 traceback_info = ''.join(traceback.format_exception(exc_type, exc_value, exc_tb))
635 636
636 637 log.error(
637 638 'error occurred handling this request for path: %s, \n tb: %s',
638 639 request.path, traceback_info)
639 640
640 641 statsd = request.registry.statsd
641 642 if statsd:
642 643 exc_type = "{}.{}".format(exception.__class__.__module__, exception.__class__.__name__)
643 644 statsd.incr('vcsserver_exception_total',
644 645 tags=["type:{}".format(exc_type)])
645 646 raise exception
646 647
647 648
648 649 class ResponseFilter(object):
649 650
650 651 def __init__(self, start_response):
651 652 self._start_response = start_response
652 653
653 654 def __call__(self, status, response_headers, exc_info=None):
654 655 headers = tuple(
655 656 (h, v) for h, v in response_headers
656 657 if not wsgiref.util.is_hop_by_hop(h))
657 658 return self._start_response(status, headers, exc_info)
658 659
659 660
660 661 def sanitize_settings_and_apply_defaults(global_config, settings):
661 662 global_settings_maker = SettingsMaker(global_config)
662 663 settings_maker = SettingsMaker(settings)
663 664
664 665 settings_maker.make_setting('logging.autoconfigure', False, parser='bool')
665 666
666 667 logging_conf = os.path.join(os.path.dirname(global_config.get('__file__')), 'logging.ini')
667 668 settings_maker.enable_logging(logging_conf)
668 669
669 670 # Default includes, possible to change as a user
670 671 pyramid_includes = settings_maker.make_setting('pyramid.includes', [], parser='list:newline')
671 672 log.debug("Using the following pyramid.includes: %s", pyramid_includes)
672 673
673 674 settings_maker.make_setting('__file__', global_config.get('__file__'))
674 675
675 676 settings_maker.make_setting('pyramid.default_locale_name', 'en')
676 677 settings_maker.make_setting('locale', 'en_US.UTF-8')
677 678
678 679 settings_maker.make_setting('core.binary_dir', '')
679 680
680 681 temp_store = tempfile.gettempdir()
681 682 default_cache_dir = os.path.join(temp_store, 'rc_cache')
682 683 # save default, cache dir, and use it for all backends later.
683 684 default_cache_dir = settings_maker.make_setting(
684 685 'cache_dir',
685 686 default=default_cache_dir, default_when_empty=True,
686 687 parser='dir:ensured')
687 688
688 689 # exception store cache
689 690 settings_maker.make_setting(
690 691 'exception_tracker.store_path',
691 692 default=os.path.join(default_cache_dir, 'exc_store'), default_when_empty=True,
692 693 parser='dir:ensured'
693 694 )
694 695
695 696 # repo_object cache defaults
696 697 settings_maker.make_setting(
697 698 'rc_cache.repo_object.backend',
698 699 default='dogpile.cache.rc.file_namespace',
699 700 parser='string')
700 701 settings_maker.make_setting(
701 702 'rc_cache.repo_object.expiration_time',
702 703 default=30 * 24 * 60 * 60, # 30days
703 704 parser='int')
704 705 settings_maker.make_setting(
705 706 'rc_cache.repo_object.arguments.filename',
706 707 default=os.path.join(default_cache_dir, 'vcsserver_cache_repo_object.db'),
707 708 parser='string')
708 709
709 710 # statsd
710 711 settings_maker.make_setting('statsd.enabled', False, parser='bool')
711 712 settings_maker.make_setting('statsd.statsd_host', 'statsd-exporter', parser='string')
712 713 settings_maker.make_setting('statsd.statsd_port', 9125, parser='int')
713 714 settings_maker.make_setting('statsd.statsd_prefix', '')
714 715 settings_maker.make_setting('statsd.statsd_ipv6', False, parser='bool')
715 716
716 717 settings_maker.env_expand()
717 718
718 719
719 720 def main(global_config, **settings):
720 721 start_time = time.time()
721 722 log.info('Pyramid app config starting')
722 723
723 724 if MercurialFactory:
724 725 hgpatches.patch_largefiles_capabilities()
725 726 hgpatches.patch_subrepo_type_mapping()
726 727
727 728 # Fill in and sanitize the defaults & do ENV expansion
728 729 sanitize_settings_and_apply_defaults(global_config, settings)
729 730
730 731 # init and bootstrap StatsdClient
731 732 StatsdClient.setup(settings)
732 733
733 734 pyramid_app = HTTPApplication(settings=settings, global_config=global_config).wsgi_app()
734 735 total_time = time.time() - start_time
735 736 log.info('Pyramid app `%s` created and configured in %.2fs',
736 737 getattr(pyramid_app, 'func_name', 'pyramid_app'), total_time)
737 738 return pyramid_app
738 739
739 740
@@ -1,329 +1,330 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import time
19 19 import errno
20 20 import logging
21 21
22 22 import msgpack
23 23 import redis
24 import pickle
24 25
25 26 from dogpile.cache.api import CachedValue
26 27 from dogpile.cache.backends import memory as memory_backend
27 28 from dogpile.cache.backends import file as file_backend
28 29 from dogpile.cache.backends import redis as redis_backend
29 30 from dogpile.cache.backends.file import NO_VALUE, FileLock
30 31 from dogpile.cache.util import memoized_property
31 32
32 33 from pyramid.settings import asbool
33 34
34 35 from vcsserver.lib.memory_lru_dict import LRUDict, LRUDictDebug
35 from vcsserver.utils import safe_str, safe_unicode
36 from vcsserver.utils import safe_str
36 37
37 38
38 39 _default_max_size = 1024
39 40
40 41 log = logging.getLogger(__name__)
41 42
42 43
43 44 class LRUMemoryBackend(memory_backend.MemoryBackend):
44 45 key_prefix = 'lru_mem_backend'
45 46 pickle_values = False
46 47
47 48 def __init__(self, arguments):
48 49 max_size = arguments.pop('max_size', _default_max_size)
49 50
50 51 LRUDictClass = LRUDict
51 52 if arguments.pop('log_key_count', None):
52 53 LRUDictClass = LRUDictDebug
53 54
54 55 arguments['cache_dict'] = LRUDictClass(max_size)
55 56 super(LRUMemoryBackend, self).__init__(arguments)
56 57
57 58 def delete(self, key):
58 59 try:
59 60 del self._cache[key]
60 61 except KeyError:
61 62 # we don't care if key isn't there at deletion
62 63 pass
63 64
64 65 def delete_multi(self, keys):
65 66 for key in keys:
66 67 self.delete(key)
67 68
68 69
69 70 class PickleSerializer(object):
70 71
71 72 def _dumps(self, value, safe=False):
72 73 try:
73 74 return pickle.dumps(value)
74 75 except Exception:
75 76 if safe:
76 77 return NO_VALUE
77 78 else:
78 79 raise
79 80
80 81 def _loads(self, value, safe=True):
81 82 try:
82 83 return pickle.loads(value)
83 84 except Exception:
84 85 if safe:
85 86 return NO_VALUE
86 87 else:
87 88 raise
88 89
89 90
90 91 class MsgPackSerializer(object):
91 92
92 93 def _dumps(self, value, safe=False):
93 94 try:
94 95 return msgpack.packb(value)
95 96 except Exception:
96 97 if safe:
97 98 return NO_VALUE
98 99 else:
99 100 raise
100 101
101 102 def _loads(self, value, safe=True):
102 103 """
103 104 pickle maintained the `CachedValue` wrapper of the tuple
104 105 msgpack does not, so it must be added back in.
105 106 """
106 107 try:
107 108 value = msgpack.unpackb(value, use_list=False)
108 109 return CachedValue(*value)
109 110 except Exception:
110 111 if safe:
111 112 return NO_VALUE
112 113 else:
113 114 raise
114 115
115 116
116 117 import fcntl
117 118 flock_org = fcntl.flock
118 119
119 120
120 121 class CustomLockFactory(FileLock):
121 122
122 123 pass
123 124
124 125
125 126 class FileNamespaceBackend(PickleSerializer, file_backend.DBMBackend):
126 127 key_prefix = 'file_backend'
127 128
128 129 def __init__(self, arguments):
129 130 arguments['lock_factory'] = CustomLockFactory
130 131 db_file = arguments.get('filename')
131 132
132 133 log.debug('initialing %s DB in %s', self.__class__.__name__, db_file)
133 134 try:
134 135 super(FileNamespaceBackend, self).__init__(arguments)
135 136 except Exception:
136 137 log.exception('Failed to initialize db at: %s', db_file)
137 138 raise
138 139
139 140 def __repr__(self):
140 141 return '{} `{}`'.format(self.__class__, self.filename)
141 142
142 143 def list_keys(self, prefix=''):
143 144 prefix = '{}:{}'.format(self.key_prefix, prefix)
144 145
145 146 def cond(v):
146 147 if not prefix:
147 148 return True
148 149
149 150 if v.startswith(prefix):
150 151 return True
151 152 return False
152 153
153 154 with self._dbm_file(True) as dbm:
154 155 try:
155 156 return filter(cond, dbm.keys())
156 157 except Exception:
157 158 log.error('Failed to fetch DBM keys from DB: %s', self.get_store())
158 159 raise
159 160
160 161 def get_store(self):
161 162 return self.filename
162 163
163 164 def _dbm_get(self, key):
164 165 with self._dbm_file(False) as dbm:
165 166 if hasattr(dbm, 'get'):
166 167 value = dbm.get(key, NO_VALUE)
167 168 else:
168 169 # gdbm objects lack a .get method
169 170 try:
170 171 value = dbm[key]
171 172 except KeyError:
172 173 value = NO_VALUE
173 174 if value is not NO_VALUE:
174 175 value = self._loads(value)
175 176 return value
176 177
177 178 def get(self, key):
178 179 try:
179 180 return self._dbm_get(key)
180 181 except Exception:
181 182 log.error('Failed to fetch DBM key %s from DB: %s', key, self.get_store())
182 183 raise
183 184
184 185 def set(self, key, value):
185 186 with self._dbm_file(True) as dbm:
186 187 dbm[key] = self._dumps(value)
187 188
188 189 def set_multi(self, mapping):
189 190 with self._dbm_file(True) as dbm:
190 191 for key, value in mapping.items():
191 192 dbm[key] = self._dumps(value)
192 193
193 194
194 195 class BaseRedisBackend(redis_backend.RedisBackend):
195 196 key_prefix = ''
196 197
197 198 def __init__(self, arguments):
198 199 super(BaseRedisBackend, self).__init__(arguments)
199 200 self._lock_timeout = self.lock_timeout
200 201 self._lock_auto_renewal = asbool(arguments.pop("lock_auto_renewal", True))
201 202
202 203 if self._lock_auto_renewal and not self._lock_timeout:
203 204 # set default timeout for auto_renewal
204 205 self._lock_timeout = 30
205 206
206 207 def _create_client(self):
207 208 args = {}
208 209
209 210 if self.url is not None:
210 211 args.update(url=self.url)
211 212
212 213 else:
213 214 args.update(
214 215 host=self.host, password=self.password,
215 216 port=self.port, db=self.db
216 217 )
217 218
218 219 connection_pool = redis.ConnectionPool(**args)
219 220
220 221 return redis.StrictRedis(connection_pool=connection_pool)
221 222
222 223 def list_keys(self, prefix=''):
223 224 prefix = '{}:{}*'.format(self.key_prefix, prefix)
224 225 return self.client.keys(prefix)
225 226
226 227 def get_store(self):
227 228 return self.client.connection_pool
228 229
229 230 def get(self, key):
230 231 value = self.client.get(key)
231 232 if value is None:
232 233 return NO_VALUE
233 234 return self._loads(value)
234 235
235 236 def get_multi(self, keys):
236 237 if not keys:
237 238 return []
238 239 values = self.client.mget(keys)
239 240 loads = self._loads
240 241 return [
241 242 loads(v) if v is not None else NO_VALUE
242 243 for v in values]
243 244
244 245 def set(self, key, value):
245 246 if self.redis_expiration_time:
246 247 self.client.setex(key, self.redis_expiration_time,
247 248 self._dumps(value))
248 249 else:
249 250 self.client.set(key, self._dumps(value))
250 251
251 252 def set_multi(self, mapping):
252 253 dumps = self._dumps
253 254 mapping = dict(
254 255 (k, dumps(v))
255 256 for k, v in mapping.items()
256 257 )
257 258
258 259 if not self.redis_expiration_time:
259 260 self.client.mset(mapping)
260 261 else:
261 262 pipe = self.client.pipeline()
262 263 for key, value in mapping.items():
263 264 pipe.setex(key, self.redis_expiration_time, value)
264 265 pipe.execute()
265 266
266 267 def get_mutex(self, key):
267 268 if self.distributed_lock:
268 lock_key = '_lock_{0}'.format(safe_unicode(key))
269 lock_key = '_lock_{0}'.format(safe_str(key))
269 270 return get_mutex_lock(self.client, lock_key, self._lock_timeout,
270 271 auto_renewal=self._lock_auto_renewal)
271 272 else:
272 273 return None
273 274
274 275
275 276 class RedisPickleBackend(PickleSerializer, BaseRedisBackend):
276 277 key_prefix = 'redis_pickle_backend'
277 278 pass
278 279
279 280
280 281 class RedisMsgPackBackend(MsgPackSerializer, BaseRedisBackend):
281 282 key_prefix = 'redis_msgpack_backend'
282 283 pass
283 284
284 285
285 286 def get_mutex_lock(client, lock_key, lock_timeout, auto_renewal=False):
286 287 import redis_lock
287 288
288 289 class _RedisLockWrapper(object):
289 290 """LockWrapper for redis_lock"""
290 291
291 292 @classmethod
292 293 def get_lock(cls):
293 294 return redis_lock.Lock(
294 295 redis_client=client,
295 296 name=lock_key,
296 297 expire=lock_timeout,
297 298 auto_renewal=auto_renewal,
298 299 strict=True,
299 300 )
300 301
301 302 def __repr__(self):
302 303 return "{}:{}".format(self.__class__.__name__, lock_key)
303 304
304 305 def __str__(self):
305 306 return "{}:{}".format(self.__class__.__name__, lock_key)
306 307
307 308 def __init__(self):
308 309 self.lock = self.get_lock()
309 310 self.lock_key = lock_key
310 311
311 312 def acquire(self, wait=True):
312 313 log.debug('Trying to acquire Redis lock for key %s', self.lock_key)
313 314 try:
314 315 acquired = self.lock.acquire(wait)
315 316 log.debug('Got lock for key %s, %s', self.lock_key, acquired)
316 317 return acquired
317 318 except redis_lock.AlreadyAcquired:
318 319 return False
319 320 except redis_lock.AlreadyStarted:
320 321 # refresh thread exists, but it also means we acquired the lock
321 322 return True
322 323
323 324 def release(self):
324 325 try:
325 326 self.lock.release()
326 327 except redis_lock.NotAcquired:
327 328 pass
328 329
329 330 return _RedisLockWrapper()
@@ -1,207 +1,207 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import time
20 20 import logging
21 21 import functools
22 22 import decorator
23 23
24 24 from dogpile.cache import CacheRegion
25 25
26 from vcsserver.utils import safe_str, sha1
26 from vcsserver.utils import safe_bytes, sha1
27 27 from vcsserver.lib.rc_cache import region_meta
28 28
29 29 log = logging.getLogger(__name__)
30 30
31 31
32 32 class RhodeCodeCacheRegion(CacheRegion):
33 33
34 34 def conditional_cache_on_arguments(
35 35 self, namespace=None,
36 36 expiration_time=None,
37 37 should_cache_fn=None,
38 38 to_str=str,
39 39 function_key_generator=None,
40 40 condition=True):
41 41 """
42 42 Custom conditional decorator, that will not touch any dogpile internals if
43 43 condition isn't meet. This works a bit different than should_cache_fn
44 44 And it's faster in cases we don't ever want to compute cached values
45 45 """
46 46 expiration_time_is_callable = callable(expiration_time)
47 47
48 48 if function_key_generator is None:
49 49 function_key_generator = self.function_key_generator
50 50
51 51 def get_or_create_for_user_func(key_generator, user_func, *arg, **kw):
52 52
53 53 if not condition:
54 54 log.debug('Calling un-cached method:%s', user_func.__name__)
55 55 start = time.time()
56 56 result = user_func(*arg, **kw)
57 57 total = time.time() - start
58 58 log.debug('un-cached method:%s took %.4fs', user_func.__name__, total)
59 59 return result
60 60
61 61 key = key_generator(*arg, **kw)
62 62
63 63 timeout = expiration_time() if expiration_time_is_callable \
64 64 else expiration_time
65 65
66 66 log.debug('Calling cached method:`%s`', user_func.__name__)
67 67 return self.get_or_create(key, user_func, timeout, should_cache_fn, (arg, kw))
68 68
69 69 def cache_decorator(user_func):
70 70 if to_str is str:
71 71 # backwards compatible
72 72 key_generator = function_key_generator(namespace, user_func)
73 73 else:
74 74 key_generator = function_key_generator(namespace, user_func, to_str=to_str)
75 75
76 76 def refresh(*arg, **kw):
77 77 """
78 78 Like invalidate, but regenerates the value instead
79 79 """
80 80 key = key_generator(*arg, **kw)
81 81 value = user_func(*arg, **kw)
82 82 self.set(key, value)
83 83 return value
84 84
85 85 def invalidate(*arg, **kw):
86 86 key = key_generator(*arg, **kw)
87 87 self.delete(key)
88 88
89 89 def set_(value, *arg, **kw):
90 90 key = key_generator(*arg, **kw)
91 91 self.set(key, value)
92 92
93 93 def get(*arg, **kw):
94 94 key = key_generator(*arg, **kw)
95 95 return self.get(key)
96 96
97 97 user_func.set = set_
98 98 user_func.invalidate = invalidate
99 99 user_func.get = get
100 100 user_func.refresh = refresh
101 101 user_func.key_generator = key_generator
102 102 user_func.original = user_func
103 103
104 104 # Use `decorate` to preserve the signature of :param:`user_func`.
105 105 return decorator.decorate(user_func, functools.partial(
106 106 get_or_create_for_user_func, key_generator))
107 107
108 108 return cache_decorator
109 109
110 110
111 111 def make_region(*arg, **kw):
112 112 return RhodeCodeCacheRegion(*arg, **kw)
113 113
114 114
115 115 def get_default_cache_settings(settings, prefixes=None):
116 116 prefixes = prefixes or []
117 117 cache_settings = {}
118 118 for key in settings.keys():
119 119 for prefix in prefixes:
120 120 if key.startswith(prefix):
121 121 name = key.split(prefix)[1].strip()
122 122 val = settings[key]
123 123 if isinstance(val, str):
124 124 val = val.strip()
125 125 cache_settings[name] = val
126 126 return cache_settings
127 127
128 128
129 129 def compute_key_from_params(*args):
130 130 """
131 131 Helper to compute key from given params to be used in cache manager
132 132 """
133 return sha1("_".join(map(safe_str, args)))
133 return sha1(safe_bytes("_".join(map(str, args))))
134 134
135 135
136 136 def backend_key_generator(backend):
137 137 """
138 138 Special wrapper that also sends over the backend to the key generator
139 139 """
140 140 def wrapper(namespace, fn):
141 141 return key_generator(backend, namespace, fn)
142 142 return wrapper
143 143
144 144
145 145 def key_generator(backend, namespace, fn):
146 146 fname = fn.__name__
147 147
148 148 def generate_key(*args):
149 149 backend_prefix = getattr(backend, 'key_prefix', None) or 'backend_prefix'
150 150 namespace_pref = namespace or 'default_namespace'
151 151 arg_key = compute_key_from_params(*args)
152 152 final_key = "{}:{}:{}_{}".format(backend_prefix, namespace_pref, fname, arg_key)
153 153
154 154 return final_key
155 155
156 156 return generate_key
157 157
158 158
159 159 def get_or_create_region(region_name, region_namespace=None):
160 160 from vcsserver.lib.rc_cache.backends import FileNamespaceBackend
161 161 region_obj = region_meta.dogpile_cache_regions.get(region_name)
162 162 if not region_obj:
163 163 raise EnvironmentError(
164 164 'Region `{}` not in configured: {}.'.format(
165 165 region_name, region_meta.dogpile_cache_regions.keys()))
166 166
167 167 region_uid_name = '{}:{}'.format(region_name, region_namespace)
168 168 if isinstance(region_obj.actual_backend, FileNamespaceBackend):
169 169 region_exist = region_meta.dogpile_cache_regions.get(region_namespace)
170 170 if region_exist:
171 171 log.debug('Using already configured region: %s', region_namespace)
172 172 return region_exist
173 173 cache_dir = region_meta.dogpile_config_defaults['cache_dir']
174 174 expiration_time = region_obj.expiration_time
175 175
176 176 if not os.path.isdir(cache_dir):
177 177 os.makedirs(cache_dir)
178 178 new_region = make_region(
179 179 name=region_uid_name,
180 180 function_key_generator=backend_key_generator(region_obj.actual_backend)
181 181 )
182 182 namespace_filename = os.path.join(
183 183 cache_dir, "{}.cache.dbm".format(region_namespace))
184 184 # special type that allows 1db per namespace
185 185 new_region.configure(
186 186 backend='dogpile.cache.rc.file_namespace',
187 187 expiration_time=expiration_time,
188 188 arguments={"filename": namespace_filename}
189 189 )
190 190
191 191 # create and save in region caches
192 192 log.debug('configuring new region: %s', region_uid_name)
193 193 region_obj = region_meta.dogpile_cache_regions[region_namespace] = new_region
194 194
195 195 return region_obj
196 196
197 197
198 198 def clear_cache_namespace(cache_region, cache_namespace_uid, invalidate=False):
199 199 region = get_or_create_region(cache_region, cache_namespace_uid)
200 200 cache_keys = region.backend.list_keys(prefix=cache_namespace_uid)
201 201 num_delete_keys = len(cache_keys)
202 202 if invalidate:
203 203 region.invalidate(hard=False)
204 204 else:
205 205 if num_delete_keys:
206 206 region.delete_multi(cache_keys)
207 207 return num_delete_keys
@@ -1,386 +1,413 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Handles the Git smart protocol."""
19 19
20 20 import os
21 21 import socket
22 22 import logging
23 23
24 import simplejson as json
25 24 import dulwich.protocol
25 from dulwich.protocol import CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K
26 26 from webob import Request, Response, exc
27 27
28 from vcsserver.lib.rc_json import json
28 29 from vcsserver import hooks, subprocessio
30 from vcsserver.utils import ascii_bytes
29 31
30 32
31 33 log = logging.getLogger(__name__)
32 34
33 35
34 36 class FileWrapper(object):
35 37 """File wrapper that ensures how much data is read from it."""
36 38
37 39 def __init__(self, fd, content_length):
38 40 self.fd = fd
39 41 self.content_length = content_length
40 42 self.remain = content_length
41 43
42 44 def read(self, size):
43 45 if size <= self.remain:
44 46 try:
45 47 data = self.fd.read(size)
46 48 except socket.error:
47 49 raise IOError(self)
48 50 self.remain -= size
49 51 elif self.remain:
50 52 data = self.fd.read(self.remain)
51 53 self.remain = 0
52 54 else:
53 55 data = None
54 56 return data
55 57
56 58 def __repr__(self):
57 59 return '<FileWrapper %s len: %s, read: %s>' % (
58 60 self.fd, self.content_length, self.content_length - self.remain
59 61 )
60 62
61 63
62 64 class GitRepository(object):
63 65 """WSGI app for handling Git smart protocol endpoints."""
64 66
65 git_folder_signature = frozenset(
66 ('config', 'head', 'info', 'objects', 'refs'))
67 git_folder_signature = frozenset(('config', 'head', 'info', 'objects', 'refs'))
67 68 commands = frozenset(('git-upload-pack', 'git-receive-pack'))
68 valid_accepts = frozenset(('application/x-%s-result' %
69 c for c in commands))
69 valid_accepts = frozenset(('application/x-{}-result'.format(c) for c in commands))
70 70
71 71 # The last bytes are the SHA1 of the first 12 bytes.
72 72 EMPTY_PACK = (
73 'PACK\x00\x00\x00\x02\x00\x00\x00\x00' +
74 '\x02\x9d\x08\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e'
73 b'PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08' +
74 b'\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e'
75 75 )
76 SIDE_BAND_CAPS = frozenset(('side-band', 'side-band-64k'))
76 FLUSH_PACKET = b"0000"
77 77
78 def __init__(self, repo_name, content_path, git_path, update_server_info,
79 extras):
78 SIDE_BAND_CAPS = frozenset((CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K))
79
80 def __init__(self, repo_name, content_path, git_path, update_server_info, extras):
80 81 files = frozenset(f.lower() for f in os.listdir(content_path))
81 82 valid_dir_signature = self.git_folder_signature.issubset(files)
82 83
83 84 if not valid_dir_signature:
84 85 raise OSError('%s missing git signature' % content_path)
85 86
86 87 self.content_path = content_path
87 88 self.repo_name = repo_name
88 89 self.extras = extras
89 90 self.git_path = git_path
90 91 self.update_server_info = update_server_info
91 92
92 93 def _get_fixedpath(self, path):
93 94 """
94 95 Small fix for repo_path
95 96
96 97 :param path:
97 98 """
98 99 path = path.split(self.repo_name, 1)[-1]
99 100 if path.startswith('.git'):
100 101 # for bare repos we still get the .git prefix inside, we skip it
101 102 # here, and remove from the service command
102 103 path = path[4:]
103 104
104 105 return path.strip('/')
105 106
106 107 def inforefs(self, request, unused_environ):
107 108 """
108 109 WSGI Response producer for HTTP GET Git Smart
109 110 HTTP /info/refs request.
110 111 """
111 112
112 113 git_command = request.GET.get('service')
113 114 if git_command not in self.commands:
114 115 log.debug('command %s not allowed', git_command)
115 116 return exc.HTTPForbidden()
116 117
117 118 # please, resist the urge to add '\n' to git capture and increment
118 119 # line count by 1.
119 120 # by git docs: Documentation/technical/http-protocol.txt#L214 \n is
120 121 # a part of protocol.
121 122 # The code in Git client not only does NOT need '\n', but actually
122 123 # blows up if you sprinkle "flush" (0000) as "0001\n".
123 124 # It reads binary, per number of bytes specified.
124 125 # if you do add '\n' as part of data, count it.
125 126 server_advert = '# service=%s\n' % git_command
126 packet_len = str(hex(len(server_advert) + 4)[2:].rjust(4, '0')).lower()
127 packet_len = hex(len(server_advert) + 4)[2:].rjust(4, '0').lower()
127 128 try:
128 129 gitenv = dict(os.environ)
129 130 # forget all configs
130 131 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
131 132 command = [self.git_path, git_command[4:], '--stateless-rpc',
132 133 '--advertise-refs', self.content_path]
133 134 out = subprocessio.SubprocessIOChunker(
134 135 command,
135 136 env=gitenv,
136 starting_values=[packet_len + server_advert + '0000'],
137 starting_values=[ascii_bytes(packet_len + server_advert) + self.FLUSH_PACKET],
137 138 shell=False
138 139 )
139 except EnvironmentError:
140 except OSError:
140 141 log.exception('Error processing command')
141 142 raise exc.HTTPExpectationFailed()
142 143
143 144 resp = Response()
144 resp.content_type = 'application/x-%s-advertisement' % str(git_command)
145 resp.content_type = f'application/x-{git_command}-advertisement'
145 146 resp.charset = None
146 147 resp.app_iter = out
147 148
148 149 return resp
149 150
150 151 def _get_want_capabilities(self, request):
151 152 """Read the capabilities found in the first want line of the request."""
152 153 pos = request.body_file_seekable.tell()
153 154 first_line = request.body_file_seekable.readline()
154 155 request.body_file_seekable.seek(pos)
155 156
156 157 return frozenset(
157 158 dulwich.protocol.extract_want_line_capabilities(first_line)[1])
158 159
159 160 def _build_failed_pre_pull_response(self, capabilities, pre_pull_messages):
160 161 """
161 162 Construct a response with an empty PACK file.
162 163
163 164 We use an empty PACK file, as that would trigger the failure of the pull
164 165 or clone command.
165 166
166 167 We also print in the error output a message explaining why the command
167 168 was aborted.
168 169
169 If aditionally, the user is accepting messages we send them the output
170 If additionally, the user is accepting messages we send them the output
170 171 of the pre-pull hook.
171 172
172 173 Note that for clients not supporting side-band we just send them the
173 174 emtpy PACK file.
174 175 """
176
175 177 if self.SIDE_BAND_CAPS.intersection(capabilities):
176 178 response = []
177 179 proto = dulwich.protocol.Protocol(None, response.append)
178 proto.write_pkt_line('NAK\n')
179 self._write_sideband_to_proto(pre_pull_messages, proto,
180 capabilities)
180 proto.write_pkt_line(dulwich.protocol.NAK_LINE)
181
182 self._write_sideband_to_proto(proto, ascii_bytes(pre_pull_messages, allow_bytes=True), capabilities)
181 183 # N.B.(skreft): Do not change the sideband channel to 3, as that
182 184 # produces a fatal error in the client:
183 185 # fatal: error in sideband demultiplexer
184 proto.write_sideband(2, 'Pre pull hook failed: aborting\n')
185 proto.write_sideband(1, self.EMPTY_PACK)
186 proto.write_sideband(
187 dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS,
188 ascii_bytes('Pre pull hook failed: aborting\n', allow_bytes=True))
189 proto.write_sideband(
190 dulwich.protocol.SIDE_BAND_CHANNEL_DATA,
191 ascii_bytes(self.EMPTY_PACK, allow_bytes=True))
186 192
187 # writes 0000
193 # writes b"0000" as default
188 194 proto.write_pkt_line(None)
189 195
190 196 return response
191 197 else:
192 return [self.EMPTY_PACK]
198 return [ascii_bytes(self.EMPTY_PACK, allow_bytes=True)]
199
200 def _build_post_pull_response(self, response, capabilities, start_message, end_message):
201 """
202 Given a list response we inject the post-pull messages.
203
204 We only inject the messages if the client supports sideband, and the
205 response has the format:
206 0008NAK\n...0000
207
208 Note that we do not check the no-progress capability as by default, git
209 sends it, which effectively would block all messages.
210 """
211
212 if not self.SIDE_BAND_CAPS.intersection(capabilities):
213 return response
214
215 if not start_message and not end_message:
216 return response
217
218 try:
219 iter(response)
220 # iterator probably will work, we continue
221 except TypeError:
222 raise TypeError(f'response must be an iterator: got {type(response)}')
223 if isinstance(response, (list, tuple)):
224 raise TypeError(f'response must be an iterator: got {type(response)}')
225
226 def injected_response():
193 227
194 def _write_sideband_to_proto(self, data, proto, capabilities):
228 do_loop = 1
229 header_injected = 0
230 next_item = None
231 has_item = False
232 while do_loop:
233
234 try:
235 next_item = next(response)
236 except StopIteration:
237 do_loop = 0
238
239 if has_item:
240 # last item ! alter it now
241 if do_loop == 0 and item.endswith(self.FLUSH_PACKET):
242 new_response = [item[:-4]]
243 new_response.extend(self._get_messages(end_message, capabilities))
244 new_response.append(self.FLUSH_PACKET)
245 item = b''.join(new_response)
246
247 yield item
248 has_item = True
249 item = next_item
250
251 # alter item if it's the initial chunk
252 if not header_injected and item.startswith(b'0008NAK\n'):
253 new_response = [b'0008NAK\n']
254 new_response.extend(self._get_messages(start_message, capabilities))
255 new_response.append(item[8:])
256 item = b''.join(new_response)
257 header_injected = 1
258
259 return injected_response()
260
261 def _write_sideband_to_proto(self, proto, data, capabilities):
195 262 """
196 Write the data to the proto's sideband number 2.
263 Write the data to the proto's sideband number 2 == SIDE_BAND_CHANNEL_PROGRESS
197 264
198 265 We do not use dulwich's write_sideband directly as it only supports
199 266 side-band-64k.
200 267 """
201 268 if not data:
202 269 return
203 270
204 271 # N.B.(skreft): The values below are explained in the pack protocol
205 272 # documentation, section Packfile Data.
206 273 # https://github.com/git/git/blob/master/Documentation/technical/pack-protocol.txt
207 if 'side-band-64k' in capabilities:
274 if CAPABILITY_SIDE_BAND_64K in capabilities:
208 275 chunk_size = 65515
209 elif 'side-band' in capabilities:
276 elif CAPABILITY_SIDE_BAND in capabilities:
210 277 chunk_size = 995
211 278 else:
212 279 return
213 280
214 chunker = (
215 data[i:i + chunk_size] for i in range(0, len(data), chunk_size))
281 chunker = (data[i:i + chunk_size] for i in range(0, len(data), chunk_size))
216 282
217 283 for chunk in chunker:
218 proto.write_sideband(2, chunk)
284 proto.write_sideband(dulwich.protocol.SIDE_BAND_CHANNEL_PROGRESS, ascii_bytes(chunk, allow_bytes=True))
219 285
220 286 def _get_messages(self, data, capabilities):
221 287 """Return a list with packets for sending data in sideband number 2."""
222 288 response = []
223 289 proto = dulwich.protocol.Protocol(None, response.append)
224 290
225 self._write_sideband_to_proto(data, proto, capabilities)
291 self._write_sideband_to_proto(proto, data, capabilities)
226 292
227 293 return response
228 294
229 def _inject_messages_to_response(self, response, capabilities,
230 start_messages, end_messages):
231 """
232 Given a list response we inject the pre/post-pull messages.
233
234 We only inject the messages if the client supports sideband, and the
235 response has the format:
236 0008NAK\n...0000
237
238 Note that we do not check the no-progress capability as by default, git
239 sends it, which effectively would block all messages.
240 """
241 if not self.SIDE_BAND_CAPS.intersection(capabilities):
242 return response
243
244 if not start_messages and not end_messages:
245 return response
246
247 # make a list out of response if it's an iterator
248 # so we can investigate it for message injection.
249 if hasattr(response, '__iter__'):
250 response = list(response)
251
252 if (not response[0].startswith('0008NAK\n') or
253 not response[-1].endswith('0000')):
254 return response
255
256 new_response = ['0008NAK\n']
257 new_response.extend(self._get_messages(start_messages, capabilities))
258 if len(response) == 1:
259 new_response.append(response[0][8:-4])
260 else:
261 new_response.append(response[0][8:])
262 new_response.extend(response[1:-1])
263 new_response.append(response[-1][:-4])
264 new_response.extend(self._get_messages(end_messages, capabilities))
265 new_response.append('0000')
266
267 return new_response
268
269 295 def backend(self, request, environ):
270 296 """
271 297 WSGI Response producer for HTTP POST Git Smart HTTP requests.
272 298 Reads commands and data from HTTP POST's body.
273 299 returns an iterator obj with contents of git command's
274 300 response to stdout
275 301 """
276 302 # TODO(skreft): think how we could detect an HTTPLockedException, as
277 303 # we probably want to have the same mechanism used by mercurial and
278 304 # simplevcs.
279 305 # For that we would need to parse the output of the command looking for
280 306 # some signs of the HTTPLockedError, parse the data and reraise it in
281 307 # pygrack. However, that would interfere with the streaming.
282 308 #
283 309 # Now the output of a blocked push is:
284 310 # Pushing to http://test_regular:test12@127.0.0.1:5001/vcs_test_git
285 311 # POST git-receive-pack (1047 bytes)
286 312 # remote: ERROR: Repository `vcs_test_git` locked by user `test_admin`. Reason:`lock_auto`
287 313 # To http://test_regular:test12@127.0.0.1:5001/vcs_test_git
288 314 # ! [remote rejected] master -> master (pre-receive hook declined)
289 315 # error: failed to push some refs to 'http://test_regular:test12@127.0.0.1:5001/vcs_test_git'
290 316
291 317 git_command = self._get_fixedpath(request.path_info)
292 318 if git_command not in self.commands:
293 319 log.debug('command %s not allowed', git_command)
294 320 return exc.HTTPForbidden()
295 321
296 322 capabilities = None
297 323 if git_command == 'git-upload-pack':
298 324 capabilities = self._get_want_capabilities(request)
299 325
300 326 if 'CONTENT_LENGTH' in environ:
301 327 inputstream = FileWrapper(request.body_file_seekable,
302 328 request.content_length)
303 329 else:
304 330 inputstream = request.body_file_seekable
305 331
306 332 resp = Response()
307 resp.content_type = ('application/x-%s-result' %
308 git_command.encode('utf8'))
333 resp.content_type = 'application/x-{}-result'.format(git_command)
309 334 resp.charset = None
310 335
311 336 pre_pull_messages = ''
337 # Upload-pack == clone
312 338 if git_command == 'git-upload-pack':
313 339 status, pre_pull_messages = hooks.git_pre_pull(self.extras)
314 340 if status != 0:
315 341 resp.app_iter = self._build_failed_pre_pull_response(
316 342 capabilities, pre_pull_messages)
317 343 return resp
318 344
319 345 gitenv = dict(os.environ)
320 346 # forget all configs
321 347 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
322 348 gitenv['RC_SCM_DATA'] = json.dumps(self.extras)
323 349 cmd = [self.git_path, git_command[4:], '--stateless-rpc',
324 350 self.content_path]
325 351 log.debug('handling cmd %s', cmd)
326 352
327 353 out = subprocessio.SubprocessIOChunker(
328 354 cmd,
329 inputstream=inputstream,
355 input_stream=inputstream,
330 356 env=gitenv,
331 357 cwd=self.content_path,
332 358 shell=False,
333 359 fail_on_stderr=False,
334 360 fail_on_return_code=False
335 361 )
336 362
337 363 if self.update_server_info and git_command == 'git-receive-pack':
338 364 # We need to fully consume the iterator here, as the
339 365 # update-server-info command needs to be run after the push.
340 366 out = list(out)
341 367
342 368 # Updating refs manually after each push.
343 369 # This is required as some clients are exposing Git repos internally
344 370 # with the dumb protocol.
345 371 cmd = [self.git_path, 'update-server-info']
346 372 log.debug('handling cmd %s', cmd)
347 373 output = subprocessio.SubprocessIOChunker(
348 374 cmd,
349 inputstream=inputstream,
375 input_stream=inputstream,
350 376 env=gitenv,
351 377 cwd=self.content_path,
352 378 shell=False,
353 379 fail_on_stderr=False,
354 380 fail_on_return_code=False
355 381 )
356 382 # Consume all the output so the subprocess finishes
357 383 for _ in output:
358 384 pass
359 385
386 # Upload-pack == clone
360 387 if git_command == 'git-upload-pack':
361 388 unused_status, post_pull_messages = hooks.git_post_pull(self.extras)
362 resp.app_iter = self._inject_messages_to_response(
363 out, capabilities, pre_pull_messages, post_pull_messages)
389
390 resp.app_iter = self._build_post_pull_response(out, capabilities, pre_pull_messages, post_pull_messages)
364 391 else:
365 392 resp.app_iter = out
366 393
367 394 return resp
368 395
369 396 def __call__(self, environ, start_response):
370 397 request = Request(environ)
371 398 _path = self._get_fixedpath(request.path_info)
372 399 if _path.startswith('info/refs'):
373 400 app = self.inforefs
374 401 else:
375 402 app = self.backend
376 403
377 404 try:
378 405 resp = app(request, environ)
379 406 except exc.HTTPException as error:
380 407 log.exception('HTTP Error')
381 408 resp = error
382 409 except Exception:
383 410 log.exception('Unknown error')
384 411 resp = exc.HTTPInternalServerError()
385 412
386 413 return resp(environ, start_response)
@@ -1,1281 +1,1317 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import collections
19 19 import logging
20 20 import os
21 21 import posixpath as vcspath
22 22 import re
23 23 import stat
24 24 import traceback
25 25 import urllib.request, urllib.parse, urllib.error
26 26 import urllib.request, urllib.error, urllib.parse
27 27 from functools import wraps
28 28
29 29 import more_itertools
30 30 import pygit2
31 31 from pygit2 import Repository as LibGit2Repo
32 32 from pygit2 import index as LibGit2Index
33 33 from dulwich import index, objects
34 34 from dulwich.client import HttpGitClient, LocalGitClient
35 35 from dulwich.errors import (
36 36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 37 MissingCommitError, ObjectMissing, HangupException,
38 38 UnexpectedCommandError)
39 39 from dulwich.repo import Repo as DulwichRepo
40 40 from dulwich.server import update_server_info
41 41
42 42 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver.utils import safe_str, safe_int, safe_unicode
43 from vcsserver.utils import safe_str, safe_int
44 44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, archive_repo
45 45 from vcsserver.hgcompat import (
46 46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 47 from vcsserver.git_lfs.lib import LFSOidStore
48 48 from vcsserver.vcs_base import RemoteBase
49 49
50 50 DIR_STAT = stat.S_IFDIR
51 51 FILE_MODE = stat.S_IFMT
52 52 GIT_LINK = objects.S_IFGITLINK
53 53 PEELED_REF_MARKER = '^{}'
54 54
55 55
56 56 log = logging.getLogger(__name__)
57 57
58 58
59 def str_to_dulwich(value):
60 """
61 Dulwich 0.10.1a requires `unicode` objects to be passed in.
62 """
63 return value.decode(settings.WIRE_ENCODING)
64
65
66 59 def reraise_safe_exceptions(func):
67 60 """Converts Dulwich exceptions to something neutral."""
68 61
69 62 @wraps(func)
70 63 def wrapper(*args, **kwargs):
71 64 try:
72 65 return func(*args, **kwargs)
73 66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
74 67 exc = exceptions.LookupException(org_exc=e)
75 68 raise exc(safe_str(e))
76 69 except (HangupException, UnexpectedCommandError) as e:
77 70 exc = exceptions.VcsException(org_exc=e)
78 71 raise exc(safe_str(e))
79 72 except Exception as e:
80 73 # NOTE(marcink): becuase of how dulwich handles some exceptions
81 74 # (KeyError on empty repos), we cannot track this and catch all
82 75 # exceptions, it's an exceptions from other handlers
83 76 #if not hasattr(e, '_vcs_kind'):
84 77 #log.exception("Unhandled exception in git remote call")
85 78 #raise_from_original(exceptions.UnhandledException)
86 79 raise
87 80 return wrapper
88 81
89 82
90 83 class Repo(DulwichRepo):
91 84 """
92 85 A wrapper for dulwich Repo class.
93 86
94 87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
95 88 "Too many open files" error. We need to close all opened file descriptors
96 89 once the repo object is destroyed.
97 90 """
98 91 def __del__(self):
99 92 if hasattr(self, 'object_store'):
100 93 self.close()
101 94
102 95
103 96 class Repository(LibGit2Repo):
104 97
105 98 def __enter__(self):
106 99 return self
107 100
108 101 def __exit__(self, exc_type, exc_val, exc_tb):
109 102 self.free()
110 103
111 104
112 105 class GitFactory(RepoFactory):
113 106 repo_type = 'git'
114 107
115 108 def _create_repo(self, wire, create, use_libgit2=False):
116 109 if use_libgit2:
117 110 return Repository(wire['path'])
118 111 else:
119 repo_path = str_to_dulwich(wire['path'])
112 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
120 113 return Repo(repo_path)
121 114
122 115 def repo(self, wire, create=False, use_libgit2=False):
123 116 """
124 117 Get a repository instance for the given path.
125 118 """
126 119 return self._create_repo(wire, create, use_libgit2)
127 120
128 121 def repo_libgit2(self, wire):
129 122 return self.repo(wire, use_libgit2=True)
130 123
131 124
132 125 class GitRemote(RemoteBase):
133 126
134 127 def __init__(self, factory):
135 128 self._factory = factory
136 129 self._bulk_methods = {
137 130 "date": self.date,
138 131 "author": self.author,
139 132 "branch": self.branch,
140 133 "message": self.message,
141 134 "parents": self.parents,
142 135 "_commit": self.revision,
143 136 }
144 137
145 138 def _wire_to_config(self, wire):
146 139 if 'config' in wire:
147 140 return dict([(x[0] + '_' + x[1], x[2]) for x in wire['config']])
148 141 return {}
149 142
150 143 def _remote_conf(self, config):
151 144 params = [
152 145 '-c', 'core.askpass=""',
153 146 ]
154 147 ssl_cert_dir = config.get('vcs_ssl_dir')
155 148 if ssl_cert_dir:
156 149 params.extend(['-c', 'http.sslCAinfo={}'.format(ssl_cert_dir)])
157 150 return params
158 151
159 152 @reraise_safe_exceptions
160 153 def discover_git_version(self):
161 154 stdout, _ = self.run_git_command(
162 155 {}, ['--version'], _bare=True, _safe=True)
163 prefix = 'git version'
156 prefix = b'git version'
164 157 if stdout.startswith(prefix):
165 158 stdout = stdout[len(prefix):]
166 159 return stdout.strip()
167 160
168 161 @reraise_safe_exceptions
169 162 def is_empty(self, wire):
170 163 repo_init = self._factory.repo_libgit2(wire)
171 164 with repo_init as repo:
172 165
173 166 try:
174 167 has_head = repo.head.name
175 168 if has_head:
176 169 return False
177 170
178 171 # NOTE(marcink): check again using more expensive method
179 172 return repo.is_empty
180 173 except Exception:
181 174 pass
182 175
183 176 return True
184 177
185 178 @reraise_safe_exceptions
186 179 def assert_correct_path(self, wire):
187 180 cache_on, context_uid, repo_id = self._cache_on(wire)
188 181 region = self._region(wire)
182
189 183 @region.conditional_cache_on_arguments(condition=cache_on)
190 184 def _assert_correct_path(_context_uid, _repo_id):
191 185 try:
192 186 repo_init = self._factory.repo_libgit2(wire)
193 187 with repo_init as repo:
194 188 pass
195 189 except pygit2.GitError:
196 190 path = wire.get('path')
197 191 tb = traceback.format_exc()
198 192 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
199 193 return False
200 194
201 195 return True
202 196 return _assert_correct_path(context_uid, repo_id)
203 197
204 198 @reraise_safe_exceptions
205 199 def bare(self, wire):
206 200 repo_init = self._factory.repo_libgit2(wire)
207 201 with repo_init as repo:
208 202 return repo.is_bare
209 203
210 204 @reraise_safe_exceptions
211 205 def blob_as_pretty_string(self, wire, sha):
212 206 repo_init = self._factory.repo_libgit2(wire)
213 207 with repo_init as repo:
214 208 blob_obj = repo[sha]
215 209 blob = blob_obj.data
216 210 return blob
217 211
218 212 @reraise_safe_exceptions
219 213 def blob_raw_length(self, wire, sha):
220 214 cache_on, context_uid, repo_id = self._cache_on(wire)
221 215 region = self._region(wire)
216
222 217 @region.conditional_cache_on_arguments(condition=cache_on)
223 218 def _blob_raw_length(_repo_id, _sha):
224 219
225 220 repo_init = self._factory.repo_libgit2(wire)
226 221 with repo_init as repo:
227 222 blob = repo[sha]
228 223 return blob.size
229 224
230 225 return _blob_raw_length(repo_id, sha)
231 226
232 227 def _parse_lfs_pointer(self, raw_content):
228 spec_string = b'version https://git-lfs.github.com/spec'
229 if raw_content and raw_content.startswith(spec_string):
233 230
234 spec_string = 'version https://git-lfs.github.com/spec'
235 if raw_content and raw_content.startswith(spec_string):
236 pattern = re.compile(r"""
231 pattern = re.compile(rb"""
237 232 (?:\n)?
238 233 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
239 234 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
240 235 ^size[ ](?P<oid_size>[0-9]+)\n
241 236 (?:\n)?
242 237 """, re.VERBOSE | re.MULTILINE)
243 238 match = pattern.match(raw_content)
244 239 if match:
245 240 return match.groupdict()
246 241
247 242 return {}
248 243
249 244 @reraise_safe_exceptions
250 245 def is_large_file(self, wire, commit_id):
251 246 cache_on, context_uid, repo_id = self._cache_on(wire)
247 region = self._region(wire)
252 248
253 region = self._region(wire)
254 249 @region.conditional_cache_on_arguments(condition=cache_on)
255 250 def _is_large_file(_repo_id, _sha):
256 251 repo_init = self._factory.repo_libgit2(wire)
257 252 with repo_init as repo:
258 253 blob = repo[commit_id]
259 254 if blob.is_binary:
260 255 return {}
261 256
262 257 return self._parse_lfs_pointer(blob.data)
263 258
264 259 return _is_large_file(repo_id, commit_id)
265 260
266 261 @reraise_safe_exceptions
267 262 def is_binary(self, wire, tree_id):
268 263 cache_on, context_uid, repo_id = self._cache_on(wire)
264 region = self._region(wire)
269 265
270 region = self._region(wire)
271 266 @region.conditional_cache_on_arguments(condition=cache_on)
272 267 def _is_binary(_repo_id, _tree_id):
273 268 repo_init = self._factory.repo_libgit2(wire)
274 269 with repo_init as repo:
275 270 blob_obj = repo[tree_id]
276 271 return blob_obj.is_binary
277 272
278 273 return _is_binary(repo_id, tree_id)
279 274
280 275 @reraise_safe_exceptions
281 276 def in_largefiles_store(self, wire, oid):
282 277 conf = self._wire_to_config(wire)
283 278 repo_init = self._factory.repo_libgit2(wire)
284 279 with repo_init as repo:
285 280 repo_name = repo.path
286 281
287 282 store_location = conf.get('vcs_git_lfs_store_location')
288 283 if store_location:
289 284
290 285 store = LFSOidStore(
291 286 oid=oid, repo=repo_name, store_location=store_location)
292 287 return store.has_oid()
293 288
294 289 return False
295 290
296 291 @reraise_safe_exceptions
297 292 def store_path(self, wire, oid):
298 293 conf = self._wire_to_config(wire)
299 294 repo_init = self._factory.repo_libgit2(wire)
300 295 with repo_init as repo:
301 296 repo_name = repo.path
302 297
303 298 store_location = conf.get('vcs_git_lfs_store_location')
304 299 if store_location:
305 300 store = LFSOidStore(
306 301 oid=oid, repo=repo_name, store_location=store_location)
307 302 return store.oid_path
308 303 raise ValueError('Unable to fetch oid with path {}'.format(oid))
309 304
310 305 @reraise_safe_exceptions
311 306 def bulk_request(self, wire, rev, pre_load):
312 307 cache_on, context_uid, repo_id = self._cache_on(wire)
313 308 region = self._region(wire)
309
314 310 @region.conditional_cache_on_arguments(condition=cache_on)
315 311 def _bulk_request(_repo_id, _rev, _pre_load):
316 312 result = {}
317 313 for attr in pre_load:
318 314 try:
319 315 method = self._bulk_methods[attr]
320 316 args = [wire, rev]
321 317 result[attr] = method(*args)
322 318 except KeyError as e:
323 319 raise exceptions.VcsException(e)(
324 320 "Unknown bulk attribute: %s" % attr)
325 321 return result
326 322
327 323 return _bulk_request(repo_id, rev, sorted(pre_load))
328 324
329 325 def _build_opener(self, url):
330 326 handlers = []
331 327 url_obj = url_parser(url)
332 328 _, authinfo = url_obj.authinfo()
333 329
334 330 if authinfo:
335 331 # create a password manager
336 332 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
337 333 passmgr.add_password(*authinfo)
338 334
339 335 handlers.extend((httpbasicauthhandler(passmgr),
340 336 httpdigestauthhandler(passmgr)))
341 337
342 338 return urllib.request.build_opener(*handlers)
343 339
344 def _type_id_to_name(self, type_id):
340 def _type_id_to_name(self, type_id: int):
345 341 return {
346 1: b'commit',
347 2: b'tree',
348 3: b'blob',
349 4: b'tag'
342 1: 'commit',
343 2: 'tree',
344 3: 'blob',
345 4: 'tag'
350 346 }[type_id]
351 347
352 348 @reraise_safe_exceptions
353 349 def check_url(self, url, config):
354 350 url_obj = url_parser(url)
355 351 test_uri, _ = url_obj.authinfo()
356 352 url_obj.passwd = '*****' if url_obj.passwd else url_obj.passwd
357 353 url_obj.query = obfuscate_qs(url_obj.query)
358 354 cleaned_uri = str(url_obj)
359 355 log.info("Checking URL for remote cloning/import: %s", cleaned_uri)
360 356
361 357 if not test_uri.endswith('info/refs'):
362 358 test_uri = test_uri.rstrip('/') + '/info/refs'
363 359
364 360 o = self._build_opener(url)
365 361 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
366 362
367 363 q = {"service": 'git-upload-pack'}
368 364 qs = '?%s' % urllib.parse.urlencode(q)
369 365 cu = "%s%s" % (test_uri, qs)
370 366 req = urllib.request.Request(cu, None, {})
371 367
372 368 try:
373 369 log.debug("Trying to open URL %s", cleaned_uri)
374 370 resp = o.open(req)
375 371 if resp.code != 200:
376 372 raise exceptions.URLError()('Return Code is not 200')
377 373 except Exception as e:
378 374 log.warning("URL cannot be opened: %s", cleaned_uri, exc_info=True)
379 375 # means it cannot be cloned
380 376 raise exceptions.URLError(e)("[%s] org_exc: %s" % (cleaned_uri, e))
381 377
382 378 # now detect if it's proper git repo
383 379 gitdata = resp.read()
384 380 if 'service=git-upload-pack' in gitdata:
385 381 pass
386 382 elif re.findall(r'[0-9a-fA-F]{40}\s+refs', gitdata):
387 383 # old style git can return some other format !
388 384 pass
389 385 else:
390 386 raise exceptions.URLError()(
391 387 "url [%s] does not look like an git" % (cleaned_uri,))
392 388
393 389 return True
394 390
395 391 @reraise_safe_exceptions
396 392 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
397 393 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
398 394 remote_refs = self.pull(wire, url, apply_refs=False)
399 395 repo = self._factory.repo(wire)
400 396 if isinstance(valid_refs, list):
401 397 valid_refs = tuple(valid_refs)
402 398
403 399 for k in remote_refs:
404 400 # only parse heads/tags and skip so called deferred tags
405 401 if k.startswith(valid_refs) and not k.endswith(deferred):
406 402 repo[k] = remote_refs[k]
407 403
408 404 if update_after_clone:
409 405 # we want to checkout HEAD
410 406 repo["HEAD"] = remote_refs["HEAD"]
411 407 index.build_index_from_tree(repo.path, repo.index_path(),
412 408 repo.object_store, repo["HEAD"].tree)
413 409
414 410 @reraise_safe_exceptions
415 411 def branch(self, wire, commit_id):
416 412 cache_on, context_uid, repo_id = self._cache_on(wire)
417 413 region = self._region(wire)
418 414 @region.conditional_cache_on_arguments(condition=cache_on)
419 415 def _branch(_context_uid, _repo_id, _commit_id):
420 416 regex = re.compile('^refs/heads')
421 417
422 418 def filter_with(ref):
423 419 return regex.match(ref[0]) and ref[1] == _commit_id
424 420
425 421 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
426 422 return [x[0].split('refs/heads/')[-1] for x in branches]
427 423
428 424 return _branch(context_uid, repo_id, commit_id)
429 425
430 426 @reraise_safe_exceptions
431 427 def commit_branches(self, wire, commit_id):
432 428 cache_on, context_uid, repo_id = self._cache_on(wire)
433 429 region = self._region(wire)
434 430 @region.conditional_cache_on_arguments(condition=cache_on)
435 431 def _commit_branches(_context_uid, _repo_id, _commit_id):
436 432 repo_init = self._factory.repo_libgit2(wire)
437 433 with repo_init as repo:
438 434 branches = [x for x in repo.branches.with_commit(_commit_id)]
439 435 return branches
440 436
441 437 return _commit_branches(context_uid, repo_id, commit_id)
442 438
443 439 @reraise_safe_exceptions
444 440 def add_object(self, wire, content):
445 441 repo_init = self._factory.repo_libgit2(wire)
446 442 with repo_init as repo:
447 443 blob = objects.Blob()
448 444 blob.set_raw_string(content)
449 445 repo.object_store.add_object(blob)
450 446 return blob.id
451 447
452 448 # TODO: this is quite complex, check if that can be simplified
453 449 @reraise_safe_exceptions
454 450 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
455 451 # Defines the root tree
456 452 class _Root(object):
457 453 def __repr__(self):
458 454 return 'ROOT TREE'
459 455 ROOT = _Root()
460 456
461 457 repo = self._factory.repo(wire)
462 458 object_store = repo.object_store
463 459
464 460 # Create tree and populates it with blobs
465 461
466 462 if commit_tree and repo[commit_tree]:
467 463 git_commit = repo[commit_data['parents'][0]]
468 464 commit_tree = repo[git_commit.tree] # root tree
469 465 else:
470 466 commit_tree = objects.Tree()
471 467
472 468 for node in updated:
473 469 # Compute subdirs if needed
474 470 dirpath, nodename = vcspath.split(node['path'])
475 471 dirnames = list(map(safe_str, dirpath and dirpath.split('/') or []))
476 472 parent = commit_tree
477 473 ancestors = [('', parent)]
478 474
479 475 # Tries to dig for the deepest existing tree
480 476 while dirnames:
481 477 curdir = dirnames.pop(0)
482 478 try:
483 479 dir_id = parent[curdir][1]
484 480 except KeyError:
485 481 # put curdir back into dirnames and stops
486 482 dirnames.insert(0, curdir)
487 483 break
488 484 else:
489 485 # If found, updates parent
490 486 parent = repo[dir_id]
491 487 ancestors.append((curdir, parent))
492 488 # Now parent is deepest existing tree and we need to create
493 489 # subtrees for dirnames (in reverse order)
494 490 # [this only applies for nodes from added]
495 491 new_trees = []
496 492
497 493 blob = objects.Blob.from_string(node['content'])
498 494
499 495 if dirnames:
500 496 # If there are trees which should be created we need to build
501 497 # them now (in reverse order)
502 498 reversed_dirnames = list(reversed(dirnames))
503 499 curtree = objects.Tree()
504 500 curtree[node['node_path']] = node['mode'], blob.id
505 501 new_trees.append(curtree)
506 502 for dirname in reversed_dirnames[:-1]:
507 503 newtree = objects.Tree()
508 504 newtree[dirname] = (DIR_STAT, curtree.id)
509 505 new_trees.append(newtree)
510 506 curtree = newtree
511 507 parent[reversed_dirnames[-1]] = (DIR_STAT, curtree.id)
512 508 else:
513 509 parent.add(name=node['node_path'], mode=node['mode'], hexsha=blob.id)
514 510
515 511 new_trees.append(parent)
516 512 # Update ancestors
517 513 reversed_ancestors = reversed(
518 514 [(a[1], b[1], b[0]) for a, b in zip(ancestors, ancestors[1:])])
519 515 for parent, tree, path in reversed_ancestors:
520 516 parent[path] = (DIR_STAT, tree.id)
521 517 object_store.add_object(tree)
522 518
523 519 object_store.add_object(blob)
524 520 for tree in new_trees:
525 521 object_store.add_object(tree)
526 522
527 523 for node_path in removed:
528 524 paths = node_path.split('/')
529 525 tree = commit_tree # start with top-level
530 526 trees = [{'tree': tree, 'path': ROOT}]
531 527 # Traverse deep into the forest...
532 528 # resolve final tree by iterating the path.
533 529 # e.g a/b/c.txt will get
534 530 # - root as tree then
535 531 # - 'a' as tree,
536 532 # - 'b' as tree,
537 533 # - stop at c as blob.
538 534 for path in paths:
539 535 try:
540 536 obj = repo[tree[path][1]]
541 537 if isinstance(obj, objects.Tree):
542 538 trees.append({'tree': obj, 'path': path})
543 539 tree = obj
544 540 except KeyError:
545 541 break
546 542 #PROBLEM:
547 543 """
548 544 We're not editing same reference tree object
549 545 """
550 546 # Cut down the blob and all rotten trees on the way back...
551 547 for path, tree_data in reversed(list(zip(paths, trees))):
552 548 tree = tree_data['tree']
553 549 tree.__delitem__(path)
554 550 # This operation edits the tree, we need to mark new commit back
555 551
556 552 if len(tree) > 0:
557 553 # This tree still has elements - don't remove it or any
558 554 # of it's parents
559 555 break
560 556
561 557 object_store.add_object(commit_tree)
562 558
563 559 # Create commit
564 560 commit = objects.Commit()
565 561 commit.tree = commit_tree.id
566 562 for k, v in commit_data.items():
567 563 setattr(commit, k, v)
568 564 object_store.add_object(commit)
569 565
570 566 self.create_branch(wire, branch, commit.id)
571 567
572 568 # dulwich set-ref
573 569 ref = 'refs/heads/%s' % branch
574 570 repo.refs[ref] = commit.id
575 571
576 572 return commit.id
577 573
578 574 @reraise_safe_exceptions
579 575 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
580 576 if url != 'default' and '://' not in url:
581 577 client = LocalGitClient(url)
582 578 else:
583 579 url_obj = url_parser(url)
584 580 o = self._build_opener(url)
585 581 url, _ = url_obj.authinfo()
586 582 client = HttpGitClient(base_url=url, opener=o)
587 583 repo = self._factory.repo(wire)
588 584
589 585 determine_wants = repo.object_store.determine_wants_all
590 586 if refs:
591 587 def determine_wants_requested(references):
592 588 return [references[r] for r in references if r in refs]
593 589 determine_wants = determine_wants_requested
594 590
595 591 try:
596 592 remote_refs = client.fetch(
597 593 path=url, target=repo, determine_wants=determine_wants)
598 594 except NotGitRepository as e:
599 595 log.warning(
600 596 'Trying to fetch from "%s" failed, not a Git repository.', url)
601 597 # Exception can contain unicode which we convert
602 598 raise exceptions.AbortException(e)(repr(e))
603 599
604 600 # mikhail: client.fetch() returns all the remote refs, but fetches only
605 601 # refs filtered by `determine_wants` function. We need to filter result
606 602 # as well
607 603 if refs:
608 604 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
609 605
610 606 if apply_refs:
611 607 # TODO: johbo: Needs proper test coverage with a git repository
612 608 # that contains a tag object, so that we would end up with
613 609 # a peeled ref at this point.
614 610 for k in remote_refs:
615 611 if k.endswith(PEELED_REF_MARKER):
616 612 log.debug("Skipping peeled reference %s", k)
617 613 continue
618 614 repo[k] = remote_refs[k]
619 615
620 616 if refs and not update_after:
621 617 # mikhail: explicitly set the head to the last ref.
622 618 repo["HEAD"] = remote_refs[refs[-1]]
623 619
624 620 if update_after:
625 621 # we want to checkout HEAD
626 622 repo["HEAD"] = remote_refs["HEAD"]
627 623 index.build_index_from_tree(repo.path, repo.index_path(),
628 624 repo.object_store, repo["HEAD"].tree)
629 625 return remote_refs
630 626
631 627 @reraise_safe_exceptions
632 628 def sync_fetch(self, wire, url, refs=None, all_refs=False):
633 629 repo = self._factory.repo(wire)
634 630 if refs and not isinstance(refs, (list, tuple)):
635 631 refs = [refs]
636 632
637 633 config = self._wire_to_config(wire)
638 634 # get all remote refs we'll use to fetch later
639 635 cmd = ['ls-remote']
640 636 if not all_refs:
641 637 cmd += ['--heads', '--tags']
642 638 cmd += [url]
643 639 output, __ = self.run_git_command(
644 640 wire, cmd, fail_on_stderr=False,
645 641 _copts=self._remote_conf(config),
646 642 extra_env={'GIT_TERMINAL_PROMPT': '0'})
647 643
648 644 remote_refs = collections.OrderedDict()
649 645 fetch_refs = []
650 646
651 647 for ref_line in output.splitlines():
652 648 sha, ref = ref_line.split('\t')
653 649 sha = sha.strip()
654 650 if ref in remote_refs:
655 651 # duplicate, skip
656 652 continue
657 653 if ref.endswith(PEELED_REF_MARKER):
658 654 log.debug("Skipping peeled reference %s", ref)
659 655 continue
660 656 # don't sync HEAD
661 657 if ref in ['HEAD']:
662 658 continue
663 659
664 660 remote_refs[ref] = sha
665 661
666 662 if refs and sha in refs:
667 663 # we filter fetch using our specified refs
668 664 fetch_refs.append('{}:{}'.format(ref, ref))
669 665 elif not refs:
670 666 fetch_refs.append('{}:{}'.format(ref, ref))
671 667 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
672 668
673 669 if fetch_refs:
674 670 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
675 671 fetch_refs_chunks = list(chunk)
676 672 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
677 _out, _err = self.run_git_command(
673 self.run_git_command(
678 674 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
679 675 fail_on_stderr=False,
680 676 _copts=self._remote_conf(config),
681 677 extra_env={'GIT_TERMINAL_PROMPT': '0'})
682 678
683 679 return remote_refs
684 680
685 681 @reraise_safe_exceptions
686 682 def sync_push(self, wire, url, refs=None):
687 683 if not self.check_url(url, wire):
688 684 return
689 685 config = self._wire_to_config(wire)
690 686 self._factory.repo(wire)
691 687 self.run_git_command(
692 688 wire, ['push', url, '--mirror'], fail_on_stderr=False,
693 689 _copts=self._remote_conf(config),
694 690 extra_env={'GIT_TERMINAL_PROMPT': '0'})
695 691
696 692 @reraise_safe_exceptions
697 693 def get_remote_refs(self, wire, url):
698 694 repo = Repo(url)
699 695 return repo.get_refs()
700 696
701 697 @reraise_safe_exceptions
702 698 def get_description(self, wire):
703 699 repo = self._factory.repo(wire)
704 700 return repo.get_description()
705 701
706 702 @reraise_safe_exceptions
707 703 def get_missing_revs(self, wire, rev1, rev2, path2):
708 704 repo = self._factory.repo(wire)
709 705 LocalGitClient(thin_packs=False).fetch(path2, repo)
710 706
711 707 wire_remote = wire.copy()
712 708 wire_remote['path'] = path2
713 709 repo_remote = self._factory.repo(wire_remote)
714 710 LocalGitClient(thin_packs=False).fetch(wire["path"], repo_remote)
715 711
716 712 revs = [
717 713 x.commit.id
718 714 for x in repo_remote.get_walker(include=[rev2], exclude=[rev1])]
719 715 return revs
720 716
721 717 @reraise_safe_exceptions
722 718 def get_object(self, wire, sha, maybe_unreachable=False):
723 719 cache_on, context_uid, repo_id = self._cache_on(wire)
724 720 region = self._region(wire)
721
725 722 @region.conditional_cache_on_arguments(condition=cache_on)
726 723 def _get_object(_context_uid, _repo_id, _sha):
727 724 repo_init = self._factory.repo_libgit2(wire)
728 725 with repo_init as repo:
729 726
730 727 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
731 728 try:
732 729 commit = repo.revparse_single(sha)
733 730 except KeyError:
734 731 # NOTE(marcink): KeyError doesn't give us any meaningful information
735 732 # here, we instead give something more explicit
736 733 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
737 734 raise exceptions.LookupException(e)(missing_commit_err)
738 735 except ValueError as e:
739 736 raise exceptions.LookupException(e)(missing_commit_err)
740 737
741 738 is_tag = False
742 739 if isinstance(commit, pygit2.Tag):
743 740 commit = repo.get(commit.target)
744 741 is_tag = True
745 742
746 743 check_dangling = True
747 744 if is_tag:
748 745 check_dangling = False
749 746
750 747 if check_dangling and maybe_unreachable:
751 748 check_dangling = False
752 749
753 750 # we used a reference and it parsed means we're not having a dangling commit
754 751 if sha != commit.hex:
755 752 check_dangling = False
756 753
757 754 if check_dangling:
758 755 # check for dangling commit
759 756 for branch in repo.branches.with_commit(commit.hex):
760 757 if branch:
761 758 break
762 759 else:
763 760 # NOTE(marcink): Empty error doesn't give us any meaningful information
764 761 # here, we instead give something more explicit
765 762 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
766 763 raise exceptions.LookupException(e)(missing_commit_err)
767 764
768 765 commit_id = commit.hex
769 type_id = commit.type_str
766 type_id = commit.type
770 767
771 768 return {
772 769 'id': commit_id,
773 770 'type': self._type_id_to_name(type_id),
774 771 'commit_id': commit_id,
775 772 'idx': 0
776 773 }
777 774
778 775 return _get_object(context_uid, repo_id, sha)
779 776
780 777 @reraise_safe_exceptions
781 778 def get_refs(self, wire):
782 779 cache_on, context_uid, repo_id = self._cache_on(wire)
783 780 region = self._region(wire)
781
784 782 @region.conditional_cache_on_arguments(condition=cache_on)
785 783 def _get_refs(_context_uid, _repo_id):
786 784
787 785 repo_init = self._factory.repo_libgit2(wire)
788 786 with repo_init as repo:
789 787 regex = re.compile('^refs/(heads|tags)/')
790 788 return {x.name: x.target.hex for x in
791 789 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
792 790
793 791 return _get_refs(context_uid, repo_id)
794 792
795 793 @reraise_safe_exceptions
796 794 def get_branch_pointers(self, wire):
797 795 cache_on, context_uid, repo_id = self._cache_on(wire)
798 796 region = self._region(wire)
797
799 798 @region.conditional_cache_on_arguments(condition=cache_on)
800 799 def _get_branch_pointers(_context_uid, _repo_id):
801 800
802 801 repo_init = self._factory.repo_libgit2(wire)
803 802 regex = re.compile('^refs/heads')
804 803 with repo_init as repo:
805 804 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
806 805 return {x.target.hex: x.shorthand for x in branches}
807 806
808 807 return _get_branch_pointers(context_uid, repo_id)
809 808
810 809 @reraise_safe_exceptions
811 810 def head(self, wire, show_exc=True):
812 811 cache_on, context_uid, repo_id = self._cache_on(wire)
813 812 region = self._region(wire)
813
814 814 @region.conditional_cache_on_arguments(condition=cache_on)
815 815 def _head(_context_uid, _repo_id, _show_exc):
816 816 repo_init = self._factory.repo_libgit2(wire)
817 817 with repo_init as repo:
818 818 try:
819 819 return repo.head.peel().hex
820 820 except Exception:
821 821 if show_exc:
822 822 raise
823 823 return _head(context_uid, repo_id, show_exc)
824 824
825 825 @reraise_safe_exceptions
826 826 def init(self, wire):
827 827 repo_path = str_to_dulwich(wire['path'])
828 828 self.repo = Repo.init(repo_path)
829 829
830 830 @reraise_safe_exceptions
831 831 def init_bare(self, wire):
832 832 repo_path = str_to_dulwich(wire['path'])
833 833 self.repo = Repo.init_bare(repo_path)
834 834
835 835 @reraise_safe_exceptions
836 836 def revision(self, wire, rev):
837 837
838 838 cache_on, context_uid, repo_id = self._cache_on(wire)
839 839 region = self._region(wire)
840
840 841 @region.conditional_cache_on_arguments(condition=cache_on)
841 842 def _revision(_context_uid, _repo_id, _rev):
842 843 repo_init = self._factory.repo_libgit2(wire)
843 844 with repo_init as repo:
844 845 commit = repo[rev]
845 846 obj_data = {
846 847 'id': commit.id.hex,
847 848 }
848 849 # tree objects itself don't have tree_id attribute
849 850 if hasattr(commit, 'tree_id'):
850 851 obj_data['tree'] = commit.tree_id.hex
851 852
852 853 return obj_data
853 854 return _revision(context_uid, repo_id, rev)
854 855
855 856 @reraise_safe_exceptions
856 857 def date(self, wire, commit_id):
857 858 cache_on, context_uid, repo_id = self._cache_on(wire)
858 859 region = self._region(wire)
860
859 861 @region.conditional_cache_on_arguments(condition=cache_on)
860 862 def _date(_repo_id, _commit_id):
861 863 repo_init = self._factory.repo_libgit2(wire)
862 864 with repo_init as repo:
863 865 commit = repo[commit_id]
864 866
865 867 if hasattr(commit, 'commit_time'):
866 868 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
867 869 else:
868 870 commit = commit.get_object()
869 871 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
870 872
871 873 # TODO(marcink): check dulwich difference of offset vs timezone
872 874 return [commit_time, commit_time_offset]
873 875 return _date(repo_id, commit_id)
874 876
875 877 @reraise_safe_exceptions
876 878 def author(self, wire, commit_id):
877 879 cache_on, context_uid, repo_id = self._cache_on(wire)
878 880 region = self._region(wire)
881
879 882 @region.conditional_cache_on_arguments(condition=cache_on)
880 883 def _author(_repo_id, _commit_id):
881 884 repo_init = self._factory.repo_libgit2(wire)
882 885 with repo_init as repo:
883 886 commit = repo[commit_id]
884 887
885 888 if hasattr(commit, 'author'):
886 889 author = commit.author
887 890 else:
888 891 author = commit.get_object().author
889 892
890 893 if author.email:
891 894 return "{} <{}>".format(author.name, author.email)
892 895
893 896 try:
894 897 return "{}".format(author.name)
895 898 except Exception:
896 return "{}".format(safe_unicode(author.raw_name))
899 return "{}".format(safe_str(author.raw_name))
897 900
898 901 return _author(repo_id, commit_id)
899 902
900 903 @reraise_safe_exceptions
901 904 def message(self, wire, commit_id):
902 905 cache_on, context_uid, repo_id = self._cache_on(wire)
903 906 region = self._region(wire)
904 907 @region.conditional_cache_on_arguments(condition=cache_on)
905 908 def _message(_repo_id, _commit_id):
906 909 repo_init = self._factory.repo_libgit2(wire)
907 910 with repo_init as repo:
908 911 commit = repo[commit_id]
909 912 return commit.message
910 913 return _message(repo_id, commit_id)
911 914
912 915 @reraise_safe_exceptions
913 916 def parents(self, wire, commit_id):
914 917 cache_on, context_uid, repo_id = self._cache_on(wire)
915 918 region = self._region(wire)
916 919 @region.conditional_cache_on_arguments(condition=cache_on)
917 920 def _parents(_repo_id, _commit_id):
918 921 repo_init = self._factory.repo_libgit2(wire)
919 922 with repo_init as repo:
920 923 commit = repo[commit_id]
921 924 if hasattr(commit, 'parent_ids'):
922 925 parent_ids = commit.parent_ids
923 926 else:
924 927 parent_ids = commit.get_object().parent_ids
925 928
926 929 return [x.hex for x in parent_ids]
927 930 return _parents(repo_id, commit_id)
928 931
929 932 @reraise_safe_exceptions
930 933 def children(self, wire, commit_id):
931 934 cache_on, context_uid, repo_id = self._cache_on(wire)
932 935 region = self._region(wire)
936
933 937 @region.conditional_cache_on_arguments(condition=cache_on)
934 938 def _children(_repo_id, _commit_id):
935 939 output, __ = self.run_git_command(
936 940 wire, ['rev-list', '--all', '--children'])
937 941
938 942 child_ids = []
939 943 pat = re.compile(r'^%s' % commit_id)
940 944 for l in output.splitlines():
941 945 if pat.match(l):
942 946 found_ids = l.split(' ')[1:]
943 947 child_ids.extend(found_ids)
944 948
945 949 return child_ids
946 950 return _children(repo_id, commit_id)
947 951
948 952 @reraise_safe_exceptions
949 953 def set_refs(self, wire, key, value):
950 954 repo_init = self._factory.repo_libgit2(wire)
951 955 with repo_init as repo:
952 956 repo.references.create(key, value, force=True)
953 957
954 958 @reraise_safe_exceptions
955 959 def create_branch(self, wire, branch_name, commit_id, force=False):
956 960 repo_init = self._factory.repo_libgit2(wire)
957 961 with repo_init as repo:
958 962 commit = repo[commit_id]
959 963
960 964 if force:
961 965 repo.branches.local.create(branch_name, commit, force=force)
962 966 elif not repo.branches.get(branch_name):
963 967 # create only if that branch isn't existing
964 968 repo.branches.local.create(branch_name, commit, force=force)
965 969
966 970 @reraise_safe_exceptions
967 971 def remove_ref(self, wire, key):
968 972 repo_init = self._factory.repo_libgit2(wire)
969 973 with repo_init as repo:
970 974 repo.references.delete(key)
971 975
972 976 @reraise_safe_exceptions
973 977 def tag_remove(self, wire, tag_name):
974 978 repo_init = self._factory.repo_libgit2(wire)
975 979 with repo_init as repo:
976 980 key = 'refs/tags/{}'.format(tag_name)
977 981 repo.references.delete(key)
978 982
979 983 @reraise_safe_exceptions
980 984 def tree_changes(self, wire, source_id, target_id):
981 985 # TODO(marcink): remove this seems it's only used by tests
982 986 repo = self._factory.repo(wire)
983 987 source = repo[source_id].tree if source_id else None
984 988 target = repo[target_id].tree
985 989 result = repo.object_store.tree_changes(source, target)
986 990 return list(result)
987 991
988 992 @reraise_safe_exceptions
989 993 def tree_and_type_for_path(self, wire, commit_id, path):
990 994
991 995 cache_on, context_uid, repo_id = self._cache_on(wire)
992 996 region = self._region(wire)
997
993 998 @region.conditional_cache_on_arguments(condition=cache_on)
994 999 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
995 1000 repo_init = self._factory.repo_libgit2(wire)
996 1001
997 1002 with repo_init as repo:
998 1003 commit = repo[commit_id]
999 1004 try:
1000 1005 tree = commit.tree[path]
1001 1006 except KeyError:
1002 1007 return None, None, None
1003 1008
1004 1009 return tree.id.hex, tree.type_str, tree.filemode
1005 1010 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1006 1011
1007 1012 @reraise_safe_exceptions
1008 1013 def tree_items(self, wire, tree_id):
1009 1014 cache_on, context_uid, repo_id = self._cache_on(wire)
1010 1015 region = self._region(wire)
1016
1011 1017 @region.conditional_cache_on_arguments(condition=cache_on)
1012 1018 def _tree_items(_repo_id, _tree_id):
1013 1019
1014 1020 repo_init = self._factory.repo_libgit2(wire)
1015 1021 with repo_init as repo:
1016 1022 try:
1017 1023 tree = repo[tree_id]
1018 1024 except KeyError:
1019 1025 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1020 1026
1021 1027 result = []
1022 1028 for item in tree:
1023 1029 item_sha = item.hex
1024 1030 item_mode = item.filemode
1025 1031 item_type = item.type_str
1026 1032
1027 1033 if item_type == 'commit':
1028 1034 # NOTE(marcink): submodules we translate to 'link' for backward compat
1029 1035 item_type = 'link'
1030 1036
1031 1037 result.append((item.name, item_mode, item_sha, item_type))
1032 1038 return result
1033 1039 return _tree_items(repo_id, tree_id)
1034 1040
1035 1041 @reraise_safe_exceptions
1036 1042 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1037 1043 """
1038 1044 Old version that uses subprocess to call diff
1039 1045 """
1040 1046
1041 1047 flags = [
1042 1048 '-U%s' % context, '--patch',
1043 1049 '--binary',
1044 1050 '--find-renames',
1045 1051 '--no-indent-heuristic',
1046 1052 # '--indent-heuristic',
1047 1053 #'--full-index',
1048 1054 #'--abbrev=40'
1049 1055 ]
1050 1056
1051 1057 if opt_ignorews:
1052 1058 flags.append('--ignore-all-space')
1053 1059
1054 1060 if commit_id_1 == self.EMPTY_COMMIT:
1055 1061 cmd = ['show'] + flags + [commit_id_2]
1056 1062 else:
1057 1063 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1058 1064
1059 1065 if file_filter:
1060 1066 cmd.extend(['--', file_filter])
1061 1067
1062 1068 diff, __ = self.run_git_command(wire, cmd)
1063 1069 # If we used 'show' command, strip first few lines (until actual diff
1064 1070 # starts)
1065 1071 if commit_id_1 == self.EMPTY_COMMIT:
1066 1072 lines = diff.splitlines()
1067 1073 x = 0
1068 1074 for line in lines:
1069 if line.startswith('diff'):
1075 if line.startswith(b'diff'):
1070 1076 break
1071 1077 x += 1
1072 1078 # Append new line just like 'diff' command do
1073 1079 diff = '\n'.join(lines[x:]) + '\n'
1074 1080 return diff
1075 1081
1076 1082 @reraise_safe_exceptions
1077 1083 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1078 1084 repo_init = self._factory.repo_libgit2(wire)
1079 1085 with repo_init as repo:
1080 1086 swap = True
1081 1087 flags = 0
1082 1088 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1083 1089
1084 1090 if opt_ignorews:
1085 1091 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1086 1092
1087 1093 if commit_id_1 == self.EMPTY_COMMIT:
1088 1094 comm1 = repo[commit_id_2]
1089 1095 diff_obj = comm1.tree.diff_to_tree(
1090 1096 flags=flags, context_lines=context, swap=swap)
1091 1097
1092 1098 else:
1093 1099 comm1 = repo[commit_id_2]
1094 1100 comm2 = repo[commit_id_1]
1095 1101 diff_obj = comm1.tree.diff_to_tree(
1096 1102 comm2.tree, flags=flags, context_lines=context, swap=swap)
1097 1103 similar_flags = 0
1098 1104 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1099 1105 diff_obj.find_similar(flags=similar_flags)
1100 1106
1101 1107 if file_filter:
1102 1108 for p in diff_obj:
1103 1109 if p.delta.old_file.path == file_filter:
1104 1110 return p.patch or ''
1105 1111 # fo matching path == no diff
1106 1112 return ''
1107 1113 return diff_obj.patch or ''
1108 1114
1109 1115 @reraise_safe_exceptions
1110 1116 def node_history(self, wire, commit_id, path, limit):
1111 1117 cache_on, context_uid, repo_id = self._cache_on(wire)
1112 1118 region = self._region(wire)
1119
1113 1120 @region.conditional_cache_on_arguments(condition=cache_on)
1114 1121 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1115 1122 # optimize for n==1, rev-list is much faster for that use-case
1116 1123 if limit == 1:
1117 1124 cmd = ['rev-list', '-1', commit_id, '--', path]
1118 1125 else:
1119 1126 cmd = ['log']
1120 1127 if limit:
1121 1128 cmd.extend(['-n', str(safe_int(limit, 0))])
1122 1129 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1123 1130
1124 1131 output, __ = self.run_git_command(wire, cmd)
1125 commit_ids = re.findall(r'[0-9a-fA-F]{40}', output)
1132 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1126 1133
1127 1134 return [x for x in commit_ids]
1128 1135 return _node_history(context_uid, repo_id, commit_id, path, limit)
1129 1136
1130 1137 @reraise_safe_exceptions
1131 def node_annotate(self, wire, commit_id, path):
1132
1138 def node_annotate_legacy(self, wire, commit_id, path):
1139 #note: replaced by pygit2 impelementation
1133 1140 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1134 1141 # -l ==> outputs long shas (and we need all 40 characters)
1135 1142 # --root ==> doesn't put '^' character for boundaries
1136 1143 # -r commit_id ==> blames for the given commit
1137 1144 output, __ = self.run_git_command(wire, cmd)
1138 1145
1139 1146 result = []
1140 for i, blame_line in enumerate(output.split('\n')[:-1]):
1147 for i, blame_line in enumerate(output.splitlines()[:-1]):
1141 1148 line_no = i + 1
1142 commit_id, line = re.split(r' ', blame_line, 1)
1143 result.append((line_no, commit_id, line))
1149 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1150 result.append((line_no, blame_commit_id, line))
1151
1144 1152 return result
1145 1153
1146 1154 @reraise_safe_exceptions
1155 def node_annotate(self, wire, commit_id, path):
1156
1157 result_libgit = []
1158 repo_init = self._factory.repo_libgit2(wire)
1159 with repo_init as repo:
1160 commit = repo[commit_id]
1161 blame_obj = repo.blame(path, newest_commit=commit_id)
1162 for i, line in enumerate(commit.tree[path].data.splitlines()):
1163 line_no = i + 1
1164 hunk = blame_obj.for_line(line_no)
1165 blame_commit_id = hunk.final_commit_id.hex
1166
1167 result_libgit.append((line_no, blame_commit_id, line))
1168
1169 return result_libgit
1170
1171 @reraise_safe_exceptions
1147 1172 def update_server_info(self, wire):
1148 1173 repo = self._factory.repo(wire)
1149 1174 update_server_info(repo)
1150 1175
1151 1176 @reraise_safe_exceptions
1152 1177 def get_all_commit_ids(self, wire):
1153 1178
1154 1179 cache_on, context_uid, repo_id = self._cache_on(wire)
1155 1180 region = self._region(wire)
1181
1156 1182 @region.conditional_cache_on_arguments(condition=cache_on)
1157 1183 def _get_all_commit_ids(_context_uid, _repo_id):
1158 1184
1159 1185 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1160 1186 try:
1161 1187 output, __ = self.run_git_command(wire, cmd)
1162 1188 return output.splitlines()
1163 1189 except Exception:
1164 1190 # Can be raised for empty repositories
1165 1191 return []
1192
1193 @region.conditional_cache_on_arguments(condition=cache_on)
1194 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1195 repo_init = self._factory.repo_libgit2(wire)
1196 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1197 results = []
1198 with repo_init as repo:
1199 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1200 results.append(commit.id.hex)
1201
1166 1202 return _get_all_commit_ids(context_uid, repo_id)
1167 1203
1168 1204 @reraise_safe_exceptions
1169 1205 def run_git_command(self, wire, cmd, **opts):
1170 1206 path = wire.get('path', None)
1171 1207
1172 1208 if path and os.path.isdir(path):
1173 1209 opts['cwd'] = path
1174 1210
1175 1211 if '_bare' in opts:
1176 1212 _copts = []
1177 1213 del opts['_bare']
1178 1214 else:
1179 1215 _copts = ['-c', 'core.quotepath=false', ]
1180 1216 safe_call = False
1181 1217 if '_safe' in opts:
1182 1218 # no exc on failure
1183 1219 del opts['_safe']
1184 1220 safe_call = True
1185 1221
1186 1222 if '_copts' in opts:
1187 1223 _copts.extend(opts['_copts'] or [])
1188 1224 del opts['_copts']
1189 1225
1190 1226 gitenv = os.environ.copy()
1191 1227 gitenv.update(opts.pop('extra_env', {}))
1192 1228 # need to clean fix GIT_DIR !
1193 1229 if 'GIT_DIR' in gitenv:
1194 1230 del gitenv['GIT_DIR']
1195 1231 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1196 1232 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1197 1233
1198 1234 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1199 1235 _opts = {'env': gitenv, 'shell': False}
1200 1236
1201 1237 proc = None
1202 1238 try:
1203 1239 _opts.update(opts)
1204 1240 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1205 1241
1206 return ''.join(proc), ''.join(proc.error)
1207 except (EnvironmentError, OSError) as err:
1242 return b''.join(proc), b''.join(proc.stderr)
1243 except OSError as err:
1208 1244 cmd = ' '.join(cmd) # human friendly CMD
1209 1245 tb_err = ("Couldn't run git command (%s).\n"
1210 1246 "Original error was:%s\n"
1211 1247 "Call options:%s\n"
1212 1248 % (cmd, err, _opts))
1213 1249 log.exception(tb_err)
1214 1250 if safe_call:
1215 1251 return '', err
1216 1252 else:
1217 1253 raise exceptions.VcsException()(tb_err)
1218 1254 finally:
1219 1255 if proc:
1220 1256 proc.close()
1221 1257
1222 1258 @reraise_safe_exceptions
1223 1259 def install_hooks(self, wire, force=False):
1224 1260 from vcsserver.hook_utils import install_git_hooks
1225 1261 bare = self.bare(wire)
1226 1262 path = wire['path']
1227 1263 return install_git_hooks(path, bare, force_create=force)
1228 1264
1229 1265 @reraise_safe_exceptions
1230 1266 def get_hooks_info(self, wire):
1231 1267 from vcsserver.hook_utils import (
1232 1268 get_git_pre_hook_version, get_git_post_hook_version)
1233 1269 bare = self.bare(wire)
1234 1270 path = wire['path']
1235 1271 return {
1236 1272 'pre_version': get_git_pre_hook_version(path, bare),
1237 1273 'post_version': get_git_post_hook_version(path, bare),
1238 1274 }
1239 1275
1240 1276 @reraise_safe_exceptions
1241 1277 def set_head_ref(self, wire, head_name):
1242 1278 log.debug('Setting refs/head to `%s`', head_name)
1243 1279 cmd = ['symbolic-ref', '"HEAD"', '"refs/heads/%s"' % head_name]
1244 1280 output, __ = self.run_git_command(wire, cmd)
1245 1281 return [head_name] + output.splitlines()
1246 1282
1247 1283 @reraise_safe_exceptions
1248 1284 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
1249 1285 archive_dir_name, commit_id):
1250 1286
1251 1287 def file_walker(_commit_id, path):
1252 1288 repo_init = self._factory.repo_libgit2(wire)
1253 1289
1254 1290 with repo_init as repo:
1255 1291 commit = repo[commit_id]
1256 1292
1257 1293 if path in ['', '/']:
1258 1294 tree = commit.tree
1259 1295 else:
1260 1296 tree = commit.tree[path.rstrip('/')]
1261 1297 tree_id = tree.id.hex
1262 1298 try:
1263 1299 tree = repo[tree_id]
1264 1300 except KeyError:
1265 1301 raise ObjectMissing('No tree with id: {}'.format(tree_id))
1266 1302
1267 1303 index = LibGit2Index.Index()
1268 1304 index.read_tree(tree)
1269 1305 file_iter = index
1270 1306
1271 1307 for fn in file_iter:
1272 1308 file_path = fn.path
1273 1309 mode = fn.mode
1274 1310 is_link = stat.S_ISLNK(mode)
1275 1311 if mode == pygit2.GIT_FILEMODE_COMMIT:
1276 1312 log.debug('Skipping path %s as a commit node', file_path)
1277 1313 continue
1278 1314 yield ArchiveNode(file_path, mode, is_link, repo[fn.hex].read_raw)
1279 1315
1280 1316 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
1281 1317 archive_dir_name, commit_id)
@@ -1,1046 +1,1057 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import logging
20 20 import stat
21 21 import urllib.request, urllib.parse, urllib.error
22 22 import urllib.request, urllib.error, urllib.parse
23 23 import traceback
24 24
25 25 from hgext import largefiles, rebase, purge
26 26
27 27 from mercurial import commands
28 28 from mercurial import unionrepo
29 29 from mercurial import verify
30 30 from mercurial import repair
31 31
32 32 import vcsserver
33 33 from vcsserver import exceptions
34 34 from vcsserver.base import RepoFactory, obfuscate_qs, raise_from_original, archive_repo, ArchiveNode
35 35 from vcsserver.hgcompat import (
36 36 archival, bin, clone, config as hgconfig, diffopts, hex, get_ctx,
37 37 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler,
38 38 makepeer, instance, match, memctx, exchange, memfilectx, nullrev, hg_merge,
39 39 patch, peer, revrange, ui, hg_tag, Abort, LookupError, RepoError,
40 40 RepoLookupError, InterventionRequired, RequirementError,
41 41 alwaysmatcher, patternmatcher, hgutil, hgext_strip)
42 from vcsserver.utils import ascii_bytes, ascii_str, safe_str
42 43 from vcsserver.vcs_base import RemoteBase
43 44
44 45 log = logging.getLogger(__name__)
45 46
46 47
47 48 def make_ui_from_config(repo_config):
48 49
49 50 class LoggingUI(ui.ui):
51
50 52 def status(self, *msg, **opts):
51 log.info(' '.join(msg).rstrip('\n'))
52 super(LoggingUI, self).status(*msg, **opts)
53 str_msg = map(safe_str, msg)
54 log.info(' '.join(str_msg).rstrip('\n'))
55 #super(LoggingUI, self).status(*msg, **opts)
53 56
54 57 def warn(self, *msg, **opts):
55 log.warn(' '.join(msg).rstrip('\n'))
56 super(LoggingUI, self).warn(*msg, **opts)
58 str_msg = map(safe_str, msg)
59 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
60 #super(LoggingUI, self).warn(*msg, **opts)
57 61
58 62 def error(self, *msg, **opts):
59 log.error(' '.join(msg).rstrip('\n'))
60 super(LoggingUI, self).error(*msg, **opts)
63 str_msg = map(safe_str, msg)
64 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
65 #super(LoggingUI, self).error(*msg, **opts)
61 66
62 67 def note(self, *msg, **opts):
63 log.info(' '.join(msg).rstrip('\n'))
64 super(LoggingUI, self).note(*msg, **opts)
68 str_msg = map(safe_str, msg)
69 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
70 #super(LoggingUI, self).note(*msg, **opts)
65 71
66 72 def debug(self, *msg, **opts):
67 log.debug(' '.join(msg).rstrip('\n'))
68 super(LoggingUI, self).debug(*msg, **opts)
73 str_msg = map(safe_str, msg)
74 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
75 #super(LoggingUI, self).debug(*msg, **opts)
69 76
70 77 baseui = LoggingUI()
71 78
72 79 # clean the baseui object
73 80 baseui._ocfg = hgconfig.config()
74 81 baseui._ucfg = hgconfig.config()
75 82 baseui._tcfg = hgconfig.config()
76 83
77 84 for section, option, value in repo_config:
78 baseui.setconfig(section, option, value)
85 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
79 86
80 87 # make our hgweb quiet so it doesn't print output
81 baseui.setconfig('ui', 'quiet', 'true')
88 baseui.setconfig(b'ui', b'quiet', b'true')
82 89
83 baseui.setconfig('ui', 'paginate', 'never')
90 baseui.setconfig(b'ui', b'paginate', b'never')
84 91 # for better Error reporting of Mercurial
85 baseui.setconfig('ui', 'message-output', 'stderr')
92 baseui.setconfig(b'ui', b'message-output', b'stderr')
86 93
87 94 # force mercurial to only use 1 thread, otherwise it may try to set a
88 95 # signal in a non-main thread, thus generating a ValueError.
89 baseui.setconfig('worker', 'numcpus', 1)
96 baseui.setconfig(b'worker', b'numcpus', 1)
90 97
91 98 # If there is no config for the largefiles extension, we explicitly disable
92 99 # it here. This overrides settings from repositories hgrc file. Recent
93 100 # mercurial versions enable largefiles in hgrc on clone from largefile
94 101 # repo.
95 if not baseui.hasconfig('extensions', 'largefiles'):
102 if not baseui.hasconfig(b'extensions', b'largefiles'):
96 103 log.debug('Explicitly disable largefiles extension for repo.')
97 baseui.setconfig('extensions', 'largefiles', '!')
104 baseui.setconfig(b'extensions', b'largefiles', b'!')
98 105
99 106 return baseui
100 107
101 108
102 109 def reraise_safe_exceptions(func):
103 110 """Decorator for converting mercurial exceptions to something neutral."""
104 111
105 112 def wrapper(*args, **kwargs):
106 113 try:
107 114 return func(*args, **kwargs)
108 115 except (Abort, InterventionRequired) as e:
109 raise_from_original(exceptions.AbortException(e))
116 raise_from_original(exceptions.AbortException(e), e)
110 117 except RepoLookupError as e:
111 raise_from_original(exceptions.LookupException(e))
118 raise_from_original(exceptions.LookupException(e), e)
112 119 except RequirementError as e:
113 raise_from_original(exceptions.RequirementException(e))
120 raise_from_original(exceptions.RequirementException(e), e)
114 121 except RepoError as e:
115 raise_from_original(exceptions.VcsException(e))
122 raise_from_original(exceptions.VcsException(e), e)
116 123 except LookupError as e:
117 raise_from_original(exceptions.LookupException(e))
124 raise_from_original(exceptions.LookupException(e), e)
118 125 except Exception as e:
119 126 if not hasattr(e, '_vcs_kind'):
120 127 log.exception("Unhandled exception in hg remote call")
121 raise_from_original(exceptions.UnhandledException(e))
128 raise_from_original(exceptions.UnhandledException(e), e)
122 129
123 130 raise
124 131 return wrapper
125 132
126 133
127 134 class MercurialFactory(RepoFactory):
128 135 repo_type = 'hg'
129 136
130 137 def _create_config(self, config, hooks=True):
131 138 if not hooks:
132 139 hooks_to_clean = frozenset((
133 140 'changegroup.repo_size', 'preoutgoing.pre_pull',
134 141 'outgoing.pull_logger', 'prechangegroup.pre_push'))
135 142 new_config = []
136 143 for section, option, value in config:
137 144 if section == 'hooks' and option in hooks_to_clean:
138 145 continue
139 146 new_config.append((section, option, value))
140 147 config = new_config
141 148
142 149 baseui = make_ui_from_config(config)
143 150 return baseui
144 151
145 152 def _create_repo(self, wire, create):
146 153 baseui = self._create_config(wire["config"])
147 return instance(baseui, wire["path"], create)
154 return instance(baseui, ascii_bytes(wire["path"]), create)
148 155
149 156 def repo(self, wire, create=False):
150 157 """
151 158 Get a repository instance for the given path.
152 159 """
153 160 return self._create_repo(wire, create)
154 161
155 162
156 163 def patch_ui_message_output(baseui):
157 baseui.setconfig('ui', 'quiet', 'false')
164 baseui.setconfig(b'ui', b'quiet', b'false')
158 165 output = io.BytesIO()
159 166
160 167 def write(data, **unused_kwargs):
161 168 output.write(data)
162 169
163 170 baseui.status = write
164 171 baseui.write = write
165 172 baseui.warn = write
166 173 baseui.debug = write
167 174
168 175 return baseui, output
169 176
170 177
171 178 class HgRemote(RemoteBase):
172 179
173 180 def __init__(self, factory):
174 181 self._factory = factory
175 182 self._bulk_methods = {
176 183 "affected_files": self.ctx_files,
177 184 "author": self.ctx_user,
178 185 "branch": self.ctx_branch,
179 186 "children": self.ctx_children,
180 187 "date": self.ctx_date,
181 188 "message": self.ctx_description,
182 189 "parents": self.ctx_parents,
183 190 "status": self.ctx_status,
184 191 "obsolete": self.ctx_obsolete,
185 192 "phase": self.ctx_phase,
186 193 "hidden": self.ctx_hidden,
187 194 "_file_paths": self.ctx_list,
188 195 }
189 196
190 197 def _get_ctx(self, repo, ref):
191 198 return get_ctx(repo, ref)
192 199
193 200 @reraise_safe_exceptions
194 201 def discover_hg_version(self):
195 202 from mercurial import util
196 203 return util.version()
197 204
198 205 @reraise_safe_exceptions
199 206 def is_empty(self, wire):
200 207 repo = self._factory.repo(wire)
201 208
202 209 try:
203 210 return len(repo) == 0
204 211 except Exception:
205 212 log.exception("failed to read object_store")
206 213 return False
207 214
208 215 @reraise_safe_exceptions
209 216 def bookmarks(self, wire):
210 217 cache_on, context_uid, repo_id = self._cache_on(wire)
211 218 region = self._region(wire)
212 219 @region.conditional_cache_on_arguments(condition=cache_on)
213 220 def _bookmarks(_context_uid, _repo_id):
214 221 repo = self._factory.repo(wire)
215 222 return dict(repo._bookmarks)
216 223
217 224 return _bookmarks(context_uid, repo_id)
218 225
219 226 @reraise_safe_exceptions
220 227 def branches(self, wire, normal, closed):
221 228 cache_on, context_uid, repo_id = self._cache_on(wire)
222 229 region = self._region(wire)
223 230 @region.conditional_cache_on_arguments(condition=cache_on)
224 231 def _branches(_context_uid, _repo_id, _normal, _closed):
225 232 repo = self._factory.repo(wire)
226 233 iter_branches = repo.branchmap().iterbranches()
227 234 bt = {}
228 235 for branch_name, _heads, tip, is_closed in iter_branches:
229 236 if normal and not is_closed:
230 237 bt[branch_name] = tip
231 238 if closed and is_closed:
232 239 bt[branch_name] = tip
233 240
234 241 return bt
235 242
236 243 return _branches(context_uid, repo_id, normal, closed)
237 244
238 245 @reraise_safe_exceptions
239 246 def bulk_request(self, wire, commit_id, pre_load):
240 247 cache_on, context_uid, repo_id = self._cache_on(wire)
241 248 region = self._region(wire)
242 249 @region.conditional_cache_on_arguments(condition=cache_on)
243 250 def _bulk_request(_repo_id, _commit_id, _pre_load):
244 251 result = {}
245 252 for attr in pre_load:
246 253 try:
247 254 method = self._bulk_methods[attr]
248 255 result[attr] = method(wire, commit_id)
249 256 except KeyError as e:
250 257 raise exceptions.VcsException(e)(
251 258 'Unknown bulk attribute: "%s"' % attr)
252 259 return result
253 260
254 261 return _bulk_request(repo_id, commit_id, sorted(pre_load))
255 262
256 263 @reraise_safe_exceptions
257 264 def ctx_branch(self, wire, commit_id):
258 265 cache_on, context_uid, repo_id = self._cache_on(wire)
259 266 region = self._region(wire)
260 267 @region.conditional_cache_on_arguments(condition=cache_on)
261 268 def _ctx_branch(_repo_id, _commit_id):
262 269 repo = self._factory.repo(wire)
263 270 ctx = self._get_ctx(repo, commit_id)
264 271 return ctx.branch()
265 272 return _ctx_branch(repo_id, commit_id)
266 273
267 274 @reraise_safe_exceptions
268 275 def ctx_date(self, wire, commit_id):
269 276 cache_on, context_uid, repo_id = self._cache_on(wire)
270 277 region = self._region(wire)
271 278 @region.conditional_cache_on_arguments(condition=cache_on)
272 279 def _ctx_date(_repo_id, _commit_id):
273 280 repo = self._factory.repo(wire)
274 281 ctx = self._get_ctx(repo, commit_id)
275 282 return ctx.date()
276 283 return _ctx_date(repo_id, commit_id)
277 284
278 285 @reraise_safe_exceptions
279 286 def ctx_description(self, wire, revision):
280 287 repo = self._factory.repo(wire)
281 288 ctx = self._get_ctx(repo, revision)
282 289 return ctx.description()
283 290
284 291 @reraise_safe_exceptions
285 292 def ctx_files(self, wire, commit_id):
286 293 cache_on, context_uid, repo_id = self._cache_on(wire)
287 294 region = self._region(wire)
288 295 @region.conditional_cache_on_arguments(condition=cache_on)
289 296 def _ctx_files(_repo_id, _commit_id):
290 297 repo = self._factory.repo(wire)
291 298 ctx = self._get_ctx(repo, commit_id)
292 299 return ctx.files()
293 300
294 301 return _ctx_files(repo_id, commit_id)
295 302
296 303 @reraise_safe_exceptions
297 304 def ctx_list(self, path, revision):
298 305 repo = self._factory.repo(path)
299 306 ctx = self._get_ctx(repo, revision)
300 307 return list(ctx)
301 308
302 309 @reraise_safe_exceptions
303 310 def ctx_parents(self, wire, commit_id):
304 311 cache_on, context_uid, repo_id = self._cache_on(wire)
305 312 region = self._region(wire)
306 313 @region.conditional_cache_on_arguments(condition=cache_on)
307 314 def _ctx_parents(_repo_id, _commit_id):
308 315 repo = self._factory.repo(wire)
309 316 ctx = self._get_ctx(repo, commit_id)
310 317 return [parent.hex() for parent in ctx.parents()
311 318 if not (parent.hidden() or parent.obsolete())]
312 319
313 320 return _ctx_parents(repo_id, commit_id)
314 321
315 322 @reraise_safe_exceptions
316 323 def ctx_children(self, wire, commit_id):
317 324 cache_on, context_uid, repo_id = self._cache_on(wire)
318 325 region = self._region(wire)
319 326 @region.conditional_cache_on_arguments(condition=cache_on)
320 327 def _ctx_children(_repo_id, _commit_id):
321 328 repo = self._factory.repo(wire)
322 329 ctx = self._get_ctx(repo, commit_id)
323 330 return [child.hex() for child in ctx.children()
324 331 if not (child.hidden() or child.obsolete())]
325 332
326 333 return _ctx_children(repo_id, commit_id)
327 334
328 335 @reraise_safe_exceptions
329 336 def ctx_phase(self, wire, commit_id):
330 337 cache_on, context_uid, repo_id = self._cache_on(wire)
331 338 region = self._region(wire)
332 339 @region.conditional_cache_on_arguments(condition=cache_on)
333 340 def _ctx_phase(_context_uid, _repo_id, _commit_id):
334 341 repo = self._factory.repo(wire)
335 342 ctx = self._get_ctx(repo, commit_id)
336 343 # public=0, draft=1, secret=3
337 344 return ctx.phase()
338 345 return _ctx_phase(context_uid, repo_id, commit_id)
339 346
340 347 @reraise_safe_exceptions
341 348 def ctx_obsolete(self, wire, commit_id):
342 349 cache_on, context_uid, repo_id = self._cache_on(wire)
343 350 region = self._region(wire)
344 351 @region.conditional_cache_on_arguments(condition=cache_on)
345 352 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
346 353 repo = self._factory.repo(wire)
347 354 ctx = self._get_ctx(repo, commit_id)
348 355 return ctx.obsolete()
349 356 return _ctx_obsolete(context_uid, repo_id, commit_id)
350 357
351 358 @reraise_safe_exceptions
352 359 def ctx_hidden(self, wire, commit_id):
353 360 cache_on, context_uid, repo_id = self._cache_on(wire)
354 361 region = self._region(wire)
355 362 @region.conditional_cache_on_arguments(condition=cache_on)
356 363 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
357 364 repo = self._factory.repo(wire)
358 365 ctx = self._get_ctx(repo, commit_id)
359 366 return ctx.hidden()
360 367 return _ctx_hidden(context_uid, repo_id, commit_id)
361 368
362 369 @reraise_safe_exceptions
363 370 def ctx_substate(self, wire, revision):
364 371 repo = self._factory.repo(wire)
365 372 ctx = self._get_ctx(repo, revision)
366 373 return ctx.substate
367 374
368 375 @reraise_safe_exceptions
369 376 def ctx_status(self, wire, revision):
370 377 repo = self._factory.repo(wire)
371 378 ctx = self._get_ctx(repo, revision)
372 379 status = repo[ctx.p1().node()].status(other=ctx.node())
373 380 # object of status (odd, custom named tuple in mercurial) is not
374 381 # correctly serializable, we make it a list, as the underling
375 382 # API expects this to be a list
376 383 return list(status)
377 384
378 385 @reraise_safe_exceptions
379 386 def ctx_user(self, wire, revision):
380 387 repo = self._factory.repo(wire)
381 388 ctx = self._get_ctx(repo, revision)
382 389 return ctx.user()
383 390
384 391 @reraise_safe_exceptions
385 392 def check_url(self, url, config):
386 393 _proto = None
387 394 if '+' in url[:url.find('://')]:
388 395 _proto = url[0:url.find('+')]
389 396 url = url[url.find('+') + 1:]
390 397 handlers = []
391 398 url_obj = url_parser(url)
392 399 test_uri, authinfo = url_obj.authinfo()
393 400 url_obj.passwd = '*****' if url_obj.passwd else url_obj.passwd
394 401 url_obj.query = obfuscate_qs(url_obj.query)
395 402
396 403 cleaned_uri = str(url_obj)
397 404 log.info("Checking URL for remote cloning/import: %s", cleaned_uri)
398 405
399 406 if authinfo:
400 407 # create a password manager
401 408 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
402 409 passmgr.add_password(*authinfo)
403 410
404 411 handlers.extend((httpbasicauthhandler(passmgr),
405 412 httpdigestauthhandler(passmgr)))
406 413
407 414 o = urllib.request.build_opener(*handlers)
408 415 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
409 416 ('Accept', 'application/mercurial-0.1')]
410 417
411 418 q = {"cmd": 'between'}
412 419 q.update({'pairs': "%s-%s" % ('0' * 40, '0' * 40)})
413 420 qs = '?%s' % urllib.parse.urlencode(q)
414 421 cu = "%s%s" % (test_uri, qs)
415 422 req = urllib.request.Request(cu, None, {})
416 423
417 424 try:
418 425 log.debug("Trying to open URL %s", cleaned_uri)
419 426 resp = o.open(req)
420 427 if resp.code != 200:
421 428 raise exceptions.URLError()('Return Code is not 200')
422 429 except Exception as e:
423 430 log.warning("URL cannot be opened: %s", cleaned_uri, exc_info=True)
424 431 # means it cannot be cloned
425 432 raise exceptions.URLError(e)("[%s] org_exc: %s" % (cleaned_uri, e))
426 433
427 434 # now check if it's a proper hg repo, but don't do it for svn
428 435 try:
429 436 if _proto == 'svn':
430 437 pass
431 438 else:
432 439 # check for pure hg repos
433 440 log.debug(
434 441 "Verifying if URL is a Mercurial repository: %s",
435 442 cleaned_uri)
436 443 ui = make_ui_from_config(config)
437 444 peer_checker = makepeer(ui, url)
438 445 peer_checker.lookup('tip')
439 446 except Exception as e:
440 447 log.warning("URL is not a valid Mercurial repository: %s",
441 448 cleaned_uri)
442 449 raise exceptions.URLError(e)(
443 450 "url [%s] does not look like an hg repo org_exc: %s"
444 451 % (cleaned_uri, e))
445 452
446 453 log.info("URL is a valid Mercurial repository: %s", cleaned_uri)
447 454 return True
448 455
449 456 @reraise_safe_exceptions
450 457 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
451 458 repo = self._factory.repo(wire)
452 459
453 460 if file_filter:
454 461 match_filter = match(file_filter[0], '', [file_filter[1]])
455 462 else:
456 463 match_filter = file_filter
457 464 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
458 465
459 466 try:
460 467 return "".join(patch.diff(
461 468 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts))
462 469 except RepoLookupError as e:
463 470 raise exceptions.LookupException(e)()
464 471
465 472 @reraise_safe_exceptions
466 473 def node_history(self, wire, revision, path, limit):
467 474 cache_on, context_uid, repo_id = self._cache_on(wire)
468 475 region = self._region(wire)
476
469 477 @region.conditional_cache_on_arguments(condition=cache_on)
470 478 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
471 479 repo = self._factory.repo(wire)
472 480
473 481 ctx = self._get_ctx(repo, revision)
474 482 fctx = ctx.filectx(path)
475 483
476 484 def history_iter():
477 485 limit_rev = fctx.rev()
478 486 for obj in reversed(list(fctx.filelog())):
479 487 obj = fctx.filectx(obj)
480 488 ctx = obj.changectx()
481 489 if ctx.hidden() or ctx.obsolete():
482 490 continue
483 491
484 492 if limit_rev >= obj.rev():
485 493 yield obj
486 494
487 495 history = []
488 496 for cnt, obj in enumerate(history_iter()):
489 497 if limit and cnt >= limit:
490 498 break
491 499 history.append(hex(obj.node()))
492 500
493 501 return [x for x in history]
494 502 return _node_history(context_uid, repo_id, revision, path, limit)
495 503
496 504 @reraise_safe_exceptions
497 505 def node_history_untill(self, wire, revision, path, limit):
498 506 cache_on, context_uid, repo_id = self._cache_on(wire)
499 507 region = self._region(wire)
508
500 509 @region.conditional_cache_on_arguments(condition=cache_on)
501 510 def _node_history_until(_context_uid, _repo_id):
502 511 repo = self._factory.repo(wire)
503 512 ctx = self._get_ctx(repo, revision)
504 513 fctx = ctx.filectx(path)
505 514
506 515 file_log = list(fctx.filelog())
507 516 if limit:
508 517 # Limit to the last n items
509 518 file_log = file_log[-limit:]
510 519
511 520 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
512 521 return _node_history_until(context_uid, repo_id, revision, path, limit)
513 522
514 523 @reraise_safe_exceptions
515 524 def fctx_annotate(self, wire, revision, path):
516 525 repo = self._factory.repo(wire)
517 526 ctx = self._get_ctx(repo, revision)
518 527 fctx = ctx.filectx(path)
519 528
520 529 result = []
521 530 for i, annotate_obj in enumerate(fctx.annotate(), 1):
522 531 ln_no = i
523 532 sha = hex(annotate_obj.fctx.node())
524 533 content = annotate_obj.text
525 534 result.append((ln_no, sha, content))
526 535 return result
527 536
528 537 @reraise_safe_exceptions
529 538 def fctx_node_data(self, wire, revision, path):
530 539 repo = self._factory.repo(wire)
531 540 ctx = self._get_ctx(repo, revision)
532 541 fctx = ctx.filectx(path)
533 return fctx.data()
542 return fctx.data_queue()
534 543
535 544 @reraise_safe_exceptions
536 545 def fctx_flags(self, wire, commit_id, path):
537 546 cache_on, context_uid, repo_id = self._cache_on(wire)
538 547 region = self._region(wire)
539 548 @region.conditional_cache_on_arguments(condition=cache_on)
540 549 def _fctx_flags(_repo_id, _commit_id, _path):
541 550 repo = self._factory.repo(wire)
542 551 ctx = self._get_ctx(repo, commit_id)
543 552 fctx = ctx.filectx(path)
544 553 return fctx.flags()
545 554
546 555 return _fctx_flags(repo_id, commit_id, path)
547 556
548 557 @reraise_safe_exceptions
549 558 def fctx_size(self, wire, commit_id, path):
550 559 cache_on, context_uid, repo_id = self._cache_on(wire)
551 560 region = self._region(wire)
552 561 @region.conditional_cache_on_arguments(condition=cache_on)
553 562 def _fctx_size(_repo_id, _revision, _path):
554 563 repo = self._factory.repo(wire)
555 564 ctx = self._get_ctx(repo, commit_id)
556 565 fctx = ctx.filectx(path)
557 566 return fctx.size()
558 567 return _fctx_size(repo_id, commit_id, path)
559 568
560 569 @reraise_safe_exceptions
561 570 def get_all_commit_ids(self, wire, name):
562 571 cache_on, context_uid, repo_id = self._cache_on(wire)
563 572 region = self._region(wire)
573
564 574 @region.conditional_cache_on_arguments(condition=cache_on)
565 575 def _get_all_commit_ids(_context_uid, _repo_id, _name):
566 576 repo = self._factory.repo(wire)
567 577 repo = repo.filtered(name)
568 revs = [hex(x[7]) for x in repo.changelog.index]
578 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
569 579 return revs
570 580 return _get_all_commit_ids(context_uid, repo_id, name)
571 581
572 582 @reraise_safe_exceptions
573 583 def get_config_value(self, wire, section, name, untrusted=False):
574 584 repo = self._factory.repo(wire)
575 585 return repo.ui.config(section, name, untrusted=untrusted)
576 586
577 587 @reraise_safe_exceptions
578 588 def is_large_file(self, wire, commit_id, path):
579 589 cache_on, context_uid, repo_id = self._cache_on(wire)
580 590 region = self._region(wire)
591
581 592 @region.conditional_cache_on_arguments(condition=cache_on)
582 593 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
583 594 return largefiles.lfutil.isstandin(path)
584 595
585 596 return _is_large_file(context_uid, repo_id, commit_id, path)
586 597
587 598 @reraise_safe_exceptions
588 599 def is_binary(self, wire, revision, path):
589 600 cache_on, context_uid, repo_id = self._cache_on(wire)
601 region = self._region(wire)
590 602
591 region = self._region(wire)
592 603 @region.conditional_cache_on_arguments(condition=cache_on)
593 604 def _is_binary(_repo_id, _sha, _path):
594 605 repo = self._factory.repo(wire)
595 606 ctx = self._get_ctx(repo, revision)
596 607 fctx = ctx.filectx(path)
597 608 return fctx.isbinary()
598 609
599 610 return _is_binary(repo_id, revision, path)
600 611
601 612 @reraise_safe_exceptions
602 613 def in_largefiles_store(self, wire, sha):
603 614 repo = self._factory.repo(wire)
604 615 return largefiles.lfutil.instore(repo, sha)
605 616
606 617 @reraise_safe_exceptions
607 618 def in_user_cache(self, wire, sha):
608 619 repo = self._factory.repo(wire)
609 620 return largefiles.lfutil.inusercache(repo.ui, sha)
610 621
611 622 @reraise_safe_exceptions
612 623 def store_path(self, wire, sha):
613 624 repo = self._factory.repo(wire)
614 625 return largefiles.lfutil.storepath(repo, sha)
615 626
616 627 @reraise_safe_exceptions
617 628 def link(self, wire, sha, path):
618 629 repo = self._factory.repo(wire)
619 630 largefiles.lfutil.link(
620 631 largefiles.lfutil.usercachepath(repo.ui, sha), path)
621 632
622 633 @reraise_safe_exceptions
623 634 def localrepository(self, wire, create=False):
624 635 self._factory.repo(wire, create=create)
625 636
626 637 @reraise_safe_exceptions
627 638 def lookup(self, wire, revision, both):
628 639 cache_on, context_uid, repo_id = self._cache_on(wire)
629 640
630 641 region = self._region(wire)
631 642 @region.conditional_cache_on_arguments(condition=cache_on)
632 643 def _lookup(_context_uid, _repo_id, _revision, _both):
633 644
634 645 repo = self._factory.repo(wire)
635 646 rev = _revision
636 647 if isinstance(rev, int):
637 648 # NOTE(marcink):
638 649 # since Mercurial doesn't support negative indexes properly
639 650 # we need to shift accordingly by one to get proper index, e.g
640 651 # repo[-1] => repo[-2]
641 652 # repo[0] => repo[-1]
642 653 if rev <= 0:
643 654 rev = rev + -1
644 655 try:
645 656 ctx = self._get_ctx(repo, rev)
646 657 except (TypeError, RepoLookupError) as e:
647 658 e._org_exc_tb = traceback.format_exc()
648 659 raise exceptions.LookupException(e)(rev)
649 660 except LookupError as e:
650 661 e._org_exc_tb = traceback.format_exc()
651 662 raise exceptions.LookupException(e)(e.name)
652 663
653 664 if not both:
654 665 return ctx.hex()
655 666
656 667 ctx = repo[ctx.hex()]
657 668 return ctx.hex(), ctx.rev()
658 669
659 670 return _lookup(context_uid, repo_id, revision, both)
660 671
661 672 @reraise_safe_exceptions
662 673 def sync_push(self, wire, url):
663 674 if not self.check_url(url, wire['config']):
664 675 return
665 676
666 677 repo = self._factory.repo(wire)
667 678
668 679 # Disable any prompts for this repo
669 repo.ui.setconfig('ui', 'interactive', 'off', '-y')
680 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
670 681
671 682 bookmarks = list(dict(repo._bookmarks).keys())
672 683 remote = peer(repo, {}, url)
673 684 # Disable any prompts for this remote
674 remote.ui.setconfig('ui', 'interactive', 'off', '-y')
685 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
675 686
676 687 return exchange.push(
677 688 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
678 689
679 690 @reraise_safe_exceptions
680 691 def revision(self, wire, rev):
681 692 repo = self._factory.repo(wire)
682 693 ctx = self._get_ctx(repo, rev)
683 694 return ctx.rev()
684 695
685 696 @reraise_safe_exceptions
686 697 def rev_range(self, wire, commit_filter):
687 698 cache_on, context_uid, repo_id = self._cache_on(wire)
688 699
689 700 region = self._region(wire)
690 701 @region.conditional_cache_on_arguments(condition=cache_on)
691 702 def _rev_range(_context_uid, _repo_id, _filter):
692 703 repo = self._factory.repo(wire)
693 704 revisions = [rev for rev in revrange(repo, commit_filter)]
694 705 return revisions
695 706
696 707 return _rev_range(context_uid, repo_id, sorted(commit_filter))
697 708
698 709 @reraise_safe_exceptions
699 710 def rev_range_hash(self, wire, node):
700 711 repo = self._factory.repo(wire)
701 712
702 713 def get_revs(repo, rev_opt):
703 714 if rev_opt:
704 715 revs = revrange(repo, rev_opt)
705 716 if len(revs) == 0:
706 717 return (nullrev, nullrev)
707 718 return max(revs), min(revs)
708 719 else:
709 720 return len(repo) - 1, 0
710 721
711 722 stop, start = get_revs(repo, [node + ':'])
712 723 revs = [hex(repo[r].node()) for r in range(start, stop + 1)]
713 724 return revs
714 725
715 726 @reraise_safe_exceptions
716 727 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
717 728 other_path = kwargs.pop('other_path', None)
718 729
719 730 # case when we want to compare two independent repositories
720 731 if other_path and other_path != wire["path"]:
721 732 baseui = self._factory._create_config(wire["config"])
722 733 repo = unionrepo.makeunionrepository(baseui, other_path, wire["path"])
723 734 else:
724 735 repo = self._factory.repo(wire)
725 736 return list(repo.revs(rev_spec, *args))
726 737
727 738 @reraise_safe_exceptions
728 739 def verify(self, wire,):
729 740 repo = self._factory.repo(wire)
730 741 baseui = self._factory._create_config(wire['config'])
731 742
732 743 baseui, output = patch_ui_message_output(baseui)
733 744
734 745 repo.ui = baseui
735 746 verify.verify(repo)
736 747 return output.getvalue()
737 748
738 749 @reraise_safe_exceptions
739 750 def hg_update_cache(self, wire,):
740 751 repo = self._factory.repo(wire)
741 752 baseui = self._factory._create_config(wire['config'])
742 753 baseui, output = patch_ui_message_output(baseui)
743 754
744 755 repo.ui = baseui
745 756 with repo.wlock(), repo.lock():
746 757 repo.updatecaches(full=True)
747 758
748 759 return output.getvalue()
749 760
750 761 @reraise_safe_exceptions
751 762 def hg_rebuild_fn_cache(self, wire,):
752 763 repo = self._factory.repo(wire)
753 764 baseui = self._factory._create_config(wire['config'])
754 765 baseui, output = patch_ui_message_output(baseui)
755 766
756 767 repo.ui = baseui
757 768
758 769 repair.rebuildfncache(baseui, repo)
759 770
760 771 return output.getvalue()
761 772
762 773 @reraise_safe_exceptions
763 774 def tags(self, wire):
764 775 cache_on, context_uid, repo_id = self._cache_on(wire)
765 776 region = self._region(wire)
766 777 @region.conditional_cache_on_arguments(condition=cache_on)
767 778 def _tags(_context_uid, _repo_id):
768 779 repo = self._factory.repo(wire)
769 780 return repo.tags()
770 781
771 782 return _tags(context_uid, repo_id)
772 783
773 784 @reraise_safe_exceptions
774 785 def update(self, wire, node=None, clean=False):
775 786 repo = self._factory.repo(wire)
776 787 baseui = self._factory._create_config(wire['config'])
777 788 commands.update(baseui, repo, node=node, clean=clean)
778 789
779 790 @reraise_safe_exceptions
780 791 def identify(self, wire):
781 792 repo = self._factory.repo(wire)
782 793 baseui = self._factory._create_config(wire['config'])
783 794 output = io.BytesIO()
784 795 baseui.write = output.write
785 796 # This is required to get a full node id
786 797 baseui.debugflag = True
787 798 commands.identify(baseui, repo, id=True)
788 799
789 800 return output.getvalue()
790 801
791 802 @reraise_safe_exceptions
792 803 def heads(self, wire, branch=None):
793 804 repo = self._factory.repo(wire)
794 805 baseui = self._factory._create_config(wire['config'])
795 806 output = io.BytesIO()
796 807
797 808 def write(data, **unused_kwargs):
798 809 output.write(data)
799 810
800 811 baseui.write = write
801 812 if branch:
802 813 args = [branch]
803 814 else:
804 815 args = []
805 816 commands.heads(baseui, repo, template='{node} ', *args)
806 817
807 818 return output.getvalue()
808 819
809 820 @reraise_safe_exceptions
810 821 def ancestor(self, wire, revision1, revision2):
811 822 repo = self._factory.repo(wire)
812 823 changelog = repo.changelog
813 824 lookup = repo.lookup
814 825 a = changelog.ancestor(lookup(revision1), lookup(revision2))
815 826 return hex(a)
816 827
817 828 @reraise_safe_exceptions
818 829 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
819 830 baseui = self._factory._create_config(wire["config"], hooks=hooks)
820 831 clone(baseui, source, dest, noupdate=not update_after_clone)
821 832
822 833 @reraise_safe_exceptions
823 834 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
824 835
825 836 repo = self._factory.repo(wire)
826 837 baseui = self._factory._create_config(wire['config'])
827 838 publishing = baseui.configbool('phases', 'publish')
828 839 if publishing:
829 840 new_commit = 'public'
830 841 else:
831 842 new_commit = 'draft'
832 843
833 844 def _filectxfn(_repo, ctx, path):
834 845 """
835 846 Marks given path as added/changed/removed in a given _repo. This is
836 847 for internal mercurial commit function.
837 848 """
838 849
839 850 # check if this path is removed
840 851 if path in removed:
841 852 # returning None is a way to mark node for removal
842 853 return None
843 854
844 855 # check if this path is added
845 856 for node in updated:
846 857 if node['path'] == path:
847 858 return memfilectx(
848 859 _repo,
849 860 changectx=ctx,
850 861 path=node['path'],
851 862 data=node['content'],
852 863 islink=False,
853 864 isexec=bool(node['mode'] & stat.S_IXUSR),
854 865 copysource=False)
855 866
856 867 raise exceptions.AbortException()(
857 868 "Given path haven't been marked as added, "
858 869 "changed or removed (%s)" % path)
859 870
860 871 with repo.ui.configoverride({('phases', 'new-commit'): new_commit}):
861 872
862 873 commit_ctx = memctx(
863 874 repo=repo,
864 875 parents=parents,
865 876 text=message,
866 877 files=files,
867 878 filectxfn=_filectxfn,
868 879 user=user,
869 880 date=(commit_time, commit_timezone),
870 881 extra=extra)
871 882
872 883 n = repo.commitctx(commit_ctx)
873 884 new_id = hex(n)
874 885
875 886 return new_id
876 887
877 888 @reraise_safe_exceptions
878 889 def pull(self, wire, url, commit_ids=None):
879 890 repo = self._factory.repo(wire)
880 891 # Disable any prompts for this repo
881 repo.ui.setconfig('ui', 'interactive', 'off', '-y')
892 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
882 893
883 894 remote = peer(repo, {}, url)
884 895 # Disable any prompts for this remote
885 remote.ui.setconfig('ui', 'interactive', 'off', '-y')
896 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
886 897
887 898 if commit_ids:
888 899 commit_ids = [bin(commit_id) for commit_id in commit_ids]
889 900
890 901 return exchange.pull(
891 902 repo, remote, heads=commit_ids, force=None).cgresult
892 903
893 904 @reraise_safe_exceptions
894 905 def pull_cmd(self, wire, source, bookmark=None, branch=None, revision=None, hooks=True):
895 906 repo = self._factory.repo(wire)
896 907 baseui = self._factory._create_config(wire['config'], hooks=hooks)
897 908
898 909 # Mercurial internally has a lot of logic that checks ONLY if
899 910 # option is defined, we just pass those if they are defined then
900 911 opts = {}
901 912 if bookmark:
902 913 opts['bookmark'] = bookmark
903 914 if branch:
904 915 opts['branch'] = branch
905 916 if revision:
906 917 opts['rev'] = revision
907 918
908 919 commands.pull(baseui, repo, source, **opts)
909 920
910 921 @reraise_safe_exceptions
911 922 def push(self, wire, revisions, dest_path, hooks=True, push_branches=False):
912 923 repo = self._factory.repo(wire)
913 924 baseui = self._factory._create_config(wire['config'], hooks=hooks)
914 925 commands.push(baseui, repo, dest=dest_path, rev=revisions,
915 926 new_branch=push_branches)
916 927
917 928 @reraise_safe_exceptions
918 929 def strip(self, wire, revision, update, backup):
919 930 repo = self._factory.repo(wire)
920 931 ctx = self._get_ctx(repo, revision)
921 932 hgext_strip(
922 933 repo.baseui, repo, ctx.node(), update=update, backup=backup)
923 934
924 935 @reraise_safe_exceptions
925 936 def get_unresolved_files(self, wire):
926 937 repo = self._factory.repo(wire)
927 938
928 939 log.debug('Calculating unresolved files for repo: %s', repo)
929 940 output = io.BytesIO()
930 941
931 942 def write(data, **unused_kwargs):
932 943 output.write(data)
933 944
934 945 baseui = self._factory._create_config(wire['config'])
935 946 baseui.write = write
936 947
937 948 commands.resolve(baseui, repo, list=True)
938 949 unresolved = output.getvalue().splitlines(0)
939 950 return unresolved
940 951
941 952 @reraise_safe_exceptions
942 953 def merge(self, wire, revision):
943 954 repo = self._factory.repo(wire)
944 955 baseui = self._factory._create_config(wire['config'])
945 repo.ui.setconfig('ui', 'merge', 'internal:dump')
956 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
946 957
947 958 # In case of sub repositories are used mercurial prompts the user in
948 959 # case of merge conflicts or different sub repository sources. By
949 960 # setting the interactive flag to `False` mercurial doesn't prompt the
950 961 # used but instead uses a default value.
951 repo.ui.setconfig('ui', 'interactive', False)
962 repo.ui.setconfig(b'ui', b'interactive', False)
952 963 commands.merge(baseui, repo, rev=revision)
953 964
954 965 @reraise_safe_exceptions
955 966 def merge_state(self, wire):
956 967 repo = self._factory.repo(wire)
957 repo.ui.setconfig('ui', 'merge', 'internal:dump')
968 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
958 969
959 970 # In case of sub repositories are used mercurial prompts the user in
960 971 # case of merge conflicts or different sub repository sources. By
961 972 # setting the interactive flag to `False` mercurial doesn't prompt the
962 973 # used but instead uses a default value.
963 repo.ui.setconfig('ui', 'interactive', False)
974 repo.ui.setconfig(b'ui', b'interactive', False)
964 975 ms = hg_merge.mergestate(repo)
965 976 return [x for x in ms.unresolved()]
966 977
967 978 @reraise_safe_exceptions
968 979 def commit(self, wire, message, username, close_branch=False):
969 980 repo = self._factory.repo(wire)
970 981 baseui = self._factory._create_config(wire['config'])
971 repo.ui.setconfig('ui', 'username', username)
982 repo.ui.setconfig(b'ui', b'username', username)
972 983 commands.commit(baseui, repo, message=message, close_branch=close_branch)
973 984
974 985 @reraise_safe_exceptions
975 986 def rebase(self, wire, source=None, dest=None, abort=False):
976 987 repo = self._factory.repo(wire)
977 988 baseui = self._factory._create_config(wire['config'])
978 repo.ui.setconfig('ui', 'merge', 'internal:dump')
989 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
979 990 # In case of sub repositories are used mercurial prompts the user in
980 991 # case of merge conflicts or different sub repository sources. By
981 992 # setting the interactive flag to `False` mercurial doesn't prompt the
982 993 # used but instead uses a default value.
983 repo.ui.setconfig('ui', 'interactive', False)
994 repo.ui.setconfig(b'ui', b'interactive', False)
984 995 rebase.rebase(baseui, repo, base=source, dest=dest, abort=abort, keep=not abort)
985 996
986 997 @reraise_safe_exceptions
987 998 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
988 999 repo = self._factory.repo(wire)
989 1000 ctx = self._get_ctx(repo, revision)
990 1001 node = ctx.node()
991 1002
992 1003 date = (tag_time, tag_timezone)
993 1004 try:
994 1005 hg_tag.tag(repo, name, node, message, local, user, date)
995 1006 except Abort as e:
996 1007 log.exception("Tag operation aborted")
997 1008 # Exception can contain unicode which we convert
998 1009 raise exceptions.AbortException(e)(repr(e))
999 1010
1000 1011 @reraise_safe_exceptions
1001 1012 def bookmark(self, wire, bookmark, revision=None):
1002 1013 repo = self._factory.repo(wire)
1003 1014 baseui = self._factory._create_config(wire['config'])
1004 1015 commands.bookmark(baseui, repo, bookmark, rev=revision, force=True)
1005 1016
1006 1017 @reraise_safe_exceptions
1007 1018 def install_hooks(self, wire, force=False):
1008 1019 # we don't need any special hooks for Mercurial
1009 1020 pass
1010 1021
1011 1022 @reraise_safe_exceptions
1012 1023 def get_hooks_info(self, wire):
1013 1024 return {
1014 1025 'pre_version': vcsserver.__version__,
1015 1026 'post_version': vcsserver.__version__,
1016 1027 }
1017 1028
1018 1029 @reraise_safe_exceptions
1019 1030 def set_head_ref(self, wire, head_name):
1020 1031 pass
1021 1032
1022 1033 @reraise_safe_exceptions
1023 1034 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
1024 1035 archive_dir_name, commit_id):
1025 1036
1026 1037 def file_walker(_commit_id, path):
1027 1038 repo = self._factory.repo(wire)
1028 1039 ctx = repo[_commit_id]
1029 1040 is_root = path in ['', '/']
1030 1041 if is_root:
1031 1042 matcher = alwaysmatcher(badfn=None)
1032 1043 else:
1033 1044 matcher = patternmatcher('', [(b'glob', path+'/**', b'')], badfn=None)
1034 1045 file_iter = ctx.manifest().walk(matcher)
1035 1046
1036 1047 for fn in file_iter:
1037 1048 file_path = fn
1038 1049 flags = ctx.flags(fn)
1039 1050 mode = b'x' in flags and 0o755 or 0o644
1040 1051 is_link = b'l' in flags
1041 1052
1042 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1053 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data_queue)
1043 1054
1044 1055 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
1045 1056 archive_dir_name, commit_id)
1046 1057
@@ -1,863 +1,864 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18
19 19 import os
20 20 import subprocess
21 21 from urllib.error import URLError
22 22 import urllib.parse
23 23 import logging
24 24 import posixpath as vcspath
25 25 import io
26 26 import urllib.request
27 27 import urllib.parse
28 28 import urllib.error
29 29 import traceback
30 30
31 31 import svn.client
32 32 import svn.core
33 33 import svn.delta
34 34 import svn.diff
35 35 import svn.fs
36 36 import svn.repos
37 37
38 38 from vcsserver import svn_diff, exceptions, subprocessio, settings
39 39 from vcsserver.base import RepoFactory, raise_from_original, ArchiveNode, archive_repo
40 40 from vcsserver.exceptions import NoContentException
41 41 from vcsserver.utils import safe_str
42 42 from vcsserver.vcs_base import RemoteBase
43 43 from vcsserver.lib.svnremoterepo import svnremoterepo
44 44 log = logging.getLogger(__name__)
45 45
46 46
47 47 svn_compatible_versions_map = {
48 48 'pre-1.4-compatible': '1.3',
49 49 'pre-1.5-compatible': '1.4',
50 50 'pre-1.6-compatible': '1.5',
51 51 'pre-1.8-compatible': '1.7',
52 52 'pre-1.9-compatible': '1.8',
53 53 }
54 54
55 55 current_compatible_version = '1.14'
56 56
57 57
58 58 def reraise_safe_exceptions(func):
59 59 """Decorator for converting svn exceptions to something neutral."""
60 60 def wrapper(*args, **kwargs):
61 61 try:
62 62 return func(*args, **kwargs)
63 63 except Exception as e:
64 64 if not hasattr(e, '_vcs_kind'):
65 65 log.exception("Unhandled exception in svn remote call")
66 66 raise_from_original(exceptions.UnhandledException(e))
67 67 raise
68 68 return wrapper
69 69
70 70
71 71 class SubversionFactory(RepoFactory):
72 72 repo_type = 'svn'
73 73
74 74 def _create_repo(self, wire, create, compatible_version):
75 75 path = svn.core.svn_path_canonicalize(wire['path'])
76 76 if create:
77 77 fs_config = {'compatible-version': current_compatible_version}
78 78 if compatible_version:
79 79
80 80 compatible_version_string = \
81 81 svn_compatible_versions_map.get(compatible_version) \
82 82 or compatible_version
83 83 fs_config['compatible-version'] = compatible_version_string
84 84
85 85 log.debug('Create SVN repo with config "%s"', fs_config)
86 86 repo = svn.repos.create(path, "", "", None, fs_config)
87 87 else:
88 88 repo = svn.repos.open(path)
89 89
90 90 log.debug('Got SVN object: %s', repo)
91 91 return repo
92 92
93 93 def repo(self, wire, create=False, compatible_version=None):
94 94 """
95 95 Get a repository instance for the given path.
96 96 """
97 97 return self._create_repo(wire, create, compatible_version)
98 98
99 99
100 100 NODE_TYPE_MAPPING = {
101 101 svn.core.svn_node_file: 'file',
102 102 svn.core.svn_node_dir: 'dir',
103 103 }
104 104
105 105
106 106 class SvnRemote(RemoteBase):
107 107
108 108 def __init__(self, factory, hg_factory=None):
109 109 self._factory = factory
110 110
111 111 @reraise_safe_exceptions
112 112 def discover_svn_version(self):
113 113 try:
114 114 import svn.core
115 115 svn_ver = svn.core.SVN_VERSION
116 116 except ImportError:
117 117 svn_ver = None
118 118 return svn_ver
119 119
120 120 @reraise_safe_exceptions
121 121 def is_empty(self, wire):
122 122
123 123 try:
124 124 return self.lookup(wire, -1) == 0
125 125 except Exception:
126 126 log.exception("failed to read object_store")
127 127 return False
128 128
129 129 def check_url(self, url):
130 130
131 131 # uuid function get's only valid UUID from proper repo, else
132 132 # throws exception
133 133 username, password, src_url = self.get_url_and_credentials(url)
134 134 try:
135 135 svnremoterepo(username, password, src_url).svn().uuid
136 136 except Exception:
137 137 tb = traceback.format_exc()
138 138 log.debug("Invalid Subversion url: `%s`, tb: %s", url, tb)
139 139 raise URLError(
140 140 '"%s" is not a valid Subversion source url.' % (url, ))
141 141 return True
142 142
143 143 def is_path_valid_repository(self, wire, path):
144 144
145 145 # NOTE(marcink): short circuit the check for SVN repo
146 146 # the repos.open might be expensive to check, but we have one cheap
147 147 # pre condition that we can use, to check for 'format' file
148 148
149 149 if not os.path.isfile(os.path.join(path, 'format')):
150 150 return False
151 151
152 152 try:
153 153 svn.repos.open(path)
154 154 except svn.core.SubversionException:
155 155 tb = traceback.format_exc()
156 156 log.debug("Invalid Subversion path `%s`, tb: %s", path, tb)
157 157 return False
158 158 return True
159 159
160 160 @reraise_safe_exceptions
161 161 def verify(self, wire,):
162 162 repo_path = wire['path']
163 163 if not self.is_path_valid_repository(wire, repo_path):
164 164 raise Exception(
165 165 "Path %s is not a valid Subversion repository." % repo_path)
166 166
167 167 cmd = ['svnadmin', 'info', repo_path]
168 168 stdout, stderr = subprocessio.run_command(cmd)
169 169 return stdout
170 170
171 171 def lookup(self, wire, revision):
172 172 if revision not in [-1, None, 'HEAD']:
173 173 raise NotImplementedError
174 174 repo = self._factory.repo(wire)
175 175 fs_ptr = svn.repos.fs(repo)
176 176 head = svn.fs.youngest_rev(fs_ptr)
177 177 return head
178 178
179 179 def lookup_interval(self, wire, start_ts, end_ts):
180 180 repo = self._factory.repo(wire)
181 181 fsobj = svn.repos.fs(repo)
182 182 start_rev = None
183 183 end_rev = None
184 184 if start_ts:
185 185 start_ts_svn = apr_time_t(start_ts)
186 186 start_rev = svn.repos.dated_revision(repo, start_ts_svn) + 1
187 187 else:
188 188 start_rev = 1
189 189 if end_ts:
190 190 end_ts_svn = apr_time_t(end_ts)
191 191 end_rev = svn.repos.dated_revision(repo, end_ts_svn)
192 192 else:
193 193 end_rev = svn.fs.youngest_rev(fsobj)
194 194 return start_rev, end_rev
195 195
196 196 def revision_properties(self, wire, revision):
197 197
198 198 cache_on, context_uid, repo_id = self._cache_on(wire)
199 199 region = self._region(wire)
200 200 @region.conditional_cache_on_arguments(condition=cache_on)
201 201 def _revision_properties(_repo_id, _revision):
202 202 repo = self._factory.repo(wire)
203 203 fs_ptr = svn.repos.fs(repo)
204 204 return svn.fs.revision_proplist(fs_ptr, revision)
205 205 return _revision_properties(repo_id, revision)
206 206
207 207 def revision_changes(self, wire, revision):
208 208
209 209 repo = self._factory.repo(wire)
210 210 fsobj = svn.repos.fs(repo)
211 211 rev_root = svn.fs.revision_root(fsobj, revision)
212 212
213 213 editor = svn.repos.ChangeCollector(fsobj, rev_root)
214 214 editor_ptr, editor_baton = svn.delta.make_editor(editor)
215 215 base_dir = ""
216 216 send_deltas = False
217 217 svn.repos.replay2(
218 218 rev_root, base_dir, svn.core.SVN_INVALID_REVNUM, send_deltas,
219 219 editor_ptr, editor_baton, None)
220 220
221 221 added = []
222 222 changed = []
223 223 removed = []
224 224
225 225 # TODO: CHANGE_ACTION_REPLACE: Figure out where it belongs
226 226 for path, change in editor.changes.items():
227 227 # TODO: Decide what to do with directory nodes. Subversion can add
228 228 # empty directories.
229 229
230 230 if change.item_kind == svn.core.svn_node_dir:
231 231 continue
232 232 if change.action in [svn.repos.CHANGE_ACTION_ADD]:
233 233 added.append(path)
234 234 elif change.action in [svn.repos.CHANGE_ACTION_MODIFY,
235 235 svn.repos.CHANGE_ACTION_REPLACE]:
236 236 changed.append(path)
237 237 elif change.action in [svn.repos.CHANGE_ACTION_DELETE]:
238 238 removed.append(path)
239 239 else:
240 240 raise NotImplementedError(
241 241 "Action %s not supported on path %s" % (
242 242 change.action, path))
243 243
244 244 changes = {
245 245 'added': added,
246 246 'changed': changed,
247 247 'removed': removed,
248 248 }
249 249 return changes
250 250
251 251 @reraise_safe_exceptions
252 252 def node_history(self, wire, path, revision, limit):
253 253 cache_on, context_uid, repo_id = self._cache_on(wire)
254 254 region = self._region(wire)
255 255 @region.conditional_cache_on_arguments(condition=cache_on)
256 256 def _assert_correct_path(_context_uid, _repo_id, _path, _revision, _limit):
257 257 cross_copies = False
258 258 repo = self._factory.repo(wire)
259 259 fsobj = svn.repos.fs(repo)
260 260 rev_root = svn.fs.revision_root(fsobj, revision)
261 261
262 262 history_revisions = []
263 263 history = svn.fs.node_history(rev_root, path)
264 264 history = svn.fs.history_prev(history, cross_copies)
265 265 while history:
266 266 __, node_revision = svn.fs.history_location(history)
267 267 history_revisions.append(node_revision)
268 268 if limit and len(history_revisions) >= limit:
269 269 break
270 270 history = svn.fs.history_prev(history, cross_copies)
271 271 return history_revisions
272 272 return _assert_correct_path(context_uid, repo_id, path, revision, limit)
273 273
274 274 def node_properties(self, wire, path, revision):
275 275 cache_on, context_uid, repo_id = self._cache_on(wire)
276 276 region = self._region(wire)
277 277 @region.conditional_cache_on_arguments(condition=cache_on)
278 278 def _node_properties(_repo_id, _path, _revision):
279 279 repo = self._factory.repo(wire)
280 280 fsobj = svn.repos.fs(repo)
281 281 rev_root = svn.fs.revision_root(fsobj, revision)
282 282 return svn.fs.node_proplist(rev_root, path)
283 283 return _node_properties(repo_id, path, revision)
284 284
285 285 def file_annotate(self, wire, path, revision):
286 286 abs_path = 'file://' + urllib.pathname2url(
287 287 vcspath.join(wire['path'], path))
288 288 file_uri = svn.core.svn_path_canonicalize(abs_path)
289 289
290 290 start_rev = svn_opt_revision_value_t(0)
291 291 peg_rev = svn_opt_revision_value_t(revision)
292 292 end_rev = peg_rev
293 293
294 294 annotations = []
295 295
296 296 def receiver(line_no, revision, author, date, line, pool):
297 297 annotations.append((line_no, revision, line))
298 298
299 299 # TODO: Cannot use blame5, missing typemap function in the swig code
300 300 try:
301 301 svn.client.blame2(
302 302 file_uri, peg_rev, start_rev, end_rev,
303 303 receiver, svn.client.create_context())
304 304 except svn.core.SubversionException as exc:
305 305 log.exception("Error during blame operation.")
306 306 raise Exception(
307 307 "Blame not supported or file does not exist at path %s. "
308 308 "Error %s." % (path, exc))
309 309
310 310 return annotations
311 311
312 312 def get_node_type(self, wire, path, revision=None):
313 313
314 314 cache_on, context_uid, repo_id = self._cache_on(wire)
315 315 region = self._region(wire)
316 316 @region.conditional_cache_on_arguments(condition=cache_on)
317 317 def _get_node_type(_repo_id, _path, _revision):
318 318 repo = self._factory.repo(wire)
319 319 fs_ptr = svn.repos.fs(repo)
320 320 if _revision is None:
321 321 _revision = svn.fs.youngest_rev(fs_ptr)
322 322 root = svn.fs.revision_root(fs_ptr, _revision)
323 323 node = svn.fs.check_path(root, path)
324 324 return NODE_TYPE_MAPPING.get(node, None)
325 325 return _get_node_type(repo_id, path, revision)
326 326
327 327 def get_nodes(self, wire, path, revision=None):
328 328
329 329 cache_on, context_uid, repo_id = self._cache_on(wire)
330 330 region = self._region(wire)
331 331 @region.conditional_cache_on_arguments(condition=cache_on)
332 332 def _get_nodes(_repo_id, _path, _revision):
333 333 repo = self._factory.repo(wire)
334 334 fsobj = svn.repos.fs(repo)
335 335 if _revision is None:
336 336 _revision = svn.fs.youngest_rev(fsobj)
337 337 root = svn.fs.revision_root(fsobj, _revision)
338 338 entries = svn.fs.dir_entries(root, path)
339 339 result = []
340 340 for entry_path, entry_info in entries.items():
341 341 result.append(
342 342 (entry_path, NODE_TYPE_MAPPING.get(entry_info.kind, None)))
343 343 return result
344 344 return _get_nodes(repo_id, path, revision)
345 345
346 346 def get_file_content(self, wire, path, rev=None):
347 347 repo = self._factory.repo(wire)
348 348 fsobj = svn.repos.fs(repo)
349 349 if rev is None:
350 350 rev = svn.fs.youngest_revision(fsobj)
351 351 root = svn.fs.revision_root(fsobj, rev)
352 352 content = svn.core.Stream(svn.fs.file_contents(root, path))
353 353 return content.read()
354 354
355 355 def get_file_size(self, wire, path, revision=None):
356 356
357 357 cache_on, context_uid, repo_id = self._cache_on(wire)
358 358 region = self._region(wire)
359 359
360 360 @region.conditional_cache_on_arguments(condition=cache_on)
361 361 def _get_file_size(_repo_id, _path, _revision):
362 362 repo = self._factory.repo(wire)
363 363 fsobj = svn.repos.fs(repo)
364 364 if _revision is None:
365 365 _revision = svn.fs.youngest_revision(fsobj)
366 366 root = svn.fs.revision_root(fsobj, _revision)
367 367 size = svn.fs.file_length(root, path)
368 368 return size
369 369 return _get_file_size(repo_id, path, revision)
370 370
371 371 def create_repository(self, wire, compatible_version=None):
372 372 log.info('Creating Subversion repository in path "%s"', wire['path'])
373 373 self._factory.repo(wire, create=True,
374 374 compatible_version=compatible_version)
375 375
376 376 def get_url_and_credentials(self, src_url):
377 377 obj = urllib.parse.urlparse(src_url)
378 378 username = obj.username or None
379 379 password = obj.password or None
380 380 return username, password, src_url
381 381
382 382 def import_remote_repository(self, wire, src_url):
383 383 repo_path = wire['path']
384 384 if not self.is_path_valid_repository(wire, repo_path):
385 385 raise Exception(
386 386 "Path %s is not a valid Subversion repository." % repo_path)
387 387
388 388 username, password, src_url = self.get_url_and_credentials(src_url)
389 389 rdump_cmd = ['svnrdump', 'dump', '--non-interactive',
390 390 '--trust-server-cert-failures=unknown-ca']
391 391 if username and password:
392 392 rdump_cmd += ['--username', username, '--password', password]
393 393 rdump_cmd += [src_url]
394 394
395 395 rdump = subprocess.Popen(
396 396 rdump_cmd,
397 397 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
398 398 load = subprocess.Popen(
399 399 ['svnadmin', 'load', repo_path], stdin=rdump.stdout)
400 400
401 401 # TODO: johbo: This can be a very long operation, might be better
402 402 # to track some kind of status and provide an api to check if the
403 403 # import is done.
404 404 rdump.wait()
405 405 load.wait()
406 406
407 407 log.debug('Return process ended with code: %s', rdump.returncode)
408 408 if rdump.returncode != 0:
409 409 errors = rdump.stderr.read()
410 log.error('svnrdump dump failed: statuscode %s: message: %s',
411 rdump.returncode, errors)
410 log.error('svnrdump dump failed: statuscode %s: message: %s', rdump.returncode, errors)
411
412 412 reason = 'UNKNOWN'
413 if 'svnrdump: E230001:' in errors:
413 if b'svnrdump: E230001:' in errors:
414 414 reason = 'INVALID_CERTIFICATE'
415 415
416 416 if reason == 'UNKNOWN':
417 reason = 'UNKNOWN:{}'.format(errors)
417 reason = 'UNKNOWN:{}'.format(safe_str(errors))
418
418 419 raise Exception(
419 420 'Failed to dump the remote repository from %s. Reason:%s' % (
420 421 src_url, reason))
421 422 if load.returncode != 0:
422 423 raise Exception(
423 424 'Failed to load the dump of remote repository from %s.' %
424 425 (src_url, ))
425 426
426 427 def commit(self, wire, message, author, timestamp, updated, removed):
427 428 assert isinstance(message, str)
428 429 assert isinstance(author, str)
429 430
430 431 repo = self._factory.repo(wire)
431 432 fsobj = svn.repos.fs(repo)
432 433
433 434 rev = svn.fs.youngest_rev(fsobj)
434 435 txn = svn.repos.fs_begin_txn_for_commit(repo, rev, author, message)
435 436 txn_root = svn.fs.txn_root(txn)
436 437
437 438 for node in updated:
438 439 TxnNodeProcessor(node, txn_root).update()
439 440 for node in removed:
440 441 TxnNodeProcessor(node, txn_root).remove()
441 442
442 443 commit_id = svn.repos.fs_commit_txn(repo, txn)
443 444
444 445 if timestamp:
445 446 apr_time = apr_time_t(timestamp)
446 447 ts_formatted = svn.core.svn_time_to_cstring(apr_time)
447 448 svn.fs.change_rev_prop(fsobj, commit_id, 'svn:date', ts_formatted)
448 449
449 450 log.debug('Committed revision "%s" to "%s".', commit_id, wire['path'])
450 451 return commit_id
451 452
452 453 def diff(self, wire, rev1, rev2, path1=None, path2=None,
453 454 ignore_whitespace=False, context=3):
454 455
455 456 wire.update(cache=False)
456 457 repo = self._factory.repo(wire)
457 458 diff_creator = SvnDiffer(
458 459 repo, rev1, path1, rev2, path2, ignore_whitespace, context)
459 460 try:
460 461 return diff_creator.generate_diff()
461 462 except svn.core.SubversionException as e:
462 463 log.exception(
463 464 "Error during diff operation operation. "
464 465 "Path might not exist %s, %s" % (path1, path2))
465 466 return ""
466 467
467 468 @reraise_safe_exceptions
468 469 def is_large_file(self, wire, path):
469 470 return False
470 471
471 472 @reraise_safe_exceptions
472 473 def is_binary(self, wire, rev, path):
473 474 cache_on, context_uid, repo_id = self._cache_on(wire)
474 475
475 476 region = self._region(wire)
476 477 @region.conditional_cache_on_arguments(condition=cache_on)
477 478 def _is_binary(_repo_id, _rev, _path):
478 479 raw_bytes = self.get_file_content(wire, path, rev)
479 480 return raw_bytes and '\0' in raw_bytes
480 481
481 482 return _is_binary(repo_id, rev, path)
482 483
483 484 @reraise_safe_exceptions
484 485 def run_svn_command(self, wire, cmd, **opts):
485 486 path = wire.get('path', None)
486 487
487 488 if path and os.path.isdir(path):
488 489 opts['cwd'] = path
489 490
490 491 safe_call = opts.pop('_safe', False)
491 492
492 493 svnenv = os.environ.copy()
493 494 svnenv.update(opts.pop('extra_env', {}))
494 495
495 496 _opts = {'env': svnenv, 'shell': False}
496 497
497 498 try:
498 499 _opts.update(opts)
499 p = subprocessio.SubprocessIOChunker(cmd, **_opts)
500 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
500 501
501 return ''.join(p), ''.join(p.error)
502 except (EnvironmentError, OSError) as err:
502 return b''.join(proc), b''.join(proc.stderr)
503 except OSError as err:
503 504 if safe_call:
504 505 return '', safe_str(err).strip()
505 506 else:
506 507 cmd = ' '.join(cmd) # human friendly CMD
507 508 tb_err = ("Couldn't run svn command (%s).\n"
508 509 "Original error was:%s\n"
509 510 "Call options:%s\n"
510 511 % (cmd, err, _opts))
511 512 log.exception(tb_err)
512 513 raise exceptions.VcsException()(tb_err)
513 514
514 515 @reraise_safe_exceptions
515 516 def install_hooks(self, wire, force=False):
516 517 from vcsserver.hook_utils import install_svn_hooks
517 518 repo_path = wire['path']
518 519 binary_dir = settings.BINARY_DIR
519 520 executable = None
520 521 if binary_dir:
521 522 executable = os.path.join(binary_dir, 'python')
522 523 return install_svn_hooks(
523 524 repo_path, executable=executable, force_create=force)
524 525
525 526 @reraise_safe_exceptions
526 527 def get_hooks_info(self, wire):
527 528 from vcsserver.hook_utils import (
528 529 get_svn_pre_hook_version, get_svn_post_hook_version)
529 530 repo_path = wire['path']
530 531 return {
531 532 'pre_version': get_svn_pre_hook_version(repo_path),
532 533 'post_version': get_svn_post_hook_version(repo_path),
533 534 }
534 535
535 536 @reraise_safe_exceptions
536 537 def set_head_ref(self, wire, head_name):
537 538 pass
538 539
539 540 @reraise_safe_exceptions
540 541 def archive_repo(self, wire, archive_dest_path, kind, mtime, archive_at_path,
541 542 archive_dir_name, commit_id):
542 543
543 544 def walk_tree(root, root_dir, _commit_id):
544 545 """
545 546 Special recursive svn repo walker
546 547 """
547 548
548 549 filemode_default = 0o100644
549 550 filemode_executable = 0o100755
550 551
551 552 file_iter = svn.fs.dir_entries(root, root_dir)
552 553 for f_name in file_iter:
553 554 f_type = NODE_TYPE_MAPPING.get(file_iter[f_name].kind, None)
554 555
555 556 if f_type == 'dir':
556 557 # return only DIR, and then all entries in that dir
557 558 yield os.path.join(root_dir, f_name), {'mode': filemode_default}, f_type
558 559 new_root = os.path.join(root_dir, f_name)
559 560 for _f_name, _f_data, _f_type in walk_tree(root, new_root, _commit_id):
560 561 yield _f_name, _f_data, _f_type
561 562 else:
562 563 f_path = os.path.join(root_dir, f_name).rstrip('/')
563 564 prop_list = svn.fs.node_proplist(root, f_path)
564 565
565 566 f_mode = filemode_default
566 567 if prop_list.get('svn:executable'):
567 568 f_mode = filemode_executable
568 569
569 570 f_is_link = False
570 571 if prop_list.get('svn:special'):
571 572 f_is_link = True
572 573
573 574 data = {
574 575 'is_link': f_is_link,
575 576 'mode': f_mode,
576 577 'content_stream': svn.core.Stream(svn.fs.file_contents(root, f_path)).read
577 578 }
578 579
579 580 yield f_path, data, f_type
580 581
581 582 def file_walker(_commit_id, path):
582 583 repo = self._factory.repo(wire)
583 584 root = svn.fs.revision_root(svn.repos.fs(repo), int(commit_id))
584 585
585 586 def no_content():
586 587 raise NoContentException()
587 588
588 589 for f_name, f_data, f_type in walk_tree(root, path, _commit_id):
589 590 file_path = f_name
590 591
591 592 if f_type == 'dir':
592 593 mode = f_data['mode']
593 594 yield ArchiveNode(file_path, mode, False, no_content)
594 595 else:
595 596 mode = f_data['mode']
596 597 is_link = f_data['is_link']
597 598 data_stream = f_data['content_stream']
598 599 yield ArchiveNode(file_path, mode, is_link, data_stream)
599 600
600 601 return archive_repo(file_walker, archive_dest_path, kind, mtime, archive_at_path,
601 602 archive_dir_name, commit_id)
602 603
603 604
604 605 class SvnDiffer(object):
605 606 """
606 607 Utility to create diffs based on difflib and the Subversion api
607 608 """
608 609
609 610 binary_content = False
610 611
611 612 def __init__(
612 613 self, repo, src_rev, src_path, tgt_rev, tgt_path,
613 614 ignore_whitespace, context):
614 615 self.repo = repo
615 616 self.ignore_whitespace = ignore_whitespace
616 617 self.context = context
617 618
618 619 fsobj = svn.repos.fs(repo)
619 620
620 621 self.tgt_rev = tgt_rev
621 622 self.tgt_path = tgt_path or ''
622 623 self.tgt_root = svn.fs.revision_root(fsobj, tgt_rev)
623 624 self.tgt_kind = svn.fs.check_path(self.tgt_root, self.tgt_path)
624 625
625 626 self.src_rev = src_rev
626 627 self.src_path = src_path or self.tgt_path
627 628 self.src_root = svn.fs.revision_root(fsobj, src_rev)
628 629 self.src_kind = svn.fs.check_path(self.src_root, self.src_path)
629 630
630 631 self._validate()
631 632
632 633 def _validate(self):
633 634 if (self.tgt_kind != svn.core.svn_node_none and
634 635 self.src_kind != svn.core.svn_node_none and
635 636 self.src_kind != self.tgt_kind):
636 637 # TODO: johbo: proper error handling
637 638 raise Exception(
638 639 "Source and target are not compatible for diff generation. "
639 640 "Source type: %s, target type: %s" %
640 641 (self.src_kind, self.tgt_kind))
641 642
642 643 def generate_diff(self):
643 644 buf = io.StringIO()
644 645 if self.tgt_kind == svn.core.svn_node_dir:
645 646 self._generate_dir_diff(buf)
646 647 else:
647 648 self._generate_file_diff(buf)
648 649 return buf.getvalue()
649 650
650 651 def _generate_dir_diff(self, buf):
651 652 editor = DiffChangeEditor()
652 653 editor_ptr, editor_baton = svn.delta.make_editor(editor)
653 654 svn.repos.dir_delta2(
654 655 self.src_root,
655 656 self.src_path,
656 657 '', # src_entry
657 658 self.tgt_root,
658 659 self.tgt_path,
659 660 editor_ptr, editor_baton,
660 661 authorization_callback_allow_all,
661 662 False, # text_deltas
662 663 svn.core.svn_depth_infinity, # depth
663 664 False, # entry_props
664 665 False, # ignore_ancestry
665 666 )
666 667
667 668 for path, __, change in sorted(editor.changes):
668 669 self._generate_node_diff(
669 670 buf, change, path, self.tgt_path, path, self.src_path)
670 671
671 672 def _generate_file_diff(self, buf):
672 673 change = None
673 674 if self.src_kind == svn.core.svn_node_none:
674 675 change = "add"
675 676 elif self.tgt_kind == svn.core.svn_node_none:
676 677 change = "delete"
677 678 tgt_base, tgt_path = vcspath.split(self.tgt_path)
678 679 src_base, src_path = vcspath.split(self.src_path)
679 680 self._generate_node_diff(
680 681 buf, change, tgt_path, tgt_base, src_path, src_base)
681 682
682 683 def _generate_node_diff(
683 684 self, buf, change, tgt_path, tgt_base, src_path, src_base):
684 685
685 686 if self.src_rev == self.tgt_rev and tgt_base == src_base:
686 687 # makes consistent behaviour with git/hg to return empty diff if
687 688 # we compare same revisions
688 689 return
689 690
690 691 tgt_full_path = vcspath.join(tgt_base, tgt_path)
691 692 src_full_path = vcspath.join(src_base, src_path)
692 693
693 694 self.binary_content = False
694 695 mime_type = self._get_mime_type(tgt_full_path)
695 696
696 697 if mime_type and not mime_type.startswith('text'):
697 698 self.binary_content = True
698 699 buf.write("=" * 67 + '\n')
699 700 buf.write("Cannot display: file marked as a binary type.\n")
700 701 buf.write("svn:mime-type = %s\n" % mime_type)
701 702 buf.write("Index: %s\n" % (tgt_path, ))
702 703 buf.write("=" * 67 + '\n')
703 704 buf.write("diff --git a/%(tgt_path)s b/%(tgt_path)s\n" % {
704 705 'tgt_path': tgt_path})
705 706
706 707 if change == 'add':
707 708 # TODO: johbo: SVN is missing a zero here compared to git
708 709 buf.write("new file mode 10644\n")
709 710
710 711 #TODO(marcink): intro to binary detection of svn patches
711 712 # if self.binary_content:
712 713 # buf.write('GIT binary patch\n')
713 714
714 715 buf.write("--- /dev/null\t(revision 0)\n")
715 716 src_lines = []
716 717 else:
717 718 if change == 'delete':
718 719 buf.write("deleted file mode 10644\n")
719 720
720 721 #TODO(marcink): intro to binary detection of svn patches
721 722 # if self.binary_content:
722 723 # buf.write('GIT binary patch\n')
723 724
724 725 buf.write("--- a/%s\t(revision %s)\n" % (
725 726 src_path, self.src_rev))
726 727 src_lines = self._svn_readlines(self.src_root, src_full_path)
727 728
728 729 if change == 'delete':
729 730 buf.write("+++ /dev/null\t(revision %s)\n" % (self.tgt_rev, ))
730 731 tgt_lines = []
731 732 else:
732 733 buf.write("+++ b/%s\t(revision %s)\n" % (
733 734 tgt_path, self.tgt_rev))
734 735 tgt_lines = self._svn_readlines(self.tgt_root, tgt_full_path)
735 736
736 737 if not self.binary_content:
737 738 udiff = svn_diff.unified_diff(
738 739 src_lines, tgt_lines, context=self.context,
739 740 ignore_blank_lines=self.ignore_whitespace,
740 741 ignore_case=False,
741 742 ignore_space_changes=self.ignore_whitespace)
742 743 buf.writelines(udiff)
743 744
744 745 def _get_mime_type(self, path):
745 746 try:
746 747 mime_type = svn.fs.node_prop(
747 748 self.tgt_root, path, svn.core.SVN_PROP_MIME_TYPE)
748 749 except svn.core.SubversionException:
749 750 mime_type = svn.fs.node_prop(
750 751 self.src_root, path, svn.core.SVN_PROP_MIME_TYPE)
751 752 return mime_type
752 753
753 754 def _svn_readlines(self, fs_root, node_path):
754 755 if self.binary_content:
755 756 return []
756 757 node_kind = svn.fs.check_path(fs_root, node_path)
757 758 if node_kind not in (
758 759 svn.core.svn_node_file, svn.core.svn_node_symlink):
759 760 return []
760 761 content = svn.core.Stream(
761 762 svn.fs.file_contents(fs_root, node_path)).read()
762 763 return content.splitlines(True)
763 764
764 765
765 766 class DiffChangeEditor(svn.delta.Editor):
766 767 """
767 768 Records changes between two given revisions
768 769 """
769 770
770 771 def __init__(self):
771 772 self.changes = []
772 773
773 774 def delete_entry(self, path, revision, parent_baton, pool=None):
774 775 self.changes.append((path, None, 'delete'))
775 776
776 777 def add_file(
777 778 self, path, parent_baton, copyfrom_path, copyfrom_revision,
778 779 file_pool=None):
779 780 self.changes.append((path, 'file', 'add'))
780 781
781 782 def open_file(self, path, parent_baton, base_revision, file_pool=None):
782 783 self.changes.append((path, 'file', 'change'))
783 784
784 785
785 786 def authorization_callback_allow_all(root, path, pool):
786 787 return True
787 788
788 789
789 790 class TxnNodeProcessor(object):
790 791 """
791 792 Utility to process the change of one node within a transaction root.
792 793
793 794 It encapsulates the knowledge of how to add, update or remove
794 795 a node for a given transaction root. The purpose is to support the method
795 796 `SvnRemote.commit`.
796 797 """
797 798
798 799 def __init__(self, node, txn_root):
799 800 assert isinstance(node['path'], str)
800 801
801 802 self.node = node
802 803 self.txn_root = txn_root
803 804
804 805 def update(self):
805 806 self._ensure_parent_dirs()
806 807 self._add_file_if_node_does_not_exist()
807 808 self._update_file_content()
808 809 self._update_file_properties()
809 810
810 811 def remove(self):
811 812 svn.fs.delete(self.txn_root, self.node['path'])
812 813 # TODO: Clean up directory if empty
813 814
814 815 def _ensure_parent_dirs(self):
815 816 curdir = vcspath.dirname(self.node['path'])
816 817 dirs_to_create = []
817 818 while not self._svn_path_exists(curdir):
818 819 dirs_to_create.append(curdir)
819 820 curdir = vcspath.dirname(curdir)
820 821
821 822 for curdir in reversed(dirs_to_create):
822 823 log.debug('Creating missing directory "%s"', curdir)
823 824 svn.fs.make_dir(self.txn_root, curdir)
824 825
825 826 def _svn_path_exists(self, path):
826 827 path_status = svn.fs.check_path(self.txn_root, path)
827 828 return path_status != svn.core.svn_node_none
828 829
829 830 def _add_file_if_node_does_not_exist(self):
830 831 kind = svn.fs.check_path(self.txn_root, self.node['path'])
831 832 if kind == svn.core.svn_node_none:
832 833 svn.fs.make_file(self.txn_root, self.node['path'])
833 834
834 835 def _update_file_content(self):
835 836 assert isinstance(self.node['content'], str)
836 837 handler, baton = svn.fs.apply_textdelta(
837 838 self.txn_root, self.node['path'], None, None)
838 839 svn.delta.svn_txdelta_send_string(self.node['content'], handler, baton)
839 840
840 841 def _update_file_properties(self):
841 842 properties = self.node.get('properties', {})
842 843 for key, value in properties.items():
843 844 svn.fs.change_node_prop(
844 845 self.txn_root, self.node['path'], key, value)
845 846
846 847
847 848 def apr_time_t(timestamp):
848 849 """
849 850 Convert a Python timestamp into APR timestamp type apr_time_t
850 851 """
851 852 return timestamp * 1E6
852 853
853 854
854 855 def svn_opt_revision_value_t(num):
855 856 """
856 857 Put `num` into a `svn_opt_revision_value_t` structure.
857 858 """
858 859 value = svn.core.svn_opt_revision_value_t()
859 860 value.number = num
860 861 revision = svn.core.svn_opt_revision_t()
861 862 revision.kind = svn.core.svn_opt_revision_number
862 863 revision.value = value
863 864 return revision
@@ -1,235 +1,235 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import logging
20 20 import itertools
21 21
22 22 import mercurial
23 23 import mercurial.error
24 24 import mercurial.wireprotoserver
25 25 import mercurial.hgweb.common
26 26 import mercurial.hgweb.hgweb_mod
27 27 import webob.exc
28 28
29 29 from vcsserver import pygrack, exceptions, settings, git_lfs
30
30 from vcsserver.utils import ascii_bytes
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34
35 35 # propagated from mercurial documentation
36 36 HG_UI_SECTIONS = [
37 37 'alias', 'auth', 'decode/encode', 'defaults', 'diff', 'email', 'extensions',
38 38 'format', 'merge-patterns', 'merge-tools', 'hooks', 'http_proxy', 'smtp',
39 39 'patch', 'paths', 'profiling', 'server', 'trusted', 'ui', 'web',
40 40 ]
41 41
42 42
43 43 class HgWeb(mercurial.hgweb.hgweb_mod.hgweb):
44 44 """Extension of hgweb that simplifies some functions."""
45 45
46 46 def _get_view(self, repo):
47 47 """Views are not supported."""
48 48 return repo
49 49
50 50 def loadsubweb(self):
51 51 """The result is only used in the templater method which is not used."""
52 52 return None
53 53
54 54 def run(self):
55 55 """Unused function so raise an exception if accidentally called."""
56 56 raise NotImplementedError
57 57
58 58 def templater(self, req):
59 59 """Function used in an unreachable code path.
60 60
61 61 This code is unreachable because we guarantee that the HTTP request,
62 62 corresponds to a Mercurial command. See the is_hg method. So, we are
63 63 never going to get a user-visible url.
64 64 """
65 65 raise NotImplementedError
66 66
67 67 def archivelist(self, nodeid):
68 68 """Unused function so raise an exception if accidentally called."""
69 69 raise NotImplementedError
70 70
71 71 def __call__(self, environ, start_response):
72 72 """Run the WSGI application.
73 73
74 74 This may be called by multiple threads.
75 75 """
76 76 from mercurial.hgweb import request as requestmod
77 77 req = requestmod.parserequestfromenv(environ)
78 78 res = requestmod.wsgiresponse(req, start_response)
79 79 gen = self.run_wsgi(req, res)
80 80
81 81 first_chunk = None
82 82
83 83 try:
84 84 data = next(gen)
85 85
86 86 def first_chunk():
87 87 yield data
88 88 except StopIteration:
89 89 pass
90 90
91 91 if first_chunk:
92 92 return itertools.chain(first_chunk(), gen)
93 93 return gen
94 94
95 95 def _runwsgi(self, req, res, repo):
96 96
97 97 cmd = req.qsparams.get('cmd', '')
98 98 if not mercurial.wireprotoserver.iscmd(cmd):
99 99 # NOTE(marcink): for unsupported commands, we return bad request
100 100 # internally from HG
101 101 from mercurial.hgweb.common import statusmessage
102 102 res.status = statusmessage(mercurial.hgweb.common.HTTP_BAD_REQUEST)
103 103 res.setbodybytes('')
104 104 return res.sendresponse()
105 105
106 106 return super(HgWeb, self)._runwsgi(req, res, repo)
107 107
108 108
109 109 def make_hg_ui_from_config(repo_config):
110 110 baseui = mercurial.ui.ui()
111 111
112 112 # clean the baseui object
113 113 baseui._ocfg = mercurial.config.config()
114 114 baseui._ucfg = mercurial.config.config()
115 115 baseui._tcfg = mercurial.config.config()
116 116
117 117 for section, option, value in repo_config:
118 baseui.setconfig(section, option, value)
118 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
119 119
120 120 # make our hgweb quiet so it doesn't print output
121 baseui.setconfig('ui', 'quiet', 'true')
121 baseui.setconfig(b'ui', b'quiet', b'true')
122 122
123 123 return baseui
124 124
125 125
126 126 def update_hg_ui_from_hgrc(baseui, repo_path):
127 127 path = os.path.join(repo_path, '.hg', 'hgrc')
128 128
129 129 if not os.path.isfile(path):
130 130 log.debug('hgrc file is not present at %s, skipping...', path)
131 131 return
132 132 log.debug('reading hgrc from %s', path)
133 133 cfg = mercurial.config.config()
134 134 cfg.read(path)
135 135 for section in HG_UI_SECTIONS:
136 136 for k, v in cfg.items(section):
137 137 log.debug('settings ui from file: [%s] %s=%s', section, k, v)
138 baseui.setconfig(section, k, v)
138 baseui.setconfig(ascii_bytes(section), ascii_bytes(k), ascii_bytes(v))
139 139
140 140
141 141 def create_hg_wsgi_app(repo_path, repo_name, config):
142 142 """
143 143 Prepares a WSGI application to handle Mercurial requests.
144 144
145 145 :param config: is a list of 3-item tuples representing a ConfigObject
146 146 (it is the serialized version of the config object).
147 147 """
148 148 log.debug("Creating Mercurial WSGI application")
149 149
150 150 baseui = make_hg_ui_from_config(config)
151 151 update_hg_ui_from_hgrc(baseui, repo_path)
152 152
153 153 try:
154 154 return HgWeb(repo_path, name=repo_name, baseui=baseui)
155 155 except mercurial.error.RequirementError as e:
156 156 raise exceptions.RequirementException(e)(e)
157 157
158 158
159 159 class GitHandler(object):
160 160 """
161 161 Handler for Git operations like push/pull etc
162 162 """
163 163 def __init__(self, repo_location, repo_name, git_path, update_server_info,
164 164 extras):
165 165 if not os.path.isdir(repo_location):
166 166 raise OSError(repo_location)
167 167 self.content_path = repo_location
168 168 self.repo_name = repo_name
169 169 self.repo_location = repo_location
170 170 self.extras = extras
171 171 self.git_path = git_path
172 172 self.update_server_info = update_server_info
173 173
174 174 def __call__(self, environ, start_response):
175 175 app = webob.exc.HTTPNotFound()
176 176 candidate_paths = (
177 177 self.content_path, os.path.join(self.content_path, '.git'))
178 178
179 179 for content_path in candidate_paths:
180 180 try:
181 181 app = pygrack.GitRepository(
182 182 self.repo_name, content_path, self.git_path,
183 183 self.update_server_info, self.extras)
184 184 break
185 185 except OSError:
186 186 continue
187 187
188 188 return app(environ, start_response)
189 189
190 190
191 191 def create_git_wsgi_app(repo_path, repo_name, config):
192 192 """
193 193 Creates a WSGI application to handle Git requests.
194 194
195 195 :param config: is a dictionary holding the extras.
196 196 """
197 197 git_path = settings.GIT_EXECUTABLE
198 198 update_server_info = config.pop('git_update_server_info')
199 199 app = GitHandler(
200 200 repo_path, repo_name, git_path, update_server_info, config)
201 201
202 202 return app
203 203
204 204
205 205 class GitLFSHandler(object):
206 206 """
207 207 Handler for Git LFS operations
208 208 """
209 209
210 210 def __init__(self, repo_location, repo_name, git_path, update_server_info,
211 211 extras):
212 212 if not os.path.isdir(repo_location):
213 213 raise OSError(repo_location)
214 214 self.content_path = repo_location
215 215 self.repo_name = repo_name
216 216 self.repo_location = repo_location
217 217 self.extras = extras
218 218 self.git_path = git_path
219 219 self.update_server_info = update_server_info
220 220
221 221 def get_app(self, git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
222 222 app = git_lfs.create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
223 223 return app
224 224
225 225
226 226 def create_git_lfs_wsgi_app(repo_path, repo_name, config):
227 227 git_path = settings.GIT_EXECUTABLE
228 update_server_info = config.pop('git_update_server_info')
229 git_lfs_enabled = config.pop('git_lfs_enabled')
230 git_lfs_store_path = config.pop('git_lfs_store_path')
231 git_lfs_http_scheme = config.pop('git_lfs_http_scheme', 'http')
228 update_server_info = config.pop(b'git_update_server_info')
229 git_lfs_enabled = config.pop(b'git_lfs_enabled')
230 git_lfs_store_path = config.pop(b'git_lfs_store_path')
231 git_lfs_http_scheme = config.pop(b'git_lfs_http_scheme', 'http')
232 232 app = GitLFSHandler(
233 233 repo_path, repo_name, git_path, update_server_info, config)
234 234
235 235 return app.get_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme)
@@ -1,519 +1,561 b''
1 1 """
2 2 Module provides a class allowing to wrap communication over subprocess.Popen
3 3 input, output, error streams into a meaningfull, non-blocking, concurrent
4 4 stream processor exposing the output data as an iterator fitting to be a
5 5 return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
6 6
7 7 Copyright (c) 2011 Daniel Dotsenko <dotsa[at]hotmail.com>
8 8
9 9 This file is part of git_http_backend.py Project.
10 10
11 11 git_http_backend.py Project is free software: you can redistribute it and/or
12 12 modify it under the terms of the GNU Lesser General Public License as
13 13 published by the Free Software Foundation, either version 2.1 of the License,
14 14 or (at your option) any later version.
15 15
16 16 git_http_backend.py Project is distributed in the hope that it will be useful,
17 17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 19 GNU Lesser General Public License for more details.
20 20
21 21 You should have received a copy of the GNU Lesser General Public License
22 22 along with git_http_backend.py Project.
23 23 If not, see <http://www.gnu.org/licenses/>.
24 24 """
25 25 import os
26 import collections
26 27 import logging
27 28 import subprocess
28 from collections import deque
29 from threading import Event, Thread
29 import threading
30 30
31 31 log = logging.getLogger(__name__)
32 32
33 33
34 class StreamFeeder(Thread):
34 class StreamFeeder(threading.Thread):
35 35 """
36 36 Normal writing into pipe-like is blocking once the buffer is filled.
37 37 This thread allows a thread to seep data from a file-like into a pipe
38 38 without blocking the main thread.
39 39 We close inpipe once the end of the source stream is reached.
40 40 """
41 41
42 42 def __init__(self, source):
43 43 super(StreamFeeder, self).__init__()
44 44 self.daemon = True
45 45 filelike = False
46 46 self.bytes = bytes()
47 47 if type(source) in (type(''), bytes, bytearray): # string-like
48 48 self.bytes = bytes(source)
49 49 else: # can be either file pointer or file-like
50 if type(source) in (int, int): # file pointer it is
50 if isinstance(source, int): # file pointer it is
51 51 # converting file descriptor (int) stdin into file-like
52 try:
53 source = os.fdopen(source, 'rb', 16384)
54 except Exception:
55 pass
52 source = os.fdopen(source, 'rb', 16384)
56 53 # let's see if source is file-like by now
57 try:
58 filelike = source.read
59 except Exception:
60 pass
54 filelike = hasattr(source, 'read')
61 55 if not filelike and not self.bytes:
62 56 raise TypeError("StreamFeeder's source object must be a readable "
63 57 "file-like, a file descriptor, or a string-like.")
64 58 self.source = source
65 59 self.readiface, self.writeiface = os.pipe()
66 60
67 61 def run(self):
68 t = self.writeiface
62 writer = self.writeiface
69 63 try:
70 64 if self.bytes:
71 os.write(t, self.bytes)
65 os.write(writer, self.bytes)
72 66 else:
73 67 s = self.source
74 b = s.read(4096)
75 while b:
76 os.write(t, b)
77 b = s.read(4096)
68
69 while 1:
70 _bytes = s.read(4096)
71 if not _bytes:
72 break
73 os.write(writer, _bytes)
74
78 75 finally:
79 os.close(t)
76 os.close(writer)
80 77
81 78 @property
82 79 def output(self):
83 80 return self.readiface
84 81
85 82
86 class InputStreamChunker(Thread):
83 class InputStreamChunker(threading.Thread):
87 84 def __init__(self, source, target, buffer_size, chunk_size):
88 85
89 86 super(InputStreamChunker, self).__init__()
90 87
91 88 self.daemon = True # die die die.
92 89
93 90 self.source = source
94 91 self.target = target
95 92 self.chunk_count_max = int(buffer_size / chunk_size) + 1
96 93 self.chunk_size = chunk_size
97 94
98 self.data_added = Event()
95 self.data_added = threading.Event()
99 96 self.data_added.clear()
100 97
101 self.keep_reading = Event()
98 self.keep_reading = threading.Event()
102 99 self.keep_reading.set()
103 100
104 self.EOF = Event()
101 self.EOF = threading.Event()
105 102 self.EOF.clear()
106 103
107 self.go = Event()
104 self.go = threading.Event()
108 105 self.go.set()
109 106
110 107 def stop(self):
111 108 self.go.clear()
112 109 self.EOF.set()
113 110 try:
114 111 # this is not proper, but is done to force the reader thread let
115 112 # go of the input because, if successful, .close() will send EOF
116 113 # down the pipe.
117 114 self.source.close()
118 115 except:
119 116 pass
120 117
121 118 def run(self):
122 119 s = self.source
123 120 t = self.target
124 121 cs = self.chunk_size
125 122 chunk_count_max = self.chunk_count_max
126 123 keep_reading = self.keep_reading
127 124 da = self.data_added
128 125 go = self.go
129 126
130 127 try:
131 128 b = s.read(cs)
132 129 except ValueError:
133 130 b = ''
134 131
135 132 timeout_input = 20
136 133 while b and go.is_set():
137 134 if len(t) > chunk_count_max:
138 135 keep_reading.clear()
139 136 keep_reading.wait(timeout_input)
140 137 if len(t) > chunk_count_max + timeout_input:
141 138 log.error("Timed out while waiting for input from subprocess.")
142 139 os._exit(-1) # this will cause the worker to recycle itself
143 140
144 141 t.append(b)
145 142 da.set()
146 143
147 144 try:
148 145 b = s.read(cs)
149 except ValueError:
146 except ValueError: # probably "I/O operation on closed file"
150 147 b = ''
151 148
152 149 self.EOF.set()
153 150 da.set() # for cases when done but there was no input.
154 151
155 152
156 153 class BufferedGenerator(object):
157 154 """
158 155 Class behaves as a non-blocking, buffered pipe reader.
159 156 Reads chunks of data (through a thread)
160 157 from a blocking pipe, and attaches these to an array (Deque) of chunks.
161 158 Reading is halted in the thread when max chunks is internally buffered.
162 159 The .next() may operate in blocking or non-blocking fashion by yielding
163 160 '' if no data is ready
164 161 to be sent or by not returning until there is some data to send
165 162 When we get EOF from underlying source pipe we raise the marker to raise
166 163 StopIteration after the last chunk of data is yielded.
167 164 """
168 165
169 def __init__(self, source, buffer_size=65536, chunk_size=4096,
166 def __init__(self, name, source, buffer_size=65536, chunk_size=4096,
170 167 starting_values=None, bottomless=False):
171 168 starting_values = starting_values or []
169 self.name = name
170 self.buffer_size = buffer_size
171 self.chunk_size = chunk_size
172 172
173 173 if bottomless:
174 174 maxlen = int(buffer_size / chunk_size)
175 175 else:
176 176 maxlen = None
177 177
178 self.data = deque(starting_values, maxlen)
179 self.worker = InputStreamChunker(source, self.data, buffer_size,
180 chunk_size)
178 self.data_queue = collections.deque(starting_values, maxlen)
179 self.worker = InputStreamChunker(source, self.data_queue, buffer_size, chunk_size)
181 180 if starting_values:
182 181 self.worker.data_added.set()
183 182 self.worker.start()
184 183
185 184 ####################
186 185 # Generator's methods
187 186 ####################
187 def __str__(self):
188 return f'BufferedGenerator(name={self.name} chunk: {self.chunk_size} on buffer: {self.buffer_size})'
188 189
189 190 def __iter__(self):
190 191 return self
191 192
192 193 def __next__(self):
193 while not len(self.data) and not self.worker.EOF.is_set():
194
195 while not self.length and not self.worker.EOF.is_set():
194 196 self.worker.data_added.clear()
195 197 self.worker.data_added.wait(0.2)
196 if len(self.data):
198
199 if self.length:
197 200 self.worker.keep_reading.set()
198 return bytes(self.data.popleft())
201 return bytes(self.data_queue.popleft())
199 202 elif self.worker.EOF.is_set():
200 203 raise StopIteration
201 204
202 205 def throw(self, exc_type, value=None, traceback=None):
203 206 if not self.worker.EOF.is_set():
204 207 raise exc_type(value)
205 208
206 209 def start(self):
207 210 self.worker.start()
208 211
209 212 def stop(self):
210 213 self.worker.stop()
211 214
212 215 def close(self):
213 216 try:
214 217 self.worker.stop()
215 218 self.throw(GeneratorExit)
216 219 except (GeneratorExit, StopIteration):
217 220 pass
218 221
219 222 ####################
220 223 # Threaded reader's infrastructure.
221 224 ####################
222 225 @property
223 226 def input(self):
224 227 return self.worker.w
225 228
226 229 @property
227 230 def data_added_event(self):
228 231 return self.worker.data_added
229 232
230 233 @property
231 234 def data_added(self):
232 235 return self.worker.data_added.is_set()
233 236
234 237 @property
235 238 def reading_paused(self):
236 239 return not self.worker.keep_reading.is_set()
237 240
238 241 @property
239 242 def done_reading_event(self):
240 243 """
241 244 Done_reding does not mean that the iterator's buffer is empty.
242 245 Iterator might have done reading from underlying source, but the read
243 246 chunks might still be available for serving through .next() method.
244 247
245 248 :returns: An Event class instance.
246 249 """
247 250 return self.worker.EOF
248 251
249 252 @property
250 253 def done_reading(self):
251 254 """
252 Done_reding does not mean that the iterator's buffer is empty.
255 Done_reading does not mean that the iterator's buffer is empty.
253 256 Iterator might have done reading from underlying source, but the read
254 257 chunks might still be available for serving through .next() method.
255 258
256 259 :returns: An Bool value.
257 260 """
258 261 return self.worker.EOF.is_set()
259 262
260 263 @property
261 264 def length(self):
262 265 """
263 266 returns int.
264 267
265 This is the lenght of the que of chunks, not the length of
268 This is the length of the queue of chunks, not the length of
266 269 the combined contents in those chunks.
267 270
268 271 __len__() cannot be meaningfully implemented because this
269 reader is just flying throuh a bottomless pit content and
270 can only know the lenght of what it already saw.
272 reader is just flying through a bottomless pit content and
273 can only know the length of what it already saw.
271 274
272 275 If __len__() on WSGI server per PEP 3333 returns a value,
273 the responce's length will be set to that. In order not to
276 the response's length will be set to that. In order not to
274 277 confuse WSGI PEP3333 servers, we will not implement __len__
275 278 at all.
276 279 """
277 return len(self.data)
280 return len(self.data_queue)
278 281
279 282 def prepend(self, x):
280 self.data.appendleft(x)
283 self.data_queue.appendleft(x)
281 284
282 285 def append(self, x):
283 self.data.append(x)
286 self.data_queue.append(x)
284 287
285 288 def extend(self, o):
286 self.data.extend(o)
289 self.data_queue.extend(o)
287 290
288 291 def __getitem__(self, i):
289 return self.data[i]
292 return self.data_queue[i]
290 293
291 294
292 295 class SubprocessIOChunker(object):
293 296 """
294 297 Processor class wrapping handling of subprocess IO.
295 298
296 299 .. important::
297 300
298 301 Watch out for the method `__del__` on this class. If this object
299 302 is deleted, it will kill the subprocess, so avoid to
300 303 return the `output` attribute or usage of it like in the following
301 304 example::
302 305
303 306 # `args` expected to run a program that produces a lot of output
304 307 output = ''.join(SubprocessIOChunker(
305 308 args, shell=False, inputstream=inputstream, env=environ).output)
306 309
307 310 # `output` will not contain all the data, because the __del__ method
308 311 # has already killed the subprocess in this case before all output
309 312 # has been consumed.
310 313
311 314
312 315
313 316 In a way, this is a "communicate()" replacement with a twist.
314 317
315 318 - We are multithreaded. Writing in and reading out, err are all sep threads.
316 319 - We support concurrent (in and out) stream processing.
317 - The output is not a stream. It's a queue of read string (bytes, not unicode)
320 - The output is not a stream. It's a queue of read string (bytes, not str)
318 321 chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
319 322 - We are non-blocking in more respects than communicate()
320 323 (reading from subprocess out pauses when internal buffer is full, but
321 324 does not block the parent calling code. On the flip side, reading from
322 325 slow-yielding subprocess may block the iteration until data shows up. This
323 326 does not block the parallel inpipe reading occurring parallel thread.)
324 327
325 328 The purpose of the object is to allow us to wrap subprocess interactions into
326 and interable that can be passed to a WSGI server as the application's return
329 an iterable that can be passed to a WSGI server as the application's return
327 330 value. Because of stream-processing-ability, WSGI does not have to read ALL
328 331 of the subprocess's output and buffer it, before handing it to WSGI server for
329 332 HTTP response. Instead, the class initializer reads just a bit of the stream
330 to figure out if error ocurred or likely to occur and if not, just hands the
333 to figure out if error occurred or likely to occur and if not, just hands the
331 334 further iteration over subprocess output to the server for completion of HTTP
332 335 response.
333 336
334 337 The real or perceived subprocess error is trapped and raised as one of
335 EnvironmentError family of exceptions
338 OSError family of exceptions
336 339
337 340 Example usage:
338 341 # try:
339 342 # answer = SubprocessIOChunker(
340 343 # cmd,
341 344 # input,
342 345 # buffer_size = 65536,
343 346 # chunk_size = 4096
344 347 # )
345 # except (EnvironmentError) as e:
348 # except (OSError) as e:
346 349 # print str(e)
347 350 # raise e
348 351 #
349 352 # return answer
350 353
351 354
352 355 """
353 356
354 357 # TODO: johbo: This is used to make sure that the open end of the PIPE
355 358 # is closed in the end. It would be way better to wrap this into an
356 359 # object, so that it is closed automatically once it is consumed or
357 360 # something similar.
358 361 _close_input_fd = None
359 362
360 363 _closed = False
364 _stdout = None
365 _stderr = None
361 366
362 def __init__(self, cmd, inputstream=None, buffer_size=65536,
367 def __init__(self, cmd, input_stream=None, buffer_size=65536,
363 368 chunk_size=4096, starting_values=None, fail_on_stderr=True,
364 369 fail_on_return_code=True, **kwargs):
365 370 """
366 371 Initializes SubprocessIOChunker
367 372
368 373 :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
369 :param inputstream: (Default: None) A file-like, string, or file pointer.
374 :param input_stream: (Default: None) A file-like, string, or file pointer.
370 375 :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
371 376 :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
372 377 :param starting_values: (Default: []) An array of strings to put in front of output que.
373 378 :param fail_on_stderr: (Default: True) Whether to raise an exception in
374 379 case something is written to stderr.
375 380 :param fail_on_return_code: (Default: True) Whether to raise an
376 381 exception if the return code is not 0.
377 382 """
378 383
384 kwargs['shell'] = kwargs.get('shell', True)
385
379 386 starting_values = starting_values or []
380 if inputstream:
381 input_streamer = StreamFeeder(inputstream)
387 if input_stream:
388 input_streamer = StreamFeeder(input_stream)
382 389 input_streamer.start()
383 inputstream = input_streamer.output
384 self._close_input_fd = inputstream
390 input_stream = input_streamer.output
391 self._close_input_fd = input_stream
385 392
386 393 self._fail_on_stderr = fail_on_stderr
387 394 self._fail_on_return_code = fail_on_return_code
388
389 _shell = kwargs.get('shell', True)
390 kwargs['shell'] = _shell
395 self.cmd = cmd
391 396
392 _p = subprocess.Popen(cmd, bufsize=-1,
393 stdin=inputstream,
394 stdout=subprocess.PIPE,
395 stderr=subprocess.PIPE,
397 _p = subprocess.Popen(cmd, bufsize=-1, stdin=input_stream, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
396 398 **kwargs)
399 self.process = _p
397 400
398 bg_out = BufferedGenerator(_p.stdout, buffer_size, chunk_size,
399 starting_values)
400 bg_err = BufferedGenerator(_p.stderr, 16000, 1, bottomless=True)
401 bg_out = BufferedGenerator('stdout', _p.stdout, buffer_size, chunk_size, starting_values)
402 bg_err = BufferedGenerator('stderr', _p.stderr, 10240, 1, bottomless=True)
401 403
402 404 while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
403 405 # doing this until we reach either end of file, or end of buffer.
404 bg_out.data_added_event.wait(1)
406 bg_out.data_added_event.wait(0.2)
405 407 bg_out.data_added_event.clear()
406 408
407 409 # at this point it's still ambiguous if we are done reading or just full buffer.
408 410 # Either way, if error (returned by ended process, or implied based on
409 411 # presence of stuff in stderr output) we error out.
410 412 # Else, we are happy.
411 _returncode = _p.poll()
413 return_code = _p.poll()
414 ret_code_ok = return_code in [None, 0]
415 ret_code_fail = return_code is not None and return_code != 0
416 if (
417 (ret_code_fail and fail_on_return_code) or
418 (ret_code_ok and fail_on_stderr and bg_err.length)
419 ):
412 420
413 if ((_returncode and fail_on_return_code) or
414 (fail_on_stderr and _returncode is None and bg_err.length)):
415 421 try:
416 422 _p.terminate()
417 423 except Exception:
418 424 pass
425
419 426 bg_out.stop()
427 out = b''.join(bg_out)
428 self._stdout = out
429
420 430 bg_err.stop()
421 if fail_on_stderr:
422 err = ''.join(bg_err)
423 raise EnvironmentError(
424 "Subprocess exited due to an error:\n" + err)
425 if _returncode and fail_on_return_code:
426 err = ''.join(bg_err)
431 err = b''.join(bg_err)
432 self._stderr = err
433
434 # code from https://github.com/schacon/grack/pull/7
435 if err.strip() == b'fatal: The remote end hung up unexpectedly' and out.startswith(b'0034shallow '):
436 bg_out = iter([out])
437 _p = None
438 elif err and fail_on_stderr:
439 text_err = err.decode()
440 raise OSError(
441 "Subprocess exited due to an error:\n{}".format(text_err))
442
443 if ret_code_fail and fail_on_return_code:
444 text_err = err.decode()
427 445 if not err:
428 446 # maybe get empty stderr, try stdout instead
429 447 # in many cases git reports the errors on stdout too
430 err = ''.join(bg_out)
431 raise EnvironmentError(
432 "Subprocess exited with non 0 ret code:%s: stderr:%s" % (
433 _returncode, err))
448 text_err = out.decode()
449 raise OSError(
450 "Subprocess exited with non 0 ret code:{}: stderr:{}".format(return_code, text_err))
434 451
435 self.process = _p
436 self.output = bg_out
437 self.error = bg_err
438 self.inputstream = inputstream
452 self.stdout = bg_out
453 self.stderr = bg_err
454 self.inputstream = input_stream
455
456 def __str__(self):
457 proc = getattr(self, 'process', 'NO_PROCESS')
458 return f'SubprocessIOChunker: {proc}'
439 459
440 460 def __iter__(self):
441 461 return self
442 462
443 463 def __next__(self):
444 464 # Note: mikhail: We need to be sure that we are checking the return
445 465 # code after the stdout stream is closed. Some processes, e.g. git
446 466 # are doing some magic in between closing stdout and terminating the
447 467 # process and, as a result, we are not getting return code on "slow"
448 468 # systems.
449 469 result = None
450 470 stop_iteration = None
451 471 try:
452 result = next(self.output)
472 result = next(self.stdout)
453 473 except StopIteration as e:
454 474 stop_iteration = e
455 475
456 if self.process.poll() and self._fail_on_return_code:
457 err = '%s' % ''.join(self.error)
458 raise EnvironmentError(
459 "Subprocess exited due to an error:\n" + err)
476 if self.process:
477 return_code = self.process.poll()
478 ret_code_fail = return_code is not None and return_code != 0
479 if ret_code_fail and self._fail_on_return_code:
480 self.stop_streams()
481 err = self.get_stderr()
482 raise OSError(
483 "Subprocess exited (exit_code:{}) due to an error during iteration:\n{}".format(return_code, err))
460 484
461 485 if stop_iteration:
462 486 raise stop_iteration
463 487 return result
464 488
465 def throw(self, type, value=None, traceback=None):
466 if self.output.length or not self.output.done_reading:
467 raise type(value)
489 def throw(self, exc_type, value=None, traceback=None):
490 if self.stdout.length or not self.stdout.done_reading:
491 raise exc_type(value)
468 492
469 493 def close(self):
470 494 if self._closed:
471 495 return
472 self._closed = True
496
473 497 try:
474 498 self.process.terminate()
475 499 except Exception:
476 500 pass
477 501 if self._close_input_fd:
478 502 os.close(self._close_input_fd)
479 503 try:
480 self.output.close()
504 self.stdout.close()
481 505 except Exception:
482 506 pass
483 507 try:
484 self.error.close()
508 self.stderr.close()
485 509 except Exception:
486 510 pass
487 511 try:
488 512 os.close(self.inputstream)
489 513 except Exception:
490 514 pass
491 515
516 self._closed = True
517
518 def stop_streams(self):
519 getattr(self.stdout, 'stop', lambda: None)()
520 getattr(self.stderr, 'stop', lambda: None)()
521
522 def get_stdout(self):
523 if self._stdout:
524 return self._stdout
525 else:
526 return b''.join(self.stdout)
527
528 def get_stderr(self):
529 if self._stderr:
530 return self._stderr
531 else:
532 return b''.join(self.stderr)
533
492 534
493 535 def run_command(arguments, env=None):
494 536 """
495 537 Run the specified command and return the stdout.
496 538
497 539 :param arguments: sequence of program arguments (including the program name)
498 540 :type arguments: list[str]
499 541 """
500 542
501 543 cmd = arguments
502 544 log.debug('Running subprocessio command %s', cmd)
503 545 proc = None
504 546 try:
505 547 _opts = {'shell': False, 'fail_on_stderr': False}
506 548 if env:
507 549 _opts.update({'env': env})
508 550 proc = SubprocessIOChunker(cmd, **_opts)
509 return ''.join(proc), ''.join(proc.error)
510 except (EnvironmentError, OSError) as err:
551 return b''.join(proc), b''.join(proc.stderr)
552 except OSError as err:
511 553 cmd = ' '.join(cmd) # human friendly CMD
512 554 tb_err = ("Couldn't run subprocessio command (%s).\n"
513 555 "Original error was:%s\n" % (cmd, err))
514 556 log.exception(tb_err)
515 557 raise Exception(tb_err)
516 558 finally:
517 559 if proc:
518 560 proc.close()
519 561
@@ -1,159 +1,162 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import inspect
19 19
20 20 import pytest
21 21 import dulwich.errors
22 22 from mock import Mock, patch
23 23
24 24 from vcsserver.remote import git
25 25
26 26 SAMPLE_REFS = {
27 27 'HEAD': 'fd627b9e0dd80b47be81af07c4a98518244ed2f7',
28 28 'refs/tags/v0.1.9': '341d28f0eec5ddf0b6b77871e13c2bbd6bec685c',
29 29 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
30 30 'refs/tags/v0.1.1': 'e6ea6d16e2f26250124a1f4b4fe37a912f9d86a0',
31 31 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
32 32 }
33 33
34 34
35 35 @pytest.fixture
36 36 def git_remote():
37 37 """
38 38 A GitRemote instance with a mock factory.
39 39 """
40 40 factory = Mock()
41 41 remote = git.GitRemote(factory)
42 42 return remote
43 43
44 44
45 45 def test_discover_git_version(git_remote):
46 46 version = git_remote.discover_git_version()
47 47 assert version
48 48
49 49
50 50 class TestGitFetch(object):
51 def setup(self):
51 def setup_method(self):
52 52 self.mock_repo = Mock()
53 53 factory = Mock()
54 54 factory.repo = Mock(return_value=self.mock_repo)
55 55 self.remote_git = git.GitRemote(factory)
56 56
57 57 def test_fetches_all_when_no_commit_ids_specified(self):
58 58 def side_effect(determine_wants, *args, **kwargs):
59 59 determine_wants(SAMPLE_REFS)
60 60
61 61 with patch('dulwich.client.LocalGitClient.fetch') as mock_fetch:
62 62 mock_fetch.side_effect = side_effect
63 63 self.remote_git.pull(wire={}, url='/tmp/', apply_refs=False)
64 64 determine_wants = self.mock_repo.object_store.determine_wants_all
65 65 determine_wants.assert_called_once_with(SAMPLE_REFS)
66 66
67 67 def test_fetches_specified_commits(self):
68 68 selected_refs = {
69 69 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
70 70 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
71 71 }
72 72
73 73 def side_effect(determine_wants, *args, **kwargs):
74 74 result = determine_wants(SAMPLE_REFS)
75 75 assert sorted(result) == sorted(selected_refs.values())
76 76 return result
77 77
78 78 with patch('dulwich.client.LocalGitClient.fetch') as mock_fetch:
79 79 mock_fetch.side_effect = side_effect
80 80 self.remote_git.pull(
81 81 wire={}, url='/tmp/', apply_refs=False,
82 82 refs=selected_refs.keys())
83 83 determine_wants = self.mock_repo.object_store.determine_wants_all
84 84 assert determine_wants.call_count == 0
85 85
86 86 def test_get_remote_refs(self):
87 87 factory = Mock()
88 88 remote_git = git.GitRemote(factory)
89 89 url = 'http://example.com/test/test.git'
90 90 sample_refs = {
91 91 'refs/tags/v0.1.8': '74ebce002c088b8a5ecf40073db09375515ecd68',
92 92 'refs/tags/v0.1.3': '5a3a8fb005554692b16e21dee62bf02667d8dc3e',
93 93 }
94 94
95 with patch('vcsserver.git.Repo', create=False) as mock_repo:
95 with patch('vcsserver.remote.git.Repo', create=False) as mock_repo:
96 96 mock_repo().get_refs.return_value = sample_refs
97 97 remote_refs = remote_git.get_remote_refs(wire={}, url=url)
98 98 mock_repo().get_refs.assert_called_once_with()
99 99 assert remote_refs == sample_refs
100 100
101 101
102 102 class TestReraiseSafeExceptions(object):
103 103
104 104 def test_method_decorated_with_reraise_safe_exceptions(self):
105 105 factory = Mock()
106 106 git_remote = git.GitRemote(factory)
107 107
108 108 def fake_function():
109 109 return None
110 110
111 111 decorator = git.reraise_safe_exceptions(fake_function)
112 112
113 113 methods = inspect.getmembers(git_remote, predicate=inspect.ismethod)
114 114 for method_name, method in methods:
115 115 if not method_name.startswith('_') and method_name not in ['vcsserver_invalidate_cache']:
116 116 assert method.__func__.__code__ == decorator.__code__
117 117
118 118 @pytest.mark.parametrize('side_effect, expected_type', [
119 119 (dulwich.errors.ChecksumMismatch('0000000', 'deadbeef'), 'lookup'),
120 120 (dulwich.errors.NotCommitError('deadbeef'), 'lookup'),
121 121 (dulwich.errors.MissingCommitError('deadbeef'), 'lookup'),
122 122 (dulwich.errors.ObjectMissing('deadbeef'), 'lookup'),
123 123 (dulwich.errors.HangupException(), 'error'),
124 124 (dulwich.errors.UnexpectedCommandError('test-cmd'), 'error'),
125 125 ])
126 126 def test_safe_exceptions_reraised(self, side_effect, expected_type):
127 127 @git.reraise_safe_exceptions
128 128 def fake_method():
129 129 raise side_effect
130 130
131 131 with pytest.raises(Exception) as exc_info:
132 132 fake_method()
133 133 assert type(exc_info.value) == Exception
134 134 assert exc_info.value._vcs_kind == expected_type
135 135
136 136
137 137 class TestDulwichRepoWrapper(object):
138 138 def test_calls_close_on_delete(self):
139 139 isdir_patcher = patch('dulwich.repo.os.path.isdir', return_value=True)
140 with isdir_patcher:
141 repo = git.Repo('/tmp/abcde')
142 with patch.object(git.DulwichRepo, 'close') as close_mock:
143 del repo
144 close_mock.assert_called_once_with()
140 with patch.object(git.Repo, 'close') as close_mock:
141 with isdir_patcher:
142 repo = git.Repo('/tmp/abcde')
143 assert repo is not None
144 repo.__del__()
145 # can't use del repo as in python3 this isn't always calling .__del__()
146
147 close_mock.assert_called_once_with()
145 148
146 149
147 150 class TestGitFactory(object):
148 151 def test_create_repo_returns_dulwich_wrapper(self):
149 152
150 153 with patch('vcsserver.lib.rc_cache.region_meta.dogpile_cache_regions') as mock:
151 154 mock.side_effect = {'repo_objects': ''}
152 155 factory = git.GitFactory()
153 156 wire = {
154 157 'path': '/tmp/abcde'
155 158 }
156 159 isdir_patcher = patch('dulwich.repo.os.path.isdir', return_value=True)
157 160 with isdir_patcher:
158 161 result = factory._create_repo(wire, True)
159 162 assert isinstance(result, git.Repo)
@@ -1,109 +1,108 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import inspect
19 19 import sys
20 20 import traceback
21 21
22 22 import pytest
23 23 from mercurial.error import LookupError
24 24 from mock import Mock, patch
25 25
26 26 from vcsserver import exceptions, hgcompat
27 27 from vcsserver.remote import hg
28 28
29 29
30 30 class TestDiff(object):
31 31 def test_raising_safe_exception_when_lookup_failed(self):
32 32
33 33 factory = Mock()
34 34 hg_remote = hg.HgRemote(factory)
35 35 with patch('mercurial.patch.diff') as diff_mock:
36 diff_mock.side_effect = LookupError(
37 'deadbeef', 'index', 'message')
36 diff_mock.side_effect = LookupError(b'deadbeef', b'index', b'message')
37
38 38 with pytest.raises(Exception) as exc_info:
39 39 hg_remote.diff(
40 40 wire={}, commit_id_1='deadbeef', commit_id_2='deadbee1',
41 41 file_filter=None, opt_git=True, opt_ignorews=True,
42 42 context=3)
43 43 assert type(exc_info.value) == Exception
44 44 assert exc_info.value._vcs_kind == 'lookup'
45 45
46 46
47 47 class TestReraiseSafeExceptions(object):
48 48 def test_method_decorated_with_reraise_safe_exceptions(self):
49 49 factory = Mock()
50 50 hg_remote = hg.HgRemote(factory)
51 51 methods = inspect.getmembers(hg_remote, predicate=inspect.ismethod)
52 52 decorator = hg.reraise_safe_exceptions(None)
53 53 for method_name, method in methods:
54 54 if not method_name.startswith('_') and method_name not in ['vcsserver_invalidate_cache']:
55 55 assert method.__func__.__code__ == decorator.__code__
56 56
57 57 @pytest.mark.parametrize('side_effect, expected_type', [
58 (hgcompat.Abort(), 'abort'),
59 (hgcompat.InterventionRequired(), 'abort'),
58 (hgcompat.Abort('failed-abort'), 'abort'),
59 (hgcompat.InterventionRequired('intervention-required'), 'abort'),
60 60 (hgcompat.RepoLookupError(), 'lookup'),
61 (hgcompat.LookupError('deadbeef', 'index', 'message'), 'lookup'),
61 (hgcompat.LookupError(b'deadbeef', b'index', b'message'), 'lookup'),
62 62 (hgcompat.RepoError(), 'error'),
63 63 (hgcompat.RequirementError(), 'requirement'),
64 64 ])
65 65 def test_safe_exceptions_reraised(self, side_effect, expected_type):
66 66 @hg.reraise_safe_exceptions
67 67 def fake_method():
68 68 raise side_effect
69 69
70 70 with pytest.raises(Exception) as exc_info:
71 71 fake_method()
72 72 assert type(exc_info.value) == Exception
73 73 assert exc_info.value._vcs_kind == expected_type
74 74
75 75 def test_keeps_original_traceback(self):
76 76 @hg.reraise_safe_exceptions
77 77 def fake_method():
78 78 try:
79 raise hgcompat.Abort()
79 raise hgcompat.Abort('test-abort')
80 80 except:
81 self.original_traceback = traceback.format_tb(
82 sys.exc_info()[2])
81 self.original_traceback = traceback.format_tb(sys.exc_info()[2])
83 82 raise
84 83
85 84 try:
86 85 fake_method()
87 86 except Exception:
88 87 new_traceback = traceback.format_tb(sys.exc_info()[2])
89 88
90 89 new_traceback_tail = new_traceback[-len(self.original_traceback):]
91 90 assert new_traceback_tail == self.original_traceback
92 91
93 92 def test_maps_unknow_exceptions_to_unhandled(self):
94 93 @hg.reraise_safe_exceptions
95 94 def stub_method():
96 95 raise ValueError('stub')
97 96
98 97 with pytest.raises(Exception) as exc_info:
99 98 stub_method()
100 99 assert exc_info.value._vcs_kind == 'unhandled'
101 100
102 101 def test_does_not_map_known_exceptions(self):
103 102 @hg.reraise_safe_exceptions
104 103 def stub_method():
105 104 raise exceptions.LookupException()('stub')
106 105
107 106 with pytest.raises(Exception) as exc_info:
108 107 stub_method()
109 108 assert exc_info.value._vcs_kind == 'lookup'
@@ -1,124 +1,119 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import mock
19 19 import pytest
20 20
21 21 from vcsserver import hgcompat, hgpatches
22 22
23 23
24 LARGEFILES_CAPABILITY = 'largefiles=serve'
24 LARGEFILES_CAPABILITY = b'largefiles=serve'
25 25
26 26
27 27 def test_patch_largefiles_capabilities_applies_patch(
28 28 patched_capabilities):
29 29 lfproto = hgcompat.largefiles.proto
30 30 hgpatches.patch_largefiles_capabilities()
31 31 assert lfproto._capabilities.__name__ == '_dynamic_capabilities'
32 32
33 33
34 34 def test_dynamic_capabilities_uses_original_function_if_not_enabled(
35 35 stub_repo, stub_proto, stub_ui, stub_extensions, patched_capabilities,
36 36 orig_capabilities):
37 37 dynamic_capabilities = hgpatches._dynamic_capabilities_wrapper(
38 38 hgcompat.largefiles.proto, stub_extensions)
39 39
40 40 caps = dynamic_capabilities(orig_capabilities, stub_repo, stub_proto)
41 41
42 42 stub_extensions.assert_called_once_with(stub_ui)
43 43 assert LARGEFILES_CAPABILITY not in caps
44 44
45 45
46 46 def test_dynamic_capabilities_ignores_updated_capabilities(
47 47 stub_repo, stub_proto, stub_ui, stub_extensions, patched_capabilities,
48 48 orig_capabilities):
49 49 stub_extensions.return_value = [('largefiles', mock.Mock())]
50 50 dynamic_capabilities = hgpatches._dynamic_capabilities_wrapper(
51 51 hgcompat.largefiles.proto, stub_extensions)
52 52
53 53 # This happens when the extension is loaded for the first time, important
54 54 # to ensure that an updated function is correctly picked up.
55 55 hgcompat.largefiles.proto._capabilities = mock.Mock(
56 56 side_effect=Exception('Must not be called'))
57 57
58 58 dynamic_capabilities(orig_capabilities, stub_repo, stub_proto)
59 59
60 60
61 61 def test_dynamic_capabilities_uses_largefiles_if_enabled(
62 62 stub_repo, stub_proto, stub_ui, stub_extensions, patched_capabilities,
63 63 orig_capabilities):
64 64 stub_extensions.return_value = [('largefiles', mock.Mock())]
65 65
66 66 dynamic_capabilities = hgpatches._dynamic_capabilities_wrapper(
67 67 hgcompat.largefiles.proto, stub_extensions)
68 68
69 69 caps = dynamic_capabilities(orig_capabilities, stub_repo, stub_proto)
70 70
71 71 stub_extensions.assert_called_once_with(stub_ui)
72 72 assert LARGEFILES_CAPABILITY in caps
73 73
74 74
75 def test_hgsubversion_import():
76 from hgsubversion import svnrepo
77 assert svnrepo
78
79
80 75 @pytest.fixture
81 76 def patched_capabilities(request):
82 77 """
83 78 Patch in `capabilitiesorig` and restore both capability functions.
84 79 """
85 80 lfproto = hgcompat.largefiles.proto
86 81 orig_capabilities = lfproto._capabilities
87 82
88 83 @request.addfinalizer
89 84 def restore():
90 85 lfproto._capabilities = orig_capabilities
91 86
92 87
93 88 @pytest.fixture
94 89 def stub_repo(stub_ui):
95 90 repo = mock.Mock()
96 91 repo.ui = stub_ui
97 92 return repo
98 93
99 94
100 95 @pytest.fixture
101 96 def stub_proto(stub_ui):
102 97 proto = mock.Mock()
103 98 proto.ui = stub_ui
104 99 return proto
105 100
106 101
107 102 @pytest.fixture
108 103 def orig_capabilities():
109 104 from mercurial.wireprotov1server import wireprotocaps
110 105
111 106 def _capabilities(repo, proto):
112 107 return wireprotocaps
113 108 return _capabilities
114 109
115 110
116 111 @pytest.fixture
117 112 def stub_ui():
118 113 return hgcompat.ui.ui()
119 114
120 115
121 116 @pytest.fixture
122 117 def stub_extensions():
123 118 extensions = mock.Mock(return_value=tuple())
124 119 return extensions
@@ -1,241 +1,245 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 import contextlib
19 import io
20 18 import threading
19 import msgpack
20
21 21 from http.server import BaseHTTPRequestHandler
22 22 from socketserver import TCPServer
23 23
24 24 import mercurial.ui
25 25 import mock
26 26 import pytest
27 import simplejson as json
28 27
28 from vcsserver.lib.rc_json import json
29 29 from vcsserver import hooks
30 30
31 31
32 32 def get_hg_ui(extras=None):
33 33 """Create a Config object with a valid RC_SCM_DATA entry."""
34 34 extras = extras or {}
35 35 required_extras = {
36 36 'username': '',
37 37 'repository': '',
38 38 'locked_by': '',
39 39 'scm': '',
40 40 'make_lock': '',
41 41 'action': '',
42 42 'ip': '',
43 43 'hooks_uri': 'fake_hooks_uri',
44 44 }
45 45 required_extras.update(extras)
46 46 hg_ui = mercurial.ui.ui()
47 hg_ui.setconfig('rhodecode', 'RC_SCM_DATA', json.dumps(required_extras))
47 hg_ui.setconfig(b'rhodecode', b'RC_SCM_DATA', json.dumps(required_extras))
48 48
49 49 return hg_ui
50 50
51 51
52 52 def test_git_pre_receive_is_disabled():
53 53 extras = {'hooks': ['pull']}
54 54 response = hooks.git_pre_receive(None, None,
55 55 {'RC_SCM_DATA': json.dumps(extras)})
56 56
57 57 assert response == 0
58 58
59 59
60 60 def test_git_post_receive_is_disabled():
61 61 extras = {'hooks': ['pull']}
62 62 response = hooks.git_post_receive(None, '',
63 63 {'RC_SCM_DATA': json.dumps(extras)})
64 64
65 65 assert response == 0
66 66
67 67
68 68 def test_git_post_receive_calls_repo_size():
69 69 extras = {'hooks': ['push', 'repo_size']}
70
70 71 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
71 72 hooks.git_post_receive(
72 73 None, '', {'RC_SCM_DATA': json.dumps(extras)})
73 74 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
74 75 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
75 76 expected_calls = [
76 77 mock.call('repo_size', extras, mock.ANY),
77 78 mock.call('post_push', extras, mock.ANY),
78 79 ]
79 80 assert call_hook_mock.call_args_list == expected_calls
80 81
81 82
82 83 def test_git_post_receive_does_not_call_disabled_repo_size():
83 84 extras = {'hooks': ['push']}
85
84 86 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
85 87 hooks.git_post_receive(
86 88 None, '', {'RC_SCM_DATA': json.dumps(extras)})
87 89 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
88 90 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
89 91 expected_calls = [
90 92 mock.call('post_push', extras, mock.ANY)
91 93 ]
92 94 assert call_hook_mock.call_args_list == expected_calls
93 95
94 96
95 97 def test_repo_size_exception_does_not_affect_git_post_receive():
96 98 extras = {'hooks': ['push', 'repo_size']}
97 99 status = 0
98 100
99 101 def side_effect(name, *args, **kwargs):
100 102 if name == 'repo_size':
101 103 raise Exception('Fake exception')
102 104 else:
103 105 return status
104 106
105 107 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
106 108 call_hook_mock.side_effect = side_effect
107 109 result = hooks.git_post_receive(
108 110 None, '', {'RC_SCM_DATA': json.dumps(extras)})
109 111 assert result == status
110 112
111 113
112 114 def test_git_pre_pull_is_disabled():
113 115 assert hooks.git_pre_pull({'hooks': ['push']}) == hooks.HookResponse(0, '')
114 116
115 117
116 118 def test_git_post_pull_is_disabled():
117 119 assert (
118 120 hooks.git_post_pull({'hooks': ['push']}) == hooks.HookResponse(0, ''))
119 121
120 122
121 123 class TestGetHooksClient(object):
122 124
123 125 def test_returns_http_client_when_protocol_matches(self):
124 126 hooks_uri = 'localhost:8000'
125 127 result = hooks._get_hooks_client({
126 128 'hooks_uri': hooks_uri,
127 129 'hooks_protocol': 'http'
128 130 })
129 131 assert isinstance(result, hooks.HooksHttpClient)
130 132 assert result.hooks_uri == hooks_uri
131 133
132 134 def test_returns_dummy_client_when_hooks_uri_not_specified(self):
133 135 fake_module = mock.Mock()
134 136 import_patcher = mock.patch.object(
135 137 hooks.importlib, 'import_module', return_value=fake_module)
136 138 fake_module_name = 'fake.module'
137 139 with import_patcher as import_mock:
138 140 result = hooks._get_hooks_client(
139 141 {'hooks_module': fake_module_name})
140 142
141 143 import_mock.assert_called_once_with(fake_module_name)
142 144 assert isinstance(result, hooks.HooksDummyClient)
143 145 assert result._hooks_module == fake_module
144 146
145 147
146 148 class TestHooksHttpClient(object):
147 149 def test_init_sets_hooks_uri(self):
148 150 uri = 'localhost:3000'
149 151 client = hooks.HooksHttpClient(uri)
150 152 assert client.hooks_uri == uri
151 153
152 def test_serialize_returns_json_string(self):
154 def test_serialize_returns_serialized_string(self):
153 155 client = hooks.HooksHttpClient('localhost:3000')
154 156 hook_name = 'test'
155 157 extras = {
156 158 'first': 1,
157 159 'second': 'two'
158 160 }
159 result = client._serialize(hook_name, extras)
160 expected_result = json.dumps({
161 hooks_proto, result = client._serialize(hook_name, extras)
162 expected_result = msgpack.packb({
161 163 'method': hook_name,
162 'extras': extras
164 'extras': extras,
163 165 })
166 assert hooks_proto == {'rc-hooks-protocol': 'msgpack.v1'}
164 167 assert result == expected_result
165 168
166 169 def test_call_queries_http_server(self, http_mirror):
167 170 client = hooks.HooksHttpClient(http_mirror.uri)
168 171 hook_name = 'test'
169 172 extras = {
170 173 'first': 1,
171 174 'second': 'two'
172 175 }
173 176 result = client(hook_name, extras)
174 expected_result = {
177 expected_result = msgpack.unpackb(msgpack.packb({
175 178 'method': hook_name,
176 179 'extras': extras
177 }
180 }), raw=False)
178 181 assert result == expected_result
179 182
180 183
181 184 class TestHooksDummyClient(object):
182 185 def test_init_imports_hooks_module(self):
183 186 hooks_module_name = 'rhodecode.fake.module'
184 187 hooks_module = mock.MagicMock()
185 188
186 189 import_patcher = mock.patch.object(
187 190 hooks.importlib, 'import_module', return_value=hooks_module)
188 191 with import_patcher as import_mock:
189 192 client = hooks.HooksDummyClient(hooks_module_name)
190 193 import_mock.assert_called_once_with(hooks_module_name)
191 194 assert client._hooks_module == hooks_module
192 195
193 196 def test_call_returns_hook_result(self):
194 197 hooks_module_name = 'rhodecode.fake.module'
195 198 hooks_module = mock.MagicMock()
196 199 import_patcher = mock.patch.object(
197 200 hooks.importlib, 'import_module', return_value=hooks_module)
198 201 with import_patcher:
199 202 client = hooks.HooksDummyClient(hooks_module_name)
200 203
201 204 result = client('post_push', {})
202 205 hooks_module.Hooks.assert_called_once_with()
203 206 assert result == hooks_module.Hooks().__enter__().post_push()
204 207
205 208
206 209 @pytest.fixture
207 210 def http_mirror(request):
208 211 server = MirrorHttpServer()
209 212 request.addfinalizer(server.stop)
210 213 return server
211 214
212 215
213 216 class MirrorHttpHandler(BaseHTTPRequestHandler):
217
214 218 def do_POST(self):
215 219 length = int(self.headers['Content-Length'])
216 body = self.rfile.read(length).decode('utf-8')
220 body = self.rfile.read(length)
217 221 self.send_response(200)
218 222 self.end_headers()
219 223 self.wfile.write(body)
220 224
221 225
222 226 class MirrorHttpServer(object):
223 227 ip_address = '127.0.0.1'
224 228 port = 0
225 229
226 230 def __init__(self):
227 231 self._daemon = TCPServer((self.ip_address, 0), MirrorHttpHandler)
228 232 _, self.port = self._daemon.server_address
229 233 self._thread = threading.Thread(target=self._daemon.serve_forever)
230 234 self._thread.daemon = True
231 235 self._thread.start()
232 236
233 237 def stop(self):
234 238 self._daemon.shutdown()
235 239 self._thread.join()
236 240 self._daemon = None
237 241 self._thread = None
238 242
239 243 @property
240 244 def uri(self):
241 245 return '{}:{}'.format(self.ip_address, self.port)
@@ -1,42 +1,42 b''
1 1 """
2 2 Tests used to profile the HTTP based implementation.
3 3 """
4 4
5 5 import pytest
6 6 import webtest
7 7
8 8 from vcsserver.http_main import main
9 9
10 10
11 11 @pytest.fixture
12 12 def vcs_app():
13 13 stub_settings = {
14 14 'dev.use_echo_app': 'true',
15 15 'locale': 'en_US.UTF-8',
16 16 }
17 17 stub_global_conf = {
18 18 '__file__': ''
19 19 }
20 20 vcs_app = main(stub_global_conf, **stub_settings)
21 21 app = webtest.TestApp(vcs_app)
22 22 return app
23 23
24 24
25 25 @pytest.fixture(scope='module')
26 26 def data():
27 27 one_kb = 'x' * 1024
28 28 return one_kb * 1024 * 10
29 29
30 30
31 31 def test_http_app_streaming_with_data(data, repeat, vcs_app):
32 32 app = vcs_app
33 for x in range(repeat / 10):
33 for x in range(repeat // 10):
34 34 response = app.post('/stream/git/', params=data)
35 35 assert response.status_code == 200
36 36
37 37
38 38 def test_http_app_streaming_no_data(repeat, vcs_app):
39 39 app = vcs_app
40 for x in range(repeat / 10):
40 for x in range(repeat // 10):
41 41 response = app.post('/stream/git/')
42 42 assert response.status_code == 200
@@ -1,206 +1,205 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import sys
20 20 import stat
21 21 import pytest
22 22 import vcsserver
23 23 import tempfile
24 24 from vcsserver import hook_utils
25 25 from vcsserver.tests.fixture import no_newline_id_generator
26 from vcsserver.utils import AttributeDict
26 from vcsserver.utils import AttributeDict, safe_bytes, safe_str
27 27
28 28
29 29 class TestCheckRhodecodeHook(object):
30 30
31 31 def test_returns_false_when_hook_file_is_wrong_found(self, tmpdir):
32 32 hook = os.path.join(str(tmpdir), 'fake_hook_file.py')
33 33 with open(hook, 'wb') as f:
34 f.write('dummy test')
34 f.write(b'dummy test')
35 35 result = hook_utils.check_rhodecode_hook(hook)
36 36 assert result is False
37 37
38 38 def test_returns_true_when_no_hook_file_found(self, tmpdir):
39 39 hook = os.path.join(str(tmpdir), 'fake_hook_file_not_existing.py')
40 40 result = hook_utils.check_rhodecode_hook(hook)
41 41 assert result
42 42
43 43 @pytest.mark.parametrize("file_content, expected_result", [
44 44 ("RC_HOOK_VER = '3.3.3'\n", True),
45 45 ("RC_HOOK = '3.3.3'\n", False),
46 46 ], ids=no_newline_id_generator)
47 47 def test_signatures(self, file_content, expected_result, tmpdir):
48 48 hook = os.path.join(str(tmpdir), 'fake_hook_file_1.py')
49 49 with open(hook, 'wb') as f:
50 f.write(file_content)
50 f.write(safe_bytes(file_content))
51 51
52 52 result = hook_utils.check_rhodecode_hook(hook)
53 53
54 54 assert result is expected_result
55 55
56 56
57 57 class BaseInstallHooks(object):
58 58 HOOK_FILES = ()
59 59
60 60 def _check_hook_file_mode(self, file_path):
61 61 assert os.path.exists(file_path), 'path %s missing' % file_path
62 62 stat_info = os.stat(file_path)
63 63
64 64 file_mode = stat.S_IMODE(stat_info.st_mode)
65 65 expected_mode = int('755', 8)
66 66 assert expected_mode == file_mode
67 67
68 68 def _check_hook_file_content(self, file_path, executable):
69 69 executable = executable or sys.executable
70 70 with open(file_path, 'rt') as hook_file:
71 71 content = hook_file.read()
72 72
73 73 expected_env = '#!{}'.format(executable)
74 expected_rc_version = "\nRC_HOOK_VER = '{}'\n".format(
75 vcsserver.__version__)
74 expected_rc_version = "\nRC_HOOK_VER = '{}'\n".format(safe_str(vcsserver.__version__))
76 75 assert content.strip().startswith(expected_env)
77 76 assert expected_rc_version in content
78 77
79 78 def _create_fake_hook(self, file_path, content):
80 79 with open(file_path, 'w') as hook_file:
81 80 hook_file.write(content)
82 81
83 82 def create_dummy_repo(self, repo_type):
84 83 tmpdir = tempfile.mkdtemp()
85 84 repo = AttributeDict()
86 85 if repo_type == 'git':
87 86 repo.path = os.path.join(tmpdir, 'test_git_hooks_installation_repo')
88 87 os.makedirs(repo.path)
89 88 os.makedirs(os.path.join(repo.path, 'hooks'))
90 89 repo.bare = True
91 90
92 91 elif repo_type == 'svn':
93 92 repo.path = os.path.join(tmpdir, 'test_svn_hooks_installation_repo')
94 93 os.makedirs(repo.path)
95 94 os.makedirs(os.path.join(repo.path, 'hooks'))
96 95
97 96 return repo
98 97
99 98 def check_hooks(self, repo_path, repo_bare=True):
100 99 for file_name in self.HOOK_FILES:
101 100 if repo_bare:
102 101 file_path = os.path.join(repo_path, 'hooks', file_name)
103 102 else:
104 103 file_path = os.path.join(repo_path, '.git', 'hooks', file_name)
105 104 self._check_hook_file_mode(file_path)
106 105 self._check_hook_file_content(file_path, sys.executable)
107 106
108 107
109 108 class TestInstallGitHooks(BaseInstallHooks):
110 109 HOOK_FILES = ('pre-receive', 'post-receive')
111 110
112 111 def test_hooks_are_installed(self):
113 112 repo = self.create_dummy_repo('git')
114 113 result = hook_utils.install_git_hooks(repo.path, repo.bare)
115 114 assert result
116 115 self.check_hooks(repo.path, repo.bare)
117 116
118 117 def test_hooks_are_replaced(self):
119 118 repo = self.create_dummy_repo('git')
120 119 hooks_path = os.path.join(repo.path, 'hooks')
121 120 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
122 121 self._create_fake_hook(
123 122 file_path, content="RC_HOOK_VER = 'abcde'\n")
124 123
125 124 result = hook_utils.install_git_hooks(repo.path, repo.bare)
126 125 assert result
127 126 self.check_hooks(repo.path, repo.bare)
128 127
129 128 def test_non_rc_hooks_are_not_replaced(self):
130 129 repo = self.create_dummy_repo('git')
131 130 hooks_path = os.path.join(repo.path, 'hooks')
132 131 non_rc_content = 'echo "non rc hook"\n'
133 132 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
134 133 self._create_fake_hook(
135 134 file_path, content=non_rc_content)
136 135
137 136 result = hook_utils.install_git_hooks(repo.path, repo.bare)
138 137 assert result
139 138
140 139 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
141 140 with open(file_path, 'rt') as hook_file:
142 141 content = hook_file.read()
143 142 assert content == non_rc_content
144 143
145 144 def test_non_rc_hooks_are_replaced_with_force_flag(self):
146 145 repo = self.create_dummy_repo('git')
147 146 hooks_path = os.path.join(repo.path, 'hooks')
148 147 non_rc_content = 'echo "non rc hook"\n'
149 148 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
150 149 self._create_fake_hook(
151 150 file_path, content=non_rc_content)
152 151
153 152 result = hook_utils.install_git_hooks(
154 153 repo.path, repo.bare, force_create=True)
155 154 assert result
156 155 self.check_hooks(repo.path, repo.bare)
157 156
158 157
159 158 class TestInstallSvnHooks(BaseInstallHooks):
160 159 HOOK_FILES = ('pre-commit', 'post-commit')
161 160
162 161 def test_hooks_are_installed(self):
163 162 repo = self.create_dummy_repo('svn')
164 163 result = hook_utils.install_svn_hooks(repo.path)
165 164 assert result
166 165 self.check_hooks(repo.path)
167 166
168 167 def test_hooks_are_replaced(self):
169 168 repo = self.create_dummy_repo('svn')
170 169 hooks_path = os.path.join(repo.path, 'hooks')
171 170 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
172 171 self._create_fake_hook(
173 172 file_path, content="RC_HOOK_VER = 'abcde'\n")
174 173
175 174 result = hook_utils.install_svn_hooks(repo.path)
176 175 assert result
177 176 self.check_hooks(repo.path)
178 177
179 178 def test_non_rc_hooks_are_not_replaced(self):
180 179 repo = self.create_dummy_repo('svn')
181 180 hooks_path = os.path.join(repo.path, 'hooks')
182 181 non_rc_content = 'echo "non rc hook"\n'
183 182 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
184 183 self._create_fake_hook(
185 184 file_path, content=non_rc_content)
186 185
187 186 result = hook_utils.install_svn_hooks(repo.path)
188 187 assert result
189 188
190 189 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
191 190 with open(file_path, 'rt') as hook_file:
192 191 content = hook_file.read()
193 192 assert content == non_rc_content
194 193
195 194 def test_non_rc_hooks_are_replaced_with_force_flag(self):
196 195 repo = self.create_dummy_repo('svn')
197 196 hooks_path = os.path.join(repo.path, 'hooks')
198 197 non_rc_content = 'echo "non rc hook"\n'
199 198 for file_path in [os.path.join(hooks_path, f) for f in self.HOOK_FILES]:
200 199 self._create_fake_hook(
201 200 file_path, content=non_rc_content)
202 201
203 202 result = hook_utils.install_svn_hooks(
204 203 repo.path, force_create=True)
205 204 assert result
206 205 self.check_hooks(repo.path, )
@@ -1,57 +1,56 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import mock
19 19 import pytest
20 20
21 21 from vcsserver import http_main
22 22 from vcsserver.base import obfuscate_qs
23 23
24 24
25 25 @mock.patch('vcsserver.http_main.VCS', mock.Mock())
26 26 @mock.patch('vcsserver.hgpatches.patch_largefiles_capabilities')
27 27 def test_applies_largefiles_patch(patch_largefiles_capabilities):
28 28 http_main.main({'__file__': ''})
29 29 patch_largefiles_capabilities.assert_called_once_with()
30 30
31 31
32 32 @mock.patch('vcsserver.http_main.VCS', mock.Mock())
33 33 @mock.patch('vcsserver.http_main.MercurialFactory', None)
34 34 @mock.patch(
35 35 'vcsserver.hgpatches.patch_largefiles_capabilities',
36 36 mock.Mock(side_effect=Exception("Must not be called")))
37 37 def test_applies_largefiles_patch_only_if_mercurial_is_available():
38 38 http_main.main({'__file__': ''})
39 39
40 40
41 41 @pytest.mark.parametrize('given, expected', [
42 42 ('bad', 'bad'),
43 43 ('query&foo=bar', 'query&foo=bar'),
44 44 ('equery&auth_token=bar', 'equery&auth_token=*****'),
45 ('a;b;c;query&foo=bar&auth_token=secret',
46 'a&b&c&query&foo=bar&auth_token=*****'),
45 ('a;b;c;query&foo=bar&auth_token=secret', 'a;b;c;query&foo=bar&auth_token=*****'),
47 46 ('', ''),
48 47 (None, None),
49 48 ('foo=bar', 'foo=bar'),
50 49 ('auth_token=secret', 'auth_token=*****'),
51 50 ('auth_token=secret&api_key=secret2',
52 51 'auth_token=*****&api_key=*****'),
53 52 ('auth_token=secret&api_key=secret2&param=value',
54 53 'auth_token=*****&api_key=*****&param=value'),
55 54 ])
56 55 def test_obfuscate_qs(given, expected):
57 56 assert expected == obfuscate_qs(given)
@@ -1,249 +1,288 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 import more_itertools
19 20
20 21 import dulwich.protocol
21 22 import mock
22 23 import pytest
23 24 import webob
24 25 import webtest
25 26
26 27 from vcsserver import hooks, pygrack
27 28
28 29 # pylint: disable=redefined-outer-name,protected-access
30 from vcsserver.utils import ascii_bytes
29 31
30 32
31 33 @pytest.fixture()
32 34 def pygrack_instance(tmpdir):
33 35 """
34 36 Creates a pygrack app instance.
35 37
36 38 Right now, it does not much helpful regarding the passed directory.
37 39 It just contains the required folders to pass the signature test.
38 40 """
39 41 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
40 42 tmpdir.mkdir(dir_name)
41 43
42 44 return pygrack.GitRepository('repo_name', str(tmpdir), 'git', False, {})
43 45
44 46
45 47 @pytest.fixture()
46 48 def pygrack_app(pygrack_instance):
47 49 """
48 50 Creates a pygrack app wrapped in webtest.TestApp.
49 51 """
50 52 return webtest.TestApp(pygrack_instance)
51 53
52 54
53 55 def test_invalid_service_info_refs_returns_403(pygrack_app):
54 56 response = pygrack_app.get('/info/refs?service=git-upload-packs',
55 57 expect_errors=True)
56 58
57 59 assert response.status_int == 403
58 60
59 61
60 62 def test_invalid_endpoint_returns_403(pygrack_app):
61 63 response = pygrack_app.post('/git-upload-packs', expect_errors=True)
62 64
63 65 assert response.status_int == 403
64 66
65 67
66 68 @pytest.mark.parametrize('sideband', [
67 69 'side-band-64k',
68 70 'side-band',
69 71 'side-band no-progress',
70 72 ])
71 73 def test_pre_pull_hook_fails_with_sideband(pygrack_app, sideband):
72 74 request = ''.join([
73 75 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ',
74 76 'multi_ack %s ofs-delta\n' % sideband,
75 77 '0000',
76 78 '0009done\n',
77 79 ])
78 with mock.patch('vcsserver.hooks.git_pre_pull',
79 return_value=hooks.HookResponse(1, 'foo')):
80 with mock.patch('vcsserver.hooks.git_pre_pull', return_value=hooks.HookResponse(1, 'foo')):
80 81 response = pygrack_app.post(
81 82 '/git-upload-pack', params=request,
82 83 content_type='application/x-git-upload-pack')
83 84
84 85 data = io.BytesIO(response.body)
85 86 proto = dulwich.protocol.Protocol(data.read, None)
86 87 packets = list(proto.read_pkt_seq())
87 88
88 89 expected_packets = [
89 'NAK\n', '\x02foo', '\x02Pre pull hook failed: aborting\n',
90 '\x01' + pygrack.GitRepository.EMPTY_PACK,
90 b'NAK\n', b'\x02foo', b'\x02Pre pull hook failed: aborting\n',
91 b'\x01' + pygrack.GitRepository.EMPTY_PACK,
91 92 ]
92 93 assert packets == expected_packets
93 94
94 95
95 96 def test_pre_pull_hook_fails_no_sideband(pygrack_app):
96 97 request = ''.join([
97 98 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
98 99 'multi_ack ofs-delta\n'
99 100 '0000',
100 101 '0009done\n',
101 102 ])
102 103 with mock.patch('vcsserver.hooks.git_pre_pull',
103 104 return_value=hooks.HookResponse(1, 'foo')):
104 105 response = pygrack_app.post(
105 106 '/git-upload-pack', params=request,
106 107 content_type='application/x-git-upload-pack')
107 108
108 109 assert response.body == pygrack.GitRepository.EMPTY_PACK
109 110
110 111
111 112 def test_pull_has_hook_messages(pygrack_app):
112 113 request = ''.join([
113 114 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
114 115 'multi_ack side-band-64k ofs-delta\n'
115 116 '0000',
116 117 '0009done\n',
117 118 ])
118 119 with mock.patch('vcsserver.hooks.git_pre_pull',
119 120 return_value=hooks.HookResponse(0, 'foo')):
120 121 with mock.patch('vcsserver.hooks.git_post_pull',
121 122 return_value=hooks.HookResponse(1, 'bar')):
122 123 with mock.patch('vcsserver.subprocessio.SubprocessIOChunker',
123 return_value=['0008NAK\n0009subp\n0000']):
124 return_value=more_itertools.always_iterable([b'0008NAK\n0009subp\n0000'])):
124 125 response = pygrack_app.post(
125 126 '/git-upload-pack', params=request,
126 127 content_type='application/x-git-upload-pack')
127 128
128 129 data = io.BytesIO(response.body)
129 130 proto = dulwich.protocol.Protocol(data.read, None)
130 131 packets = list(proto.read_pkt_seq())
131 132
132 assert packets == ['NAK\n', '\x02foo', 'subp\n', '\x02bar']
133 assert packets == [b'NAK\n', b'\x02foo', b'subp\n', b'\x02bar']
133 134
134 135
135 136 def test_get_want_capabilities(pygrack_instance):
136 137 data = io.BytesIO(
137 '0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
138 'multi_ack side-band-64k ofs-delta\n00000009done\n')
138 b'0054want 74730d410fcb6603ace96f1dc55ea6196122532d ' +
139 b'multi_ack side-band-64k ofs-delta\n00000009done\n')
139 140
140 141 request = webob.Request({
141 142 'wsgi.input': data,
142 143 'REQUEST_METHOD': 'POST',
143 144 'webob.is_body_seekable': True
144 145 })
145 146
146 147 capabilities = pygrack_instance._get_want_capabilities(request)
147 148
148 149 assert capabilities == frozenset(
149 ('ofs-delta', 'multi_ack', 'side-band-64k'))
150 (b'ofs-delta', b'multi_ack', b'side-band-64k'))
150 151 assert data.tell() == 0
151 152
152 153
153 154 @pytest.mark.parametrize('data,capabilities,expected', [
154 155 ('foo', [], []),
155 ('', ['side-band-64k'], []),
156 ('', ['side-band'], []),
157 ('foo', ['side-band-64k'], ['0008\x02foo']),
158 ('foo', ['side-band'], ['0008\x02foo']),
159 ('f'*1000, ['side-band-64k'], ['03ed\x02' + 'f' * 1000]),
160 ('f'*1000, ['side-band'], ['03e8\x02' + 'f' * 995, '000a\x02fffff']),
161 ('f'*65520, ['side-band-64k'], ['fff0\x02' + 'f' * 65515, '000a\x02fffff']),
162 ('f'*65520, ['side-band'], ['03e8\x02' + 'f' * 995] * 65 + ['0352\x02' + 'f' * 845]),
156 ('', [pygrack.CAPABILITY_SIDE_BAND_64K], []),
157 ('', [pygrack.CAPABILITY_SIDE_BAND], []),
158 ('foo', [pygrack.CAPABILITY_SIDE_BAND_64K], [b'0008\x02foo']),
159 ('foo', [pygrack.CAPABILITY_SIDE_BAND], [b'0008\x02foo']),
160 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'03ed\x02' + b'f' * 1000]),
161 ('f'*1000, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995, b'000a\x02fffff']),
162 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND_64K], [b'fff0\x02' + b'f' * 65515, b'000a\x02fffff']),
163 ('f'*65520, [pygrack.CAPABILITY_SIDE_BAND], [b'03e8\x02' + b'f' * 995] * 65 + [b'0352\x02' + b'f' * 845]),
163 164 ], ids=[
164 165 'foo-empty',
165 166 'empty-64k', 'empty',
166 167 'foo-64k', 'foo',
167 168 'f-1000-64k', 'f-1000',
168 169 'f-65520-64k', 'f-65520'])
169 170 def test_get_messages(pygrack_instance, data, capabilities, expected):
170 171 messages = pygrack_instance._get_messages(data, capabilities)
171 172
172 173 assert messages == expected
173 174
174 175
175 176 @pytest.mark.parametrize('response,capabilities,pre_pull_messages,post_pull_messages', [
176 177 # Unexpected response
177 ('unexpected_response', ['side-band-64k'], 'foo', 'bar'),
178 ([b'unexpected_response[no_initial_header]'], [pygrack.CAPABILITY_SIDE_BAND_64K], 'foo', 'bar'),
178 179 # No sideband
179 ('no-sideband', [], 'foo', 'bar'),
180 ([b'no-sideband'], [], 'foo', 'bar'),
180 181 # No messages
181 ('no-messages', ['side-band-64k'], '', ''),
182 ([b'no-messages'], [pygrack.CAPABILITY_SIDE_BAND_64K], '', ''),
182 183 ])
183 184 def test_inject_messages_to_response_nothing_to_do(
184 pygrack_instance, response, capabilities, pre_pull_messages,
185 post_pull_messages):
186 new_response = pygrack_instance._inject_messages_to_response(
187 response, capabilities, pre_pull_messages, post_pull_messages)
185 pygrack_instance, response, capabilities, pre_pull_messages, post_pull_messages):
188 186
189 assert new_response == response
187 new_response = pygrack_instance._build_post_pull_response(
188 more_itertools.always_iterable(response), capabilities, pre_pull_messages, post_pull_messages)
189
190 assert list(new_response) == response
190 191
191 192
192 193 @pytest.mark.parametrize('capabilities', [
193 ['side-band'],
194 ['side-band-64k'],
194 [pygrack.CAPABILITY_SIDE_BAND],
195 [pygrack.CAPABILITY_SIDE_BAND_64K],
195 196 ])
196 def test_inject_messages_to_response_single_element(pygrack_instance,
197 capabilities):
198 response = ['0008NAK\n0009subp\n0000']
199 new_response = pygrack_instance._inject_messages_to_response(
200 response, capabilities, 'foo', 'bar')
197 def test_inject_messages_to_response_single_element(pygrack_instance, capabilities):
198 response = [b'0008NAK\n0009subp\n0000']
199 new_response = pygrack_instance._build_post_pull_response(
200 more_itertools.always_iterable(response), capabilities, 'foo', 'bar')
201 201
202 expected_response = [
203 '0008NAK\n', '0008\x02foo', '0009subp\n', '0008\x02bar', '0000']
202 expected_response = b''.join([
203 b'0008NAK\n',
204 b'0008\x02foo',
205 b'0009subp\n',
206 b'0008\x02bar',
207 b'0000'])
204 208
205 assert new_response == expected_response
209 assert b''.join(new_response) == expected_response
206 210
207 211
208 212 @pytest.mark.parametrize('capabilities', [
209 ['side-band'],
210 ['side-band-64k'],
213 [pygrack.CAPABILITY_SIDE_BAND],
214 [pygrack.CAPABILITY_SIDE_BAND_64K],
211 215 ])
212 def test_inject_messages_to_response_multi_element(pygrack_instance,
213 capabilities):
214 response = [
215 '0008NAK\n000asubp1\n', '000asubp2\n', '000asubp3\n', '000asubp4\n0000']
216 new_response = pygrack_instance._inject_messages_to_response(
217 response, capabilities, 'foo', 'bar')
216 def test_inject_messages_to_response_multi_element(pygrack_instance, capabilities):
217 response = more_itertools.always_iterable([
218 b'0008NAK\n000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n0000'
219 ])
220 new_response = pygrack_instance._build_post_pull_response(response, capabilities, 'foo', 'bar')
218 221
219 expected_response = [
220 '0008NAK\n', '0008\x02foo', '000asubp1\n', '000asubp2\n', '000asubp3\n',
221 '000asubp4\n', '0008\x02bar', '0000'
222 ]
222 expected_response = b''.join([
223 b'0008NAK\n',
224 b'0008\x02foo',
225 b'000asubp1\n', b'000asubp2\n', b'000asubp3\n', b'000asubp4\n',
226 b'0008\x02bar',
227 b'0000'
228 ])
223 229
224 assert new_response == expected_response
230 assert b''.join(new_response) == expected_response
225 231
226 232
227 233 def test_build_failed_pre_pull_response_no_sideband(pygrack_instance):
228 234 response = pygrack_instance._build_failed_pre_pull_response([], 'foo')
229 235
230 236 assert response == [pygrack.GitRepository.EMPTY_PACK]
231 237
232 238
233 239 @pytest.mark.parametrize('capabilities', [
234 ['side-band'],
235 ['side-band-64k'],
236 ['side-band-64k', 'no-progress'],
240 [pygrack.CAPABILITY_SIDE_BAND],
241 [pygrack.CAPABILITY_SIDE_BAND_64K],
242 [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'],
237 243 ])
238 244 def test_build_failed_pre_pull_response(pygrack_instance, capabilities):
239 response = pygrack_instance._build_failed_pre_pull_response(
240 capabilities, 'foo')
245 response = pygrack_instance._build_failed_pre_pull_response(capabilities, 'foo')
241 246
242 247 expected_response = [
243 '0008NAK\n', '0008\x02foo', '0024\x02Pre pull hook failed: aborting\n',
244 '%04x\x01%s' % (len(pygrack.GitRepository.EMPTY_PACK) + 5,
245 pygrack.GitRepository.EMPTY_PACK),
246 '0000',
248 b'0008NAK\n', b'0008\x02foo', b'0024\x02Pre pull hook failed: aborting\n',
249 b'%04x\x01%s' % (len(pygrack.GitRepository.EMPTY_PACK) + 5, pygrack.GitRepository.EMPTY_PACK),
250 pygrack.GitRepository.FLUSH_PACKET,
247 251 ]
248 252
249 253 assert response == expected_response
254
255
256 def test_inject_messages_to_response_generator(pygrack_instance):
257
258 def response_generator():
259 response = [
260 # protocol start
261 b'0008NAK\n',
262 ]
263 response += [ascii_bytes(f'000asubp{x}\n') for x in range(1000)]
264 response += [
265 # protocol end
266 pygrack.GitRepository.FLUSH_PACKET
267 ]
268 for elem in response:
269 yield elem
270
271 new_response = pygrack_instance._build_post_pull_response(
272 response_generator(), [pygrack.CAPABILITY_SIDE_BAND_64K, b'no-progress'], 'PRE_PULL_MSG\n', 'POST_PULL_MSG\n')
273
274 assert iter(new_response)
275
276 expected_response = b''.join([
277 # start
278 b'0008NAK\n0012\x02PRE_PULL_MSG\n',
279 ] + [
280 # ... rest
281 ascii_bytes(f'000asubp{x}\n') for x in range(1000)
282 ] + [
283 # final message,
284 b'0013\x02POST_PULL_MSG\n0000',
285
286 ])
287
288 assert b''.join(new_response) == expected_response
@@ -1,86 +1,87 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19
20 20 import mercurial.hg
21 21 import mercurial.ui
22 22 import mercurial.error
23 23 import mock
24 24 import pytest
25 25 import webtest
26 26
27 27 from vcsserver import scm_app
28 from vcsserver.utils import ascii_bytes
28 29
29 30
30 31 def test_hg_does_not_accept_invalid_cmd(tmpdir):
31 repo = mercurial.hg.repository(mercurial.ui.ui(), str(tmpdir), create=True)
32 repo = mercurial.hg.repository(mercurial.ui.ui(), ascii_bytes(str(tmpdir)), create=True)
32 33 app = webtest.TestApp(scm_app.HgWeb(repo))
33 34
34 35 response = app.get('/repo?cmd=invalidcmd', expect_errors=True)
35 36
36 37 assert response.status_int == 400
37 38
38 39
39 40 def test_create_hg_wsgi_app_requirement_error(tmpdir):
40 repo = mercurial.hg.repository(mercurial.ui.ui(), str(tmpdir), create=True)
41 repo = mercurial.hg.repository(mercurial.ui.ui(), ascii_bytes(str(tmpdir)), create=True)
41 42 config = (
42 43 ('paths', 'default', ''),
43 44 )
44 45 with mock.patch('vcsserver.scm_app.HgWeb') as hgweb_mock:
45 46 hgweb_mock.side_effect = mercurial.error.RequirementError()
46 47 with pytest.raises(Exception):
47 48 scm_app.create_hg_wsgi_app(str(tmpdir), repo, config)
48 49
49 50
50 51 def test_git_returns_not_found(tmpdir):
51 52 app = webtest.TestApp(
52 53 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
53 54
54 55 response = app.get('/repo_name/inforefs?service=git-upload-pack',
55 56 expect_errors=True)
56 57
57 58 assert response.status_int == 404
58 59
59 60
60 61 def test_git(tmpdir):
61 62 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
62 63 tmpdir.mkdir(dir_name)
63 64
64 65 app = webtest.TestApp(
65 66 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
66 67
67 68 # We set service to git-upload-packs to trigger a 403
68 69 response = app.get('/repo_name/inforefs?service=git-upload-packs',
69 70 expect_errors=True)
70 71
71 72 assert response.status_int == 403
72 73
73 74
74 75 def test_git_fallbacks_to_git_folder(tmpdir):
75 76 tmpdir.mkdir('.git')
76 77 for dir_name in ('config', 'head', 'info', 'objects', 'refs'):
77 78 tmpdir.mkdir(os.path.join('.git', dir_name))
78 79
79 80 app = webtest.TestApp(
80 81 scm_app.GitHandler(str(tmpdir), 'repo_name', 'git', False, {}))
81 82
82 83 # We set service to git-upload-packs to trigger a 403
83 84 response = app.get('/repo_name/inforefs?service=git-upload-packs',
84 85 expect_errors=True)
85 86
86 87 assert response.status_int == 403
@@ -1,155 +1,155 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21
22 22 import pytest
23 23
24 24 from vcsserver import subprocessio
25 from vcsserver.utils import ascii_bytes
25 26
26 27
27 class KindaFilelike(object): # pragma: no cover
28 class FileLikeObj(object): # pragma: no cover
28 29
29 def __init__(self, data, size):
30 chunks = size / len(data)
30 def __init__(self, data: bytes, size):
31 chunks = size // len(data)
31 32
32 33 self.stream = self._get_stream(data, chunks)
33 34
34 35 def _get_stream(self, data, chunks):
35 36 for x in range(chunks):
36 37 yield data
37 38
38 39 def read(self, n):
39 40
40 buffer_stream = ''
41 buffer_stream = b''
41 42 for chunk in self.stream:
42 43 buffer_stream += chunk
43 44 if len(buffer_stream) >= n:
44 45 break
45 46
46 47 # self.stream = self.bytes[n:]
47 48 return buffer_stream
48 49
49 50
50 51 @pytest.fixture(scope='module')
51 52 def environ():
52 53 """Delete coverage variables, as they make the tests fail."""
53 54 env = dict(os.environ)
54 55 for key in env.keys():
55 56 if key.startswith('COV_CORE_'):
56 57 del env[key]
57 58
58 59 return env
59 60
60 61
61 62 def _get_python_args(script):
62 63 return [sys.executable, '-c', 'import sys; import time; import shutil; ' + script]
63 64
64 65
65 66 def test_raise_exception_on_non_zero_return_code(environ):
66 args = _get_python_args('sys.exit(1)')
67 with pytest.raises(EnvironmentError):
68 list(subprocessio.SubprocessIOChunker(args, shell=False, env=environ))
67 call_args = _get_python_args('raise ValueError("fail")')
68 with pytest.raises(OSError):
69 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
69 70
70 71
71 72 def test_does_not_fail_on_non_zero_return_code(environ):
72 args = _get_python_args('sys.exit(1)')
73 output = ''.join(
74 subprocessio.SubprocessIOChunker(
75 args, shell=False, fail_on_return_code=False, env=environ
76 )
77 )
73 call_args = _get_python_args('sys.stdout.write("hello"); sys.exit(1)')
74 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
75 output = b''.join(proc)
78 76
79 assert output == ''
77 assert output == b'hello'
80 78
81 79
82 80 def test_raise_exception_on_stderr(environ):
83 args = _get_python_args('sys.stderr.write("X"); time.sleep(1);')
84 with pytest.raises(EnvironmentError) as excinfo:
85 list(subprocessio.SubprocessIOChunker(args, shell=False, env=environ))
81 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); time.sleep(1);')
86 82
87 assert 'exited due to an error:\nX' in str(excinfo.value)
83 with pytest.raises(OSError) as excinfo:
84 b''.join(subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ))
85
86 assert 'exited due to an error:\nWRITE_TO_STDERR' in str(excinfo.value)
88 87
89 88
90 89 def test_does_not_fail_on_stderr(environ):
91 args = _get_python_args('sys.stderr.write("X"); time.sleep(1);')
92 output = ''.join(
93 subprocessio.SubprocessIOChunker(
94 args, shell=False, fail_on_stderr=False, env=environ
95 )
96 )
90 call_args = _get_python_args('sys.stderr.write("WRITE_TO_STDERR"); sys.stderr.flush; time.sleep(2);')
91 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_stderr=False, env=environ)
92 output = b''.join(proc)
97 93
98 assert output == ''
94 assert output == b''
99 95
100 96
101 @pytest.mark.parametrize('size', [1, 10 ** 5])
97 @pytest.mark.parametrize('size', [
98 1,
99 10 ** 5
100 ])
102 101 def test_output_with_no_input(size, environ):
103 print((type(environ)))
104 data = 'X'
105 args = _get_python_args('sys.stdout.write("%s" * %d)' % (data, size))
106 output = ''.join(subprocessio.SubprocessIOChunker(args, shell=False, env=environ))
102 call_args = _get_python_args(f'sys.stdout.write("X" * {size});')
103 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, env=environ)
104 output = b''.join(proc)
107 105
108 assert output == data * size
106 assert output == ascii_bytes("X" * size)
109 107
110 108
111 @pytest.mark.parametrize('size', [1, 10 ** 5])
109 @pytest.mark.parametrize('size', [
110 1,
111 10 ** 5
112 ])
112 113 def test_output_with_no_input_does_not_fail(size, environ):
113 data = 'X'
114 args = _get_python_args('sys.stdout.write("%s" * %d); sys.exit(1)' % (data, size))
115 output = ''.join(
116 subprocessio.SubprocessIOChunker(
117 args, shell=False, fail_on_return_code=False, env=environ
118 )
119 )
120 114
121 print(("{} {}".format(len(data * size), len(output))))
122 assert output == data * size
115 call_args = _get_python_args(f'sys.stdout.write("X" * {size}); sys.exit(1)')
116 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, fail_on_return_code=False, env=environ)
117 output = b''.join(proc)
118
119 assert output == ascii_bytes("X" * size)
123 120
124 121
125 @pytest.mark.parametrize('size', [1, 10 ** 5])
122 @pytest.mark.parametrize('size', [
123 1,
124 10 ** 5
125 ])
126 126 def test_output_with_input(size, environ):
127 127 data_len = size
128 inputstream = KindaFilelike('X', size)
128 inputstream = FileLikeObj(b'X', size)
129 129
130 130 # This acts like the cat command.
131 args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 output = ''.join(
133 subprocessio.SubprocessIOChunker(
134 args, shell=False, inputstream=inputstream, env=environ
135 )
131 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
132 # note: in this tests we explicitly don't assign chunker to a variable and let it stream directly
133 output = b''.join(
134 subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
136 135 )
137 136
138 137 assert len(output) == data_len
139 138
140 139
141 @pytest.mark.parametrize('size', [1, 10 ** 5])
140 @pytest.mark.parametrize('size', [
141 1,
142 10 ** 5
143 ])
142 144 def test_output_with_input_skipping_iterator(size, environ):
143 145 data_len = size
144 inputstream = KindaFilelike('X', size)
146 inputstream = FileLikeObj(b'X', size)
145 147
146 148 # This acts like the cat command.
147 args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
149 call_args = _get_python_args('shutil.copyfileobj(sys.stdin, sys.stdout)')
148 150
149 151 # Note: assigning the chunker makes sure that it is not deleted too early
150 chunker = subprocessio.SubprocessIOChunker(
151 args, shell=False, inputstream=inputstream, env=environ
152 )
153 output = ''.join(chunker.output)
152 proc = subprocessio.SubprocessIOChunker(call_args, shell=False, input_stream=inputstream, env=environ)
153 output = b''.join(proc.stdout)
154 154
155 155 assert len(output) == data_len
@@ -1,86 +1,103 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import mock
20 20 import pytest
21 21 import sys
22 22
23 from vcsserver.utils import ascii_bytes
24
23 25
24 26 class MockPopen(object):
25 27 def __init__(self, stderr):
26 self.stdout = io.BytesIO('')
28 self.stdout = io.BytesIO(b'')
27 29 self.stderr = io.BytesIO(stderr)
28 30 self.returncode = 1
29 31
30 32 def wait(self):
31 33 pass
32 34
33 35
34 36 INVALID_CERTIFICATE_STDERR = '\n'.join([
35 37 'svnrdump: E230001: Unable to connect to a repository at URL url',
36 38 'svnrdump: E230001: Server SSL certificate verification failed: issuer is not trusted',
37 39 ])
38 40
39 41
40 42 @pytest.mark.parametrize('stderr,expected_reason', [
41 43 (INVALID_CERTIFICATE_STDERR, 'INVALID_CERTIFICATE'),
42 44 ('svnrdump: E123456', 'UNKNOWN:svnrdump: E123456'),
43 45 ], ids=['invalid-cert-stderr', 'svnrdump-err-123456'])
44 46 @pytest.mark.xfail(sys.platform == "cygwin",
45 47 reason="SVN not packaged for Cygwin")
46 48 def test_import_remote_repository_certificate_error(stderr, expected_reason):
47 49 from vcsserver.remote import svn
48 50 factory = mock.Mock()
49 51 factory.repo = mock.Mock(return_value=mock.Mock())
50 52
51 53 remote = svn.SvnRemote(factory)
52 54 remote.is_path_valid_repository = lambda wire, path: True
53 55
54 56 with mock.patch('subprocess.Popen',
55 return_value=MockPopen(stderr)):
57 return_value=MockPopen(ascii_bytes(stderr))):
56 58 with pytest.raises(Exception) as excinfo:
57 59 remote.import_remote_repository({'path': 'path'}, 'url')
58 60
59 expected_error_args = (
60 'Failed to dump the remote repository from url. Reason:{}'.format(expected_reason),)
61 expected_error_args = 'Failed to dump the remote repository from url. Reason:{}'.format(expected_reason)
61 62
62 assert excinfo.value.args == expected_error_args
63 assert excinfo.value.args[0] == expected_error_args
63 64
64 65
65 66 def test_svn_libraries_can_be_imported():
66 67 import svn.client
67 68 assert svn.client is not None
68 69
69 70
70 71 @pytest.mark.parametrize('example_url, parts', [
71 72 ('http://server.com', (None, None, 'http://server.com')),
72 73 ('http://user@server.com', ('user', None, 'http://user@server.com')),
73 74 ('http://user:pass@server.com', ('user', 'pass', 'http://user:pass@server.com')),
74 75 ('<script>', (None, None, '<script>')),
75 76 ('http://', (None, None, 'http://')),
76 77 ])
77 78 def test_username_password_extraction_from_url(example_url, parts):
78 79 from vcsserver.remote import svn
79 80
80 81 factory = mock.Mock()
81 82 factory.repo = mock.Mock(return_value=mock.Mock())
82 83
83 84 remote = svn.SvnRemote(factory)
84 85 remote.is_path_valid_repository = lambda wire, path: True
85 86
86 87 assert remote.get_url_and_credentials(example_url) == parts
88
89
90 @pytest.mark.parametrize('call_url', [
91 b'https://svn.code.sf.net/p/svnbook/source/trunk/',
92 b'https://marcink@svn.code.sf.net/p/svnbook/source/trunk/',
93 b'https://marcink:qweqwe@svn.code.sf.net/p/svnbook/source/trunk/',
94 ])
95 def test_check_url(call_url):
96 from vcsserver.remote import svn
97 factory = mock.Mock()
98 factory.repo = mock.Mock(return_value=mock.Mock())
99
100 remote = svn.SvnRemote(factory)
101 remote.is_path_valid_repository = lambda wire, path: True
102 assert remote.check_url(call_url)
103
@@ -1,96 +1,98 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import wsgiref.simple_server
19 19 import wsgiref.validate
20 20
21 21 from vcsserver import wsgi_app_caller
22
23
24 # pylint: disable=protected-access,too-many-public-methods
22 from vcsserver.utils import ascii_bytes, safe_str
25 23
26 24
27 25 @wsgiref.validate.validator
28 26 def demo_app(environ, start_response):
29 27 """WSGI app used for testing."""
28
29 input_data = safe_str(environ['wsgi.input'].read(1024))
30
30 31 data = [
31 'Hello World!\n',
32 'input_data=%s\n' % environ['wsgi.input'].read(),
32 f'Hello World!\n',
33 f'input_data={input_data}\n',
33 34 ]
34 35 for key, value in sorted(environ.items()):
35 data.append('%s=%s\n' % (key, value))
36 data.append(f'{key}={value}\n')
36 37
37 38 write = start_response("200 OK", [('Content-Type', 'text/plain')])
38 write('Old school write method\n')
39 write('***********************\n')
40 return data
39 write(b'Old school write method\n')
40 write(b'***********************\n')
41 return list(map(ascii_bytes, data))
41 42
42 43
43 44 BASE_ENVIRON = {
44 45 'REQUEST_METHOD': 'GET',
45 46 'SERVER_NAME': 'localhost',
46 47 'SERVER_PORT': '80',
47 48 'SCRIPT_NAME': '',
48 49 'PATH_INFO': '/',
49 50 'QUERY_STRING': '',
50 51 'foo.var': 'bla',
51 52 }
52 53
53 54
54 55 def test_complete_environ():
55 56 environ = dict(BASE_ENVIRON)
56 data = "data"
57 data = b"data"
57 58 wsgi_app_caller._complete_environ(environ, data)
58 59 wsgiref.validate.check_environ(environ)
59 60
60 assert data == environ['wsgi.input'].read()
61 assert data == environ['wsgi.input'].read(1024)
61 62
62 63
63 64 def test_start_response():
64 65 start_response = wsgi_app_caller._StartResponse()
65 66 status = '200 OK'
66 67 headers = [('Content-Type', 'text/plain')]
67 68 start_response(status, headers)
68 69
69 70 assert status == start_response.status
70 71 assert headers == start_response.headers
71 72
72 73
73 74 def test_start_response_with_error():
74 75 start_response = wsgi_app_caller._StartResponse()
75 76 status = '500 Internal Server Error'
76 77 headers = [('Content-Type', 'text/plain')]
77 78 start_response(status, headers, (None, None, None))
78 79
79 80 assert status == start_response.status
80 81 assert headers == start_response.headers
81 82
82 83
83 84 def test_wsgi_app_caller():
84 caller = wsgi_app_caller.WSGIAppCaller(demo_app)
85 85 environ = dict(BASE_ENVIRON)
86 86 input_data = 'some text'
87
88 caller = wsgi_app_caller.WSGIAppCaller(demo_app)
87 89 responses, status, headers = caller.handle(environ, input_data)
88 response = ''.join(responses)
90 response = b''.join(responses)
89 91
90 92 assert status == '200 OK'
91 93 assert headers == [('Content-Type', 'text/plain')]
92 assert response.startswith(
93 'Old school write method\n***********************\n')
94 assert 'Hello World!\n' in response
95 assert 'foo.var=bla\n' in response
96 assert 'input_data=%s\n' % input_data in response
94 assert response.startswith(b'Old school write method\n***********************\n')
95 assert b'Hello World!\n' in response
96 assert b'foo.var=bla\n' in response
97
98 assert ascii_bytes(f'input_data={input_data}\n') in response
@@ -1,107 +1,106 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import time
19 19 import logging
20 20
21 21 import vcsserver
22 from vcsserver.utils import safe_str
23
22 from vcsserver.utils import safe_str, ascii_str
24 23
25 24 log = logging.getLogger(__name__)
26 25
27 26
28 27 def get_access_path(environ):
29 28 path = environ.get('PATH_INFO')
30 29 return path
31 30
32 31
33 32 def get_user_agent(environ):
34 33 return environ.get('HTTP_USER_AGENT')
35 34
36 35
37 36 def get_vcs_method(environ):
38 37 return environ.get('HTTP_X_RC_METHOD')
39 38
40 39
41 40 def get_vcs_repo(environ):
42 41 return environ.get('HTTP_X_RC_REPO_NAME')
43 42
44 43
45 44 class RequestWrapperTween(object):
46 45 def __init__(self, handler, registry):
47 46 self.handler = handler
48 47 self.registry = registry
49 48
50 49 # one-time configuration code goes here
51 50
52 51 def __call__(self, request):
53 52 start = time.time()
54 53 log.debug('Starting request time measurement')
55 54 response = None
56 55
57 56 ua = get_user_agent(request.environ)
58 57 vcs_method = get_vcs_method(request.environ)
59 58 repo_name = get_vcs_repo(request.environ)
60 59
61 60 try:
62 61 response = self.handler(request)
63 62 finally:
64 63 count = request.request_count()
65 _ver_ = vcsserver.__version__
64 _ver_ = ascii_str(vcsserver.__version__)
66 65 _path = safe_str(get_access_path(request.environ))
67 66 ip = '127.0.0.1'
68 67 match_route = request.matched_route.name if request.matched_route else "NOT_FOUND"
69 68 resp_code = getattr(response, 'status_code', 'UNDEFINED')
70 69
71 70 total = time.time() - start
72 71
73 _view_path = "{}/{}@{}".format(_path, vcs_method, repo_name)
72 _view_path = f"{repo_name}@{_path}/{vcs_method}"
74 73 log.info(
75 74 'Req[%4s] IP: %s %s Request to %s time: %.4fs [%s], VCSServer %s',
76 75 count, ip, request.environ.get('REQUEST_METHOD'),
77 76 _view_path, total, ua, _ver_,
78 77 extra={"time": total, "ver": _ver_, "code": resp_code,
79 78 "path": _path, "view_name": match_route, "user_agent": ua,
80 79 "vcs_method": vcs_method, "repo_name": repo_name}
81 80 )
82 81
83 82 statsd = request.registry.statsd
84 83 if statsd:
85 84 match_route = request.matched_route.name if request.matched_route else _path
86 85 elapsed_time_ms = round(1000.0 * total) # use ms only
87 86 statsd.timing(
88 87 "vcsserver_req_timing.histogram", elapsed_time_ms,
89 88 tags=[
90 89 "view_name:{}".format(match_route),
91 90 "code:{}".format(resp_code)
92 91 ],
93 92 use_decimals=False
94 93 )
95 94 statsd.incr(
96 95 "vcsserver_req_total", tags=[
97 96 "view_name:{}".format(match_route),
98 97 "code:{}".format(resp_code)
99 98 ])
100 99
101 100 return response
102 101
103 102
104 103 def includeme(config):
105 104 config.add_tween(
106 105 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
107 106 )
@@ -1,107 +1,137 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17 import logging
18 18 import hashlib
19 19
20 20 log = logging.getLogger(__name__)
21 21
22 22
23 23 def safe_int(val, default=None):
24 24 """
25 25 Returns int() of val if val is not convertable to int use default
26 26 instead
27 27
28 28 :param val:
29 29 :param default:
30 30 """
31 31
32 32 try:
33 33 val = int(val)
34 34 except (ValueError, TypeError):
35 35 val = default
36 36
37 37 return val
38 38
39 39
40 40 def safe_str(str_, to_encoding=None) -> str:
41 41 """
42 42 safe str function. Does few trick to turn unicode_ into string
43 43
44 44 :param str_: str to encode
45 45 :param to_encoding: encode to this type UTF8 default
46 46 :rtype: str
47 47 :returns: str object
48 48 """
49 49 if isinstance(str_, str):
50 50 return str_
51 51
52 52 # if it's bytes cast to str
53 53 if not isinstance(str_, bytes):
54 54 return str(str_)
55 55
56 56 to_encoding = to_encoding or ['utf8']
57 57 if not isinstance(to_encoding, (list, tuple)):
58 58 to_encoding = [to_encoding]
59 59
60 60 for enc in to_encoding:
61 61 try:
62 62 return str(str_, enc)
63 63 except UnicodeDecodeError:
64 64 pass
65 65
66 66 return str(str_, to_encoding[0], 'replace')
67 67
68 68
69 69 def safe_bytes(str_, from_encoding=None) -> bytes:
70 70 """
71 71 safe bytes function. Does few trick to turn str_ into bytes string:
72 72
73 73 :param str_: string to decode
74 74 :param from_encoding: encode from this type UTF8 default
75 75 :rtype: unicode
76 76 :returns: unicode object
77 77 """
78 78 if isinstance(str_, bytes):
79 79 return str_
80 80
81 81 if not isinstance(str_, str):
82 82 raise ValueError('safe_bytes cannot convert other types than str: got: {}'.format(type(str_)))
83 83
84 84 from_encoding = from_encoding or ['utf8']
85 85 if not isinstance(from_encoding, (list, tuple)):
86 86 from_encoding = [from_encoding]
87 87
88 88 for enc in from_encoding:
89 89 try:
90 90 return str_.encode(enc)
91 91 except UnicodeDecodeError:
92 92 pass
93 93
94 return unicode(str_, from_encoding[0], 'replace')
94 return str_.encode(from_encoding[0], 'replace')
95
96
97 def ascii_bytes(str_, allow_bytes=False) -> bytes:
98 """
99 Simple conversion from str to bytes, with assumption that str_ is pure ASCII.
100 Fails with UnicodeError on invalid input.
101 This should be used where encoding and "safe" ambiguity should be avoided.
102 Where strings already have been encoded in other ways but still are unicode
103 string - for example to hex, base64, json, urlencoding, or are known to be
104 identifiers.
105 """
106 if allow_bytes and isinstance(str_, bytes):
107 return str_
108
109 if not isinstance(str_, str):
110 raise ValueError('ascii_bytes cannot convert other types than str: got: {}'.format(type(str_)))
111 return str_.encode('ascii')
112
113
114 def ascii_str(str_):
115 """
116 Simple conversion from bytes to str, with assumption that str_ is pure ASCII.
117 Fails with UnicodeError on invalid input.
118 This should be used where encoding and "safe" ambiguity should be avoided.
119 Where strings are encoded but also in other ways are known to be ASCII, and
120 where a unicode string is wanted without caring about encoding. For example
121 to hex, base64, urlencoding, or are known to be identifiers.
122 """
123
124 if not isinstance(str_, bytes):
125 raise ValueError('ascii_str cannot convert other types than bytes: got: {}'.format(type(str_)))
126 return str_.decode('ascii')
95 127
96 128
97 129 class AttributeDict(dict):
98 130 def __getattr__(self, attr):
99 131 return self.get(attr, None)
100 132 __setattr__ = dict.__setitem__
101 133 __delattr__ = dict.__delitem__
102 134
103 135
104 136 def sha1(val):
105 137 return hashlib.sha1(val).hexdigest()
106
107
@@ -1,46 +1,47 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 from vcsserver.lib import rc_cache
19 19
20
20 21 class RemoteBase(object):
21 22 EMPTY_COMMIT = '0' * 40
22 23
23 24 def _region(self, wire):
24 25 cache_repo_id = wire.get('cache_repo_id', '')
25 26 cache_namespace_uid = 'cache_repo.{}'.format(cache_repo_id)
26 27 return rc_cache.get_or_create_region('repo_object', cache_namespace_uid)
27 28
28 29 def _cache_on(self, wire):
29 30 context = wire.get('context', '')
30 31 context_uid = '{}'.format(context)
31 32 repo_id = wire.get('repo_id', '')
32 33 cache = wire.get('cache', True)
33 34 cache_on = context and cache
34 35 return cache_on, context_uid, repo_id
35 36
36 37 def vcsserver_invalidate_cache(self, wire, delete):
37 38 from vcsserver.lib import rc_cache
38 39 repo_id = wire.get('repo_id', '')
39 40 cache_repo_id = wire.get('cache_repo_id', '')
40 41 cache_namespace_uid = 'cache_repo.{}'.format(cache_repo_id)
41 42
42 43 if delete:
43 44 rc_cache.clear_cache_namespace(
44 45 'repo_object', cache_namespace_uid, invalidate=True)
45 46
46 47 return {'invalidated': {'repo_id': repo_id, 'delete': delete}}
@@ -1,116 +1,116 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 """Extract the responses of a WSGI app."""
19 19
20 20 __all__ = ('WSGIAppCaller',)
21 21
22 22 import io
23 23 import logging
24 24 import os
25 25
26 from vcsserver.utils import ascii_bytes
26 27
27 28 log = logging.getLogger(__name__)
28 29
29 30 DEV_NULL = open(os.devnull)
30 31
31 32
32 def _complete_environ(environ, input_data):
33 def _complete_environ(environ, input_data: bytes):
33 34 """Update the missing wsgi.* variables of a WSGI environment.
34 35
35 36 :param environ: WSGI environment to update
36 37 :type environ: dict
37 38 :param input_data: data to be read by the app
38 :type input_data: str
39 :type input_data: bytes
39 40 """
40 41 environ.update({
41 42 'wsgi.version': (1, 0),
42 43 'wsgi.url_scheme': 'http',
43 44 'wsgi.multithread': True,
44 45 'wsgi.multiprocess': True,
45 46 'wsgi.run_once': False,
46 47 'wsgi.input': io.BytesIO(input_data),
47 48 'wsgi.errors': DEV_NULL,
48 49 })
49 50
50 51
51 52 # pylint: disable=too-few-public-methods
52 53 class _StartResponse(object):
53 54 """Save the arguments of a start_response call."""
54 55
55 56 __slots__ = ['status', 'headers', 'content']
56 57
57 58 def __init__(self):
58 59 self.status = None
59 60 self.headers = None
60 61 self.content = []
61 62
62 63 def __call__(self, status, headers, exc_info=None):
63 64 # TODO(skreft): do something meaningful with the exc_info
64 65 exc_info = None # avoid dangling circular reference
65 66 self.status = status
66 67 self.headers = headers
67 68
68 69 return self.write
69 70
70 71 def write(self, content):
71 72 """Write method returning when calling this object.
72 73
73 74 All the data written is then available in content.
74 75 """
75 76 self.content.append(content)
76 77
77 78
78 79 class WSGIAppCaller(object):
79 80 """Calls a WSGI app."""
80 81
81 82 def __init__(self, app):
82 83 """
83 84 :param app: WSGI app to call
84 85 """
85 86 self.app = app
86 87
87 88 def handle(self, environ, input_data):
88 89 """Process a request with the WSGI app.
89 90
90 91 The returned data of the app is fully consumed into a list.
91 92
92 93 :param environ: WSGI environment to update
93 94 :type environ: dict
94 95 :param input_data: data to be read by the app
95 :type input_data: str
96 :type input_data: str/bytes
96 97
97 98 :returns: a tuple with the contents, status and headers
98 99 :rtype: (list<str>, str, list<(str, str)>)
99 100 """
100 _complete_environ(environ, input_data)
101 _complete_environ(environ, ascii_bytes(input_data, allow_bytes=True))
101 102 start_response = _StartResponse()
102 103 log.debug("Calling wrapped WSGI application")
103 104 responses = self.app(environ, start_response)
104 105 responses_list = list(responses)
105 106 existing_responses = start_response.content
106 107 if existing_responses:
107 log.debug(
108 "Adding returned response to response written via write()")
108 log.debug("Adding returned response to response written via write()")
109 109 existing_responses.extend(responses_list)
110 110 responses_list = existing_responses
111 111 if hasattr(responses, 'close'):
112 112 log.debug("Closing iterator from WSGI application")
113 113 responses.close()
114 114
115 115 log.debug("Handling of WSGI request done, returning response")
116 116 return responses_list, start_response.status, start_response.headers
General Comments 0
You need to be logged in to leave comments. Login now