##// END OF EJS Templates
vcsserver: modernize code for python3
super-admin -
r1130:d68a72e0 python3
parent child Browse files
Show More
@@ -1,168 +1,168 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 import os
19 import os
20 import textwrap
20 import textwrap
21 import string
21 import string
22 import functools
22 import functools
23 import logging
23 import logging
24 import tempfile
24 import tempfile
25 import logging.config
25 import logging.config
26
26
27 from vcsserver.type_utils import str2bool, aslist
27 from vcsserver.type_utils import str2bool, aslist
28
28
29 log = logging.getLogger(__name__)
29 log = logging.getLogger(__name__)
30
30
31 # skip keys, that are set here, so we don't double process those
31 # skip keys, that are set here, so we don't double process those
32 set_keys = {
32 set_keys = {
33 '__file__': ''
33 '__file__': ''
34 }
34 }
35
35
36
36
37 class SettingsMaker(object):
37 class SettingsMaker(object):
38
38
39 def __init__(self, app_settings):
39 def __init__(self, app_settings):
40 self.settings = app_settings
40 self.settings = app_settings
41
41
42 @classmethod
42 @classmethod
43 def _bool_func(cls, input_val):
43 def _bool_func(cls, input_val):
44 if isinstance(input_val, bytes):
44 if isinstance(input_val, bytes):
45 # decode to str
45 # decode to str
46 input_val = input_val.decode('utf8')
46 input_val = input_val.decode('utf8')
47 return str2bool(input_val)
47 return str2bool(input_val)
48
48
49 @classmethod
49 @classmethod
50 def _int_func(cls, input_val):
50 def _int_func(cls, input_val):
51 return int(input_val)
51 return int(input_val)
52
52
53 @classmethod
53 @classmethod
54 def _list_func(cls, input_val, sep=','):
54 def _list_func(cls, input_val, sep=','):
55 return aslist(input_val, sep=sep)
55 return aslist(input_val, sep=sep)
56
56
57 @classmethod
57 @classmethod
58 def _string_func(cls, input_val, lower=True):
58 def _string_func(cls, input_val, lower=True):
59 if lower:
59 if lower:
60 input_val = input_val.lower()
60 input_val = input_val.lower()
61 return input_val
61 return input_val
62
62
63 @classmethod
63 @classmethod
64 def _float_func(cls, input_val):
64 def _float_func(cls, input_val):
65 return float(input_val)
65 return float(input_val)
66
66
67 @classmethod
67 @classmethod
68 def _dir_func(cls, input_val, ensure_dir=False, mode=0o755):
68 def _dir_func(cls, input_val, ensure_dir=False, mode=0o755):
69
69
70 # ensure we have our dir created
70 # ensure we have our dir created
71 if not os.path.isdir(input_val) and ensure_dir:
71 if not os.path.isdir(input_val) and ensure_dir:
72 os.makedirs(input_val, mode=mode, exist_ok=True)
72 os.makedirs(input_val, mode=mode, exist_ok=True)
73
73
74 if not os.path.isdir(input_val):
74 if not os.path.isdir(input_val):
75 raise Exception('Dir at {} does not exist'.format(input_val))
75 raise Exception(f'Dir at {input_val} does not exist')
76 return input_val
76 return input_val
77
77
78 @classmethod
78 @classmethod
79 def _file_path_func(cls, input_val, ensure_dir=False, mode=0o755):
79 def _file_path_func(cls, input_val, ensure_dir=False, mode=0o755):
80 dirname = os.path.dirname(input_val)
80 dirname = os.path.dirname(input_val)
81 cls._dir_func(dirname, ensure_dir=ensure_dir)
81 cls._dir_func(dirname, ensure_dir=ensure_dir)
82 return input_val
82 return input_val
83
83
84 @classmethod
84 @classmethod
85 def _key_transformator(cls, key):
85 def _key_transformator(cls, key):
86 return "{}_{}".format('RC'.upper(), key.upper().replace('.', '_').replace('-', '_'))
86 return "{}_{}".format('RC'.upper(), key.upper().replace('.', '_').replace('-', '_'))
87
87
88 def maybe_env_key(self, key):
88 def maybe_env_key(self, key):
89 # now maybe we have this KEY in env, search and use the value with higher priority.
89 # now maybe we have this KEY in env, search and use the value with higher priority.
90 transformed_key = self._key_transformator(key)
90 transformed_key = self._key_transformator(key)
91 envvar_value = os.environ.get(transformed_key)
91 envvar_value = os.environ.get(transformed_key)
92 if envvar_value:
92 if envvar_value:
93 log.debug('using `%s` key instead of `%s` key for config', transformed_key, key)
93 log.debug('using `%s` key instead of `%s` key for config', transformed_key, key)
94
94
95 return envvar_value
95 return envvar_value
96
96
97 def env_expand(self):
97 def env_expand(self):
98 replaced = {}
98 replaced = {}
99 for k, v in self.settings.items():
99 for k, v in self.settings.items():
100 if k not in set_keys:
100 if k not in set_keys:
101 envvar_value = self.maybe_env_key(k)
101 envvar_value = self.maybe_env_key(k)
102 if envvar_value:
102 if envvar_value:
103 replaced[k] = envvar_value
103 replaced[k] = envvar_value
104 set_keys[k] = envvar_value
104 set_keys[k] = envvar_value
105
105
106 # replace ALL keys updated
106 # replace ALL keys updated
107 self.settings.update(replaced)
107 self.settings.update(replaced)
108
108
109 def enable_logging(self, logging_conf=None, level='INFO', formatter='generic'):
109 def enable_logging(self, logging_conf=None, level='INFO', formatter='generic'):
110 """
110 """
111 Helper to enable debug on running instance
111 Helper to enable debug on running instance
112 :return:
112 :return:
113 """
113 """
114
114
115 if not str2bool(self.settings.get('logging.autoconfigure')):
115 if not str2bool(self.settings.get('logging.autoconfigure')):
116 log.info('logging configuration based on main .ini file')
116 log.info('logging configuration based on main .ini file')
117 return
117 return
118
118
119 if logging_conf is None:
119 if logging_conf is None:
120 logging_conf = self.settings.get('logging.logging_conf_file') or ''
120 logging_conf = self.settings.get('logging.logging_conf_file') or ''
121
121
122 if not os.path.isfile(logging_conf):
122 if not os.path.isfile(logging_conf):
123 log.error('Unable to setup logging based on %s, '
123 log.error('Unable to setup logging based on %s, '
124 'file does not exist.... specify path using logging.logging_conf_file= config setting. ', logging_conf)
124 'file does not exist.... specify path using logging.logging_conf_file= config setting. ', logging_conf)
125 return
125 return
126
126
127 with open(logging_conf, 'rt') as f:
127 with open(logging_conf, 'rt') as f:
128 ini_template = textwrap.dedent(f.read())
128 ini_template = textwrap.dedent(f.read())
129 ini_template = string.Template(ini_template).safe_substitute(
129 ini_template = string.Template(ini_template).safe_substitute(
130 RC_LOGGING_LEVEL=os.environ.get('RC_LOGGING_LEVEL', '') or level,
130 RC_LOGGING_LEVEL=os.environ.get('RC_LOGGING_LEVEL', '') or level,
131 RC_LOGGING_FORMATTER=os.environ.get('RC_LOGGING_FORMATTER', '') or formatter
131 RC_LOGGING_FORMATTER=os.environ.get('RC_LOGGING_FORMATTER', '') or formatter
132 )
132 )
133
133
134 with tempfile.NamedTemporaryFile(prefix='rc_logging_', suffix='.ini', delete=False) as f:
134 with tempfile.NamedTemporaryFile(prefix='rc_logging_', suffix='.ini', delete=False) as f:
135 log.info('Saved Temporary LOGGING config at %s', f.name)
135 log.info('Saved Temporary LOGGING config at %s', f.name)
136 f.write(ini_template)
136 f.write(ini_template)
137
137
138 logging.config.fileConfig(f.name)
138 logging.config.fileConfig(f.name)
139 os.remove(f.name)
139 os.remove(f.name)
140
140
141 def make_setting(self, key, default, lower=False, default_when_empty=False, parser=None):
141 def make_setting(self, key, default, lower=False, default_when_empty=False, parser=None):
142 input_val = self.settings.get(key, default)
142 input_val = self.settings.get(key, default)
143
143
144 if default_when_empty and not input_val:
144 if default_when_empty and not input_val:
145 # use default value when value is set in the config but it is empty
145 # use default value when value is set in the config but it is empty
146 input_val = default
146 input_val = default
147
147
148 parser_func = {
148 parser_func = {
149 'bool': self._bool_func,
149 'bool': self._bool_func,
150 'int': self._int_func,
150 'int': self._int_func,
151 'list': self._list_func,
151 'list': self._list_func,
152 'list:newline': functools.partial(self._list_func, sep='/n'),
152 'list:newline': functools.partial(self._list_func, sep='/n'),
153 'list:spacesep': functools.partial(self._list_func, sep=' '),
153 'list:spacesep': functools.partial(self._list_func, sep=' '),
154 'string': functools.partial(self._string_func, lower=lower),
154 'string': functools.partial(self._string_func, lower=lower),
155 'dir': self._dir_func,
155 'dir': self._dir_func,
156 'dir:ensured': functools.partial(self._dir_func, ensure_dir=True),
156 'dir:ensured': functools.partial(self._dir_func, ensure_dir=True),
157 'file': self._file_path_func,
157 'file': self._file_path_func,
158 'file:ensured': functools.partial(self._file_path_func, ensure_dir=True),
158 'file:ensured': functools.partial(self._file_path_func, ensure_dir=True),
159 None: lambda i: i
159 None: lambda i: i
160 }[parser]
160 }[parser]
161
161
162 envvar_value = self.maybe_env_key(key)
162 envvar_value = self.maybe_env_key(key)
163 if envvar_value:
163 if envvar_value:
164 input_val = envvar_value
164 input_val = envvar_value
165 set_keys[key] = input_val
165 set_keys[key] = input_val
166
166
167 self.settings[key] = parser_func(input_val)
167 self.settings[key] = parser_func(input_val)
168 return self.settings[key]
168 return self.settings[key]
@@ -1,292 +1,292 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import re
18 import re
19 import logging
19 import logging
20 from wsgiref.util import FileWrapper
20 from wsgiref.util import FileWrapper
21
21
22 from pyramid.config import Configurator
22 from pyramid.config import Configurator
23 from pyramid.response import Response, FileIter
23 from pyramid.response import Response, FileIter
24 from pyramid.httpexceptions import (
24 from pyramid.httpexceptions import (
25 HTTPBadRequest, HTTPNotImplemented, HTTPNotFound, HTTPForbidden,
25 HTTPBadRequest, HTTPNotImplemented, HTTPNotFound, HTTPForbidden,
26 HTTPUnprocessableEntity)
26 HTTPUnprocessableEntity)
27
27
28 from vcsserver.lib.rc_json import json
28 from vcsserver.lib.rc_json import json
29 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
29 from vcsserver.git_lfs.lib import OidHandler, LFSOidStore
30 from vcsserver.git_lfs.utils import safe_result, get_cython_compat_decorator
30 from vcsserver.git_lfs.utils import safe_result, get_cython_compat_decorator
31 from vcsserver.str_utils import safe_int
31 from vcsserver.str_utils import safe_int
32
32
33 log = logging.getLogger(__name__)
33 log = logging.getLogger(__name__)
34
34
35
35
36 GIT_LFS_CONTENT_TYPE = 'application/vnd.git-lfs' #+json ?
36 GIT_LFS_CONTENT_TYPE = 'application/vnd.git-lfs' #+json ?
37 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
37 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
38
38
39
39
40 def write_response_error(http_exception, text=None):
40 def write_response_error(http_exception, text=None):
41 content_type = GIT_LFS_CONTENT_TYPE + '+json'
41 content_type = GIT_LFS_CONTENT_TYPE + '+json'
42 _exception = http_exception(content_type=content_type)
42 _exception = http_exception(content_type=content_type)
43 _exception.content_type = content_type
43 _exception.content_type = content_type
44 if text:
44 if text:
45 _exception.body = json.dumps({'message': text})
45 _exception.body = json.dumps({'message': text})
46 log.debug('LFS: writing response of type %s to client with text:%s',
46 log.debug('LFS: writing response of type %s to client with text:%s',
47 http_exception, text)
47 http_exception, text)
48 return _exception
48 return _exception
49
49
50
50
51 class AuthHeaderRequired(object):
51 class AuthHeaderRequired(object):
52 """
52 """
53 Decorator to check if request has proper auth-header
53 Decorator to check if request has proper auth-header
54 """
54 """
55
55
56 def __call__(self, func):
56 def __call__(self, func):
57 return get_cython_compat_decorator(self.__wrapper, func)
57 return get_cython_compat_decorator(self.__wrapper, func)
58
58
59 def __wrapper(self, func, *fargs, **fkwargs):
59 def __wrapper(self, func, *fargs, **fkwargs):
60 request = fargs[1]
60 request = fargs[1]
61 auth = request.authorization
61 auth = request.authorization
62 if not auth:
62 if not auth:
63 return write_response_error(HTTPForbidden)
63 return write_response_error(HTTPForbidden)
64 return func(*fargs[1:], **fkwargs)
64 return func(*fargs[1:], **fkwargs)
65
65
66
66
67 # views
67 # views
68
68
69 def lfs_objects(request):
69 def lfs_objects(request):
70 # indicate not supported, V1 API
70 # indicate not supported, V1 API
71 log.warning('LFS: v1 api not supported, reporting it back to client')
71 log.warning('LFS: v1 api not supported, reporting it back to client')
72 return write_response_error(HTTPNotImplemented, 'LFS: v1 api not supported')
72 return write_response_error(HTTPNotImplemented, 'LFS: v1 api not supported')
73
73
74
74
75 @AuthHeaderRequired()
75 @AuthHeaderRequired()
76 def lfs_objects_batch(request):
76 def lfs_objects_batch(request):
77 """
77 """
78 The client sends the following information to the Batch endpoint to transfer some objects:
78 The client sends the following information to the Batch endpoint to transfer some objects:
79
79
80 operation - Should be download or upload.
80 operation - Should be download or upload.
81 transfers - An optional Array of String identifiers for transfer
81 transfers - An optional Array of String identifiers for transfer
82 adapters that the client has configured. If omitted, the basic
82 adapters that the client has configured. If omitted, the basic
83 transfer adapter MUST be assumed by the server.
83 transfer adapter MUST be assumed by the server.
84 objects - An Array of objects to download.
84 objects - An Array of objects to download.
85 oid - String OID of the LFS object.
85 oid - String OID of the LFS object.
86 size - Integer byte size of the LFS object. Must be at least zero.
86 size - Integer byte size of the LFS object. Must be at least zero.
87 """
87 """
88 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
88 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
89 auth = request.authorization
89 auth = request.authorization
90 repo = request.matchdict.get('repo')
90 repo = request.matchdict.get('repo')
91 data = request.json
91 data = request.json
92 operation = data.get('operation')
92 operation = data.get('operation')
93 http_scheme = request.registry.git_lfs_http_scheme
93 http_scheme = request.registry.git_lfs_http_scheme
94
94
95 if operation not in ('download', 'upload'):
95 if operation not in ('download', 'upload'):
96 log.debug('LFS: unsupported operation:%s', operation)
96 log.debug('LFS: unsupported operation:%s', operation)
97 return write_response_error(
97 return write_response_error(
98 HTTPBadRequest, 'unsupported operation mode: `%s`' % operation)
98 HTTPBadRequest, 'unsupported operation mode: `%s`' % operation)
99
99
100 if 'objects' not in data:
100 if 'objects' not in data:
101 log.debug('LFS: missing objects data')
101 log.debug('LFS: missing objects data')
102 return write_response_error(
102 return write_response_error(
103 HTTPBadRequest, 'missing objects data')
103 HTTPBadRequest, 'missing objects data')
104
104
105 log.debug('LFS: handling operation of type: %s', operation)
105 log.debug('LFS: handling operation of type: %s', operation)
106
106
107 objects = []
107 objects = []
108 for o in data['objects']:
108 for o in data['objects']:
109 try:
109 try:
110 oid = o['oid']
110 oid = o['oid']
111 obj_size = o['size']
111 obj_size = o['size']
112 except KeyError:
112 except KeyError:
113 log.exception('LFS, failed to extract data')
113 log.exception('LFS, failed to extract data')
114 return write_response_error(
114 return write_response_error(
115 HTTPBadRequest, 'unsupported data in objects')
115 HTTPBadRequest, 'unsupported data in objects')
116
116
117 obj_data = {'oid': oid}
117 obj_data = {'oid': oid}
118
118
119 obj_href = request.route_url('lfs_objects_oid', repo=repo, oid=oid,
119 obj_href = request.route_url('lfs_objects_oid', repo=repo, oid=oid,
120 _scheme=http_scheme)
120 _scheme=http_scheme)
121 obj_verify_href = request.route_url('lfs_objects_verify', repo=repo,
121 obj_verify_href = request.route_url('lfs_objects_verify', repo=repo,
122 _scheme=http_scheme)
122 _scheme=http_scheme)
123 store = LFSOidStore(
123 store = LFSOidStore(
124 oid, repo, store_location=request.registry.git_lfs_store_path)
124 oid, repo, store_location=request.registry.git_lfs_store_path)
125 handler = OidHandler(
125 handler = OidHandler(
126 store, repo, auth, oid, obj_size, obj_data,
126 store, repo, auth, oid, obj_size, obj_data,
127 obj_href, obj_verify_href)
127 obj_href, obj_verify_href)
128
128
129 # this verifies also OIDs
129 # this verifies also OIDs
130 actions, errors = handler.exec_operation(operation)
130 actions, errors = handler.exec_operation(operation)
131 if errors:
131 if errors:
132 log.warning('LFS: got following errors: %s', errors)
132 log.warning('LFS: got following errors: %s', errors)
133 obj_data['errors'] = errors
133 obj_data['errors'] = errors
134
134
135 if actions:
135 if actions:
136 obj_data['actions'] = actions
136 obj_data['actions'] = actions
137
137
138 obj_data['size'] = obj_size
138 obj_data['size'] = obj_size
139 obj_data['authenticated'] = True
139 obj_data['authenticated'] = True
140 objects.append(obj_data)
140 objects.append(obj_data)
141
141
142 result = {'objects': objects, 'transfer': 'basic'}
142 result = {'objects': objects, 'transfer': 'basic'}
143 log.debug('LFS Response %s', safe_result(result))
143 log.debug('LFS Response %s', safe_result(result))
144
144
145 return result
145 return result
146
146
147
147
148 def lfs_objects_oid_upload(request):
148 def lfs_objects_oid_upload(request):
149 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
149 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
150 repo = request.matchdict.get('repo')
150 repo = request.matchdict.get('repo')
151 oid = request.matchdict.get('oid')
151 oid = request.matchdict.get('oid')
152 store = LFSOidStore(
152 store = LFSOidStore(
153 oid, repo, store_location=request.registry.git_lfs_store_path)
153 oid, repo, store_location=request.registry.git_lfs_store_path)
154 engine = store.get_engine(mode='wb')
154 engine = store.get_engine(mode='wb')
155 log.debug('LFS: starting chunked write of LFS oid: %s to storage', oid)
155 log.debug('LFS: starting chunked write of LFS oid: %s to storage', oid)
156
156
157 body = request.environ['wsgi.input']
157 body = request.environ['wsgi.input']
158
158
159 with engine as f:
159 with engine as f:
160 blksize = 64 * 1024 # 64kb
160 blksize = 64 * 1024 # 64kb
161 while True:
161 while True:
162 # read in chunks as stream comes in from Gunicorn
162 # read in chunks as stream comes in from Gunicorn
163 # this is a specific Gunicorn support function.
163 # this is a specific Gunicorn support function.
164 # might work differently on waitress
164 # might work differently on waitress
165 chunk = body.read(blksize)
165 chunk = body.read(blksize)
166 if not chunk:
166 if not chunk:
167 break
167 break
168 f.write(chunk)
168 f.write(chunk)
169
169
170 return {'upload': 'ok'}
170 return {'upload': 'ok'}
171
171
172
172
173 def lfs_objects_oid_download(request):
173 def lfs_objects_oid_download(request):
174 repo = request.matchdict.get('repo')
174 repo = request.matchdict.get('repo')
175 oid = request.matchdict.get('oid')
175 oid = request.matchdict.get('oid')
176
176
177 store = LFSOidStore(
177 store = LFSOidStore(
178 oid, repo, store_location=request.registry.git_lfs_store_path)
178 oid, repo, store_location=request.registry.git_lfs_store_path)
179 if not store.has_oid():
179 if not store.has_oid():
180 log.debug('LFS: oid %s does not exists in store', oid)
180 log.debug('LFS: oid %s does not exists in store', oid)
181 return write_response_error(
181 return write_response_error(
182 HTTPNotFound, 'requested file with oid `%s` not found in store' % oid)
182 HTTPNotFound, 'requested file with oid `%s` not found in store' % oid)
183
183
184 # TODO(marcink): support range header ?
184 # TODO(marcink): support range header ?
185 # Range: bytes=0-, `bytes=(\d+)\-.*`
185 # Range: bytes=0-, `bytes=(\d+)\-.*`
186
186
187 f = open(store.oid_path, 'rb')
187 f = open(store.oid_path, 'rb')
188 response = Response(
188 response = Response(
189 content_type='application/octet-stream', app_iter=FileIter(f))
189 content_type='application/octet-stream', app_iter=FileIter(f))
190 response.headers.add('X-RC-LFS-Response-Oid', str(oid))
190 response.headers.add('X-RC-LFS-Response-Oid', str(oid))
191 return response
191 return response
192
192
193
193
194 def lfs_objects_verify(request):
194 def lfs_objects_verify(request):
195 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
195 request.response.content_type = GIT_LFS_CONTENT_TYPE + '+json'
196 repo = request.matchdict.get('repo')
196 repo = request.matchdict.get('repo')
197
197
198 data = request.json
198 data = request.json
199 oid = data.get('oid')
199 oid = data.get('oid')
200 size = safe_int(data.get('size'))
200 size = safe_int(data.get('size'))
201
201
202 if not (oid and size):
202 if not (oid and size):
203 return write_response_error(
203 return write_response_error(
204 HTTPBadRequest, 'missing oid and size in request data')
204 HTTPBadRequest, 'missing oid and size in request data')
205
205
206 store = LFSOidStore(
206 store = LFSOidStore(
207 oid, repo, store_location=request.registry.git_lfs_store_path)
207 oid, repo, store_location=request.registry.git_lfs_store_path)
208 if not store.has_oid():
208 if not store.has_oid():
209 log.debug('LFS: oid %s does not exists in store', oid)
209 log.debug('LFS: oid %s does not exists in store', oid)
210 return write_response_error(
210 return write_response_error(
211 HTTPNotFound, 'oid `%s` does not exists in store' % oid)
211 HTTPNotFound, 'oid `%s` does not exists in store' % oid)
212
212
213 store_size = store.size_oid()
213 store_size = store.size_oid()
214 if store_size != size:
214 if store_size != size:
215 msg = 'requested file size mismatch store size:%s requested:%s' % (
215 msg = 'requested file size mismatch store size:{} requested:{}'.format(
216 store_size, size)
216 store_size, size)
217 return write_response_error(
217 return write_response_error(
218 HTTPUnprocessableEntity, msg)
218 HTTPUnprocessableEntity, msg)
219
219
220 return {'message': {'size': 'ok', 'in_store': 'ok'}}
220 return {'message': {'size': 'ok', 'in_store': 'ok'}}
221
221
222
222
223 def lfs_objects_lock(request):
223 def lfs_objects_lock(request):
224 return write_response_error(
224 return write_response_error(
225 HTTPNotImplemented, 'GIT LFS locking api not supported')
225 HTTPNotImplemented, 'GIT LFS locking api not supported')
226
226
227
227
228 def not_found(request):
228 def not_found(request):
229 return write_response_error(
229 return write_response_error(
230 HTTPNotFound, 'request path not found')
230 HTTPNotFound, 'request path not found')
231
231
232
232
233 def lfs_disabled(request):
233 def lfs_disabled(request):
234 return write_response_error(
234 return write_response_error(
235 HTTPNotImplemented, 'GIT LFS disabled for this repo')
235 HTTPNotImplemented, 'GIT LFS disabled for this repo')
236
236
237
237
238 def git_lfs_app(config):
238 def git_lfs_app(config):
239
239
240 # v1 API deprecation endpoint
240 # v1 API deprecation endpoint
241 config.add_route('lfs_objects',
241 config.add_route('lfs_objects',
242 '/{repo:.*?[^/]}/info/lfs/objects')
242 '/{repo:.*?[^/]}/info/lfs/objects')
243 config.add_view(lfs_objects, route_name='lfs_objects',
243 config.add_view(lfs_objects, route_name='lfs_objects',
244 request_method='POST', renderer='json')
244 request_method='POST', renderer='json')
245
245
246 # locking API
246 # locking API
247 config.add_route('lfs_objects_lock',
247 config.add_route('lfs_objects_lock',
248 '/{repo:.*?[^/]}/info/lfs/locks')
248 '/{repo:.*?[^/]}/info/lfs/locks')
249 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock',
249 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock',
250 request_method=('POST', 'GET'), renderer='json')
250 request_method=('POST', 'GET'), renderer='json')
251
251
252 config.add_route('lfs_objects_lock_verify',
252 config.add_route('lfs_objects_lock_verify',
253 '/{repo:.*?[^/]}/info/lfs/locks/verify')
253 '/{repo:.*?[^/]}/info/lfs/locks/verify')
254 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock_verify',
254 config.add_view(lfs_objects_lock, route_name='lfs_objects_lock_verify',
255 request_method=('POST', 'GET'), renderer='json')
255 request_method=('POST', 'GET'), renderer='json')
256
256
257 # batch API
257 # batch API
258 config.add_route('lfs_objects_batch',
258 config.add_route('lfs_objects_batch',
259 '/{repo:.*?[^/]}/info/lfs/objects/batch')
259 '/{repo:.*?[^/]}/info/lfs/objects/batch')
260 config.add_view(lfs_objects_batch, route_name='lfs_objects_batch',
260 config.add_view(lfs_objects_batch, route_name='lfs_objects_batch',
261 request_method='POST', renderer='json')
261 request_method='POST', renderer='json')
262
262
263 # oid upload/download API
263 # oid upload/download API
264 config.add_route('lfs_objects_oid',
264 config.add_route('lfs_objects_oid',
265 '/{repo:.*?[^/]}/info/lfs/objects/{oid}')
265 '/{repo:.*?[^/]}/info/lfs/objects/{oid}')
266 config.add_view(lfs_objects_oid_upload, route_name='lfs_objects_oid',
266 config.add_view(lfs_objects_oid_upload, route_name='lfs_objects_oid',
267 request_method='PUT', renderer='json')
267 request_method='PUT', renderer='json')
268 config.add_view(lfs_objects_oid_download, route_name='lfs_objects_oid',
268 config.add_view(lfs_objects_oid_download, route_name='lfs_objects_oid',
269 request_method='GET', renderer='json')
269 request_method='GET', renderer='json')
270
270
271 # verification API
271 # verification API
272 config.add_route('lfs_objects_verify',
272 config.add_route('lfs_objects_verify',
273 '/{repo:.*?[^/]}/info/lfs/verify')
273 '/{repo:.*?[^/]}/info/lfs/verify')
274 config.add_view(lfs_objects_verify, route_name='lfs_objects_verify',
274 config.add_view(lfs_objects_verify, route_name='lfs_objects_verify',
275 request_method='POST', renderer='json')
275 request_method='POST', renderer='json')
276
276
277 # not found handler for API
277 # not found handler for API
278 config.add_notfound_view(not_found, renderer='json')
278 config.add_notfound_view(not_found, renderer='json')
279
279
280
280
281 def create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
281 def create_app(git_lfs_enabled, git_lfs_store_path, git_lfs_http_scheme):
282 config = Configurator()
282 config = Configurator()
283 if git_lfs_enabled:
283 if git_lfs_enabled:
284 config.include(git_lfs_app)
284 config.include(git_lfs_app)
285 config.registry.git_lfs_store_path = git_lfs_store_path
285 config.registry.git_lfs_store_path = git_lfs_store_path
286 config.registry.git_lfs_http_scheme = git_lfs_http_scheme
286 config.registry.git_lfs_http_scheme = git_lfs_http_scheme
287 else:
287 else:
288 # not found handler for API, reporting disabled LFS support
288 # not found handler for API, reporting disabled LFS support
289 config.add_notfound_view(lfs_disabled, renderer='json')
289 config.add_notfound_view(lfs_disabled, renderer='json')
290
290
291 app = config.make_wsgi_app()
291 app = config.make_wsgi_app()
292 return app
292 return app
@@ -1,175 +1,175 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import os
18 import os
19 import shutil
19 import shutil
20 import logging
20 import logging
21 from collections import OrderedDict
21 from collections import OrderedDict
22
22
23 log = logging.getLogger(__name__)
23 log = logging.getLogger(__name__)
24
24
25
25
26 class OidHandler(object):
26 class OidHandler(object):
27
27
28 def __init__(self, store, repo_name, auth, oid, obj_size, obj_data, obj_href,
28 def __init__(self, store, repo_name, auth, oid, obj_size, obj_data, obj_href,
29 obj_verify_href=None):
29 obj_verify_href=None):
30 self.current_store = store
30 self.current_store = store
31 self.repo_name = repo_name
31 self.repo_name = repo_name
32 self.auth = auth
32 self.auth = auth
33 self.oid = oid
33 self.oid = oid
34 self.obj_size = obj_size
34 self.obj_size = obj_size
35 self.obj_data = obj_data
35 self.obj_data = obj_data
36 self.obj_href = obj_href
36 self.obj_href = obj_href
37 self.obj_verify_href = obj_verify_href
37 self.obj_verify_href = obj_verify_href
38
38
39 def get_store(self, mode=None):
39 def get_store(self, mode=None):
40 return self.current_store
40 return self.current_store
41
41
42 def get_auth(self):
42 def get_auth(self):
43 """returns auth header for re-use in upload/download"""
43 """returns auth header for re-use in upload/download"""
44 return " ".join(self.auth)
44 return " ".join(self.auth)
45
45
46 def download(self):
46 def download(self):
47
47
48 store = self.get_store()
48 store = self.get_store()
49 response = None
49 response = None
50 has_errors = None
50 has_errors = None
51
51
52 if not store.has_oid():
52 if not store.has_oid():
53 # error reply back to client that something is wrong with dl
53 # error reply back to client that something is wrong with dl
54 err_msg = 'object: {} does not exist in store'.format(store.oid)
54 err_msg = f'object: {store.oid} does not exist in store'
55 has_errors = OrderedDict(
55 has_errors = OrderedDict(
56 error=OrderedDict(
56 error=OrderedDict(
57 code=404,
57 code=404,
58 message=err_msg
58 message=err_msg
59 )
59 )
60 )
60 )
61
61
62 download_action = OrderedDict(
62 download_action = OrderedDict(
63 href=self.obj_href,
63 href=self.obj_href,
64 header=OrderedDict([("Authorization", self.get_auth())])
64 header=OrderedDict([("Authorization", self.get_auth())])
65 )
65 )
66 if not has_errors:
66 if not has_errors:
67 response = OrderedDict(download=download_action)
67 response = OrderedDict(download=download_action)
68 return response, has_errors
68 return response, has_errors
69
69
70 def upload(self, skip_existing=True):
70 def upload(self, skip_existing=True):
71 """
71 """
72 Write upload action for git-lfs server
72 Write upload action for git-lfs server
73 """
73 """
74
74
75 store = self.get_store()
75 store = self.get_store()
76 response = None
76 response = None
77 has_errors = None
77 has_errors = None
78
78
79 # verify if we have the OID before, if we do, reply with empty
79 # verify if we have the OID before, if we do, reply with empty
80 if store.has_oid():
80 if store.has_oid():
81 log.debug('LFS: store already has oid %s', store.oid)
81 log.debug('LFS: store already has oid %s', store.oid)
82
82
83 # validate size
83 # validate size
84 store_size = store.size_oid()
84 store_size = store.size_oid()
85 size_match = store_size == self.obj_size
85 size_match = store_size == self.obj_size
86 if not size_match:
86 if not size_match:
87 log.warning(
87 log.warning(
88 'LFS: size mismatch for oid:%s, in store:%s expected: %s',
88 'LFS: size mismatch for oid:%s, in store:%s expected: %s',
89 self.oid, store_size, self.obj_size)
89 self.oid, store_size, self.obj_size)
90 elif skip_existing:
90 elif skip_existing:
91 log.debug('LFS: skipping further action as oid is existing')
91 log.debug('LFS: skipping further action as oid is existing')
92 return response, has_errors
92 return response, has_errors
93
93
94 chunked = ("Transfer-Encoding", "chunked")
94 chunked = ("Transfer-Encoding", "chunked")
95 upload_action = OrderedDict(
95 upload_action = OrderedDict(
96 href=self.obj_href,
96 href=self.obj_href,
97 header=OrderedDict([("Authorization", self.get_auth()), chunked])
97 header=OrderedDict([("Authorization", self.get_auth()), chunked])
98 )
98 )
99 if not has_errors:
99 if not has_errors:
100 response = OrderedDict(upload=upload_action)
100 response = OrderedDict(upload=upload_action)
101 # if specified in handler, return the verification endpoint
101 # if specified in handler, return the verification endpoint
102 if self.obj_verify_href:
102 if self.obj_verify_href:
103 verify_action = OrderedDict(
103 verify_action = OrderedDict(
104 href=self.obj_verify_href,
104 href=self.obj_verify_href,
105 header=OrderedDict([("Authorization", self.get_auth())])
105 header=OrderedDict([("Authorization", self.get_auth())])
106 )
106 )
107 response['verify'] = verify_action
107 response['verify'] = verify_action
108 return response, has_errors
108 return response, has_errors
109
109
110 def exec_operation(self, operation, *args, **kwargs):
110 def exec_operation(self, operation, *args, **kwargs):
111 handler = getattr(self, operation)
111 handler = getattr(self, operation)
112 log.debug('LFS: handling request using %s handler', handler)
112 log.debug('LFS: handling request using %s handler', handler)
113 return handler(*args, **kwargs)
113 return handler(*args, **kwargs)
114
114
115
115
116 class LFSOidStore(object):
116 class LFSOidStore(object):
117
117
118 def __init__(self, oid, repo, store_location=None):
118 def __init__(self, oid, repo, store_location=None):
119 self.oid = oid
119 self.oid = oid
120 self.repo = repo
120 self.repo = repo
121 self.store_path = store_location or self.get_default_store()
121 self.store_path = store_location or self.get_default_store()
122 self.tmp_oid_path = os.path.join(self.store_path, oid + '.tmp')
122 self.tmp_oid_path = os.path.join(self.store_path, oid + '.tmp')
123 self.oid_path = os.path.join(self.store_path, oid)
123 self.oid_path = os.path.join(self.store_path, oid)
124 self.fd = None
124 self.fd = None
125
125
126 def get_engine(self, mode):
126 def get_engine(self, mode):
127 """
127 """
128 engine = .get_engine(mode='wb')
128 engine = .get_engine(mode='wb')
129 with engine as f:
129 with engine as f:
130 f.write('...')
130 f.write('...')
131 """
131 """
132
132
133 class StoreEngine(object):
133 class StoreEngine(object):
134 def __init__(self, mode, store_path, oid_path, tmp_oid_path):
134 def __init__(self, mode, store_path, oid_path, tmp_oid_path):
135 self.mode = mode
135 self.mode = mode
136 self.store_path = store_path
136 self.store_path = store_path
137 self.oid_path = oid_path
137 self.oid_path = oid_path
138 self.tmp_oid_path = tmp_oid_path
138 self.tmp_oid_path = tmp_oid_path
139
139
140 def __enter__(self):
140 def __enter__(self):
141 if not os.path.isdir(self.store_path):
141 if not os.path.isdir(self.store_path):
142 os.makedirs(self.store_path)
142 os.makedirs(self.store_path)
143
143
144 # TODO(marcink): maybe write metadata here with size/oid ?
144 # TODO(marcink): maybe write metadata here with size/oid ?
145 fd = open(self.tmp_oid_path, self.mode)
145 fd = open(self.tmp_oid_path, self.mode)
146 self.fd = fd
146 self.fd = fd
147 return fd
147 return fd
148
148
149 def __exit__(self, exc_type, exc_value, traceback):
149 def __exit__(self, exc_type, exc_value, traceback):
150 # close tmp file, and rename to final destination
150 # close tmp file, and rename to final destination
151 self.fd.close()
151 self.fd.close()
152 shutil.move(self.tmp_oid_path, self.oid_path)
152 shutil.move(self.tmp_oid_path, self.oid_path)
153
153
154 return StoreEngine(
154 return StoreEngine(
155 mode, self.store_path, self.oid_path, self.tmp_oid_path)
155 mode, self.store_path, self.oid_path, self.tmp_oid_path)
156
156
157 def get_default_store(self):
157 def get_default_store(self):
158 """
158 """
159 Default store, consistent with defaults of Mercurial large files store
159 Default store, consistent with defaults of Mercurial large files store
160 which is /home/username/.cache/largefiles
160 which is /home/username/.cache/largefiles
161 """
161 """
162 user_home = os.path.expanduser("~")
162 user_home = os.path.expanduser("~")
163 return os.path.join(user_home, '.cache', 'lfs-store')
163 return os.path.join(user_home, '.cache', 'lfs-store')
164
164
165 def has_oid(self):
165 def has_oid(self):
166 return os.path.exists(os.path.join(self.store_path, self.oid))
166 return os.path.exists(os.path.join(self.store_path, self.oid))
167
167
168 def size_oid(self):
168 def size_oid(self):
169 size = -1
169 size = -1
170
170
171 if self.has_oid():
171 if self.has_oid():
172 oid = os.path.join(self.store_path, self.oid)
172 oid = os.path.join(self.store_path, self.oid)
173 size = os.stat(oid).st_size
173 size = os.stat(oid).st_size
174
174
175 return size
175 return size
@@ -1,202 +1,202 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import re
18 import re
19 import os
19 import os
20 import sys
20 import sys
21 import datetime
21 import datetime
22 import logging
22 import logging
23 import pkg_resources
23 import pkg_resources
24
24
25 import vcsserver
25 import vcsserver
26 from vcsserver.str_utils import safe_bytes
26 from vcsserver.str_utils import safe_bytes
27
27
28 log = logging.getLogger(__name__)
28 log = logging.getLogger(__name__)
29
29
30
30
31 def get_git_hooks_path(repo_path, bare):
31 def get_git_hooks_path(repo_path, bare):
32 hooks_path = os.path.join(repo_path, 'hooks')
32 hooks_path = os.path.join(repo_path, 'hooks')
33 if not bare:
33 if not bare:
34 hooks_path = os.path.join(repo_path, '.git', 'hooks')
34 hooks_path = os.path.join(repo_path, '.git', 'hooks')
35
35
36 return hooks_path
36 return hooks_path
37
37
38
38
39 def install_git_hooks(repo_path, bare, executable=None, force_create=False):
39 def install_git_hooks(repo_path, bare, executable=None, force_create=False):
40 """
40 """
41 Creates a RhodeCode hook inside a git repository
41 Creates a RhodeCode hook inside a git repository
42
42
43 :param repo_path: path to repository
43 :param repo_path: path to repository
44 :param executable: binary executable to put in the hooks
44 :param executable: binary executable to put in the hooks
45 :param force_create: Create even if same name hook exists
45 :param force_create: Create even if same name hook exists
46 """
46 """
47 executable = executable or sys.executable
47 executable = executable or sys.executable
48 hooks_path = get_git_hooks_path(repo_path, bare)
48 hooks_path = get_git_hooks_path(repo_path, bare)
49
49
50 if not os.path.isdir(hooks_path):
50 if not os.path.isdir(hooks_path):
51 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
51 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
52
52
53 tmpl_post = pkg_resources.resource_string(
53 tmpl_post = pkg_resources.resource_string(
54 'vcsserver', '/'.join(
54 'vcsserver', '/'.join(
55 ('hook_utils', 'hook_templates', 'git_post_receive.py.tmpl')))
55 ('hook_utils', 'hook_templates', 'git_post_receive.py.tmpl')))
56 tmpl_pre = pkg_resources.resource_string(
56 tmpl_pre = pkg_resources.resource_string(
57 'vcsserver', '/'.join(
57 'vcsserver', '/'.join(
58 ('hook_utils', 'hook_templates', 'git_pre_receive.py.tmpl')))
58 ('hook_utils', 'hook_templates', 'git_pre_receive.py.tmpl')))
59
59
60 path = '' # not used for now
60 path = '' # not used for now
61 timestamp = datetime.datetime.utcnow().isoformat()
61 timestamp = datetime.datetime.utcnow().isoformat()
62
62
63 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
63 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
64 log.debug('Installing git hook in repo %s', repo_path)
64 log.debug('Installing git hook in repo %s', repo_path)
65 _hook_file = os.path.join(hooks_path, '%s-receive' % h_type)
65 _hook_file = os.path.join(hooks_path, '%s-receive' % h_type)
66 _rhodecode_hook = check_rhodecode_hook(_hook_file)
66 _rhodecode_hook = check_rhodecode_hook(_hook_file)
67
67
68 if _rhodecode_hook or force_create:
68 if _rhodecode_hook or force_create:
69 log.debug('writing git %s hook file at %s !', h_type, _hook_file)
69 log.debug('writing git %s hook file at %s !', h_type, _hook_file)
70 try:
70 try:
71 with open(_hook_file, 'wb') as f:
71 with open(_hook_file, 'wb') as f:
72 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
72 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
73 template = template.replace(b'_DATE_', safe_bytes(timestamp))
73 template = template.replace(b'_DATE_', safe_bytes(timestamp))
74 template = template.replace(b'_ENV_', safe_bytes(executable))
74 template = template.replace(b'_ENV_', safe_bytes(executable))
75 template = template.replace(b'_PATH_', safe_bytes(path))
75 template = template.replace(b'_PATH_', safe_bytes(path))
76 f.write(template)
76 f.write(template)
77 os.chmod(_hook_file, 0o755)
77 os.chmod(_hook_file, 0o755)
78 except IOError:
78 except OSError:
79 log.exception('error writing hook file %s', _hook_file)
79 log.exception('error writing hook file %s', _hook_file)
80 else:
80 else:
81 log.debug('skipping writing hook file')
81 log.debug('skipping writing hook file')
82
82
83 return True
83 return True
84
84
85
85
86 def get_svn_hooks_path(repo_path):
86 def get_svn_hooks_path(repo_path):
87 hooks_path = os.path.join(repo_path, 'hooks')
87 hooks_path = os.path.join(repo_path, 'hooks')
88
88
89 return hooks_path
89 return hooks_path
90
90
91
91
92 def install_svn_hooks(repo_path, executable=None, force_create=False):
92 def install_svn_hooks(repo_path, executable=None, force_create=False):
93 """
93 """
94 Creates RhodeCode hooks inside a svn repository
94 Creates RhodeCode hooks inside a svn repository
95
95
96 :param repo_path: path to repository
96 :param repo_path: path to repository
97 :param executable: binary executable to put in the hooks
97 :param executable: binary executable to put in the hooks
98 :param force_create: Create even if same name hook exists
98 :param force_create: Create even if same name hook exists
99 """
99 """
100 executable = executable or sys.executable
100 executable = executable or sys.executable
101 hooks_path = get_svn_hooks_path(repo_path)
101 hooks_path = get_svn_hooks_path(repo_path)
102 if not os.path.isdir(hooks_path):
102 if not os.path.isdir(hooks_path):
103 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
103 os.makedirs(hooks_path, mode=0o777, exist_ok=True)
104
104
105 tmpl_post = pkg_resources.resource_string(
105 tmpl_post = pkg_resources.resource_string(
106 'vcsserver', '/'.join(
106 'vcsserver', '/'.join(
107 ('hook_utils', 'hook_templates', 'svn_post_commit_hook.py.tmpl')))
107 ('hook_utils', 'hook_templates', 'svn_post_commit_hook.py.tmpl')))
108 tmpl_pre = pkg_resources.resource_string(
108 tmpl_pre = pkg_resources.resource_string(
109 'vcsserver', '/'.join(
109 'vcsserver', '/'.join(
110 ('hook_utils', 'hook_templates', 'svn_pre_commit_hook.py.tmpl')))
110 ('hook_utils', 'hook_templates', 'svn_pre_commit_hook.py.tmpl')))
111
111
112 path = '' # not used for now
112 path = '' # not used for now
113 timestamp = datetime.datetime.utcnow().isoformat()
113 timestamp = datetime.datetime.utcnow().isoformat()
114
114
115 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
115 for h_type, template in [('pre', tmpl_pre), ('post', tmpl_post)]:
116 log.debug('Installing svn hook in repo %s', repo_path)
116 log.debug('Installing svn hook in repo %s', repo_path)
117 _hook_file = os.path.join(hooks_path, '%s-commit' % h_type)
117 _hook_file = os.path.join(hooks_path, '%s-commit' % h_type)
118 _rhodecode_hook = check_rhodecode_hook(_hook_file)
118 _rhodecode_hook = check_rhodecode_hook(_hook_file)
119
119
120 if _rhodecode_hook or force_create:
120 if _rhodecode_hook or force_create:
121 log.debug('writing svn %s hook file at %s !', h_type, _hook_file)
121 log.debug('writing svn %s hook file at %s !', h_type, _hook_file)
122
122
123 try:
123 try:
124 with open(_hook_file, 'wb') as f:
124 with open(_hook_file, 'wb') as f:
125 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
125 template = template.replace(b'_TMPL_', safe_bytes(vcsserver.__version__))
126 template = template.replace(b'_DATE_', safe_bytes(timestamp))
126 template = template.replace(b'_DATE_', safe_bytes(timestamp))
127 template = template.replace(b'_ENV_', safe_bytes(executable))
127 template = template.replace(b'_ENV_', safe_bytes(executable))
128 template = template.replace(b'_PATH_', safe_bytes(path))
128 template = template.replace(b'_PATH_', safe_bytes(path))
129
129
130 f.write(template)
130 f.write(template)
131 os.chmod(_hook_file, 0o755)
131 os.chmod(_hook_file, 0o755)
132 except IOError:
132 except OSError:
133 log.exception('error writing hook file %s', _hook_file)
133 log.exception('error writing hook file %s', _hook_file)
134 else:
134 else:
135 log.debug('skipping writing hook file')
135 log.debug('skipping writing hook file')
136
136
137 return True
137 return True
138
138
139
139
140 def get_version_from_hook(hook_path):
140 def get_version_from_hook(hook_path):
141 version = b''
141 version = b''
142 hook_content = read_hook_content(hook_path)
142 hook_content = read_hook_content(hook_path)
143 matches = re.search(rb'RC_HOOK_VER\s*=\s*(.*)', hook_content)
143 matches = re.search(rb'RC_HOOK_VER\s*=\s*(.*)', hook_content)
144 if matches:
144 if matches:
145 try:
145 try:
146 version = matches.groups()[0]
146 version = matches.groups()[0]
147 log.debug('got version %s from hooks.', version)
147 log.debug('got version %s from hooks.', version)
148 except Exception:
148 except Exception:
149 log.exception("Exception while reading the hook version.")
149 log.exception("Exception while reading the hook version.")
150 return version.replace(b"'", b"")
150 return version.replace(b"'", b"")
151
151
152
152
153 def check_rhodecode_hook(hook_path):
153 def check_rhodecode_hook(hook_path):
154 """
154 """
155 Check if the hook was created by RhodeCode
155 Check if the hook was created by RhodeCode
156 """
156 """
157 if not os.path.exists(hook_path):
157 if not os.path.exists(hook_path):
158 return True
158 return True
159
159
160 log.debug('hook exists, checking if it is from RhodeCode')
160 log.debug('hook exists, checking if it is from RhodeCode')
161
161
162 version = get_version_from_hook(hook_path)
162 version = get_version_from_hook(hook_path)
163 if version:
163 if version:
164 return True
164 return True
165
165
166 return False
166 return False
167
167
168
168
169 def read_hook_content(hook_path) -> bytes:
169 def read_hook_content(hook_path) -> bytes:
170 content = b''
170 content = b''
171 if os.path.isfile(hook_path):
171 if os.path.isfile(hook_path):
172 with open(hook_path, 'rb') as f:
172 with open(hook_path, 'rb') as f:
173 content = f.read()
173 content = f.read()
174 return content
174 return content
175
175
176
176
177 def get_git_pre_hook_version(repo_path, bare):
177 def get_git_pre_hook_version(repo_path, bare):
178 hooks_path = get_git_hooks_path(repo_path, bare)
178 hooks_path = get_git_hooks_path(repo_path, bare)
179 _hook_file = os.path.join(hooks_path, 'pre-receive')
179 _hook_file = os.path.join(hooks_path, 'pre-receive')
180 version = get_version_from_hook(_hook_file)
180 version = get_version_from_hook(_hook_file)
181 return version
181 return version
182
182
183
183
184 def get_git_post_hook_version(repo_path, bare):
184 def get_git_post_hook_version(repo_path, bare):
185 hooks_path = get_git_hooks_path(repo_path, bare)
185 hooks_path = get_git_hooks_path(repo_path, bare)
186 _hook_file = os.path.join(hooks_path, 'post-receive')
186 _hook_file = os.path.join(hooks_path, 'post-receive')
187 version = get_version_from_hook(_hook_file)
187 version = get_version_from_hook(_hook_file)
188 return version
188 return version
189
189
190
190
191 def get_svn_pre_hook_version(repo_path):
191 def get_svn_pre_hook_version(repo_path):
192 hooks_path = get_svn_hooks_path(repo_path)
192 hooks_path = get_svn_hooks_path(repo_path)
193 _hook_file = os.path.join(hooks_path, 'pre-commit')
193 _hook_file = os.path.join(hooks_path, 'pre-commit')
194 version = get_version_from_hook(_hook_file)
194 version = get_version_from_hook(_hook_file)
195 return version
195 return version
196
196
197
197
198 def get_svn_post_hook_version(repo_path):
198 def get_svn_post_hook_version(repo_path):
199 hooks_path = get_svn_hooks_path(repo_path)
199 hooks_path = get_svn_hooks_path(repo_path)
200 _hook_file = os.path.join(hooks_path, 'post-commit')
200 _hook_file = os.path.join(hooks_path, 'post-commit')
201 version = get_version_from_hook(_hook_file)
201 version = get_version_from_hook(_hook_file)
202 return version
202 return version
@@ -1,243 +1,243 b''
1 '''
1 '''
2 This library is provided to allow standard python logging
2 This library is provided to allow standard python logging
3 to output log data as JSON formatted strings
3 to output log data as JSON formatted strings
4 '''
4 '''
5 import logging
5 import logging
6 import json
6 import json
7 import re
7 import re
8 from datetime import date, datetime, time, tzinfo, timedelta
8 from datetime import date, datetime, time, tzinfo, timedelta
9 import traceback
9 import traceback
10 import importlib
10 import importlib
11
11
12 from inspect import istraceback
12 from inspect import istraceback
13
13
14 from collections import OrderedDict
14 from collections import OrderedDict
15
15
16
16
17 def _inject_req_id(record, *args, **kwargs):
17 def _inject_req_id(record, *args, **kwargs):
18 return record
18 return record
19
19
20
20
21 ExceptionAwareFormatter = logging.Formatter
21 ExceptionAwareFormatter = logging.Formatter
22
22
23
23
24 ZERO = timedelta(0)
24 ZERO = timedelta(0)
25 HOUR = timedelta(hours=1)
25 HOUR = timedelta(hours=1)
26
26
27
27
28 class UTC(tzinfo):
28 class UTC(tzinfo):
29 """UTC"""
29 """UTC"""
30
30
31 def utcoffset(self, dt):
31 def utcoffset(self, dt):
32 return ZERO
32 return ZERO
33
33
34 def tzname(self, dt):
34 def tzname(self, dt):
35 return "UTC"
35 return "UTC"
36
36
37 def dst(self, dt):
37 def dst(self, dt):
38 return ZERO
38 return ZERO
39
39
40 utc = UTC()
40 utc = UTC()
41
41
42
42
43 # skip natural LogRecord attributes
43 # skip natural LogRecord attributes
44 # http://docs.python.org/library/logging.html#logrecord-attributes
44 # http://docs.python.org/library/logging.html#logrecord-attributes
45 RESERVED_ATTRS = (
45 RESERVED_ATTRS = (
46 'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
46 'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
47 'funcName', 'levelname', 'levelno', 'lineno', 'module',
47 'funcName', 'levelname', 'levelno', 'lineno', 'module',
48 'msecs', 'message', 'msg', 'name', 'pathname', 'process',
48 'msecs', 'message', 'msg', 'name', 'pathname', 'process',
49 'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName')
49 'processName', 'relativeCreated', 'stack_info', 'thread', 'threadName')
50
50
51
51
52 def merge_record_extra(record, target, reserved):
52 def merge_record_extra(record, target, reserved):
53 """
53 """
54 Merges extra attributes from LogRecord object into target dictionary
54 Merges extra attributes from LogRecord object into target dictionary
55
55
56 :param record: logging.LogRecord
56 :param record: logging.LogRecord
57 :param target: dict to update
57 :param target: dict to update
58 :param reserved: dict or list with reserved keys to skip
58 :param reserved: dict or list with reserved keys to skip
59 """
59 """
60 for key, value in record.__dict__.items():
60 for key, value in record.__dict__.items():
61 # this allows to have numeric keys
61 # this allows to have numeric keys
62 if (key not in reserved
62 if (key not in reserved
63 and not (hasattr(key, "startswith")
63 and not (hasattr(key, "startswith")
64 and key.startswith('_'))):
64 and key.startswith('_'))):
65 target[key] = value
65 target[key] = value
66 return target
66 return target
67
67
68
68
69 class JsonEncoder(json.JSONEncoder):
69 class JsonEncoder(json.JSONEncoder):
70 """
70 """
71 A custom encoder extending the default JSONEncoder
71 A custom encoder extending the default JSONEncoder
72 """
72 """
73
73
74 def default(self, obj):
74 def default(self, obj):
75 if isinstance(obj, (date, datetime, time)):
75 if isinstance(obj, (date, datetime, time)):
76 return self.format_datetime_obj(obj)
76 return self.format_datetime_obj(obj)
77
77
78 elif istraceback(obj):
78 elif istraceback(obj):
79 return ''.join(traceback.format_tb(obj)).strip()
79 return ''.join(traceback.format_tb(obj)).strip()
80
80
81 elif type(obj) == Exception \
81 elif type(obj) == Exception \
82 or isinstance(obj, Exception) \
82 or isinstance(obj, Exception) \
83 or type(obj) == type:
83 or type(obj) == type:
84 return str(obj)
84 return str(obj)
85
85
86 try:
86 try:
87 return super(JsonEncoder, self).default(obj)
87 return super().default(obj)
88
88
89 except TypeError:
89 except TypeError:
90 try:
90 try:
91 return str(obj)
91 return str(obj)
92
92
93 except Exception:
93 except Exception:
94 return None
94 return None
95
95
96 def format_datetime_obj(self, obj):
96 def format_datetime_obj(self, obj):
97 return obj.isoformat()
97 return obj.isoformat()
98
98
99
99
100 class JsonFormatter(ExceptionAwareFormatter):
100 class JsonFormatter(ExceptionAwareFormatter):
101 """
101 """
102 A custom formatter to format logging records as json strings.
102 A custom formatter to format logging records as json strings.
103 Extra values will be formatted as str() if not supported by
103 Extra values will be formatted as str() if not supported by
104 json default encoder
104 json default encoder
105 """
105 """
106
106
107 def __init__(self, *args, **kwargs):
107 def __init__(self, *args, **kwargs):
108 """
108 """
109 :param json_default: a function for encoding non-standard objects
109 :param json_default: a function for encoding non-standard objects
110 as outlined in http://docs.python.org/2/library/json.html
110 as outlined in http://docs.python.org/2/library/json.html
111 :param json_encoder: optional custom encoder
111 :param json_encoder: optional custom encoder
112 :param json_serializer: a :meth:`json.dumps`-compatible callable
112 :param json_serializer: a :meth:`json.dumps`-compatible callable
113 that will be used to serialize the log record.
113 that will be used to serialize the log record.
114 :param json_indent: an optional :meth:`json.dumps`-compatible numeric value
114 :param json_indent: an optional :meth:`json.dumps`-compatible numeric value
115 that will be used to customize the indent of the output json.
115 that will be used to customize the indent of the output json.
116 :param prefix: an optional string prefix added at the beginning of
116 :param prefix: an optional string prefix added at the beginning of
117 the formatted string
117 the formatted string
118 :param json_indent: indent parameter for json.dumps
118 :param json_indent: indent parameter for json.dumps
119 :param json_ensure_ascii: ensure_ascii parameter for json.dumps
119 :param json_ensure_ascii: ensure_ascii parameter for json.dumps
120 :param reserved_attrs: an optional list of fields that will be skipped when
120 :param reserved_attrs: an optional list of fields that will be skipped when
121 outputting json log record. Defaults to all log record attributes:
121 outputting json log record. Defaults to all log record attributes:
122 http://docs.python.org/library/logging.html#logrecord-attributes
122 http://docs.python.org/library/logging.html#logrecord-attributes
123 :param timestamp: an optional string/boolean field to add a timestamp when
123 :param timestamp: an optional string/boolean field to add a timestamp when
124 outputting the json log record. If string is passed, timestamp will be added
124 outputting the json log record. If string is passed, timestamp will be added
125 to log record using string as key. If True boolean is passed, timestamp key
125 to log record using string as key. If True boolean is passed, timestamp key
126 will be "timestamp". Defaults to False/off.
126 will be "timestamp". Defaults to False/off.
127 """
127 """
128 self.json_default = self._str_to_fn(kwargs.pop("json_default", None))
128 self.json_default = self._str_to_fn(kwargs.pop("json_default", None))
129 self.json_encoder = self._str_to_fn(kwargs.pop("json_encoder", None))
129 self.json_encoder = self._str_to_fn(kwargs.pop("json_encoder", None))
130 self.json_serializer = self._str_to_fn(kwargs.pop("json_serializer", json.dumps))
130 self.json_serializer = self._str_to_fn(kwargs.pop("json_serializer", json.dumps))
131 self.json_indent = kwargs.pop("json_indent", None)
131 self.json_indent = kwargs.pop("json_indent", None)
132 self.json_ensure_ascii = kwargs.pop("json_ensure_ascii", True)
132 self.json_ensure_ascii = kwargs.pop("json_ensure_ascii", True)
133 self.prefix = kwargs.pop("prefix", "")
133 self.prefix = kwargs.pop("prefix", "")
134 reserved_attrs = kwargs.pop("reserved_attrs", RESERVED_ATTRS)
134 reserved_attrs = kwargs.pop("reserved_attrs", RESERVED_ATTRS)
135 self.reserved_attrs = dict(list(zip(reserved_attrs, reserved_attrs)))
135 self.reserved_attrs = dict(list(zip(reserved_attrs, reserved_attrs)))
136 self.timestamp = kwargs.pop("timestamp", True)
136 self.timestamp = kwargs.pop("timestamp", True)
137
137
138 # super(JsonFormatter, self).__init__(*args, **kwargs)
138 # super(JsonFormatter, self).__init__(*args, **kwargs)
139 logging.Formatter.__init__(self, *args, **kwargs)
139 logging.Formatter.__init__(self, *args, **kwargs)
140 if not self.json_encoder and not self.json_default:
140 if not self.json_encoder and not self.json_default:
141 self.json_encoder = JsonEncoder
141 self.json_encoder = JsonEncoder
142
142
143 self._required_fields = self.parse()
143 self._required_fields = self.parse()
144 self._skip_fields = dict(list(zip(self._required_fields,
144 self._skip_fields = dict(list(zip(self._required_fields,
145 self._required_fields)))
145 self._required_fields)))
146 self._skip_fields.update(self.reserved_attrs)
146 self._skip_fields.update(self.reserved_attrs)
147
147
148 def _str_to_fn(self, fn_as_str):
148 def _str_to_fn(self, fn_as_str):
149 """
149 """
150 If the argument is not a string, return whatever was passed in.
150 If the argument is not a string, return whatever was passed in.
151 Parses a string such as package.module.function, imports the module
151 Parses a string such as package.module.function, imports the module
152 and returns the function.
152 and returns the function.
153
153
154 :param fn_as_str: The string to parse. If not a string, return it.
154 :param fn_as_str: The string to parse. If not a string, return it.
155 """
155 """
156 if not isinstance(fn_as_str, str):
156 if not isinstance(fn_as_str, str):
157 return fn_as_str
157 return fn_as_str
158
158
159 path, _, function = fn_as_str.rpartition('.')
159 path, _, function = fn_as_str.rpartition('.')
160 module = importlib.import_module(path)
160 module = importlib.import_module(path)
161 return getattr(module, function)
161 return getattr(module, function)
162
162
163 def parse(self):
163 def parse(self):
164 """
164 """
165 Parses format string looking for substitutions
165 Parses format string looking for substitutions
166
166
167 This method is responsible for returning a list of fields (as strings)
167 This method is responsible for returning a list of fields (as strings)
168 to include in all log messages.
168 to include in all log messages.
169 """
169 """
170 standard_formatters = re.compile(r'\((.+?)\)', re.IGNORECASE)
170 standard_formatters = re.compile(r'\((.+?)\)', re.IGNORECASE)
171 return standard_formatters.findall(self._fmt)
171 return standard_formatters.findall(self._fmt)
172
172
173 def add_fields(self, log_record, record, message_dict):
173 def add_fields(self, log_record, record, message_dict):
174 """
174 """
175 Override this method to implement custom logic for adding fields.
175 Override this method to implement custom logic for adding fields.
176 """
176 """
177 for field in self._required_fields:
177 for field in self._required_fields:
178 log_record[field] = record.__dict__.get(field)
178 log_record[field] = record.__dict__.get(field)
179 log_record.update(message_dict)
179 log_record.update(message_dict)
180 merge_record_extra(record, log_record, reserved=self._skip_fields)
180 merge_record_extra(record, log_record, reserved=self._skip_fields)
181
181
182 if self.timestamp:
182 if self.timestamp:
183 key = self.timestamp if type(self.timestamp) == str else 'timestamp'
183 key = self.timestamp if type(self.timestamp) == str else 'timestamp'
184 log_record[key] = datetime.fromtimestamp(record.created, tz=utc)
184 log_record[key] = datetime.fromtimestamp(record.created, tz=utc)
185
185
186 def process_log_record(self, log_record):
186 def process_log_record(self, log_record):
187 """
187 """
188 Override this method to implement custom logic
188 Override this method to implement custom logic
189 on the possibly ordered dictionary.
189 on the possibly ordered dictionary.
190 """
190 """
191 return log_record
191 return log_record
192
192
193 def jsonify_log_record(self, log_record):
193 def jsonify_log_record(self, log_record):
194 """Returns a json string of the log record."""
194 """Returns a json string of the log record."""
195 return self.json_serializer(log_record,
195 return self.json_serializer(log_record,
196 default=self.json_default,
196 default=self.json_default,
197 cls=self.json_encoder,
197 cls=self.json_encoder,
198 indent=self.json_indent,
198 indent=self.json_indent,
199 ensure_ascii=self.json_ensure_ascii)
199 ensure_ascii=self.json_ensure_ascii)
200
200
201 def serialize_log_record(self, log_record):
201 def serialize_log_record(self, log_record):
202 """Returns the final representation of the log record."""
202 """Returns the final representation of the log record."""
203 return "%s%s" % (self.prefix, self.jsonify_log_record(log_record))
203 return "{}{}".format(self.prefix, self.jsonify_log_record(log_record))
204
204
205 def format(self, record):
205 def format(self, record):
206 """Formats a log record and serializes to json"""
206 """Formats a log record and serializes to json"""
207 message_dict = {}
207 message_dict = {}
208 # FIXME: logging.LogRecord.msg and logging.LogRecord.message in typeshed
208 # FIXME: logging.LogRecord.msg and logging.LogRecord.message in typeshed
209 # are always type of str. We shouldn't need to override that.
209 # are always type of str. We shouldn't need to override that.
210 if isinstance(record.msg, dict):
210 if isinstance(record.msg, dict):
211 message_dict = record.msg
211 message_dict = record.msg
212 record.message = None
212 record.message = None
213 else:
213 else:
214 record.message = record.getMessage()
214 record.message = record.getMessage()
215 # only format time if needed
215 # only format time if needed
216 if "asctime" in self._required_fields:
216 if "asctime" in self._required_fields:
217 record.asctime = self.formatTime(record, self.datefmt)
217 record.asctime = self.formatTime(record, self.datefmt)
218
218
219 # Display formatted exception, but allow overriding it in the
219 # Display formatted exception, but allow overriding it in the
220 # user-supplied dict.
220 # user-supplied dict.
221 if record.exc_info and not message_dict.get('exc_info'):
221 if record.exc_info and not message_dict.get('exc_info'):
222 message_dict['exc_info'] = self.formatException(record.exc_info)
222 message_dict['exc_info'] = self.formatException(record.exc_info)
223 if not message_dict.get('exc_info') and record.exc_text:
223 if not message_dict.get('exc_info') and record.exc_text:
224 message_dict['exc_info'] = record.exc_text
224 message_dict['exc_info'] = record.exc_text
225 # Display formatted record of stack frames
225 # Display formatted record of stack frames
226 # default format is a string returned from :func:`traceback.print_stack`
226 # default format is a string returned from :func:`traceback.print_stack`
227 try:
227 try:
228 if record.stack_info and not message_dict.get('stack_info'):
228 if record.stack_info and not message_dict.get('stack_info'):
229 message_dict['stack_info'] = self.formatStack(record.stack_info)
229 message_dict['stack_info'] = self.formatStack(record.stack_info)
230 except AttributeError:
230 except AttributeError:
231 # Python2.7 doesn't have stack_info.
231 # Python2.7 doesn't have stack_info.
232 pass
232 pass
233
233
234 try:
234 try:
235 log_record = OrderedDict()
235 log_record = OrderedDict()
236 except NameError:
236 except NameError:
237 log_record = {}
237 log_record = {}
238
238
239 _inject_req_id(record, with_prefix=False)
239 _inject_req_id(record, with_prefix=False)
240 self.add_fields(log_record, record, message_dict)
240 self.add_fields(log_record, record, message_dict)
241 log_record = self.process_log_record(log_record)
241 log_record = self.process_log_record(log_record)
242
242
243 return self.serialize_log_record(log_record)
243 return self.serialize_log_record(log_record)
@@ -1,384 +1,384 b''
1 import sys
1 import sys
2 import threading
2 import threading
3 import weakref
3 import weakref
4 from base64 import b64encode
4 from base64 import b64encode
5 from logging import getLogger
5 from logging import getLogger
6 from os import urandom
6 from os import urandom
7
7
8 from redis import StrictRedis
8 from redis import StrictRedis
9
9
10 __version__ = '3.7.0'
10 __version__ = '3.7.0'
11
11
12 loggers = {
12 loggers = {
13 k: getLogger("vcsserver." + ".".join((__name__, k)))
13 k: getLogger("vcsserver." + ".".join((__name__, k)))
14 for k in [
14 for k in [
15 "acquire",
15 "acquire",
16 "refresh.thread.start",
16 "refresh.thread.start",
17 "refresh.thread.stop",
17 "refresh.thread.stop",
18 "refresh.thread.exit",
18 "refresh.thread.exit",
19 "refresh.start",
19 "refresh.start",
20 "refresh.shutdown",
20 "refresh.shutdown",
21 "refresh.exit",
21 "refresh.exit",
22 "release",
22 "release",
23 ]
23 ]
24 }
24 }
25
25
26 text_type = str
26 text_type = str
27 binary_type = bytes
27 binary_type = bytes
28
28
29
29
30 # Check if the id match. If not, return an error code.
30 # Check if the id match. If not, return an error code.
31 UNLOCK_SCRIPT = b"""
31 UNLOCK_SCRIPT = b"""
32 if redis.call("get", KEYS[1]) ~= ARGV[1] then
32 if redis.call("get", KEYS[1]) ~= ARGV[1] then
33 return 1
33 return 1
34 else
34 else
35 redis.call("del", KEYS[2])
35 redis.call("del", KEYS[2])
36 redis.call("lpush", KEYS[2], 1)
36 redis.call("lpush", KEYS[2], 1)
37 redis.call("pexpire", KEYS[2], ARGV[2])
37 redis.call("pexpire", KEYS[2], ARGV[2])
38 redis.call("del", KEYS[1])
38 redis.call("del", KEYS[1])
39 return 0
39 return 0
40 end
40 end
41 """
41 """
42
42
43 # Covers both cases when key doesn't exist and doesn't equal to lock's id
43 # Covers both cases when key doesn't exist and doesn't equal to lock's id
44 EXTEND_SCRIPT = b"""
44 EXTEND_SCRIPT = b"""
45 if redis.call("get", KEYS[1]) ~= ARGV[1] then
45 if redis.call("get", KEYS[1]) ~= ARGV[1] then
46 return 1
46 return 1
47 elseif redis.call("ttl", KEYS[1]) < 0 then
47 elseif redis.call("ttl", KEYS[1]) < 0 then
48 return 2
48 return 2
49 else
49 else
50 redis.call("expire", KEYS[1], ARGV[2])
50 redis.call("expire", KEYS[1], ARGV[2])
51 return 0
51 return 0
52 end
52 end
53 """
53 """
54
54
55 RESET_SCRIPT = b"""
55 RESET_SCRIPT = b"""
56 redis.call('del', KEYS[2])
56 redis.call('del', KEYS[2])
57 redis.call('lpush', KEYS[2], 1)
57 redis.call('lpush', KEYS[2], 1)
58 redis.call('pexpire', KEYS[2], ARGV[2])
58 redis.call('pexpire', KEYS[2], ARGV[2])
59 return redis.call('del', KEYS[1])
59 return redis.call('del', KEYS[1])
60 """
60 """
61
61
62 RESET_ALL_SCRIPT = b"""
62 RESET_ALL_SCRIPT = b"""
63 local locks = redis.call('keys', 'lock:*')
63 local locks = redis.call('keys', 'lock:*')
64 local signal
64 local signal
65 for _, lock in pairs(locks) do
65 for _, lock in pairs(locks) do
66 signal = 'lock-signal:' .. string.sub(lock, 6)
66 signal = 'lock-signal:' .. string.sub(lock, 6)
67 redis.call('del', signal)
67 redis.call('del', signal)
68 redis.call('lpush', signal, 1)
68 redis.call('lpush', signal, 1)
69 redis.call('expire', signal, 1)
69 redis.call('expire', signal, 1)
70 redis.call('del', lock)
70 redis.call('del', lock)
71 end
71 end
72 return #locks
72 return #locks
73 """
73 """
74
74
75
75
76 class AlreadyAcquired(RuntimeError):
76 class AlreadyAcquired(RuntimeError):
77 pass
77 pass
78
78
79
79
80 class NotAcquired(RuntimeError):
80 class NotAcquired(RuntimeError):
81 pass
81 pass
82
82
83
83
84 class AlreadyStarted(RuntimeError):
84 class AlreadyStarted(RuntimeError):
85 pass
85 pass
86
86
87
87
88 class TimeoutNotUsable(RuntimeError):
88 class TimeoutNotUsable(RuntimeError):
89 pass
89 pass
90
90
91
91
92 class InvalidTimeout(RuntimeError):
92 class InvalidTimeout(RuntimeError):
93 pass
93 pass
94
94
95
95
96 class TimeoutTooLarge(RuntimeError):
96 class TimeoutTooLarge(RuntimeError):
97 pass
97 pass
98
98
99
99
100 class NotExpirable(RuntimeError):
100 class NotExpirable(RuntimeError):
101 pass
101 pass
102
102
103
103
104 class Lock(object):
104 class Lock(object):
105 """
105 """
106 A Lock context manager implemented via redis SETNX/BLPOP.
106 A Lock context manager implemented via redis SETNX/BLPOP.
107 """
107 """
108 unlock_script = None
108 unlock_script = None
109 extend_script = None
109 extend_script = None
110 reset_script = None
110 reset_script = None
111 reset_all_script = None
111 reset_all_script = None
112
112
113 def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000):
113 def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000):
114 """
114 """
115 :param redis_client:
115 :param redis_client:
116 An instance of :class:`~StrictRedis`.
116 An instance of :class:`~StrictRedis`.
117 :param name:
117 :param name:
118 The name (redis key) the lock should have.
118 The name (redis key) the lock should have.
119 :param expire:
119 :param expire:
120 The lock expiry time in seconds. If left at the default (None)
120 The lock expiry time in seconds. If left at the default (None)
121 the lock will not expire.
121 the lock will not expire.
122 :param id:
122 :param id:
123 The ID (redis value) the lock should have. A random value is
123 The ID (redis value) the lock should have. A random value is
124 generated when left at the default.
124 generated when left at the default.
125
125
126 Note that if you specify this then the lock is marked as "held". Acquires
126 Note that if you specify this then the lock is marked as "held". Acquires
127 won't be possible.
127 won't be possible.
128 :param auto_renewal:
128 :param auto_renewal:
129 If set to ``True``, Lock will automatically renew the lock so that it
129 If set to ``True``, Lock will automatically renew the lock so that it
130 doesn't expire for as long as the lock is held (acquire() called
130 doesn't expire for as long as the lock is held (acquire() called
131 or running in a context manager).
131 or running in a context manager).
132
132
133 Implementation note: Renewal will happen using a daemon thread with
133 Implementation note: Renewal will happen using a daemon thread with
134 an interval of ``expire*2/3``. If wishing to use a different renewal
134 an interval of ``expire*2/3``. If wishing to use a different renewal
135 time, subclass Lock, call ``super().__init__()`` then set
135 time, subclass Lock, call ``super().__init__()`` then set
136 ``self._lock_renewal_interval`` to your desired interval.
136 ``self._lock_renewal_interval`` to your desired interval.
137 :param strict:
137 :param strict:
138 If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
138 If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
139 :param signal_expire:
139 :param signal_expire:
140 Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``.
140 Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``.
141 """
141 """
142 if strict and not isinstance(redis_client, StrictRedis):
142 if strict and not isinstance(redis_client, StrictRedis):
143 raise ValueError("redis_client must be instance of StrictRedis. "
143 raise ValueError("redis_client must be instance of StrictRedis. "
144 "Use strict=False if you know what you're doing.")
144 "Use strict=False if you know what you're doing.")
145 if auto_renewal and expire is None:
145 if auto_renewal and expire is None:
146 raise ValueError("Expire may not be None when auto_renewal is set")
146 raise ValueError("Expire may not be None when auto_renewal is set")
147
147
148 self._client = redis_client
148 self._client = redis_client
149
149
150 if expire:
150 if expire:
151 expire = int(expire)
151 expire = int(expire)
152 if expire < 0:
152 if expire < 0:
153 raise ValueError("A negative expire is not acceptable.")
153 raise ValueError("A negative expire is not acceptable.")
154 else:
154 else:
155 expire = None
155 expire = None
156 self._expire = expire
156 self._expire = expire
157
157
158 self._signal_expire = signal_expire
158 self._signal_expire = signal_expire
159 if id is None:
159 if id is None:
160 self._id = b64encode(urandom(18)).decode('ascii')
160 self._id = b64encode(urandom(18)).decode('ascii')
161 elif isinstance(id, binary_type):
161 elif isinstance(id, binary_type):
162 try:
162 try:
163 self._id = id.decode('ascii')
163 self._id = id.decode('ascii')
164 except UnicodeDecodeError:
164 except UnicodeDecodeError:
165 self._id = b64encode(id).decode('ascii')
165 self._id = b64encode(id).decode('ascii')
166 elif isinstance(id, text_type):
166 elif isinstance(id, text_type):
167 self._id = id
167 self._id = id
168 else:
168 else:
169 raise TypeError("Incorrect type for `id`. Must be bytes/str not %s." % type(id))
169 raise TypeError("Incorrect type for `id`. Must be bytes/str not %s." % type(id))
170 self._name = 'lock:' + name
170 self._name = 'lock:' + name
171 self._signal = 'lock-signal:' + name
171 self._signal = 'lock-signal:' + name
172 self._lock_renewal_interval = (float(expire) * 2 / 3
172 self._lock_renewal_interval = (float(expire) * 2 / 3
173 if auto_renewal
173 if auto_renewal
174 else None)
174 else None)
175 self._lock_renewal_thread = None
175 self._lock_renewal_thread = None
176
176
177 self.register_scripts(redis_client)
177 self.register_scripts(redis_client)
178
178
179 @classmethod
179 @classmethod
180 def register_scripts(cls, redis_client):
180 def register_scripts(cls, redis_client):
181 global reset_all_script
181 global reset_all_script
182 if reset_all_script is None:
182 if reset_all_script is None:
183 reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
183 reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
184 cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT)
184 cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT)
185 cls.extend_script = redis_client.register_script(EXTEND_SCRIPT)
185 cls.extend_script = redis_client.register_script(EXTEND_SCRIPT)
186 cls.reset_script = redis_client.register_script(RESET_SCRIPT)
186 cls.reset_script = redis_client.register_script(RESET_SCRIPT)
187 cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
187 cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
188
188
189 @property
189 @property
190 def _held(self):
190 def _held(self):
191 return self.id == self.get_owner_id()
191 return self.id == self.get_owner_id()
192
192
193 def reset(self):
193 def reset(self):
194 """
194 """
195 Forcibly deletes the lock. Use this with care.
195 Forcibly deletes the lock. Use this with care.
196 """
196 """
197 self.reset_script(client=self._client, keys=(self._name, self._signal), args=(self.id, self._signal_expire))
197 self.reset_script(client=self._client, keys=(self._name, self._signal), args=(self.id, self._signal_expire))
198
198
199 @property
199 @property
200 def id(self):
200 def id(self):
201 return self._id
201 return self._id
202
202
203 def get_owner_id(self):
203 def get_owner_id(self):
204 owner_id = self._client.get(self._name)
204 owner_id = self._client.get(self._name)
205 if isinstance(owner_id, binary_type):
205 if isinstance(owner_id, binary_type):
206 owner_id = owner_id.decode('ascii', 'replace')
206 owner_id = owner_id.decode('ascii', 'replace')
207 return owner_id
207 return owner_id
208
208
209 def acquire(self, blocking=True, timeout=None):
209 def acquire(self, blocking=True, timeout=None):
210 """
210 """
211 :param blocking:
211 :param blocking:
212 Boolean value specifying whether lock should be blocking or not.
212 Boolean value specifying whether lock should be blocking or not.
213 :param timeout:
213 :param timeout:
214 An integer value specifying the maximum number of seconds to block.
214 An integer value specifying the maximum number of seconds to block.
215 """
215 """
216 logger = loggers["acquire"]
216 logger = loggers["acquire"]
217
217
218 logger.debug("Getting blocking: %s acquire on %r ...", blocking, self._name)
218 logger.debug("Getting blocking: %s acquire on %r ...", blocking, self._name)
219
219
220 if self._held:
220 if self._held:
221 owner_id = self.get_owner_id()
221 owner_id = self.get_owner_id()
222 raise AlreadyAcquired("Already acquired from this Lock instance. Lock id: {}".format(owner_id))
222 raise AlreadyAcquired(f"Already acquired from this Lock instance. Lock id: {owner_id}")
223
223
224 if not blocking and timeout is not None:
224 if not blocking and timeout is not None:
225 raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
225 raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
226
226
227 if timeout:
227 if timeout:
228 timeout = int(timeout)
228 timeout = int(timeout)
229 if timeout < 0:
229 if timeout < 0:
230 raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
230 raise InvalidTimeout("Timeout (%d) cannot be less than or equal to 0" % timeout)
231
231
232 if self._expire and not self._lock_renewal_interval and timeout > self._expire:
232 if self._expire and not self._lock_renewal_interval and timeout > self._expire:
233 raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
233 raise TimeoutTooLarge("Timeout (%d) cannot be greater than expire (%d)" % (timeout, self._expire))
234
234
235 busy = True
235 busy = True
236 blpop_timeout = timeout or self._expire or 0
236 blpop_timeout = timeout or self._expire or 0
237 timed_out = False
237 timed_out = False
238 while busy:
238 while busy:
239 busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
239 busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
240 if busy:
240 if busy:
241 if timed_out:
241 if timed_out:
242 return False
242 return False
243 elif blocking:
243 elif blocking:
244 timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
244 timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
245 else:
245 else:
246 logger.warning("Failed to get %r.", self._name)
246 logger.warning("Failed to get %r.", self._name)
247 return False
247 return False
248
248
249 logger.debug("Got lock for %r.", self._name)
249 logger.debug("Got lock for %r.", self._name)
250 if self._lock_renewal_interval is not None:
250 if self._lock_renewal_interval is not None:
251 self._start_lock_renewer()
251 self._start_lock_renewer()
252 return True
252 return True
253
253
254 def extend(self, expire=None):
254 def extend(self, expire=None):
255 """Extends expiration time of the lock.
255 """Extends expiration time of the lock.
256
256
257 :param expire:
257 :param expire:
258 New expiration time. If ``None`` - `expire` provided during
258 New expiration time. If ``None`` - `expire` provided during
259 lock initialization will be taken.
259 lock initialization will be taken.
260 """
260 """
261 if expire:
261 if expire:
262 expire = int(expire)
262 expire = int(expire)
263 if expire < 0:
263 if expire < 0:
264 raise ValueError("A negative expire is not acceptable.")
264 raise ValueError("A negative expire is not acceptable.")
265 elif self._expire is not None:
265 elif self._expire is not None:
266 expire = self._expire
266 expire = self._expire
267 else:
267 else:
268 raise TypeError(
268 raise TypeError(
269 "To extend a lock 'expire' must be provided as an "
269 "To extend a lock 'expire' must be provided as an "
270 "argument to extend() method or at initialization time."
270 "argument to extend() method or at initialization time."
271 )
271 )
272
272
273 error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
273 error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
274 if error == 1:
274 if error == 1:
275 raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
275 raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
276 elif error == 2:
276 elif error == 2:
277 raise NotExpirable("Lock %s has no assigned expiration time" % self._name)
277 raise NotExpirable("Lock %s has no assigned expiration time" % self._name)
278 elif error:
278 elif error:
279 raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
279 raise RuntimeError("Unsupported error code %s from EXTEND script" % error)
280
280
281 @staticmethod
281 @staticmethod
282 def _lock_renewer(lockref, interval, stop):
282 def _lock_renewer(lockref, interval, stop):
283 """
283 """
284 Renew the lock key in redis every `interval` seconds for as long
284 Renew the lock key in redis every `interval` seconds for as long
285 as `self._lock_renewal_thread.should_exit` is False.
285 as `self._lock_renewal_thread.should_exit` is False.
286 """
286 """
287 while not stop.wait(timeout=interval):
287 while not stop.wait(timeout=interval):
288 loggers["refresh.thread.start"].debug("Refreshing lock")
288 loggers["refresh.thread.start"].debug("Refreshing lock")
289 lock = lockref()
289 lock = lockref()
290 if lock is None:
290 if lock is None:
291 loggers["refresh.thread.stop"].debug(
291 loggers["refresh.thread.stop"].debug(
292 "The lock no longer exists, stopping lock refreshing"
292 "The lock no longer exists, stopping lock refreshing"
293 )
293 )
294 break
294 break
295 lock.extend(expire=lock._expire)
295 lock.extend(expire=lock._expire)
296 del lock
296 del lock
297 loggers["refresh.thread.exit"].debug("Exit requested, stopping lock refreshing")
297 loggers["refresh.thread.exit"].debug("Exit requested, stopping lock refreshing")
298
298
299 def _start_lock_renewer(self):
299 def _start_lock_renewer(self):
300 """
300 """
301 Starts the lock refresher thread.
301 Starts the lock refresher thread.
302 """
302 """
303 if self._lock_renewal_thread is not None:
303 if self._lock_renewal_thread is not None:
304 raise AlreadyStarted("Lock refresh thread already started")
304 raise AlreadyStarted("Lock refresh thread already started")
305
305
306 loggers["refresh.start"].debug(
306 loggers["refresh.start"].debug(
307 "Starting thread to refresh lock every %s seconds",
307 "Starting thread to refresh lock every %s seconds",
308 self._lock_renewal_interval
308 self._lock_renewal_interval
309 )
309 )
310 self._lock_renewal_stop = threading.Event()
310 self._lock_renewal_stop = threading.Event()
311 self._lock_renewal_thread = threading.Thread(
311 self._lock_renewal_thread = threading.Thread(
312 group=None,
312 group=None,
313 target=self._lock_renewer,
313 target=self._lock_renewer,
314 kwargs={'lockref': weakref.ref(self),
314 kwargs={'lockref': weakref.ref(self),
315 'interval': self._lock_renewal_interval,
315 'interval': self._lock_renewal_interval,
316 'stop': self._lock_renewal_stop}
316 'stop': self._lock_renewal_stop}
317 )
317 )
318 self._lock_renewal_thread.setDaemon(True)
318 self._lock_renewal_thread.setDaemon(True)
319 self._lock_renewal_thread.start()
319 self._lock_renewal_thread.start()
320
320
321 def _stop_lock_renewer(self):
321 def _stop_lock_renewer(self):
322 """
322 """
323 Stop the lock renewer.
323 Stop the lock renewer.
324
324
325 This signals the renewal thread and waits for its exit.
325 This signals the renewal thread and waits for its exit.
326 """
326 """
327 if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
327 if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
328 return
328 return
329 loggers["refresh.shutdown"].debug("Signalling the lock refresher to stop")
329 loggers["refresh.shutdown"].debug("Signalling the lock refresher to stop")
330 self._lock_renewal_stop.set()
330 self._lock_renewal_stop.set()
331 self._lock_renewal_thread.join()
331 self._lock_renewal_thread.join()
332 self._lock_renewal_thread = None
332 self._lock_renewal_thread = None
333 loggers["refresh.exit"].debug("Lock refresher has stopped")
333 loggers["refresh.exit"].debug("Lock refresher has stopped")
334
334
335 def __enter__(self):
335 def __enter__(self):
336 acquired = self.acquire(blocking=True)
336 acquired = self.acquire(blocking=True)
337 assert acquired, "Lock wasn't acquired, but blocking=True"
337 assert acquired, "Lock wasn't acquired, but blocking=True"
338 return self
338 return self
339
339
340 def __exit__(self, exc_type=None, exc_value=None, traceback=None):
340 def __exit__(self, exc_type=None, exc_value=None, traceback=None):
341 self.release()
341 self.release()
342
342
343 def release(self):
343 def release(self):
344 """Releases the lock, that was acquired with the same object.
344 """Releases the lock, that was acquired with the same object.
345
345
346 .. note::
346 .. note::
347
347
348 If you want to release a lock that you acquired in a different place you have two choices:
348 If you want to release a lock that you acquired in a different place you have two choices:
349
349
350 * Use ``Lock("name", id=id_from_other_place).release()``
350 * Use ``Lock("name", id=id_from_other_place).release()``
351 * Use ``Lock("name").reset()``
351 * Use ``Lock("name").reset()``
352 """
352 """
353 if self._lock_renewal_thread is not None:
353 if self._lock_renewal_thread is not None:
354 self._stop_lock_renewer()
354 self._stop_lock_renewer()
355 loggers["release"].debug("Releasing %r.", self._name)
355 loggers["release"].debug("Releasing %r.", self._name)
356 error = self.unlock_script(client=self._client, keys=(self._name, self._signal), args=(self._id, self._signal_expire))
356 error = self.unlock_script(client=self._client, keys=(self._name, self._signal), args=(self._id, self._signal_expire))
357 if error == 1:
357 if error == 1:
358 raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
358 raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
359 elif error:
359 elif error:
360 raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
360 raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
361
361
362 def locked(self):
362 def locked(self):
363 """
363 """
364 Return true if the lock is acquired.
364 Return true if the lock is acquired.
365
365
366 Checks that lock with same name already exists. This method returns true, even if
366 Checks that lock with same name already exists. This method returns true, even if
367 lock have another id.
367 lock have another id.
368 """
368 """
369 return self._client.exists(self._name) == 1
369 return self._client.exists(self._name) == 1
370
370
371
371
372 reset_all_script = None
372 reset_all_script = None
373
373
374
374
375 def reset_all(redis_client):
375 def reset_all(redis_client):
376 """
376 """
377 Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
377 Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
378
378
379 :param redis_client:
379 :param redis_client:
380 An instance of :class:`~StrictRedis`.
380 An instance of :class:`~StrictRedis`.
381 """
381 """
382 Lock.register_scripts(redis_client)
382 Lock.register_scripts(redis_client)
383
383
384 reset_all_script(client=redis_client) # noqa
384 reset_all_script(client=redis_client) # noqa
@@ -1,52 +1,50 b''
1
2
3 import logging
1 import logging
4
2
5 from .stream import TCPStatsClient, UnixSocketStatsClient # noqa
3 from .stream import TCPStatsClient, UnixSocketStatsClient # noqa
6 from .udp import StatsClient # noqa
4 from .udp import StatsClient # noqa
7
5
8 HOST = 'localhost'
6 HOST = 'localhost'
9 PORT = 8125
7 PORT = 8125
10 IPV6 = False
8 IPV6 = False
11 PREFIX = None
9 PREFIX = None
12 MAXUDPSIZE = 512
10 MAXUDPSIZE = 512
13
11
14 log = logging.getLogger('rhodecode.statsd')
12 log = logging.getLogger('rhodecode.statsd')
15
13
16
14
17 def statsd_config(config, prefix='statsd.'):
15 def statsd_config(config, prefix='statsd.'):
18 _config = {}
16 _config = {}
19 for key in config.keys():
17 for key in config.keys():
20 if key.startswith(prefix):
18 if key.startswith(prefix):
21 _config[key[len(prefix):]] = config[key]
19 _config[key[len(prefix):]] = config[key]
22 return _config
20 return _config
23
21
24
22
25 def client_from_config(configuration, prefix='statsd.', **kwargs):
23 def client_from_config(configuration, prefix='statsd.', **kwargs):
26 from pyramid.settings import asbool
24 from pyramid.settings import asbool
27
25
28 _config = statsd_config(configuration, prefix)
26 _config = statsd_config(configuration, prefix)
29 statsd_enabled = asbool(_config.pop('enabled', False))
27 statsd_enabled = asbool(_config.pop('enabled', False))
30 if not statsd_enabled:
28 if not statsd_enabled:
31 log.debug('statsd client not enabled by statsd.enabled = flag, skipping...')
29 log.debug('statsd client not enabled by statsd.enabled = flag, skipping...')
32 return
30 return
33
31
34 host = _config.pop('statsd_host', HOST)
32 host = _config.pop('statsd_host', HOST)
35 port = _config.pop('statsd_port', PORT)
33 port = _config.pop('statsd_port', PORT)
36 prefix = _config.pop('statsd_prefix', PREFIX)
34 prefix = _config.pop('statsd_prefix', PREFIX)
37 maxudpsize = _config.pop('statsd_maxudpsize', MAXUDPSIZE)
35 maxudpsize = _config.pop('statsd_maxudpsize', MAXUDPSIZE)
38 ipv6 = asbool(_config.pop('statsd_ipv6', IPV6))
36 ipv6 = asbool(_config.pop('statsd_ipv6', IPV6))
39 log.debug('configured statsd client %s:%s', host, port)
37 log.debug('configured statsd client %s:%s', host, port)
40
38
41 try:
39 try:
42 client = StatsClient(
40 client = StatsClient(
43 host=host, port=port, prefix=prefix, maxudpsize=maxudpsize, ipv6=ipv6)
41 host=host, port=port, prefix=prefix, maxudpsize=maxudpsize, ipv6=ipv6)
44 except Exception:
42 except Exception:
45 log.exception('StatsD is enabled, but failed to connect to statsd server, fallback: disable statsd')
43 log.exception('StatsD is enabled, but failed to connect to statsd server, fallback: disable statsd')
46 client = None
44 client = None
47
45
48 return client
46 return client
49
47
50
48
51 def get_statsd_client(request):
49 def get_statsd_client(request):
52 return client_from_config(request.registry.settings)
50 return client_from_config(request.registry.settings)
@@ -1,156 +1,154 b''
1
2
3 import re
1 import re
4 import random
2 import random
5 from collections import deque
3 from collections import deque
6 from datetime import timedelta
4 from datetime import timedelta
7 from repoze.lru import lru_cache
5 from repoze.lru import lru_cache
8
6
9 from .timer import Timer
7 from .timer import Timer
10
8
11 TAG_INVALID_CHARS_RE = re.compile(
9 TAG_INVALID_CHARS_RE = re.compile(
12 r"[^\w\d_\-:/\.]",
10 r"[^\w\d_\-:/\.]",
13 #re.UNICODE
11 #re.UNICODE
14 )
12 )
15 TAG_INVALID_CHARS_SUBS = "_"
13 TAG_INVALID_CHARS_SUBS = "_"
16
14
17 # we save and expose methods called by statsd for discovery
15 # we save and expose methods called by statsd for discovery
18 buckets_dict = {
16 buckets_dict = {
19
17
20 }
18 }
21
19
22
20
23 @lru_cache(maxsize=500)
21 @lru_cache(maxsize=500)
24 def _normalize_tags_with_cache(tag_list):
22 def _normalize_tags_with_cache(tag_list):
25 return [TAG_INVALID_CHARS_RE.sub(TAG_INVALID_CHARS_SUBS, tag) for tag in tag_list]
23 return [TAG_INVALID_CHARS_RE.sub(TAG_INVALID_CHARS_SUBS, tag) for tag in tag_list]
26
24
27
25
28 def normalize_tags(tag_list):
26 def normalize_tags(tag_list):
29 # We have to turn our input tag list into a non-mutable tuple for it to
27 # We have to turn our input tag list into a non-mutable tuple for it to
30 # be hashable (and thus usable) by the @lru_cache decorator.
28 # be hashable (and thus usable) by the @lru_cache decorator.
31 return _normalize_tags_with_cache(tuple(tag_list))
29 return _normalize_tags_with_cache(tuple(tag_list))
32
30
33
31
34 class StatsClientBase(object):
32 class StatsClientBase(object):
35 """A Base class for various statsd clients."""
33 """A Base class for various statsd clients."""
36
34
37 def close(self):
35 def close(self):
38 """Used to close and clean up any underlying resources."""
36 """Used to close and clean up any underlying resources."""
39 raise NotImplementedError()
37 raise NotImplementedError()
40
38
41 def _send(self):
39 def _send(self):
42 raise NotImplementedError()
40 raise NotImplementedError()
43
41
44 def pipeline(self):
42 def pipeline(self):
45 raise NotImplementedError()
43 raise NotImplementedError()
46
44
47 def timer(self, stat, rate=1, tags=None, auto_send=True):
45 def timer(self, stat, rate=1, tags=None, auto_send=True):
48 """
46 """
49 statsd = StatsdClient.statsd
47 statsd = StatsdClient.statsd
50 with statsd.timer('bucket_name', auto_send=True) as tmr:
48 with statsd.timer('bucket_name', auto_send=True) as tmr:
51 # This block will be timed.
49 # This block will be timed.
52 for i in range(0, 100000):
50 for i in range(0, 100000):
53 i ** 2
51 i ** 2
54 # you can access time here...
52 # you can access time here...
55 elapsed_ms = tmr.ms
53 elapsed_ms = tmr.ms
56 """
54 """
57 return Timer(self, stat, rate, tags, auto_send=auto_send)
55 return Timer(self, stat, rate, tags, auto_send=auto_send)
58
56
59 def timing(self, stat, delta, rate=1, tags=None, use_decimals=True):
57 def timing(self, stat, delta, rate=1, tags=None, use_decimals=True):
60 """
58 """
61 Send new timing information.
59 Send new timing information.
62
60
63 `delta` can be either a number of milliseconds or a timedelta.
61 `delta` can be either a number of milliseconds or a timedelta.
64 """
62 """
65 if isinstance(delta, timedelta):
63 if isinstance(delta, timedelta):
66 # Convert timedelta to number of milliseconds.
64 # Convert timedelta to number of milliseconds.
67 delta = delta.total_seconds() * 1000.
65 delta = delta.total_seconds() * 1000.
68 if use_decimals:
66 if use_decimals:
69 fmt = '%0.6f|ms'
67 fmt = '%0.6f|ms'
70 else:
68 else:
71 fmt = '%s|ms'
69 fmt = '%s|ms'
72 self._send_stat(stat, fmt % delta, rate, tags)
70 self._send_stat(stat, fmt % delta, rate, tags)
73
71
74 def incr(self, stat, count=1, rate=1, tags=None):
72 def incr(self, stat, count=1, rate=1, tags=None):
75 """Increment a stat by `count`."""
73 """Increment a stat by `count`."""
76 self._send_stat(stat, '%s|c' % count, rate, tags)
74 self._send_stat(stat, '%s|c' % count, rate, tags)
77
75
78 def decr(self, stat, count=1, rate=1, tags=None):
76 def decr(self, stat, count=1, rate=1, tags=None):
79 """Decrement a stat by `count`."""
77 """Decrement a stat by `count`."""
80 self.incr(stat, -count, rate, tags)
78 self.incr(stat, -count, rate, tags)
81
79
82 def gauge(self, stat, value, rate=1, delta=False, tags=None):
80 def gauge(self, stat, value, rate=1, delta=False, tags=None):
83 """Set a gauge value."""
81 """Set a gauge value."""
84 if value < 0 and not delta:
82 if value < 0 and not delta:
85 if rate < 1:
83 if rate < 1:
86 if random.random() > rate:
84 if random.random() > rate:
87 return
85 return
88 with self.pipeline() as pipe:
86 with self.pipeline() as pipe:
89 pipe._send_stat(stat, '0|g', 1)
87 pipe._send_stat(stat, '0|g', 1)
90 pipe._send_stat(stat, '%s|g' % value, 1)
88 pipe._send_stat(stat, '%s|g' % value, 1)
91 else:
89 else:
92 prefix = '+' if delta and value >= 0 else ''
90 prefix = '+' if delta and value >= 0 else ''
93 self._send_stat(stat, '%s%s|g' % (prefix, value), rate, tags)
91 self._send_stat(stat, '%s%s|g' % (prefix, value), rate, tags)
94
92
95 def set(self, stat, value, rate=1):
93 def set(self, stat, value, rate=1):
96 """Set a set value."""
94 """Set a set value."""
97 self._send_stat(stat, '%s|s' % value, rate)
95 self._send_stat(stat, '%s|s' % value, rate)
98
96
99 def histogram(self, stat, value, rate=1, tags=None):
97 def histogram(self, stat, value, rate=1, tags=None):
100 """Set a histogram"""
98 """Set a histogram"""
101 self._send_stat(stat, '%s|h' % value, rate, tags)
99 self._send_stat(stat, '%s|h' % value, rate, tags)
102
100
103 def _send_stat(self, stat, value, rate, tags=None):
101 def _send_stat(self, stat, value, rate, tags=None):
104 self._after(self._prepare(stat, value, rate, tags))
102 self._after(self._prepare(stat, value, rate, tags))
105
103
106 def _prepare(self, stat, value, rate, tags=None):
104 def _prepare(self, stat, value, rate, tags=None):
107 global buckets_dict
105 global buckets_dict
108 buckets_dict[stat] = 1
106 buckets_dict[stat] = 1
109
107
110 if rate < 1:
108 if rate < 1:
111 if random.random() > rate:
109 if random.random() > rate:
112 return
110 return
113 value = '%s|@%s' % (value, rate)
111 value = '%s|@%s' % (value, rate)
114
112
115 if self._prefix:
113 if self._prefix:
116 stat = '%s.%s' % (self._prefix, stat)
114 stat = '%s.%s' % (self._prefix, stat)
117
115
118 res = '%s:%s%s' % (
116 res = '%s:%s%s' % (
119 stat,
117 stat,
120 value,
118 value,
121 ("|#" + ",".join(normalize_tags(tags))) if tags else "",
119 ("|#" + ",".join(normalize_tags(tags))) if tags else "",
122 )
120 )
123 return res
121 return res
124
122
125 def _after(self, data):
123 def _after(self, data):
126 if data:
124 if data:
127 self._send(data)
125 self._send(data)
128
126
129
127
130 class PipelineBase(StatsClientBase):
128 class PipelineBase(StatsClientBase):
131
129
132 def __init__(self, client):
130 def __init__(self, client):
133 self._client = client
131 self._client = client
134 self._prefix = client._prefix
132 self._prefix = client._prefix
135 self._stats = deque()
133 self._stats = deque()
136
134
137 def _send(self):
135 def _send(self):
138 raise NotImplementedError()
136 raise NotImplementedError()
139
137
140 def _after(self, data):
138 def _after(self, data):
141 if data is not None:
139 if data is not None:
142 self._stats.append(data)
140 self._stats.append(data)
143
141
144 def __enter__(self):
142 def __enter__(self):
145 return self
143 return self
146
144
147 def __exit__(self, typ, value, tb):
145 def __exit__(self, typ, value, tb):
148 self.send()
146 self.send()
149
147
150 def send(self):
148 def send(self):
151 if not self._stats:
149 if not self._stats:
152 return
150 return
153 self._send()
151 self._send()
154
152
155 def pipeline(self):
153 def pipeline(self):
156 return self.__class__(self)
154 return self.__class__(self)
@@ -1,75 +1,73 b''
1
2
3 import socket
1 import socket
4
2
5 from .base import StatsClientBase, PipelineBase
3 from .base import StatsClientBase, PipelineBase
6
4
7
5
8 class StreamPipeline(PipelineBase):
6 class StreamPipeline(PipelineBase):
9 def _send(self):
7 def _send(self):
10 self._client._after('\n'.join(self._stats))
8 self._client._after('\n'.join(self._stats))
11 self._stats.clear()
9 self._stats.clear()
12
10
13
11
14 class StreamClientBase(StatsClientBase):
12 class StreamClientBase(StatsClientBase):
15 def connect(self):
13 def connect(self):
16 raise NotImplementedError()
14 raise NotImplementedError()
17
15
18 def close(self):
16 def close(self):
19 if self._sock and hasattr(self._sock, 'close'):
17 if self._sock and hasattr(self._sock, 'close'):
20 self._sock.close()
18 self._sock.close()
21 self._sock = None
19 self._sock = None
22
20
23 def reconnect(self):
21 def reconnect(self):
24 self.close()
22 self.close()
25 self.connect()
23 self.connect()
26
24
27 def pipeline(self):
25 def pipeline(self):
28 return StreamPipeline(self)
26 return StreamPipeline(self)
29
27
30 def _send(self, data):
28 def _send(self, data):
31 """Send data to statsd."""
29 """Send data to statsd."""
32 if not self._sock:
30 if not self._sock:
33 self.connect()
31 self.connect()
34 self._do_send(data)
32 self._do_send(data)
35
33
36 def _do_send(self, data):
34 def _do_send(self, data):
37 self._sock.sendall(data.encode('ascii') + b'\n')
35 self._sock.sendall(data.encode('ascii') + b'\n')
38
36
39
37
40 class TCPStatsClient(StreamClientBase):
38 class TCPStatsClient(StreamClientBase):
41 """TCP version of StatsClient."""
39 """TCP version of StatsClient."""
42
40
43 def __init__(self, host='localhost', port=8125, prefix=None,
41 def __init__(self, host='localhost', port=8125, prefix=None,
44 timeout=None, ipv6=False):
42 timeout=None, ipv6=False):
45 """Create a new client."""
43 """Create a new client."""
46 self._host = host
44 self._host = host
47 self._port = port
45 self._port = port
48 self._ipv6 = ipv6
46 self._ipv6 = ipv6
49 self._timeout = timeout
47 self._timeout = timeout
50 self._prefix = prefix
48 self._prefix = prefix
51 self._sock = None
49 self._sock = None
52
50
53 def connect(self):
51 def connect(self):
54 fam = socket.AF_INET6 if self._ipv6 else socket.AF_INET
52 fam = socket.AF_INET6 if self._ipv6 else socket.AF_INET
55 family, _, _, _, addr = socket.getaddrinfo(
53 family, _, _, _, addr = socket.getaddrinfo(
56 self._host, self._port, fam, socket.SOCK_STREAM)[0]
54 self._host, self._port, fam, socket.SOCK_STREAM)[0]
57 self._sock = socket.socket(family, socket.SOCK_STREAM)
55 self._sock = socket.socket(family, socket.SOCK_STREAM)
58 self._sock.settimeout(self._timeout)
56 self._sock.settimeout(self._timeout)
59 self._sock.connect(addr)
57 self._sock.connect(addr)
60
58
61
59
62 class UnixSocketStatsClient(StreamClientBase):
60 class UnixSocketStatsClient(StreamClientBase):
63 """Unix domain socket version of StatsClient."""
61 """Unix domain socket version of StatsClient."""
64
62
65 def __init__(self, socket_path, prefix=None, timeout=None):
63 def __init__(self, socket_path, prefix=None, timeout=None):
66 """Create a new client."""
64 """Create a new client."""
67 self._socket_path = socket_path
65 self._socket_path = socket_path
68 self._timeout = timeout
66 self._timeout = timeout
69 self._prefix = prefix
67 self._prefix = prefix
70 self._sock = None
68 self._sock = None
71
69
72 def connect(self):
70 def connect(self):
73 self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
71 self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
74 self._sock.settimeout(self._timeout)
72 self._sock.settimeout(self._timeout)
75 self._sock.connect(self._socket_path)
73 self._sock.connect(self._socket_path)
@@ -1,68 +1,66 b''
1
2
3 import functools
1 import functools
4 from time import perf_counter as time_now
2 from time import perf_counter as time_now
5
3
6
4
7 def safe_wraps(wrapper, *args, **kwargs):
5 def safe_wraps(wrapper, *args, **kwargs):
8 """Safely wraps partial functions."""
6 """Safely wraps partial functions."""
9 while isinstance(wrapper, functools.partial):
7 while isinstance(wrapper, functools.partial):
10 wrapper = wrapper.func
8 wrapper = wrapper.func
11 return functools.wraps(wrapper, *args, **kwargs)
9 return functools.wraps(wrapper, *args, **kwargs)
12
10
13
11
14 class Timer(object):
12 class Timer(object):
15 """A context manager/decorator for statsd.timing()."""
13 """A context manager/decorator for statsd.timing()."""
16
14
17 def __init__(self, client, stat, rate=1, tags=None, use_decimals=True, auto_send=True):
15 def __init__(self, client, stat, rate=1, tags=None, use_decimals=True, auto_send=True):
18 self.client = client
16 self.client = client
19 self.stat = stat
17 self.stat = stat
20 self.rate = rate
18 self.rate = rate
21 self.tags = tags
19 self.tags = tags
22 self.ms = None
20 self.ms = None
23 self._sent = False
21 self._sent = False
24 self._start_time = None
22 self._start_time = None
25 self.use_decimals = use_decimals
23 self.use_decimals = use_decimals
26 self.auto_send = auto_send
24 self.auto_send = auto_send
27
25
28 def __call__(self, f):
26 def __call__(self, f):
29 """Thread-safe timing function decorator."""
27 """Thread-safe timing function decorator."""
30 @safe_wraps(f)
28 @safe_wraps(f)
31 def _wrapped(*args, **kwargs):
29 def _wrapped(*args, **kwargs):
32 start_time = time_now()
30 start_time = time_now()
33 try:
31 try:
34 return f(*args, **kwargs)
32 return f(*args, **kwargs)
35 finally:
33 finally:
36 elapsed_time_ms = 1000.0 * (time_now() - start_time)
34 elapsed_time_ms = 1000.0 * (time_now() - start_time)
37 self.client.timing(self.stat, elapsed_time_ms, self.rate, self.tags, self.use_decimals)
35 self.client.timing(self.stat, elapsed_time_ms, self.rate, self.tags, self.use_decimals)
38 self._sent = True
36 self._sent = True
39 return _wrapped
37 return _wrapped
40
38
41 def __enter__(self):
39 def __enter__(self):
42 return self.start()
40 return self.start()
43
41
44 def __exit__(self, typ, value, tb):
42 def __exit__(self, typ, value, tb):
45 self.stop(send=self.auto_send)
43 self.stop(send=self.auto_send)
46
44
47 def start(self):
45 def start(self):
48 self.ms = None
46 self.ms = None
49 self._sent = False
47 self._sent = False
50 self._start_time = time_now()
48 self._start_time = time_now()
51 return self
49 return self
52
50
53 def stop(self, send=True):
51 def stop(self, send=True):
54 if self._start_time is None:
52 if self._start_time is None:
55 raise RuntimeError('Timer has not started.')
53 raise RuntimeError('Timer has not started.')
56 dt = time_now() - self._start_time
54 dt = time_now() - self._start_time
57 self.ms = 1000.0 * dt # Convert to milliseconds.
55 self.ms = 1000.0 * dt # Convert to milliseconds.
58 if send:
56 if send:
59 self.send()
57 self.send()
60 return self
58 return self
61
59
62 def send(self):
60 def send(self):
63 if self.ms is None:
61 if self.ms is None:
64 raise RuntimeError('No data recorded.')
62 raise RuntimeError('No data recorded.')
65 if self._sent:
63 if self._sent:
66 raise RuntimeError('Already sent data.')
64 raise RuntimeError('Already sent data.')
67 self._sent = True
65 self._sent = True
68 self.client.timing(self.stat, self.ms, self.rate, self.tags, self.use_decimals)
66 self.client.timing(self.stat, self.ms, self.rate, self.tags, self.use_decimals)
@@ -1,55 +1,53 b''
1
2
3 import socket
1 import socket
4
2
5 from .base import StatsClientBase, PipelineBase
3 from .base import StatsClientBase, PipelineBase
6
4
7
5
8 class Pipeline(PipelineBase):
6 class Pipeline(PipelineBase):
9
7
10 def __init__(self, client):
8 def __init__(self, client):
11 super(Pipeline, self).__init__(client)
9 super().__init__(client)
12 self._maxudpsize = client._maxudpsize
10 self._maxudpsize = client._maxudpsize
13
11
14 def _send(self):
12 def _send(self):
15 data = self._stats.popleft()
13 data = self._stats.popleft()
16 while self._stats:
14 while self._stats:
17 # Use popleft to preserve the order of the stats.
15 # Use popleft to preserve the order of the stats.
18 stat = self._stats.popleft()
16 stat = self._stats.popleft()
19 if len(stat) + len(data) + 1 >= self._maxudpsize:
17 if len(stat) + len(data) + 1 >= self._maxudpsize:
20 self._client._after(data)
18 self._client._after(data)
21 data = stat
19 data = stat
22 else:
20 else:
23 data += '\n' + stat
21 data += '\n' + stat
24 self._client._after(data)
22 self._client._after(data)
25
23
26
24
27 class StatsClient(StatsClientBase):
25 class StatsClient(StatsClientBase):
28 """A client for statsd."""
26 """A client for statsd."""
29
27
30 def __init__(self, host='localhost', port=8125, prefix=None,
28 def __init__(self, host='localhost', port=8125, prefix=None,
31 maxudpsize=512, ipv6=False):
29 maxudpsize=512, ipv6=False):
32 """Create a new client."""
30 """Create a new client."""
33 fam = socket.AF_INET6 if ipv6 else socket.AF_INET
31 fam = socket.AF_INET6 if ipv6 else socket.AF_INET
34 family, _, _, _, addr = socket.getaddrinfo(
32 family, _, _, _, addr = socket.getaddrinfo(
35 host, port, fam, socket.SOCK_DGRAM)[0]
33 host, port, fam, socket.SOCK_DGRAM)[0]
36 self._addr = addr
34 self._addr = addr
37 self._sock = socket.socket(family, socket.SOCK_DGRAM)
35 self._sock = socket.socket(family, socket.SOCK_DGRAM)
38 self._prefix = prefix
36 self._prefix = prefix
39 self._maxudpsize = maxudpsize
37 self._maxudpsize = maxudpsize
40
38
41 def _send(self, data):
39 def _send(self, data):
42 """Send data to statsd."""
40 """Send data to statsd."""
43 try:
41 try:
44 self._sock.sendto(data.encode('ascii'), self._addr)
42 self._sock.sendto(data.encode('ascii'), self._addr)
45 except (socket.error, RuntimeError):
43 except (socket.error, RuntimeError):
46 # No time for love, Dr. Jones!
44 # No time for love, Dr. Jones!
47 pass
45 pass
48
46
49 def close(self):
47 def close(self):
50 if self._sock and hasattr(self._sock, 'close'):
48 if self._sock and hasattr(self._sock, 'close'):
51 self._sock.close()
49 self._sock.close()
52 self._sock = None
50 self._sock = None
53
51
54 def pipeline(self):
52 def pipeline(self):
55 return Pipeline(self)
53 return Pipeline(self)
@@ -1,53 +1,53 b''
1 # Copyright (C) 2010-2023 RhodeCode GmbH
1 # Copyright (C) 2010-2023 RhodeCode GmbH
2 #
2 #
3 # This program is free software: you can redistribute it and/or modify
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU Affero General Public License, version 3
4 # it under the terms of the GNU Affero General Public License, version 3
5 # (only), as published by the Free Software Foundation.
5 # (only), as published by the Free Software Foundation.
6 #
6 #
7 # This program is distributed in the hope that it will be useful,
7 # This program is distributed in the hope that it will be useful,
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # GNU General Public License for more details.
10 # GNU General Public License for more details.
11 #
11 #
12 # You should have received a copy of the GNU Affero General Public License
12 # You should have received a copy of the GNU Affero General Public License
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 #
14 #
15 # This program is dual-licensed. If you wish to learn more about the
15 # This program is dual-licensed. If you wish to learn more about the
16 # RhodeCode Enterprise Edition, including its added features, Support services,
16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18
18
19 import sys
19 import sys
20 import logging
20 import logging
21
21
22
22
23 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = list(range(30, 38))
23 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = list(range(30, 38))
24
24
25 # Sequences
25 # Sequences
26 RESET_SEQ = "\033[0m"
26 RESET_SEQ = "\033[0m"
27 COLOR_SEQ = "\033[0;%dm"
27 COLOR_SEQ = "\033[0;%dm"
28 BOLD_SEQ = "\033[1m"
28 BOLD_SEQ = "\033[1m"
29
29
30 COLORS = {
30 COLORS = {
31 'CRITICAL': MAGENTA,
31 'CRITICAL': MAGENTA,
32 'ERROR': RED,
32 'ERROR': RED,
33 'WARNING': CYAN,
33 'WARNING': CYAN,
34 'INFO': GREEN,
34 'INFO': GREEN,
35 'DEBUG': BLUE,
35 'DEBUG': BLUE,
36 'SQL': YELLOW
36 'SQL': YELLOW
37 }
37 }
38
38
39
39
40 class ColorFormatter(logging.Formatter):
40 class ColorFormatter(logging.Formatter):
41
41
42 def format(self, record):
42 def format(self, record):
43 """
43 """
44 Change record's levelname to use with COLORS enum
44 Change record's levelname to use with COLORS enum
45 """
45 """
46 def_record = super(ColorFormatter, self).format(record)
46 def_record = super().format(record)
47
47
48 levelname = record.levelname
48 levelname = record.levelname
49 start = COLOR_SEQ % (COLORS[levelname])
49 start = COLOR_SEQ % (COLORS[levelname])
50 end = RESET_SEQ
50 end = RESET_SEQ
51
51
52 colored_record = ''.join([start, def_record, end])
52 colored_record = ''.join([start, def_record, end])
53 return colored_record
53 return colored_record
@@ -1,63 +1,63 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18
18
19 import logging
19 import logging
20
20
21 from repoze.lru import LRUCache
21 from repoze.lru import LRUCache
22
22
23 from vcsserver.str_utils import safe_str
23 from vcsserver.str_utils import safe_str
24
24
25 log = logging.getLogger(__name__)
25 log = logging.getLogger(__name__)
26
26
27
27
28 class LRUDict(LRUCache):
28 class LRUDict(LRUCache):
29 """
29 """
30 Wrapper to provide partial dict access
30 Wrapper to provide partial dict access
31 """
31 """
32
32
33 def __setitem__(self, key, value):
33 def __setitem__(self, key, value):
34 return self.put(key, value)
34 return self.put(key, value)
35
35
36 def __getitem__(self, key):
36 def __getitem__(self, key):
37 return self.get(key)
37 return self.get(key)
38
38
39 def __contains__(self, key):
39 def __contains__(self, key):
40 return bool(self.get(key))
40 return bool(self.get(key))
41
41
42 def __delitem__(self, key):
42 def __delitem__(self, key):
43 del self.data[key]
43 del self.data[key]
44
44
45 def keys(self):
45 def keys(self):
46 return list(self.data.keys())
46 return list(self.data.keys())
47
47
48
48
49 class LRUDictDebug(LRUDict):
49 class LRUDictDebug(LRUDict):
50 """
50 """
51 Wrapper to provide some debug options
51 Wrapper to provide some debug options
52 """
52 """
53 def _report_keys(self):
53 def _report_keys(self):
54 elems_cnt = '{}/{}'.format(len(list(self.keys())), self.size)
54 elems_cnt = f'{len(list(self.keys()))}/{self.size}'
55 # trick for pformat print it more nicely
55 # trick for pformat print it more nicely
56 fmt = '\n'
56 fmt = '\n'
57 for cnt, elem in enumerate(self.keys()):
57 for cnt, elem in enumerate(self.keys()):
58 fmt += '{} - {}\n'.format(cnt+1, safe_str(elem))
58 fmt += f'{cnt+1} - {safe_str(elem)}\n'
59 log.debug('current LRU keys (%s):%s', elems_cnt, fmt)
59 log.debug('current LRU keys (%s):%s', elems_cnt, fmt)
60
60
61 def __getitem__(self, key):
61 def __getitem__(self, key):
62 self._report_keys()
62 self._report_keys()
63 return self.get(key)
63 return self.get(key)
@@ -1,247 +1,247 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import functools
18 import functools
19 import logging
19 import logging
20 import os
20 import os
21 import threading
21 import threading
22 import time
22 import time
23
23
24 import decorator
24 import decorator
25 from dogpile.cache import CacheRegion
25 from dogpile.cache import CacheRegion
26
26
27
27
28 from vcsserver.utils import sha1
28 from vcsserver.utils import sha1
29 from vcsserver.str_utils import safe_bytes
29 from vcsserver.str_utils import safe_bytes
30 from vcsserver.type_utils import str2bool
30 from vcsserver.type_utils import str2bool
31
31
32 from . import region_meta
32 from . import region_meta
33
33
34 log = logging.getLogger(__name__)
34 log = logging.getLogger(__name__)
35
35
36
36
37 class RhodeCodeCacheRegion(CacheRegion):
37 class RhodeCodeCacheRegion(CacheRegion):
38
38
39 def __repr__(self):
39 def __repr__(self):
40 return f'{self.__class__}(name={self.name})'
40 return f'{self.__class__}(name={self.name})'
41
41
42 def conditional_cache_on_arguments(
42 def conditional_cache_on_arguments(
43 self, namespace=None,
43 self, namespace=None,
44 expiration_time=None,
44 expiration_time=None,
45 should_cache_fn=None,
45 should_cache_fn=None,
46 to_str=str,
46 to_str=str,
47 function_key_generator=None,
47 function_key_generator=None,
48 condition=True):
48 condition=True):
49 """
49 """
50 Custom conditional decorator, that will not touch any dogpile internals if
50 Custom conditional decorator, that will not touch any dogpile internals if
51 condition isn't meet. This works a bit different from should_cache_fn
51 condition isn't meet. This works a bit different from should_cache_fn
52 And it's faster in cases we don't ever want to compute cached values
52 And it's faster in cases we don't ever want to compute cached values
53 """
53 """
54 expiration_time_is_callable = callable(expiration_time)
54 expiration_time_is_callable = callable(expiration_time)
55 if not namespace:
55 if not namespace:
56 namespace = getattr(self, '_default_namespace', None)
56 namespace = getattr(self, '_default_namespace', None)
57
57
58 if function_key_generator is None:
58 if function_key_generator is None:
59 function_key_generator = self.function_key_generator
59 function_key_generator = self.function_key_generator
60
60
61 def get_or_create_for_user_func(func_key_generator, user_func, *arg, **kw):
61 def get_or_create_for_user_func(func_key_generator, user_func, *arg, **kw):
62
62
63 if not condition:
63 if not condition:
64 log.debug('Calling un-cached method:%s', user_func.__name__)
64 log.debug('Calling un-cached method:%s', user_func.__name__)
65 start = time.time()
65 start = time.time()
66 result = user_func(*arg, **kw)
66 result = user_func(*arg, **kw)
67 total = time.time() - start
67 total = time.time() - start
68 log.debug('un-cached method:%s took %.4fs', user_func.__name__, total)
68 log.debug('un-cached method:%s took %.4fs', user_func.__name__, total)
69 return result
69 return result
70
70
71 key = func_key_generator(*arg, **kw)
71 key = func_key_generator(*arg, **kw)
72
72
73 timeout = expiration_time() if expiration_time_is_callable \
73 timeout = expiration_time() if expiration_time_is_callable \
74 else expiration_time
74 else expiration_time
75
75
76 log.debug('Calling cached method:`%s`', user_func.__name__)
76 log.debug('Calling cached method:`%s`', user_func.__name__)
77 return self.get_or_create(key, user_func, timeout, should_cache_fn, (arg, kw))
77 return self.get_or_create(key, user_func, timeout, should_cache_fn, (arg, kw))
78
78
79 def cache_decorator(user_func):
79 def cache_decorator(user_func):
80 if to_str is str:
80 if to_str is str:
81 # backwards compatible
81 # backwards compatible
82 key_generator = function_key_generator(namespace, user_func)
82 key_generator = function_key_generator(namespace, user_func)
83 else:
83 else:
84 key_generator = function_key_generator(namespace, user_func, to_str=to_str)
84 key_generator = function_key_generator(namespace, user_func, to_str=to_str)
85
85
86 def refresh(*arg, **kw):
86 def refresh(*arg, **kw):
87 """
87 """
88 Like invalidate, but regenerates the value instead
88 Like invalidate, but regenerates the value instead
89 """
89 """
90 key = key_generator(*arg, **kw)
90 key = key_generator(*arg, **kw)
91 value = user_func(*arg, **kw)
91 value = user_func(*arg, **kw)
92 self.set(key, value)
92 self.set(key, value)
93 return value
93 return value
94
94
95 def invalidate(*arg, **kw):
95 def invalidate(*arg, **kw):
96 key = key_generator(*arg, **kw)
96 key = key_generator(*arg, **kw)
97 self.delete(key)
97 self.delete(key)
98
98
99 def set_(value, *arg, **kw):
99 def set_(value, *arg, **kw):
100 key = key_generator(*arg, **kw)
100 key = key_generator(*arg, **kw)
101 self.set(key, value)
101 self.set(key, value)
102
102
103 def get(*arg, **kw):
103 def get(*arg, **kw):
104 key = key_generator(*arg, **kw)
104 key = key_generator(*arg, **kw)
105 return self.get(key)
105 return self.get(key)
106
106
107 user_func.set = set_
107 user_func.set = set_
108 user_func.invalidate = invalidate
108 user_func.invalidate = invalidate
109 user_func.get = get
109 user_func.get = get
110 user_func.refresh = refresh
110 user_func.refresh = refresh
111 user_func.key_generator = key_generator
111 user_func.key_generator = key_generator
112 user_func.original = user_func
112 user_func.original = user_func
113
113
114 # Use `decorate` to preserve the signature of :param:`user_func`.
114 # Use `decorate` to preserve the signature of :param:`user_func`.
115 return decorator.decorate(user_func, functools.partial(
115 return decorator.decorate(user_func, functools.partial(
116 get_or_create_for_user_func, key_generator))
116 get_or_create_for_user_func, key_generator))
117
117
118 return cache_decorator
118 return cache_decorator
119
119
120
120
121 def make_region(*arg, **kw):
121 def make_region(*arg, **kw):
122 return RhodeCodeCacheRegion(*arg, **kw)
122 return RhodeCodeCacheRegion(*arg, **kw)
123
123
124
124
125 def get_default_cache_settings(settings, prefixes=None):
125 def get_default_cache_settings(settings, prefixes=None):
126 prefixes = prefixes or []
126 prefixes = prefixes or []
127 cache_settings = {}
127 cache_settings = {}
128 for key in settings.keys():
128 for key in settings.keys():
129 for prefix in prefixes:
129 for prefix in prefixes:
130 if key.startswith(prefix):
130 if key.startswith(prefix):
131 name = key.split(prefix)[1].strip()
131 name = key.split(prefix)[1].strip()
132 val = settings[key]
132 val = settings[key]
133 if isinstance(val, str):
133 if isinstance(val, str):
134 val = val.strip()
134 val = val.strip()
135 cache_settings[name] = val
135 cache_settings[name] = val
136 return cache_settings
136 return cache_settings
137
137
138
138
139 def compute_key_from_params(*args):
139 def compute_key_from_params(*args):
140 """
140 """
141 Helper to compute key from given params to be used in cache manager
141 Helper to compute key from given params to be used in cache manager
142 """
142 """
143 return sha1(safe_bytes("_".join(map(str, args))))
143 return sha1(safe_bytes("_".join(map(str, args))))
144
144
145
145
146 def custom_key_generator(backend, namespace, fn):
146 def custom_key_generator(backend, namespace, fn):
147 func_name = fn.__name__
147 func_name = fn.__name__
148
148
149 def generate_key(*args):
149 def generate_key(*args):
150 backend_pref = getattr(backend, 'key_prefix', None) or 'backend_prefix'
150 backend_pref = getattr(backend, 'key_prefix', None) or 'backend_prefix'
151 namespace_pref = namespace or 'default_namespace'
151 namespace_pref = namespace or 'default_namespace'
152 arg_key = compute_key_from_params(*args)
152 arg_key = compute_key_from_params(*args)
153 final_key = f"{backend_pref}:{namespace_pref}:{func_name}_{arg_key}"
153 final_key = f"{backend_pref}:{namespace_pref}:{func_name}_{arg_key}"
154
154
155 return final_key
155 return final_key
156
156
157 return generate_key
157 return generate_key
158
158
159
159
160 def backend_key_generator(backend):
160 def backend_key_generator(backend):
161 """
161 """
162 Special wrapper that also sends over the backend to the key generator
162 Special wrapper that also sends over the backend to the key generator
163 """
163 """
164 def wrapper(namespace, fn):
164 def wrapper(namespace, fn):
165 return custom_key_generator(backend, namespace, fn)
165 return custom_key_generator(backend, namespace, fn)
166 return wrapper
166 return wrapper
167
167
168
168
169 def get_or_create_region(region_name, region_namespace: str = None, use_async_runner=False):
169 def get_or_create_region(region_name, region_namespace: str = None, use_async_runner=False):
170 from .backends import FileNamespaceBackend
170 from .backends import FileNamespaceBackend
171 from . import async_creation_runner
171 from . import async_creation_runner
172
172
173 region_obj = region_meta.dogpile_cache_regions.get(region_name)
173 region_obj = region_meta.dogpile_cache_regions.get(region_name)
174 if not region_obj:
174 if not region_obj:
175 reg_keys = list(region_meta.dogpile_cache_regions.keys())
175 reg_keys = list(region_meta.dogpile_cache_regions.keys())
176 raise EnvironmentError(f'Region `{region_name}` not in configured: {reg_keys}.')
176 raise OSError(f'Region `{region_name}` not in configured: {reg_keys}.')
177
177
178 region_uid_name = f'{region_name}:{region_namespace}'
178 region_uid_name = f'{region_name}:{region_namespace}'
179
179
180 if isinstance(region_obj.actual_backend, FileNamespaceBackend):
180 if isinstance(region_obj.actual_backend, FileNamespaceBackend):
181 if not region_namespace:
181 if not region_namespace:
182 raise ValueError(f'{FileNamespaceBackend} used requires to specify region_namespace param')
182 raise ValueError(f'{FileNamespaceBackend} used requires to specify region_namespace param')
183
183
184 region_exist = region_meta.dogpile_cache_regions.get(region_namespace)
184 region_exist = region_meta.dogpile_cache_regions.get(region_namespace)
185 if region_exist:
185 if region_exist:
186 log.debug('Using already configured region: %s', region_namespace)
186 log.debug('Using already configured region: %s', region_namespace)
187 return region_exist
187 return region_exist
188
188
189 expiration_time = region_obj.expiration_time
189 expiration_time = region_obj.expiration_time
190
190
191 cache_dir = region_meta.dogpile_config_defaults['cache_dir']
191 cache_dir = region_meta.dogpile_config_defaults['cache_dir']
192 namespace_cache_dir = cache_dir
192 namespace_cache_dir = cache_dir
193
193
194 # we default the namespace_cache_dir to our default cache dir.
194 # we default the namespace_cache_dir to our default cache dir.
195 # however if this backend is configured with filename= param, we prioritize that
195 # however if this backend is configured with filename= param, we prioritize that
196 # so all caches within that particular region, even those namespaced end up in the same path
196 # so all caches within that particular region, even those namespaced end up in the same path
197 if region_obj.actual_backend.filename:
197 if region_obj.actual_backend.filename:
198 namespace_cache_dir = os.path.dirname(region_obj.actual_backend.filename)
198 namespace_cache_dir = os.path.dirname(region_obj.actual_backend.filename)
199
199
200 if not os.path.isdir(namespace_cache_dir):
200 if not os.path.isdir(namespace_cache_dir):
201 os.makedirs(namespace_cache_dir)
201 os.makedirs(namespace_cache_dir)
202 new_region = make_region(
202 new_region = make_region(
203 name=region_uid_name,
203 name=region_uid_name,
204 function_key_generator=backend_key_generator(region_obj.actual_backend)
204 function_key_generator=backend_key_generator(region_obj.actual_backend)
205 )
205 )
206
206
207 namespace_filename = os.path.join(
207 namespace_filename = os.path.join(
208 namespace_cache_dir, f"{region_name}_{region_namespace}.cache_db")
208 namespace_cache_dir, f"{region_name}_{region_namespace}.cache_db")
209 # special type that allows 1db per namespace
209 # special type that allows 1db per namespace
210 new_region.configure(
210 new_region.configure(
211 backend='dogpile.cache.rc.file_namespace',
211 backend='dogpile.cache.rc.file_namespace',
212 expiration_time=expiration_time,
212 expiration_time=expiration_time,
213 arguments={"filename": namespace_filename}
213 arguments={"filename": namespace_filename}
214 )
214 )
215
215
216 # create and save in region caches
216 # create and save in region caches
217 log.debug('configuring new region: %s', region_uid_name)
217 log.debug('configuring new region: %s', region_uid_name)
218 region_obj = region_meta.dogpile_cache_regions[region_namespace] = new_region
218 region_obj = region_meta.dogpile_cache_regions[region_namespace] = new_region
219
219
220 region_obj._default_namespace = region_namespace
220 region_obj._default_namespace = region_namespace
221 if use_async_runner:
221 if use_async_runner:
222 region_obj.async_creation_runner = async_creation_runner
222 region_obj.async_creation_runner = async_creation_runner
223 return region_obj
223 return region_obj
224
224
225
225
226 def clear_cache_namespace(cache_region: str | RhodeCodeCacheRegion, cache_namespace_uid: str, method: str):
226 def clear_cache_namespace(cache_region: str | RhodeCodeCacheRegion, cache_namespace_uid: str, method: str):
227 from . import CLEAR_DELETE, CLEAR_INVALIDATE
227 from . import CLEAR_DELETE, CLEAR_INVALIDATE
228
228
229 if not isinstance(cache_region, RhodeCodeCacheRegion):
229 if not isinstance(cache_region, RhodeCodeCacheRegion):
230 cache_region = get_or_create_region(cache_region, cache_namespace_uid)
230 cache_region = get_or_create_region(cache_region, cache_namespace_uid)
231 log.debug('clearing cache region: %s with method=%s', cache_region, method)
231 log.debug('clearing cache region: %s with method=%s', cache_region, method)
232
232
233 num_affected_keys = None
233 num_affected_keys = None
234
234
235 if method == CLEAR_INVALIDATE:
235 if method == CLEAR_INVALIDATE:
236 # NOTE: The CacheRegion.invalidate() method’s default mode of
236 # NOTE: The CacheRegion.invalidate() method’s default mode of
237 # operation is to set a timestamp local to this CacheRegion in this Python process only.
237 # operation is to set a timestamp local to this CacheRegion in this Python process only.
238 # It does not impact other Python processes or regions as the timestamp is only stored locally in memory.
238 # It does not impact other Python processes or regions as the timestamp is only stored locally in memory.
239 cache_region.invalidate(hard=True)
239 cache_region.invalidate(hard=True)
240
240
241 if method == CLEAR_DELETE:
241 if method == CLEAR_DELETE:
242 cache_keys = cache_region.backend.list_keys(prefix=cache_namespace_uid)
242 cache_keys = cache_region.backend.list_keys(prefix=cache_namespace_uid)
243 num_affected_keys = len(cache_keys)
243 num_affected_keys = len(cache_keys)
244 if num_affected_keys:
244 if num_affected_keys:
245 cache_region.delete_multi(cache_keys)
245 cache_region.delete_multi(cache_keys)
246
246
247 return num_affected_keys
247 return num_affected_keys
@@ -1,1463 +1,1463 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
18 import collections
18 import collections
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import traceback
23 import traceback
24 import urllib.request
24 import urllib.request
25 import urllib.parse
25 import urllib.parse
26 import urllib.error
26 import urllib.error
27 from functools import wraps
27 from functools import wraps
28
28
29 import more_itertools
29 import more_itertools
30 import pygit2
30 import pygit2
31 from pygit2 import Repository as LibGit2Repo
31 from pygit2 import Repository as LibGit2Repo
32 from pygit2 import index as LibGit2Index
32 from pygit2 import index as LibGit2Index
33 from dulwich import index, objects
33 from dulwich import index, objects
34 from dulwich.client import HttpGitClient, LocalGitClient, FetchPackResult
34 from dulwich.client import HttpGitClient, LocalGitClient, FetchPackResult
35 from dulwich.errors import (
35 from dulwich.errors import (
36 NotGitRepository, ChecksumMismatch, WrongObjectException,
36 NotGitRepository, ChecksumMismatch, WrongObjectException,
37 MissingCommitError, ObjectMissing, HangupException,
37 MissingCommitError, ObjectMissing, HangupException,
38 UnexpectedCommandError)
38 UnexpectedCommandError)
39 from dulwich.repo import Repo as DulwichRepo
39 from dulwich.repo import Repo as DulwichRepo
40 from dulwich.server import update_server_info
40 from dulwich.server import update_server_info
41
41
42 from vcsserver import exceptions, settings, subprocessio
42 from vcsserver import exceptions, settings, subprocessio
43 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes
43 from vcsserver.str_utils import safe_str, safe_int, safe_bytes, ascii_bytes
44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, store_archive_in_cache, BytesEnvelope, BinaryEnvelope
44 from vcsserver.base import RepoFactory, obfuscate_qs, ArchiveNode, store_archive_in_cache, BytesEnvelope, BinaryEnvelope
45 from vcsserver.hgcompat import (
45 from vcsserver.hgcompat import (
46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
46 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler)
47 from vcsserver.git_lfs.lib import LFSOidStore
47 from vcsserver.git_lfs.lib import LFSOidStore
48 from vcsserver.vcs_base import RemoteBase
48 from vcsserver.vcs_base import RemoteBase
49
49
50 DIR_STAT = stat.S_IFDIR
50 DIR_STAT = stat.S_IFDIR
51 FILE_MODE = stat.S_IFMT
51 FILE_MODE = stat.S_IFMT
52 GIT_LINK = objects.S_IFGITLINK
52 GIT_LINK = objects.S_IFGITLINK
53 PEELED_REF_MARKER = b'^{}'
53 PEELED_REF_MARKER = b'^{}'
54 HEAD_MARKER = b'HEAD'
54 HEAD_MARKER = b'HEAD'
55
55
56 log = logging.getLogger(__name__)
56 log = logging.getLogger(__name__)
57
57
58
58
59 def reraise_safe_exceptions(func):
59 def reraise_safe_exceptions(func):
60 """Converts Dulwich exceptions to something neutral."""
60 """Converts Dulwich exceptions to something neutral."""
61
61
62 @wraps(func)
62 @wraps(func)
63 def wrapper(*args, **kwargs):
63 def wrapper(*args, **kwargs):
64 try:
64 try:
65 return func(*args, **kwargs)
65 return func(*args, **kwargs)
66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
66 except (ChecksumMismatch, WrongObjectException, MissingCommitError, ObjectMissing,) as e:
67 exc = exceptions.LookupException(org_exc=e)
67 exc = exceptions.LookupException(org_exc=e)
68 raise exc(safe_str(e))
68 raise exc(safe_str(e))
69 except (HangupException, UnexpectedCommandError) as e:
69 except (HangupException, UnexpectedCommandError) as e:
70 exc = exceptions.VcsException(org_exc=e)
70 exc = exceptions.VcsException(org_exc=e)
71 raise exc(safe_str(e))
71 raise exc(safe_str(e))
72 except Exception:
72 except Exception:
73 # NOTE(marcink): because of how dulwich handles some exceptions
73 # NOTE(marcink): because of how dulwich handles some exceptions
74 # (KeyError on empty repos), we cannot track this and catch all
74 # (KeyError on empty repos), we cannot track this and catch all
75 # exceptions, it's an exceptions from other handlers
75 # exceptions, it's an exceptions from other handlers
76 #if not hasattr(e, '_vcs_kind'):
76 #if not hasattr(e, '_vcs_kind'):
77 #log.exception("Unhandled exception in git remote call")
77 #log.exception("Unhandled exception in git remote call")
78 #raise_from_original(exceptions.UnhandledException)
78 #raise_from_original(exceptions.UnhandledException)
79 raise
79 raise
80 return wrapper
80 return wrapper
81
81
82
82
83 class Repo(DulwichRepo):
83 class Repo(DulwichRepo):
84 """
84 """
85 A wrapper for dulwich Repo class.
85 A wrapper for dulwich Repo class.
86
86
87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
87 Since dulwich is sometimes keeping .idx file descriptors open, it leads to
88 "Too many open files" error. We need to close all opened file descriptors
88 "Too many open files" error. We need to close all opened file descriptors
89 once the repo object is destroyed.
89 once the repo object is destroyed.
90 """
90 """
91 def __del__(self):
91 def __del__(self):
92 if hasattr(self, 'object_store'):
92 if hasattr(self, 'object_store'):
93 self.close()
93 self.close()
94
94
95
95
96 class Repository(LibGit2Repo):
96 class Repository(LibGit2Repo):
97
97
98 def __enter__(self):
98 def __enter__(self):
99 return self
99 return self
100
100
101 def __exit__(self, exc_type, exc_val, exc_tb):
101 def __exit__(self, exc_type, exc_val, exc_tb):
102 self.free()
102 self.free()
103
103
104
104
105 class GitFactory(RepoFactory):
105 class GitFactory(RepoFactory):
106 repo_type = 'git'
106 repo_type = 'git'
107
107
108 def _create_repo(self, wire, create, use_libgit2=False):
108 def _create_repo(self, wire, create, use_libgit2=False):
109 if use_libgit2:
109 if use_libgit2:
110 repo = Repository(safe_bytes(wire['path']))
110 repo = Repository(safe_bytes(wire['path']))
111 else:
111 else:
112 # dulwich mode
112 # dulwich mode
113 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
113 repo_path = safe_str(wire['path'], to_encoding=settings.WIRE_ENCODING)
114 repo = Repo(repo_path)
114 repo = Repo(repo_path)
115
115
116 log.debug('repository created: got GIT object: %s', repo)
116 log.debug('repository created: got GIT object: %s', repo)
117 return repo
117 return repo
118
118
119 def repo(self, wire, create=False, use_libgit2=False):
119 def repo(self, wire, create=False, use_libgit2=False):
120 """
120 """
121 Get a repository instance for the given path.
121 Get a repository instance for the given path.
122 """
122 """
123 return self._create_repo(wire, create, use_libgit2)
123 return self._create_repo(wire, create, use_libgit2)
124
124
125 def repo_libgit2(self, wire):
125 def repo_libgit2(self, wire):
126 return self.repo(wire, use_libgit2=True)
126 return self.repo(wire, use_libgit2=True)
127
127
128
128
129 def create_signature_from_string(author_str, **kwargs):
129 def create_signature_from_string(author_str, **kwargs):
130 """
130 """
131 Creates a pygit2.Signature object from a string of the format 'Name <email>'.
131 Creates a pygit2.Signature object from a string of the format 'Name <email>'.
132
132
133 :param author_str: String of the format 'Name <email>'
133 :param author_str: String of the format 'Name <email>'
134 :return: pygit2.Signature object
134 :return: pygit2.Signature object
135 """
135 """
136 match = re.match(r'^(.+) <(.+)>$', author_str)
136 match = re.match(r'^(.+) <(.+)>$', author_str)
137 if match is None:
137 if match is None:
138 raise ValueError(f"Invalid format: {author_str}")
138 raise ValueError(f"Invalid format: {author_str}")
139
139
140 name, email = match.groups()
140 name, email = match.groups()
141 return pygit2.Signature(name, email, **kwargs)
141 return pygit2.Signature(name, email, **kwargs)
142
142
143
143
144 def get_obfuscated_url(url_obj):
144 def get_obfuscated_url(url_obj):
145 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
145 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
146 url_obj.query = obfuscate_qs(url_obj.query)
146 url_obj.query = obfuscate_qs(url_obj.query)
147 obfuscated_uri = str(url_obj)
147 obfuscated_uri = str(url_obj)
148 return obfuscated_uri
148 return obfuscated_uri
149
149
150
150
151 class GitRemote(RemoteBase):
151 class GitRemote(RemoteBase):
152
152
153 def __init__(self, factory):
153 def __init__(self, factory):
154 self._factory = factory
154 self._factory = factory
155 self._bulk_methods = {
155 self._bulk_methods = {
156 "date": self.date,
156 "date": self.date,
157 "author": self.author,
157 "author": self.author,
158 "branch": self.branch,
158 "branch": self.branch,
159 "message": self.message,
159 "message": self.message,
160 "parents": self.parents,
160 "parents": self.parents,
161 "_commit": self.revision,
161 "_commit": self.revision,
162 }
162 }
163 self._bulk_file_methods = {
163 self._bulk_file_methods = {
164 "size": self.get_node_size,
164 "size": self.get_node_size,
165 "data": self.get_node_data,
165 "data": self.get_node_data,
166 "flags": self.get_node_flags,
166 "flags": self.get_node_flags,
167 "is_binary": self.get_node_is_binary,
167 "is_binary": self.get_node_is_binary,
168 "md5": self.md5_hash
168 "md5": self.md5_hash
169 }
169 }
170
170
171 def _wire_to_config(self, wire):
171 def _wire_to_config(self, wire):
172 if 'config' in wire:
172 if 'config' in wire:
173 return {x[0] + '_' + x[1]: x[2] for x in wire['config']}
173 return {x[0] + '_' + x[1]: x[2] for x in wire['config']}
174 return {}
174 return {}
175
175
176 def _remote_conf(self, config):
176 def _remote_conf(self, config):
177 params = [
177 params = [
178 '-c', 'core.askpass=""',
178 '-c', 'core.askpass=""',
179 ]
179 ]
180 ssl_cert_dir = config.get('vcs_ssl_dir')
180 ssl_cert_dir = config.get('vcs_ssl_dir')
181 if ssl_cert_dir:
181 if ssl_cert_dir:
182 params.extend(['-c', f'http.sslCAinfo={ssl_cert_dir}'])
182 params.extend(['-c', f'http.sslCAinfo={ssl_cert_dir}'])
183 return params
183 return params
184
184
185 @reraise_safe_exceptions
185 @reraise_safe_exceptions
186 def discover_git_version(self):
186 def discover_git_version(self):
187 stdout, _ = self.run_git_command(
187 stdout, _ = self.run_git_command(
188 {}, ['--version'], _bare=True, _safe=True)
188 {}, ['--version'], _bare=True, _safe=True)
189 prefix = b'git version'
189 prefix = b'git version'
190 if stdout.startswith(prefix):
190 if stdout.startswith(prefix):
191 stdout = stdout[len(prefix):]
191 stdout = stdout[len(prefix):]
192 return safe_str(stdout.strip())
192 return safe_str(stdout.strip())
193
193
194 @reraise_safe_exceptions
194 @reraise_safe_exceptions
195 def is_empty(self, wire):
195 def is_empty(self, wire):
196 repo_init = self._factory.repo_libgit2(wire)
196 repo_init = self._factory.repo_libgit2(wire)
197 with repo_init as repo:
197 with repo_init as repo:
198
198
199 try:
199 try:
200 has_head = repo.head.name
200 has_head = repo.head.name
201 if has_head:
201 if has_head:
202 return False
202 return False
203
203
204 # NOTE(marcink): check again using more expensive method
204 # NOTE(marcink): check again using more expensive method
205 return repo.is_empty
205 return repo.is_empty
206 except Exception:
206 except Exception:
207 pass
207 pass
208
208
209 return True
209 return True
210
210
211 @reraise_safe_exceptions
211 @reraise_safe_exceptions
212 def assert_correct_path(self, wire):
212 def assert_correct_path(self, wire):
213 cache_on, context_uid, repo_id = self._cache_on(wire)
213 cache_on, context_uid, repo_id = self._cache_on(wire)
214 region = self._region(wire)
214 region = self._region(wire)
215
215
216 @region.conditional_cache_on_arguments(condition=cache_on)
216 @region.conditional_cache_on_arguments(condition=cache_on)
217 def _assert_correct_path(_context_uid, _repo_id, fast_check):
217 def _assert_correct_path(_context_uid, _repo_id, fast_check):
218 if fast_check:
218 if fast_check:
219 path = safe_str(wire['path'])
219 path = safe_str(wire['path'])
220 if pygit2.discover_repository(path):
220 if pygit2.discover_repository(path):
221 return True
221 return True
222 return False
222 return False
223 else:
223 else:
224 try:
224 try:
225 repo_init = self._factory.repo_libgit2(wire)
225 repo_init = self._factory.repo_libgit2(wire)
226 with repo_init:
226 with repo_init:
227 pass
227 pass
228 except pygit2.GitError:
228 except pygit2.GitError:
229 path = wire.get('path')
229 path = wire.get('path')
230 tb = traceback.format_exc()
230 tb = traceback.format_exc()
231 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
231 log.debug("Invalid Git path `%s`, tb: %s", path, tb)
232 return False
232 return False
233 return True
233 return True
234
234
235 return _assert_correct_path(context_uid, repo_id, True)
235 return _assert_correct_path(context_uid, repo_id, True)
236
236
237 @reraise_safe_exceptions
237 @reraise_safe_exceptions
238 def bare(self, wire):
238 def bare(self, wire):
239 repo_init = self._factory.repo_libgit2(wire)
239 repo_init = self._factory.repo_libgit2(wire)
240 with repo_init as repo:
240 with repo_init as repo:
241 return repo.is_bare
241 return repo.is_bare
242
242
243 @reraise_safe_exceptions
243 @reraise_safe_exceptions
244 def get_node_data(self, wire, commit_id, path):
244 def get_node_data(self, wire, commit_id, path):
245 repo_init = self._factory.repo_libgit2(wire)
245 repo_init = self._factory.repo_libgit2(wire)
246 with repo_init as repo:
246 with repo_init as repo:
247 commit = repo[commit_id]
247 commit = repo[commit_id]
248 blob_obj = commit.tree[path]
248 blob_obj = commit.tree[path]
249
249
250 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
250 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
251 raise exceptions.LookupException()(
251 raise exceptions.LookupException()(
252 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
252 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
253
253
254 return BytesEnvelope(blob_obj.data)
254 return BytesEnvelope(blob_obj.data)
255
255
256 @reraise_safe_exceptions
256 @reraise_safe_exceptions
257 def get_node_size(self, wire, commit_id, path):
257 def get_node_size(self, wire, commit_id, path):
258 repo_init = self._factory.repo_libgit2(wire)
258 repo_init = self._factory.repo_libgit2(wire)
259 with repo_init as repo:
259 with repo_init as repo:
260 commit = repo[commit_id]
260 commit = repo[commit_id]
261 blob_obj = commit.tree[path]
261 blob_obj = commit.tree[path]
262
262
263 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
263 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
264 raise exceptions.LookupException()(
264 raise exceptions.LookupException()(
265 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
265 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
266
266
267 return blob_obj.size
267 return blob_obj.size
268
268
269 @reraise_safe_exceptions
269 @reraise_safe_exceptions
270 def get_node_flags(self, wire, commit_id, path):
270 def get_node_flags(self, wire, commit_id, path):
271 repo_init = self._factory.repo_libgit2(wire)
271 repo_init = self._factory.repo_libgit2(wire)
272 with repo_init as repo:
272 with repo_init as repo:
273 commit = repo[commit_id]
273 commit = repo[commit_id]
274 blob_obj = commit.tree[path]
274 blob_obj = commit.tree[path]
275
275
276 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
276 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
277 raise exceptions.LookupException()(
277 raise exceptions.LookupException()(
278 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
278 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
279
279
280 return blob_obj.filemode
280 return blob_obj.filemode
281
281
282 @reraise_safe_exceptions
282 @reraise_safe_exceptions
283 def get_node_is_binary(self, wire, commit_id, path):
283 def get_node_is_binary(self, wire, commit_id, path):
284 repo_init = self._factory.repo_libgit2(wire)
284 repo_init = self._factory.repo_libgit2(wire)
285 with repo_init as repo:
285 with repo_init as repo:
286 commit = repo[commit_id]
286 commit = repo[commit_id]
287 blob_obj = commit.tree[path]
287 blob_obj = commit.tree[path]
288
288
289 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
289 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
290 raise exceptions.LookupException()(
290 raise exceptions.LookupException()(
291 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
291 f'Tree for commit_id:{commit_id} is not a blob: {blob_obj.type_str}')
292
292
293 return blob_obj.is_binary
293 return blob_obj.is_binary
294
294
295 @reraise_safe_exceptions
295 @reraise_safe_exceptions
296 def blob_as_pretty_string(self, wire, sha):
296 def blob_as_pretty_string(self, wire, sha):
297 repo_init = self._factory.repo_libgit2(wire)
297 repo_init = self._factory.repo_libgit2(wire)
298 with repo_init as repo:
298 with repo_init as repo:
299 blob_obj = repo[sha]
299 blob_obj = repo[sha]
300 return BytesEnvelope(blob_obj.data)
300 return BytesEnvelope(blob_obj.data)
301
301
302 @reraise_safe_exceptions
302 @reraise_safe_exceptions
303 def blob_raw_length(self, wire, sha):
303 def blob_raw_length(self, wire, sha):
304 cache_on, context_uid, repo_id = self._cache_on(wire)
304 cache_on, context_uid, repo_id = self._cache_on(wire)
305 region = self._region(wire)
305 region = self._region(wire)
306
306
307 @region.conditional_cache_on_arguments(condition=cache_on)
307 @region.conditional_cache_on_arguments(condition=cache_on)
308 def _blob_raw_length(_repo_id, _sha):
308 def _blob_raw_length(_repo_id, _sha):
309
309
310 repo_init = self._factory.repo_libgit2(wire)
310 repo_init = self._factory.repo_libgit2(wire)
311 with repo_init as repo:
311 with repo_init as repo:
312 blob = repo[sha]
312 blob = repo[sha]
313 return blob.size
313 return blob.size
314
314
315 return _blob_raw_length(repo_id, sha)
315 return _blob_raw_length(repo_id, sha)
316
316
317 def _parse_lfs_pointer(self, raw_content):
317 def _parse_lfs_pointer(self, raw_content):
318 spec_string = b'version https://git-lfs.github.com/spec'
318 spec_string = b'version https://git-lfs.github.com/spec'
319 if raw_content and raw_content.startswith(spec_string):
319 if raw_content and raw_content.startswith(spec_string):
320
320
321 pattern = re.compile(rb"""
321 pattern = re.compile(rb"""
322 (?:\n)?
322 (?:\n)?
323 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
323 ^version[ ]https://git-lfs\.github\.com/spec/(?P<spec_ver>v\d+)\n
324 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
324 ^oid[ ] sha256:(?P<oid_hash>[0-9a-f]{64})\n
325 ^size[ ](?P<oid_size>[0-9]+)\n
325 ^size[ ](?P<oid_size>[0-9]+)\n
326 (?:\n)?
326 (?:\n)?
327 """, re.VERBOSE | re.MULTILINE)
327 """, re.VERBOSE | re.MULTILINE)
328 match = pattern.match(raw_content)
328 match = pattern.match(raw_content)
329 if match:
329 if match:
330 return match.groupdict()
330 return match.groupdict()
331
331
332 return {}
332 return {}
333
333
334 @reraise_safe_exceptions
334 @reraise_safe_exceptions
335 def is_large_file(self, wire, commit_id):
335 def is_large_file(self, wire, commit_id):
336 cache_on, context_uid, repo_id = self._cache_on(wire)
336 cache_on, context_uid, repo_id = self._cache_on(wire)
337 region = self._region(wire)
337 region = self._region(wire)
338
338
339 @region.conditional_cache_on_arguments(condition=cache_on)
339 @region.conditional_cache_on_arguments(condition=cache_on)
340 def _is_large_file(_repo_id, _sha):
340 def _is_large_file(_repo_id, _sha):
341 repo_init = self._factory.repo_libgit2(wire)
341 repo_init = self._factory.repo_libgit2(wire)
342 with repo_init as repo:
342 with repo_init as repo:
343 blob = repo[commit_id]
343 blob = repo[commit_id]
344 if blob.is_binary:
344 if blob.is_binary:
345 return {}
345 return {}
346
346
347 return self._parse_lfs_pointer(blob.data)
347 return self._parse_lfs_pointer(blob.data)
348
348
349 return _is_large_file(repo_id, commit_id)
349 return _is_large_file(repo_id, commit_id)
350
350
351 @reraise_safe_exceptions
351 @reraise_safe_exceptions
352 def is_binary(self, wire, tree_id):
352 def is_binary(self, wire, tree_id):
353 cache_on, context_uid, repo_id = self._cache_on(wire)
353 cache_on, context_uid, repo_id = self._cache_on(wire)
354 region = self._region(wire)
354 region = self._region(wire)
355
355
356 @region.conditional_cache_on_arguments(condition=cache_on)
356 @region.conditional_cache_on_arguments(condition=cache_on)
357 def _is_binary(_repo_id, _tree_id):
357 def _is_binary(_repo_id, _tree_id):
358 repo_init = self._factory.repo_libgit2(wire)
358 repo_init = self._factory.repo_libgit2(wire)
359 with repo_init as repo:
359 with repo_init as repo:
360 blob_obj = repo[tree_id]
360 blob_obj = repo[tree_id]
361 return blob_obj.is_binary
361 return blob_obj.is_binary
362
362
363 return _is_binary(repo_id, tree_id)
363 return _is_binary(repo_id, tree_id)
364
364
365 @reraise_safe_exceptions
365 @reraise_safe_exceptions
366 def md5_hash(self, wire, commit_id, path):
366 def md5_hash(self, wire, commit_id, path):
367 cache_on, context_uid, repo_id = self._cache_on(wire)
367 cache_on, context_uid, repo_id = self._cache_on(wire)
368 region = self._region(wire)
368 region = self._region(wire)
369
369
370 @region.conditional_cache_on_arguments(condition=cache_on)
370 @region.conditional_cache_on_arguments(condition=cache_on)
371 def _md5_hash(_repo_id, _commit_id, _path):
371 def _md5_hash(_repo_id, _commit_id, _path):
372 repo_init = self._factory.repo_libgit2(wire)
372 repo_init = self._factory.repo_libgit2(wire)
373 with repo_init as repo:
373 with repo_init as repo:
374 commit = repo[_commit_id]
374 commit = repo[_commit_id]
375 blob_obj = commit.tree[_path]
375 blob_obj = commit.tree[_path]
376
376
377 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
377 if blob_obj.type != pygit2.GIT_OBJ_BLOB:
378 raise exceptions.LookupException()(
378 raise exceptions.LookupException()(
379 f'Tree for commit_id:{_commit_id} is not a blob: {blob_obj.type_str}')
379 f'Tree for commit_id:{_commit_id} is not a blob: {blob_obj.type_str}')
380
380
381 return ''
381 return ''
382
382
383 return _md5_hash(repo_id, commit_id, path)
383 return _md5_hash(repo_id, commit_id, path)
384
384
385 @reraise_safe_exceptions
385 @reraise_safe_exceptions
386 def in_largefiles_store(self, wire, oid):
386 def in_largefiles_store(self, wire, oid):
387 conf = self._wire_to_config(wire)
387 conf = self._wire_to_config(wire)
388 repo_init = self._factory.repo_libgit2(wire)
388 repo_init = self._factory.repo_libgit2(wire)
389 with repo_init as repo:
389 with repo_init as repo:
390 repo_name = repo.path
390 repo_name = repo.path
391
391
392 store_location = conf.get('vcs_git_lfs_store_location')
392 store_location = conf.get('vcs_git_lfs_store_location')
393 if store_location:
393 if store_location:
394
394
395 store = LFSOidStore(
395 store = LFSOidStore(
396 oid=oid, repo=repo_name, store_location=store_location)
396 oid=oid, repo=repo_name, store_location=store_location)
397 return store.has_oid()
397 return store.has_oid()
398
398
399 return False
399 return False
400
400
401 @reraise_safe_exceptions
401 @reraise_safe_exceptions
402 def store_path(self, wire, oid):
402 def store_path(self, wire, oid):
403 conf = self._wire_to_config(wire)
403 conf = self._wire_to_config(wire)
404 repo_init = self._factory.repo_libgit2(wire)
404 repo_init = self._factory.repo_libgit2(wire)
405 with repo_init as repo:
405 with repo_init as repo:
406 repo_name = repo.path
406 repo_name = repo.path
407
407
408 store_location = conf.get('vcs_git_lfs_store_location')
408 store_location = conf.get('vcs_git_lfs_store_location')
409 if store_location:
409 if store_location:
410 store = LFSOidStore(
410 store = LFSOidStore(
411 oid=oid, repo=repo_name, store_location=store_location)
411 oid=oid, repo=repo_name, store_location=store_location)
412 return store.oid_path
412 return store.oid_path
413 raise ValueError(f'Unable to fetch oid with path {oid}')
413 raise ValueError(f'Unable to fetch oid with path {oid}')
414
414
415 @reraise_safe_exceptions
415 @reraise_safe_exceptions
416 def bulk_request(self, wire, rev, pre_load):
416 def bulk_request(self, wire, rev, pre_load):
417 cache_on, context_uid, repo_id = self._cache_on(wire)
417 cache_on, context_uid, repo_id = self._cache_on(wire)
418 region = self._region(wire)
418 region = self._region(wire)
419
419
420 @region.conditional_cache_on_arguments(condition=cache_on)
420 @region.conditional_cache_on_arguments(condition=cache_on)
421 def _bulk_request(_repo_id, _rev, _pre_load):
421 def _bulk_request(_repo_id, _rev, _pre_load):
422 result = {}
422 result = {}
423 for attr in pre_load:
423 for attr in pre_load:
424 try:
424 try:
425 method = self._bulk_methods[attr]
425 method = self._bulk_methods[attr]
426 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
426 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
427 args = [wire, rev]
427 args = [wire, rev]
428 result[attr] = method(*args)
428 result[attr] = method(*args)
429 except KeyError as e:
429 except KeyError as e:
430 raise exceptions.VcsException(e)(f"Unknown bulk attribute: {attr}")
430 raise exceptions.VcsException(e)(f"Unknown bulk attribute: {attr}")
431 return result
431 return result
432
432
433 return _bulk_request(repo_id, rev, sorted(pre_load))
433 return _bulk_request(repo_id, rev, sorted(pre_load))
434
434
435 @reraise_safe_exceptions
435 @reraise_safe_exceptions
436 def bulk_file_request(self, wire, commit_id, path, pre_load):
436 def bulk_file_request(self, wire, commit_id, path, pre_load):
437 cache_on, context_uid, repo_id = self._cache_on(wire)
437 cache_on, context_uid, repo_id = self._cache_on(wire)
438 region = self._region(wire)
438 region = self._region(wire)
439
439
440 @region.conditional_cache_on_arguments(condition=cache_on)
440 @region.conditional_cache_on_arguments(condition=cache_on)
441 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
441 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
442 result = {}
442 result = {}
443 for attr in pre_load:
443 for attr in pre_load:
444 try:
444 try:
445 method = self._bulk_file_methods[attr]
445 method = self._bulk_file_methods[attr]
446 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
446 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
447 result[attr] = method(wire, _commit_id, _path)
447 result[attr] = method(wire, _commit_id, _path)
448 except KeyError as e:
448 except KeyError as e:
449 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
449 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
450 return BinaryEnvelope(result)
450 return BinaryEnvelope(result)
451
451
452 return _bulk_file_request(repo_id, commit_id, path, sorted(pre_load))
452 return _bulk_file_request(repo_id, commit_id, path, sorted(pre_load))
453
453
454 def _build_opener(self, url: str):
454 def _build_opener(self, url: str):
455 handlers = []
455 handlers = []
456 url_obj = url_parser(safe_bytes(url))
456 url_obj = url_parser(safe_bytes(url))
457 authinfo = url_obj.authinfo()[1]
457 authinfo = url_obj.authinfo()[1]
458
458
459 if authinfo:
459 if authinfo:
460 # create a password manager
460 # create a password manager
461 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
461 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
462 passmgr.add_password(*authinfo)
462 passmgr.add_password(*authinfo)
463
463
464 handlers.extend((httpbasicauthhandler(passmgr),
464 handlers.extend((httpbasicauthhandler(passmgr),
465 httpdigestauthhandler(passmgr)))
465 httpdigestauthhandler(passmgr)))
466
466
467 return urllib.request.build_opener(*handlers)
467 return urllib.request.build_opener(*handlers)
468
468
469 @reraise_safe_exceptions
469 @reraise_safe_exceptions
470 def check_url(self, url, config):
470 def check_url(self, url, config):
471 url_obj = url_parser(safe_bytes(url))
471 url_obj = url_parser(safe_bytes(url))
472
472
473 test_uri = safe_str(url_obj.authinfo()[0])
473 test_uri = safe_str(url_obj.authinfo()[0])
474 obfuscated_uri = get_obfuscated_url(url_obj)
474 obfuscated_uri = get_obfuscated_url(url_obj)
475
475
476 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
476 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
477
477
478 if not test_uri.endswith('info/refs'):
478 if not test_uri.endswith('info/refs'):
479 test_uri = test_uri.rstrip('/') + '/info/refs'
479 test_uri = test_uri.rstrip('/') + '/info/refs'
480
480
481 o = self._build_opener(test_uri)
481 o = self._build_opener(test_uri)
482 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
482 o.addheaders = [('User-Agent', 'git/1.7.8.0')] # fake some git
483
483
484 q = {"service": 'git-upload-pack'}
484 q = {"service": 'git-upload-pack'}
485 qs = '?%s' % urllib.parse.urlencode(q)
485 qs = '?%s' % urllib.parse.urlencode(q)
486 cu = "{}{}".format(test_uri, qs)
486 cu = f"{test_uri}{qs}"
487 req = urllib.request.Request(cu, None, {})
487 req = urllib.request.Request(cu, None, {})
488
488
489 try:
489 try:
490 log.debug("Trying to open URL %s", obfuscated_uri)
490 log.debug("Trying to open URL %s", obfuscated_uri)
491 resp = o.open(req)
491 resp = o.open(req)
492 if resp.code != 200:
492 if resp.code != 200:
493 raise exceptions.URLError()('Return Code is not 200')
493 raise exceptions.URLError()('Return Code is not 200')
494 except Exception as e:
494 except Exception as e:
495 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
495 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
496 # means it cannot be cloned
496 # means it cannot be cloned
497 raise exceptions.URLError(e)("[{}] org_exc: {}".format(obfuscated_uri, e))
497 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
498
498
499 # now detect if it's proper git repo
499 # now detect if it's proper git repo
500 gitdata: bytes = resp.read()
500 gitdata: bytes = resp.read()
501
501
502 if b'service=git-upload-pack' in gitdata:
502 if b'service=git-upload-pack' in gitdata:
503 pass
503 pass
504 elif re.findall(br'[0-9a-fA-F]{40}\s+refs', gitdata):
504 elif re.findall(br'[0-9a-fA-F]{40}\s+refs', gitdata):
505 # old style git can return some other format !
505 # old style git can return some other format !
506 pass
506 pass
507 else:
507 else:
508 e = None
508 e = None
509 raise exceptions.URLError(e)(
509 raise exceptions.URLError(e)(
510 "url [%s] does not look like an hg repo org_exc: %s"
510 "url [%s] does not look like an hg repo org_exc: %s"
511 % (obfuscated_uri, e))
511 % (obfuscated_uri, e))
512
512
513 return True
513 return True
514
514
515 @reraise_safe_exceptions
515 @reraise_safe_exceptions
516 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
516 def clone(self, wire, url, deferred, valid_refs, update_after_clone):
517 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
517 # TODO(marcink): deprecate this method. Last i checked we don't use it anymore
518 remote_refs = self.pull(wire, url, apply_refs=False)
518 remote_refs = self.pull(wire, url, apply_refs=False)
519 repo = self._factory.repo(wire)
519 repo = self._factory.repo(wire)
520 if isinstance(valid_refs, list):
520 if isinstance(valid_refs, list):
521 valid_refs = tuple(valid_refs)
521 valid_refs = tuple(valid_refs)
522
522
523 for k in remote_refs:
523 for k in remote_refs:
524 # only parse heads/tags and skip so called deferred tags
524 # only parse heads/tags and skip so called deferred tags
525 if k.startswith(valid_refs) and not k.endswith(deferred):
525 if k.startswith(valid_refs) and not k.endswith(deferred):
526 repo[k] = remote_refs[k]
526 repo[k] = remote_refs[k]
527
527
528 if update_after_clone:
528 if update_after_clone:
529 # we want to checkout HEAD
529 # we want to checkout HEAD
530 repo["HEAD"] = remote_refs["HEAD"]
530 repo["HEAD"] = remote_refs["HEAD"]
531 index.build_index_from_tree(repo.path, repo.index_path(),
531 index.build_index_from_tree(repo.path, repo.index_path(),
532 repo.object_store, repo["HEAD"].tree)
532 repo.object_store, repo["HEAD"].tree)
533
533
534 @reraise_safe_exceptions
534 @reraise_safe_exceptions
535 def branch(self, wire, commit_id):
535 def branch(self, wire, commit_id):
536 cache_on, context_uid, repo_id = self._cache_on(wire)
536 cache_on, context_uid, repo_id = self._cache_on(wire)
537 region = self._region(wire)
537 region = self._region(wire)
538
538
539 @region.conditional_cache_on_arguments(condition=cache_on)
539 @region.conditional_cache_on_arguments(condition=cache_on)
540 def _branch(_context_uid, _repo_id, _commit_id):
540 def _branch(_context_uid, _repo_id, _commit_id):
541 regex = re.compile('^refs/heads')
541 regex = re.compile('^refs/heads')
542
542
543 def filter_with(ref):
543 def filter_with(ref):
544 return regex.match(ref[0]) and ref[1] == _commit_id
544 return regex.match(ref[0]) and ref[1] == _commit_id
545
545
546 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
546 branches = list(filter(filter_with, list(self.get_refs(wire).items())))
547 return [x[0].split('refs/heads/')[-1] for x in branches]
547 return [x[0].split('refs/heads/')[-1] for x in branches]
548
548
549 return _branch(context_uid, repo_id, commit_id)
549 return _branch(context_uid, repo_id, commit_id)
550
550
551 @reraise_safe_exceptions
551 @reraise_safe_exceptions
552 def commit_branches(self, wire, commit_id):
552 def commit_branches(self, wire, commit_id):
553 cache_on, context_uid, repo_id = self._cache_on(wire)
553 cache_on, context_uid, repo_id = self._cache_on(wire)
554 region = self._region(wire)
554 region = self._region(wire)
555
555
556 @region.conditional_cache_on_arguments(condition=cache_on)
556 @region.conditional_cache_on_arguments(condition=cache_on)
557 def _commit_branches(_context_uid, _repo_id, _commit_id):
557 def _commit_branches(_context_uid, _repo_id, _commit_id):
558 repo_init = self._factory.repo_libgit2(wire)
558 repo_init = self._factory.repo_libgit2(wire)
559 with repo_init as repo:
559 with repo_init as repo:
560 branches = [x for x in repo.branches.with_commit(_commit_id)]
560 branches = [x for x in repo.branches.with_commit(_commit_id)]
561 return branches
561 return branches
562
562
563 return _commit_branches(context_uid, repo_id, commit_id)
563 return _commit_branches(context_uid, repo_id, commit_id)
564
564
565 @reraise_safe_exceptions
565 @reraise_safe_exceptions
566 def add_object(self, wire, content):
566 def add_object(self, wire, content):
567 repo_init = self._factory.repo_libgit2(wire)
567 repo_init = self._factory.repo_libgit2(wire)
568 with repo_init as repo:
568 with repo_init as repo:
569 blob = objects.Blob()
569 blob = objects.Blob()
570 blob.set_raw_string(content)
570 blob.set_raw_string(content)
571 repo.object_store.add_object(blob)
571 repo.object_store.add_object(blob)
572 return blob.id
572 return blob.id
573
573
574 @reraise_safe_exceptions
574 @reraise_safe_exceptions
575 def create_commit(self, wire, author, committer, message, branch, new_tree_id, date_args: list[int, int] = None):
575 def create_commit(self, wire, author, committer, message, branch, new_tree_id, date_args: list[int, int] = None):
576 repo_init = self._factory.repo_libgit2(wire)
576 repo_init = self._factory.repo_libgit2(wire)
577 with repo_init as repo:
577 with repo_init as repo:
578
578
579 if date_args:
579 if date_args:
580 current_time, offset = date_args
580 current_time, offset = date_args
581
581
582 kw = {
582 kw = {
583 'time': current_time,
583 'time': current_time,
584 'offset': offset
584 'offset': offset
585 }
585 }
586 author = create_signature_from_string(author, **kw)
586 author = create_signature_from_string(author, **kw)
587 committer = create_signature_from_string(committer, **kw)
587 committer = create_signature_from_string(committer, **kw)
588
588
589 tree = new_tree_id
589 tree = new_tree_id
590 if isinstance(tree, (bytes, str)):
590 if isinstance(tree, (bytes, str)):
591 # validate this tree is in the repo...
591 # validate this tree is in the repo...
592 tree = repo[safe_str(tree)].id
592 tree = repo[safe_str(tree)].id
593
593
594 parents = []
594 parents = []
595 # ensure we COMMIT on top of given branch head
595 # ensure we COMMIT on top of given branch head
596 # check if this repo has ANY branches, otherwise it's a new branch case we need to make
596 # check if this repo has ANY branches, otherwise it's a new branch case we need to make
597 if branch in repo.branches.local:
597 if branch in repo.branches.local:
598 parents += [repo.branches[branch].target]
598 parents += [repo.branches[branch].target]
599 elif [x for x in repo.branches.local]:
599 elif [x for x in repo.branches.local]:
600 parents += [repo.head.target]
600 parents += [repo.head.target]
601 #else:
601 #else:
602 # in case we want to commit on new branch we create it on top of HEAD
602 # in case we want to commit on new branch we create it on top of HEAD
603 #repo.branches.local.create(branch, repo.revparse_single('HEAD'))
603 #repo.branches.local.create(branch, repo.revparse_single('HEAD'))
604
604
605 # # Create a new commit
605 # # Create a new commit
606 commit_oid = repo.create_commit(
606 commit_oid = repo.create_commit(
607 f'refs/heads/{branch}', # the name of the reference to update
607 f'refs/heads/{branch}', # the name of the reference to update
608 author, # the author of the commit
608 author, # the author of the commit
609 committer, # the committer of the commit
609 committer, # the committer of the commit
610 message, # the commit message
610 message, # the commit message
611 tree, # the tree produced by the index
611 tree, # the tree produced by the index
612 parents # list of parents for the new commit, usually just one,
612 parents # list of parents for the new commit, usually just one,
613 )
613 )
614
614
615 new_commit_id = safe_str(commit_oid)
615 new_commit_id = safe_str(commit_oid)
616
616
617 return new_commit_id
617 return new_commit_id
618
618
619 @reraise_safe_exceptions
619 @reraise_safe_exceptions
620 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
620 def commit(self, wire, commit_data, branch, commit_tree, updated, removed):
621
621
622 def mode2pygit(mode):
622 def mode2pygit(mode):
623 """
623 """
624 git only supports two filemode 644 and 755
624 git only supports two filemode 644 and 755
625
625
626 0o100755 -> 33261
626 0o100755 -> 33261
627 0o100644 -> 33188
627 0o100644 -> 33188
628 """
628 """
629 return {
629 return {
630 0o100644: pygit2.GIT_FILEMODE_BLOB,
630 0o100644: pygit2.GIT_FILEMODE_BLOB,
631 0o100755: pygit2.GIT_FILEMODE_BLOB_EXECUTABLE,
631 0o100755: pygit2.GIT_FILEMODE_BLOB_EXECUTABLE,
632 0o120000: pygit2.GIT_FILEMODE_LINK
632 0o120000: pygit2.GIT_FILEMODE_LINK
633 }.get(mode) or pygit2.GIT_FILEMODE_BLOB
633 }.get(mode) or pygit2.GIT_FILEMODE_BLOB
634
634
635 repo_init = self._factory.repo_libgit2(wire)
635 repo_init = self._factory.repo_libgit2(wire)
636 with repo_init as repo:
636 with repo_init as repo:
637 repo_index = repo.index
637 repo_index = repo.index
638
638
639 for pathspec in updated:
639 for pathspec in updated:
640 blob_id = repo.create_blob(pathspec['content'])
640 blob_id = repo.create_blob(pathspec['content'])
641 ie = pygit2.IndexEntry(pathspec['path'], blob_id, mode2pygit(pathspec['mode']))
641 ie = pygit2.IndexEntry(pathspec['path'], blob_id, mode2pygit(pathspec['mode']))
642 repo_index.add(ie)
642 repo_index.add(ie)
643
643
644 for pathspec in removed:
644 for pathspec in removed:
645 repo_index.remove(pathspec)
645 repo_index.remove(pathspec)
646
646
647 # Write changes to the index
647 # Write changes to the index
648 repo_index.write()
648 repo_index.write()
649
649
650 # Create a tree from the updated index
650 # Create a tree from the updated index
651 commit_tree = repo_index.write_tree()
651 commit_tree = repo_index.write_tree()
652
652
653 new_tree_id = commit_tree
653 new_tree_id = commit_tree
654
654
655 author = commit_data['author']
655 author = commit_data['author']
656 committer = commit_data['committer']
656 committer = commit_data['committer']
657 message = commit_data['message']
657 message = commit_data['message']
658
658
659 date_args = [int(commit_data['commit_time']), int(commit_data['commit_timezone'])]
659 date_args = [int(commit_data['commit_time']), int(commit_data['commit_timezone'])]
660
660
661 new_commit_id = self.create_commit(wire, author, committer, message, branch,
661 new_commit_id = self.create_commit(wire, author, committer, message, branch,
662 new_tree_id, date_args=date_args)
662 new_tree_id, date_args=date_args)
663
663
664 # libgit2, ensure the branch is there and exists
664 # libgit2, ensure the branch is there and exists
665 self.create_branch(wire, branch, new_commit_id)
665 self.create_branch(wire, branch, new_commit_id)
666
666
667 # libgit2, set new ref to this created commit
667 # libgit2, set new ref to this created commit
668 self.set_refs(wire, f'refs/heads/{branch}', new_commit_id)
668 self.set_refs(wire, f'refs/heads/{branch}', new_commit_id)
669
669
670 return new_commit_id
670 return new_commit_id
671
671
672 @reraise_safe_exceptions
672 @reraise_safe_exceptions
673 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
673 def pull(self, wire, url, apply_refs=True, refs=None, update_after=False):
674 if url != 'default' and '://' not in url:
674 if url != 'default' and '://' not in url:
675 client = LocalGitClient(url)
675 client = LocalGitClient(url)
676 else:
676 else:
677 url_obj = url_parser(safe_bytes(url))
677 url_obj = url_parser(safe_bytes(url))
678 o = self._build_opener(url)
678 o = self._build_opener(url)
679 url = url_obj.authinfo()[0]
679 url = url_obj.authinfo()[0]
680 client = HttpGitClient(base_url=url, opener=o)
680 client = HttpGitClient(base_url=url, opener=o)
681 repo = self._factory.repo(wire)
681 repo = self._factory.repo(wire)
682
682
683 determine_wants = repo.object_store.determine_wants_all
683 determine_wants = repo.object_store.determine_wants_all
684 if refs:
684 if refs:
685 refs = [ascii_bytes(x) for x in refs]
685 refs = [ascii_bytes(x) for x in refs]
686
686
687 def determine_wants_requested(remote_refs):
687 def determine_wants_requested(remote_refs):
688 determined = []
688 determined = []
689 for ref_name, ref_hash in remote_refs.items():
689 for ref_name, ref_hash in remote_refs.items():
690 bytes_ref_name = safe_bytes(ref_name)
690 bytes_ref_name = safe_bytes(ref_name)
691
691
692 if bytes_ref_name in refs:
692 if bytes_ref_name in refs:
693 bytes_ref_hash = safe_bytes(ref_hash)
693 bytes_ref_hash = safe_bytes(ref_hash)
694 determined.append(bytes_ref_hash)
694 determined.append(bytes_ref_hash)
695 return determined
695 return determined
696
696
697 # swap with our custom requested wants
697 # swap with our custom requested wants
698 determine_wants = determine_wants_requested
698 determine_wants = determine_wants_requested
699
699
700 try:
700 try:
701 remote_refs = client.fetch(
701 remote_refs = client.fetch(
702 path=url, target=repo, determine_wants=determine_wants)
702 path=url, target=repo, determine_wants=determine_wants)
703
703
704 except NotGitRepository as e:
704 except NotGitRepository as e:
705 log.warning(
705 log.warning(
706 'Trying to fetch from "%s" failed, not a Git repository.', url)
706 'Trying to fetch from "%s" failed, not a Git repository.', url)
707 # Exception can contain unicode which we convert
707 # Exception can contain unicode which we convert
708 raise exceptions.AbortException(e)(repr(e))
708 raise exceptions.AbortException(e)(repr(e))
709
709
710 # mikhail: client.fetch() returns all the remote refs, but fetches only
710 # mikhail: client.fetch() returns all the remote refs, but fetches only
711 # refs filtered by `determine_wants` function. We need to filter result
711 # refs filtered by `determine_wants` function. We need to filter result
712 # as well
712 # as well
713 if refs:
713 if refs:
714 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
714 remote_refs = {k: remote_refs[k] for k in remote_refs if k in refs}
715
715
716 if apply_refs:
716 if apply_refs:
717 # TODO: johbo: Needs proper test coverage with a git repository
717 # TODO: johbo: Needs proper test coverage with a git repository
718 # that contains a tag object, so that we would end up with
718 # that contains a tag object, so that we would end up with
719 # a peeled ref at this point.
719 # a peeled ref at this point.
720 for k in remote_refs:
720 for k in remote_refs:
721 if k.endswith(PEELED_REF_MARKER):
721 if k.endswith(PEELED_REF_MARKER):
722 log.debug("Skipping peeled reference %s", k)
722 log.debug("Skipping peeled reference %s", k)
723 continue
723 continue
724 repo[k] = remote_refs[k]
724 repo[k] = remote_refs[k]
725
725
726 if refs and not update_after:
726 if refs and not update_after:
727 # mikhail: explicitly set the head to the last ref.
727 # mikhail: explicitly set the head to the last ref.
728 repo[HEAD_MARKER] = remote_refs[refs[-1]]
728 repo[HEAD_MARKER] = remote_refs[refs[-1]]
729
729
730 if update_after:
730 if update_after:
731 # we want to check out HEAD
731 # we want to check out HEAD
732 repo[HEAD_MARKER] = remote_refs[HEAD_MARKER]
732 repo[HEAD_MARKER] = remote_refs[HEAD_MARKER]
733 index.build_index_from_tree(repo.path, repo.index_path(),
733 index.build_index_from_tree(repo.path, repo.index_path(),
734 repo.object_store, repo[HEAD_MARKER].tree)
734 repo.object_store, repo[HEAD_MARKER].tree)
735
735
736 if isinstance(remote_refs, FetchPackResult):
736 if isinstance(remote_refs, FetchPackResult):
737 return remote_refs.refs
737 return remote_refs.refs
738 return remote_refs
738 return remote_refs
739
739
740 @reraise_safe_exceptions
740 @reraise_safe_exceptions
741 def sync_fetch(self, wire, url, refs=None, all_refs=False):
741 def sync_fetch(self, wire, url, refs=None, all_refs=False):
742 self._factory.repo(wire)
742 self._factory.repo(wire)
743 if refs and not isinstance(refs, (list, tuple)):
743 if refs and not isinstance(refs, (list, tuple)):
744 refs = [refs]
744 refs = [refs]
745
745
746 config = self._wire_to_config(wire)
746 config = self._wire_to_config(wire)
747 # get all remote refs we'll use to fetch later
747 # get all remote refs we'll use to fetch later
748 cmd = ['ls-remote']
748 cmd = ['ls-remote']
749 if not all_refs:
749 if not all_refs:
750 cmd += ['--heads', '--tags']
750 cmd += ['--heads', '--tags']
751 cmd += [url]
751 cmd += [url]
752 output, __ = self.run_git_command(
752 output, __ = self.run_git_command(
753 wire, cmd, fail_on_stderr=False,
753 wire, cmd, fail_on_stderr=False,
754 _copts=self._remote_conf(config),
754 _copts=self._remote_conf(config),
755 extra_env={'GIT_TERMINAL_PROMPT': '0'})
755 extra_env={'GIT_TERMINAL_PROMPT': '0'})
756
756
757 remote_refs = collections.OrderedDict()
757 remote_refs = collections.OrderedDict()
758 fetch_refs = []
758 fetch_refs = []
759
759
760 for ref_line in output.splitlines():
760 for ref_line in output.splitlines():
761 sha, ref = ref_line.split(b'\t')
761 sha, ref = ref_line.split(b'\t')
762 sha = sha.strip()
762 sha = sha.strip()
763 if ref in remote_refs:
763 if ref in remote_refs:
764 # duplicate, skip
764 # duplicate, skip
765 continue
765 continue
766 if ref.endswith(PEELED_REF_MARKER):
766 if ref.endswith(PEELED_REF_MARKER):
767 log.debug("Skipping peeled reference %s", ref)
767 log.debug("Skipping peeled reference %s", ref)
768 continue
768 continue
769 # don't sync HEAD
769 # don't sync HEAD
770 if ref in [HEAD_MARKER]:
770 if ref in [HEAD_MARKER]:
771 continue
771 continue
772
772
773 remote_refs[ref] = sha
773 remote_refs[ref] = sha
774
774
775 if refs and sha in refs:
775 if refs and sha in refs:
776 # we filter fetch using our specified refs
776 # we filter fetch using our specified refs
777 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
777 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
778 elif not refs:
778 elif not refs:
779 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
779 fetch_refs.append(f'{safe_str(ref)}:{safe_str(ref)}')
780 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
780 log.debug('Finished obtaining fetch refs, total: %s', len(fetch_refs))
781
781
782 if fetch_refs:
782 if fetch_refs:
783 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
783 for chunk in more_itertools.chunked(fetch_refs, 1024 * 4):
784 fetch_refs_chunks = list(chunk)
784 fetch_refs_chunks = list(chunk)
785 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
785 log.debug('Fetching %s refs from import url', len(fetch_refs_chunks))
786 self.run_git_command(
786 self.run_git_command(
787 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
787 wire, ['fetch', url, '--force', '--prune', '--'] + fetch_refs_chunks,
788 fail_on_stderr=False,
788 fail_on_stderr=False,
789 _copts=self._remote_conf(config),
789 _copts=self._remote_conf(config),
790 extra_env={'GIT_TERMINAL_PROMPT': '0'})
790 extra_env={'GIT_TERMINAL_PROMPT': '0'})
791
791
792 return remote_refs
792 return remote_refs
793
793
794 @reraise_safe_exceptions
794 @reraise_safe_exceptions
795 def sync_push(self, wire, url, refs=None):
795 def sync_push(self, wire, url, refs=None):
796 if not self.check_url(url, wire):
796 if not self.check_url(url, wire):
797 return
797 return
798 config = self._wire_to_config(wire)
798 config = self._wire_to_config(wire)
799 self._factory.repo(wire)
799 self._factory.repo(wire)
800 self.run_git_command(
800 self.run_git_command(
801 wire, ['push', url, '--mirror'], fail_on_stderr=False,
801 wire, ['push', url, '--mirror'], fail_on_stderr=False,
802 _copts=self._remote_conf(config),
802 _copts=self._remote_conf(config),
803 extra_env={'GIT_TERMINAL_PROMPT': '0'})
803 extra_env={'GIT_TERMINAL_PROMPT': '0'})
804
804
805 @reraise_safe_exceptions
805 @reraise_safe_exceptions
806 def get_remote_refs(self, wire, url):
806 def get_remote_refs(self, wire, url):
807 repo = Repo(url)
807 repo = Repo(url)
808 return repo.get_refs()
808 return repo.get_refs()
809
809
810 @reraise_safe_exceptions
810 @reraise_safe_exceptions
811 def get_description(self, wire):
811 def get_description(self, wire):
812 repo = self._factory.repo(wire)
812 repo = self._factory.repo(wire)
813 return repo.get_description()
813 return repo.get_description()
814
814
815 @reraise_safe_exceptions
815 @reraise_safe_exceptions
816 def get_missing_revs(self, wire, rev1, rev2, path2):
816 def get_missing_revs(self, wire, rev1, rev2, path2):
817 repo = self._factory.repo(wire)
817 repo = self._factory.repo(wire)
818 LocalGitClient(thin_packs=False).fetch(path2, repo)
818 LocalGitClient(thin_packs=False).fetch(path2, repo)
819
819
820 wire_remote = wire.copy()
820 wire_remote = wire.copy()
821 wire_remote['path'] = path2
821 wire_remote['path'] = path2
822 repo_remote = self._factory.repo(wire_remote)
822 repo_remote = self._factory.repo(wire_remote)
823 LocalGitClient(thin_packs=False).fetch(path2, repo_remote)
823 LocalGitClient(thin_packs=False).fetch(path2, repo_remote)
824
824
825 revs = [
825 revs = [
826 x.commit.id
826 x.commit.id
827 for x in repo_remote.get_walker(include=[safe_bytes(rev2)], exclude=[safe_bytes(rev1)])]
827 for x in repo_remote.get_walker(include=[safe_bytes(rev2)], exclude=[safe_bytes(rev1)])]
828 return revs
828 return revs
829
829
830 @reraise_safe_exceptions
830 @reraise_safe_exceptions
831 def get_object(self, wire, sha, maybe_unreachable=False):
831 def get_object(self, wire, sha, maybe_unreachable=False):
832 cache_on, context_uid, repo_id = self._cache_on(wire)
832 cache_on, context_uid, repo_id = self._cache_on(wire)
833 region = self._region(wire)
833 region = self._region(wire)
834
834
835 @region.conditional_cache_on_arguments(condition=cache_on)
835 @region.conditional_cache_on_arguments(condition=cache_on)
836 def _get_object(_context_uid, _repo_id, _sha):
836 def _get_object(_context_uid, _repo_id, _sha):
837 repo_init = self._factory.repo_libgit2(wire)
837 repo_init = self._factory.repo_libgit2(wire)
838 with repo_init as repo:
838 with repo_init as repo:
839
839
840 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
840 missing_commit_err = 'Commit {} does not exist for `{}`'.format(sha, wire['path'])
841 try:
841 try:
842 commit = repo.revparse_single(sha)
842 commit = repo.revparse_single(sha)
843 except KeyError:
843 except KeyError:
844 # NOTE(marcink): KeyError doesn't give us any meaningful information
844 # NOTE(marcink): KeyError doesn't give us any meaningful information
845 # here, we instead give something more explicit
845 # here, we instead give something more explicit
846 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
846 e = exceptions.RefNotFoundException('SHA: %s not found', sha)
847 raise exceptions.LookupException(e)(missing_commit_err)
847 raise exceptions.LookupException(e)(missing_commit_err)
848 except ValueError as e:
848 except ValueError as e:
849 raise exceptions.LookupException(e)(missing_commit_err)
849 raise exceptions.LookupException(e)(missing_commit_err)
850
850
851 is_tag = False
851 is_tag = False
852 if isinstance(commit, pygit2.Tag):
852 if isinstance(commit, pygit2.Tag):
853 commit = repo.get(commit.target)
853 commit = repo.get(commit.target)
854 is_tag = True
854 is_tag = True
855
855
856 check_dangling = True
856 check_dangling = True
857 if is_tag:
857 if is_tag:
858 check_dangling = False
858 check_dangling = False
859
859
860 if check_dangling and maybe_unreachable:
860 if check_dangling and maybe_unreachable:
861 check_dangling = False
861 check_dangling = False
862
862
863 # we used a reference and it parsed means we're not having a dangling commit
863 # we used a reference and it parsed means we're not having a dangling commit
864 if sha != commit.hex:
864 if sha != commit.hex:
865 check_dangling = False
865 check_dangling = False
866
866
867 if check_dangling:
867 if check_dangling:
868 # check for dangling commit
868 # check for dangling commit
869 for branch in repo.branches.with_commit(commit.hex):
869 for branch in repo.branches.with_commit(commit.hex):
870 if branch:
870 if branch:
871 break
871 break
872 else:
872 else:
873 # NOTE(marcink): Empty error doesn't give us any meaningful information
873 # NOTE(marcink): Empty error doesn't give us any meaningful information
874 # here, we instead give something more explicit
874 # here, we instead give something more explicit
875 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
875 e = exceptions.RefNotFoundException('SHA: %s not found in branches', sha)
876 raise exceptions.LookupException(e)(missing_commit_err)
876 raise exceptions.LookupException(e)(missing_commit_err)
877
877
878 commit_id = commit.hex
878 commit_id = commit.hex
879 type_str = commit.type_str
879 type_str = commit.type_str
880
880
881 return {
881 return {
882 'id': commit_id,
882 'id': commit_id,
883 'type': type_str,
883 'type': type_str,
884 'commit_id': commit_id,
884 'commit_id': commit_id,
885 'idx': 0
885 'idx': 0
886 }
886 }
887
887
888 return _get_object(context_uid, repo_id, sha)
888 return _get_object(context_uid, repo_id, sha)
889
889
890 @reraise_safe_exceptions
890 @reraise_safe_exceptions
891 def get_refs(self, wire):
891 def get_refs(self, wire):
892 cache_on, context_uid, repo_id = self._cache_on(wire)
892 cache_on, context_uid, repo_id = self._cache_on(wire)
893 region = self._region(wire)
893 region = self._region(wire)
894
894
895 @region.conditional_cache_on_arguments(condition=cache_on)
895 @region.conditional_cache_on_arguments(condition=cache_on)
896 def _get_refs(_context_uid, _repo_id):
896 def _get_refs(_context_uid, _repo_id):
897
897
898 repo_init = self._factory.repo_libgit2(wire)
898 repo_init = self._factory.repo_libgit2(wire)
899 with repo_init as repo:
899 with repo_init as repo:
900 regex = re.compile('^refs/(heads|tags)/')
900 regex = re.compile('^refs/(heads|tags)/')
901 return {x.name: x.target.hex for x in
901 return {x.name: x.target.hex for x in
902 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
902 [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]}
903
903
904 return _get_refs(context_uid, repo_id)
904 return _get_refs(context_uid, repo_id)
905
905
906 @reraise_safe_exceptions
906 @reraise_safe_exceptions
907 def get_branch_pointers(self, wire):
907 def get_branch_pointers(self, wire):
908 cache_on, context_uid, repo_id = self._cache_on(wire)
908 cache_on, context_uid, repo_id = self._cache_on(wire)
909 region = self._region(wire)
909 region = self._region(wire)
910
910
911 @region.conditional_cache_on_arguments(condition=cache_on)
911 @region.conditional_cache_on_arguments(condition=cache_on)
912 def _get_branch_pointers(_context_uid, _repo_id):
912 def _get_branch_pointers(_context_uid, _repo_id):
913
913
914 repo_init = self._factory.repo_libgit2(wire)
914 repo_init = self._factory.repo_libgit2(wire)
915 regex = re.compile('^refs/heads')
915 regex = re.compile('^refs/heads')
916 with repo_init as repo:
916 with repo_init as repo:
917 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
917 branches = [ref for ref in repo.listall_reference_objects() if regex.match(ref.name)]
918 return {x.target.hex: x.shorthand for x in branches}
918 return {x.target.hex: x.shorthand for x in branches}
919
919
920 return _get_branch_pointers(context_uid, repo_id)
920 return _get_branch_pointers(context_uid, repo_id)
921
921
922 @reraise_safe_exceptions
922 @reraise_safe_exceptions
923 def head(self, wire, show_exc=True):
923 def head(self, wire, show_exc=True):
924 cache_on, context_uid, repo_id = self._cache_on(wire)
924 cache_on, context_uid, repo_id = self._cache_on(wire)
925 region = self._region(wire)
925 region = self._region(wire)
926
926
927 @region.conditional_cache_on_arguments(condition=cache_on)
927 @region.conditional_cache_on_arguments(condition=cache_on)
928 def _head(_context_uid, _repo_id, _show_exc):
928 def _head(_context_uid, _repo_id, _show_exc):
929 repo_init = self._factory.repo_libgit2(wire)
929 repo_init = self._factory.repo_libgit2(wire)
930 with repo_init as repo:
930 with repo_init as repo:
931 try:
931 try:
932 return repo.head.peel().hex
932 return repo.head.peel().hex
933 except Exception:
933 except Exception:
934 if show_exc:
934 if show_exc:
935 raise
935 raise
936 return _head(context_uid, repo_id, show_exc)
936 return _head(context_uid, repo_id, show_exc)
937
937
938 @reraise_safe_exceptions
938 @reraise_safe_exceptions
939 def init(self, wire):
939 def init(self, wire):
940 repo_path = safe_str(wire['path'])
940 repo_path = safe_str(wire['path'])
941 self.repo = Repo.init(repo_path)
941 self.repo = Repo.init(repo_path)
942
942
943 @reraise_safe_exceptions
943 @reraise_safe_exceptions
944 def init_bare(self, wire):
944 def init_bare(self, wire):
945 repo_path = safe_str(wire['path'])
945 repo_path = safe_str(wire['path'])
946 self.repo = Repo.init_bare(repo_path)
946 self.repo = Repo.init_bare(repo_path)
947
947
948 @reraise_safe_exceptions
948 @reraise_safe_exceptions
949 def revision(self, wire, rev):
949 def revision(self, wire, rev):
950
950
951 cache_on, context_uid, repo_id = self._cache_on(wire)
951 cache_on, context_uid, repo_id = self._cache_on(wire)
952 region = self._region(wire)
952 region = self._region(wire)
953
953
954 @region.conditional_cache_on_arguments(condition=cache_on)
954 @region.conditional_cache_on_arguments(condition=cache_on)
955 def _revision(_context_uid, _repo_id, _rev):
955 def _revision(_context_uid, _repo_id, _rev):
956 repo_init = self._factory.repo_libgit2(wire)
956 repo_init = self._factory.repo_libgit2(wire)
957 with repo_init as repo:
957 with repo_init as repo:
958 commit = repo[rev]
958 commit = repo[rev]
959 obj_data = {
959 obj_data = {
960 'id': commit.id.hex,
960 'id': commit.id.hex,
961 }
961 }
962 # tree objects itself don't have tree_id attribute
962 # tree objects itself don't have tree_id attribute
963 if hasattr(commit, 'tree_id'):
963 if hasattr(commit, 'tree_id'):
964 obj_data['tree'] = commit.tree_id.hex
964 obj_data['tree'] = commit.tree_id.hex
965
965
966 return obj_data
966 return obj_data
967 return _revision(context_uid, repo_id, rev)
967 return _revision(context_uid, repo_id, rev)
968
968
969 @reraise_safe_exceptions
969 @reraise_safe_exceptions
970 def date(self, wire, commit_id):
970 def date(self, wire, commit_id):
971 cache_on, context_uid, repo_id = self._cache_on(wire)
971 cache_on, context_uid, repo_id = self._cache_on(wire)
972 region = self._region(wire)
972 region = self._region(wire)
973
973
974 @region.conditional_cache_on_arguments(condition=cache_on)
974 @region.conditional_cache_on_arguments(condition=cache_on)
975 def _date(_repo_id, _commit_id):
975 def _date(_repo_id, _commit_id):
976 repo_init = self._factory.repo_libgit2(wire)
976 repo_init = self._factory.repo_libgit2(wire)
977 with repo_init as repo:
977 with repo_init as repo:
978 commit = repo[commit_id]
978 commit = repo[commit_id]
979
979
980 if hasattr(commit, 'commit_time'):
980 if hasattr(commit, 'commit_time'):
981 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
981 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
982 else:
982 else:
983 commit = commit.get_object()
983 commit = commit.get_object()
984 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
984 commit_time, commit_time_offset = commit.commit_time, commit.commit_time_offset
985
985
986 # TODO(marcink): check dulwich difference of offset vs timezone
986 # TODO(marcink): check dulwich difference of offset vs timezone
987 return [commit_time, commit_time_offset]
987 return [commit_time, commit_time_offset]
988 return _date(repo_id, commit_id)
988 return _date(repo_id, commit_id)
989
989
990 @reraise_safe_exceptions
990 @reraise_safe_exceptions
991 def author(self, wire, commit_id):
991 def author(self, wire, commit_id):
992 cache_on, context_uid, repo_id = self._cache_on(wire)
992 cache_on, context_uid, repo_id = self._cache_on(wire)
993 region = self._region(wire)
993 region = self._region(wire)
994
994
995 @region.conditional_cache_on_arguments(condition=cache_on)
995 @region.conditional_cache_on_arguments(condition=cache_on)
996 def _author(_repo_id, _commit_id):
996 def _author(_repo_id, _commit_id):
997 repo_init = self._factory.repo_libgit2(wire)
997 repo_init = self._factory.repo_libgit2(wire)
998 with repo_init as repo:
998 with repo_init as repo:
999 commit = repo[commit_id]
999 commit = repo[commit_id]
1000
1000
1001 if hasattr(commit, 'author'):
1001 if hasattr(commit, 'author'):
1002 author = commit.author
1002 author = commit.author
1003 else:
1003 else:
1004 author = commit.get_object().author
1004 author = commit.get_object().author
1005
1005
1006 if author.email:
1006 if author.email:
1007 return f"{author.name} <{author.email}>"
1007 return f"{author.name} <{author.email}>"
1008
1008
1009 try:
1009 try:
1010 return f"{author.name}"
1010 return f"{author.name}"
1011 except Exception:
1011 except Exception:
1012 return f"{safe_str(author.raw_name)}"
1012 return f"{safe_str(author.raw_name)}"
1013
1013
1014 return _author(repo_id, commit_id)
1014 return _author(repo_id, commit_id)
1015
1015
1016 @reraise_safe_exceptions
1016 @reraise_safe_exceptions
1017 def message(self, wire, commit_id):
1017 def message(self, wire, commit_id):
1018 cache_on, context_uid, repo_id = self._cache_on(wire)
1018 cache_on, context_uid, repo_id = self._cache_on(wire)
1019 region = self._region(wire)
1019 region = self._region(wire)
1020
1020
1021 @region.conditional_cache_on_arguments(condition=cache_on)
1021 @region.conditional_cache_on_arguments(condition=cache_on)
1022 def _message(_repo_id, _commit_id):
1022 def _message(_repo_id, _commit_id):
1023 repo_init = self._factory.repo_libgit2(wire)
1023 repo_init = self._factory.repo_libgit2(wire)
1024 with repo_init as repo:
1024 with repo_init as repo:
1025 commit = repo[commit_id]
1025 commit = repo[commit_id]
1026 return commit.message
1026 return commit.message
1027 return _message(repo_id, commit_id)
1027 return _message(repo_id, commit_id)
1028
1028
1029 @reraise_safe_exceptions
1029 @reraise_safe_exceptions
1030 def parents(self, wire, commit_id):
1030 def parents(self, wire, commit_id):
1031 cache_on, context_uid, repo_id = self._cache_on(wire)
1031 cache_on, context_uid, repo_id = self._cache_on(wire)
1032 region = self._region(wire)
1032 region = self._region(wire)
1033
1033
1034 @region.conditional_cache_on_arguments(condition=cache_on)
1034 @region.conditional_cache_on_arguments(condition=cache_on)
1035 def _parents(_repo_id, _commit_id):
1035 def _parents(_repo_id, _commit_id):
1036 repo_init = self._factory.repo_libgit2(wire)
1036 repo_init = self._factory.repo_libgit2(wire)
1037 with repo_init as repo:
1037 with repo_init as repo:
1038 commit = repo[commit_id]
1038 commit = repo[commit_id]
1039 if hasattr(commit, 'parent_ids'):
1039 if hasattr(commit, 'parent_ids'):
1040 parent_ids = commit.parent_ids
1040 parent_ids = commit.parent_ids
1041 else:
1041 else:
1042 parent_ids = commit.get_object().parent_ids
1042 parent_ids = commit.get_object().parent_ids
1043
1043
1044 return [x.hex for x in parent_ids]
1044 return [x.hex for x in parent_ids]
1045 return _parents(repo_id, commit_id)
1045 return _parents(repo_id, commit_id)
1046
1046
1047 @reraise_safe_exceptions
1047 @reraise_safe_exceptions
1048 def children(self, wire, commit_id):
1048 def children(self, wire, commit_id):
1049 cache_on, context_uid, repo_id = self._cache_on(wire)
1049 cache_on, context_uid, repo_id = self._cache_on(wire)
1050 region = self._region(wire)
1050 region = self._region(wire)
1051
1051
1052 head = self.head(wire)
1052 head = self.head(wire)
1053
1053
1054 @region.conditional_cache_on_arguments(condition=cache_on)
1054 @region.conditional_cache_on_arguments(condition=cache_on)
1055 def _children(_repo_id, _commit_id):
1055 def _children(_repo_id, _commit_id):
1056
1056
1057 output, __ = self.run_git_command(
1057 output, __ = self.run_git_command(
1058 wire, ['rev-list', '--all', '--children', f'{commit_id}^..{head}'])
1058 wire, ['rev-list', '--all', '--children', f'{commit_id}^..{head}'])
1059
1059
1060 child_ids = []
1060 child_ids = []
1061 pat = re.compile(fr'^{commit_id}')
1061 pat = re.compile(fr'^{commit_id}')
1062 for line in output.splitlines():
1062 for line in output.splitlines():
1063 line = safe_str(line)
1063 line = safe_str(line)
1064 if pat.match(line):
1064 if pat.match(line):
1065 found_ids = line.split(' ')[1:]
1065 found_ids = line.split(' ')[1:]
1066 child_ids.extend(found_ids)
1066 child_ids.extend(found_ids)
1067 break
1067 break
1068
1068
1069 return child_ids
1069 return child_ids
1070 return _children(repo_id, commit_id)
1070 return _children(repo_id, commit_id)
1071
1071
1072 @reraise_safe_exceptions
1072 @reraise_safe_exceptions
1073 def set_refs(self, wire, key, value):
1073 def set_refs(self, wire, key, value):
1074 repo_init = self._factory.repo_libgit2(wire)
1074 repo_init = self._factory.repo_libgit2(wire)
1075 with repo_init as repo:
1075 with repo_init as repo:
1076 repo.references.create(key, value, force=True)
1076 repo.references.create(key, value, force=True)
1077
1077
1078 @reraise_safe_exceptions
1078 @reraise_safe_exceptions
1079 def create_branch(self, wire, branch_name, commit_id, force=False):
1079 def create_branch(self, wire, branch_name, commit_id, force=False):
1080 repo_init = self._factory.repo_libgit2(wire)
1080 repo_init = self._factory.repo_libgit2(wire)
1081 with repo_init as repo:
1081 with repo_init as repo:
1082 if commit_id:
1082 if commit_id:
1083 commit = repo[commit_id]
1083 commit = repo[commit_id]
1084 else:
1084 else:
1085 # if commit is not given just use the HEAD
1085 # if commit is not given just use the HEAD
1086 commit = repo.head()
1086 commit = repo.head()
1087
1087
1088 if force:
1088 if force:
1089 repo.branches.local.create(branch_name, commit, force=force)
1089 repo.branches.local.create(branch_name, commit, force=force)
1090 elif not repo.branches.get(branch_name):
1090 elif not repo.branches.get(branch_name):
1091 # create only if that branch isn't existing
1091 # create only if that branch isn't existing
1092 repo.branches.local.create(branch_name, commit, force=force)
1092 repo.branches.local.create(branch_name, commit, force=force)
1093
1093
1094 @reraise_safe_exceptions
1094 @reraise_safe_exceptions
1095 def remove_ref(self, wire, key):
1095 def remove_ref(self, wire, key):
1096 repo_init = self._factory.repo_libgit2(wire)
1096 repo_init = self._factory.repo_libgit2(wire)
1097 with repo_init as repo:
1097 with repo_init as repo:
1098 repo.references.delete(key)
1098 repo.references.delete(key)
1099
1099
1100 @reraise_safe_exceptions
1100 @reraise_safe_exceptions
1101 def tag_remove(self, wire, tag_name):
1101 def tag_remove(self, wire, tag_name):
1102 repo_init = self._factory.repo_libgit2(wire)
1102 repo_init = self._factory.repo_libgit2(wire)
1103 with repo_init as repo:
1103 with repo_init as repo:
1104 key = f'refs/tags/{tag_name}'
1104 key = f'refs/tags/{tag_name}'
1105 repo.references.delete(key)
1105 repo.references.delete(key)
1106
1106
1107 @reraise_safe_exceptions
1107 @reraise_safe_exceptions
1108 def tree_changes(self, wire, source_id, target_id):
1108 def tree_changes(self, wire, source_id, target_id):
1109 repo = self._factory.repo(wire)
1109 repo = self._factory.repo(wire)
1110 # source can be empty
1110 # source can be empty
1111 source_id = safe_bytes(source_id if source_id else b'')
1111 source_id = safe_bytes(source_id if source_id else b'')
1112 target_id = safe_bytes(target_id)
1112 target_id = safe_bytes(target_id)
1113
1113
1114 source = repo[source_id].tree if source_id else None
1114 source = repo[source_id].tree if source_id else None
1115 target = repo[target_id].tree
1115 target = repo[target_id].tree
1116 result = repo.object_store.tree_changes(source, target)
1116 result = repo.object_store.tree_changes(source, target)
1117
1117
1118 added = set()
1118 added = set()
1119 modified = set()
1119 modified = set()
1120 deleted = set()
1120 deleted = set()
1121 for (old_path, new_path), (_, _), (_, _) in list(result):
1121 for (old_path, new_path), (_, _), (_, _) in list(result):
1122 if new_path and old_path:
1122 if new_path and old_path:
1123 modified.add(new_path)
1123 modified.add(new_path)
1124 elif new_path and not old_path:
1124 elif new_path and not old_path:
1125 added.add(new_path)
1125 added.add(new_path)
1126 elif not new_path and old_path:
1126 elif not new_path and old_path:
1127 deleted.add(old_path)
1127 deleted.add(old_path)
1128
1128
1129 return list(added), list(modified), list(deleted)
1129 return list(added), list(modified), list(deleted)
1130
1130
1131 @reraise_safe_exceptions
1131 @reraise_safe_exceptions
1132 def tree_and_type_for_path(self, wire, commit_id, path):
1132 def tree_and_type_for_path(self, wire, commit_id, path):
1133
1133
1134 cache_on, context_uid, repo_id = self._cache_on(wire)
1134 cache_on, context_uid, repo_id = self._cache_on(wire)
1135 region = self._region(wire)
1135 region = self._region(wire)
1136
1136
1137 @region.conditional_cache_on_arguments(condition=cache_on)
1137 @region.conditional_cache_on_arguments(condition=cache_on)
1138 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1138 def _tree_and_type_for_path(_context_uid, _repo_id, _commit_id, _path):
1139 repo_init = self._factory.repo_libgit2(wire)
1139 repo_init = self._factory.repo_libgit2(wire)
1140
1140
1141 with repo_init as repo:
1141 with repo_init as repo:
1142 commit = repo[commit_id]
1142 commit = repo[commit_id]
1143 try:
1143 try:
1144 tree = commit.tree[path]
1144 tree = commit.tree[path]
1145 except KeyError:
1145 except KeyError:
1146 return None, None, None
1146 return None, None, None
1147
1147
1148 return tree.id.hex, tree.type_str, tree.filemode
1148 return tree.id.hex, tree.type_str, tree.filemode
1149 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1149 return _tree_and_type_for_path(context_uid, repo_id, commit_id, path)
1150
1150
1151 @reraise_safe_exceptions
1151 @reraise_safe_exceptions
1152 def tree_items(self, wire, tree_id):
1152 def tree_items(self, wire, tree_id):
1153 cache_on, context_uid, repo_id = self._cache_on(wire)
1153 cache_on, context_uid, repo_id = self._cache_on(wire)
1154 region = self._region(wire)
1154 region = self._region(wire)
1155
1155
1156 @region.conditional_cache_on_arguments(condition=cache_on)
1156 @region.conditional_cache_on_arguments(condition=cache_on)
1157 def _tree_items(_repo_id, _tree_id):
1157 def _tree_items(_repo_id, _tree_id):
1158
1158
1159 repo_init = self._factory.repo_libgit2(wire)
1159 repo_init = self._factory.repo_libgit2(wire)
1160 with repo_init as repo:
1160 with repo_init as repo:
1161 try:
1161 try:
1162 tree = repo[tree_id]
1162 tree = repo[tree_id]
1163 except KeyError:
1163 except KeyError:
1164 raise ObjectMissing(f'No tree with id: {tree_id}')
1164 raise ObjectMissing(f'No tree with id: {tree_id}')
1165
1165
1166 result = []
1166 result = []
1167 for item in tree:
1167 for item in tree:
1168 item_sha = item.hex
1168 item_sha = item.hex
1169 item_mode = item.filemode
1169 item_mode = item.filemode
1170 item_type = item.type_str
1170 item_type = item.type_str
1171
1171
1172 if item_type == 'commit':
1172 if item_type == 'commit':
1173 # NOTE(marcink): submodules we translate to 'link' for backward compat
1173 # NOTE(marcink): submodules we translate to 'link' for backward compat
1174 item_type = 'link'
1174 item_type = 'link'
1175
1175
1176 result.append((item.name, item_mode, item_sha, item_type))
1176 result.append((item.name, item_mode, item_sha, item_type))
1177 return result
1177 return result
1178 return _tree_items(repo_id, tree_id)
1178 return _tree_items(repo_id, tree_id)
1179
1179
1180 @reraise_safe_exceptions
1180 @reraise_safe_exceptions
1181 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1181 def diff_2(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1182 """
1182 """
1183 Old version that uses subprocess to call diff
1183 Old version that uses subprocess to call diff
1184 """
1184 """
1185
1185
1186 flags = [
1186 flags = [
1187 '-U%s' % context, '--patch',
1187 '-U%s' % context, '--patch',
1188 '--binary',
1188 '--binary',
1189 '--find-renames',
1189 '--find-renames',
1190 '--no-indent-heuristic',
1190 '--no-indent-heuristic',
1191 # '--indent-heuristic',
1191 # '--indent-heuristic',
1192 #'--full-index',
1192 #'--full-index',
1193 #'--abbrev=40'
1193 #'--abbrev=40'
1194 ]
1194 ]
1195
1195
1196 if opt_ignorews:
1196 if opt_ignorews:
1197 flags.append('--ignore-all-space')
1197 flags.append('--ignore-all-space')
1198
1198
1199 if commit_id_1 == self.EMPTY_COMMIT:
1199 if commit_id_1 == self.EMPTY_COMMIT:
1200 cmd = ['show'] + flags + [commit_id_2]
1200 cmd = ['show'] + flags + [commit_id_2]
1201 else:
1201 else:
1202 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1202 cmd = ['diff'] + flags + [commit_id_1, commit_id_2]
1203
1203
1204 if file_filter:
1204 if file_filter:
1205 cmd.extend(['--', file_filter])
1205 cmd.extend(['--', file_filter])
1206
1206
1207 diff, __ = self.run_git_command(wire, cmd)
1207 diff, __ = self.run_git_command(wire, cmd)
1208 # If we used 'show' command, strip first few lines (until actual diff
1208 # If we used 'show' command, strip first few lines (until actual diff
1209 # starts)
1209 # starts)
1210 if commit_id_1 == self.EMPTY_COMMIT:
1210 if commit_id_1 == self.EMPTY_COMMIT:
1211 lines = diff.splitlines()
1211 lines = diff.splitlines()
1212 x = 0
1212 x = 0
1213 for line in lines:
1213 for line in lines:
1214 if line.startswith(b'diff'):
1214 if line.startswith(b'diff'):
1215 break
1215 break
1216 x += 1
1216 x += 1
1217 # Append new line just like 'diff' command do
1217 # Append new line just like 'diff' command do
1218 diff = '\n'.join(lines[x:]) + '\n'
1218 diff = '\n'.join(lines[x:]) + '\n'
1219 return diff
1219 return diff
1220
1220
1221 @reraise_safe_exceptions
1221 @reraise_safe_exceptions
1222 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1222 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_ignorews, context):
1223 repo_init = self._factory.repo_libgit2(wire)
1223 repo_init = self._factory.repo_libgit2(wire)
1224
1224
1225 with repo_init as repo:
1225 with repo_init as repo:
1226 swap = True
1226 swap = True
1227 flags = 0
1227 flags = 0
1228 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1228 flags |= pygit2.GIT_DIFF_SHOW_BINARY
1229
1229
1230 if opt_ignorews:
1230 if opt_ignorews:
1231 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1231 flags |= pygit2.GIT_DIFF_IGNORE_WHITESPACE
1232
1232
1233 if commit_id_1 == self.EMPTY_COMMIT:
1233 if commit_id_1 == self.EMPTY_COMMIT:
1234 comm1 = repo[commit_id_2]
1234 comm1 = repo[commit_id_2]
1235 diff_obj = comm1.tree.diff_to_tree(
1235 diff_obj = comm1.tree.diff_to_tree(
1236 flags=flags, context_lines=context, swap=swap)
1236 flags=flags, context_lines=context, swap=swap)
1237
1237
1238 else:
1238 else:
1239 comm1 = repo[commit_id_2]
1239 comm1 = repo[commit_id_2]
1240 comm2 = repo[commit_id_1]
1240 comm2 = repo[commit_id_1]
1241 diff_obj = comm1.tree.diff_to_tree(
1241 diff_obj = comm1.tree.diff_to_tree(
1242 comm2.tree, flags=flags, context_lines=context, swap=swap)
1242 comm2.tree, flags=flags, context_lines=context, swap=swap)
1243 similar_flags = 0
1243 similar_flags = 0
1244 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1244 similar_flags |= pygit2.GIT_DIFF_FIND_RENAMES
1245 diff_obj.find_similar(flags=similar_flags)
1245 diff_obj.find_similar(flags=similar_flags)
1246
1246
1247 if file_filter:
1247 if file_filter:
1248 for p in diff_obj:
1248 for p in diff_obj:
1249 if p.delta.old_file.path == file_filter:
1249 if p.delta.old_file.path == file_filter:
1250 return BytesEnvelope(p.data) or BytesEnvelope(b'')
1250 return BytesEnvelope(p.data) or BytesEnvelope(b'')
1251 # fo matching path == no diff
1251 # fo matching path == no diff
1252 return BytesEnvelope(b'')
1252 return BytesEnvelope(b'')
1253
1253
1254 return BytesEnvelope(safe_bytes(diff_obj.patch)) or BytesEnvelope(b'')
1254 return BytesEnvelope(safe_bytes(diff_obj.patch)) or BytesEnvelope(b'')
1255
1255
1256 @reraise_safe_exceptions
1256 @reraise_safe_exceptions
1257 def node_history(self, wire, commit_id, path, limit):
1257 def node_history(self, wire, commit_id, path, limit):
1258 cache_on, context_uid, repo_id = self._cache_on(wire)
1258 cache_on, context_uid, repo_id = self._cache_on(wire)
1259 region = self._region(wire)
1259 region = self._region(wire)
1260
1260
1261 @region.conditional_cache_on_arguments(condition=cache_on)
1261 @region.conditional_cache_on_arguments(condition=cache_on)
1262 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1262 def _node_history(_context_uid, _repo_id, _commit_id, _path, _limit):
1263 # optimize for n==1, rev-list is much faster for that use-case
1263 # optimize for n==1, rev-list is much faster for that use-case
1264 if limit == 1:
1264 if limit == 1:
1265 cmd = ['rev-list', '-1', commit_id, '--', path]
1265 cmd = ['rev-list', '-1', commit_id, '--', path]
1266 else:
1266 else:
1267 cmd = ['log']
1267 cmd = ['log']
1268 if limit:
1268 if limit:
1269 cmd.extend(['-n', str(safe_int(limit, 0))])
1269 cmd.extend(['-n', str(safe_int(limit, 0))])
1270 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1270 cmd.extend(['--pretty=format: %H', '-s', commit_id, '--', path])
1271
1271
1272 output, __ = self.run_git_command(wire, cmd)
1272 output, __ = self.run_git_command(wire, cmd)
1273 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1273 commit_ids = re.findall(rb'[0-9a-fA-F]{40}', output)
1274
1274
1275 return [x for x in commit_ids]
1275 return [x for x in commit_ids]
1276 return _node_history(context_uid, repo_id, commit_id, path, limit)
1276 return _node_history(context_uid, repo_id, commit_id, path, limit)
1277
1277
1278 @reraise_safe_exceptions
1278 @reraise_safe_exceptions
1279 def node_annotate_legacy(self, wire, commit_id, path):
1279 def node_annotate_legacy(self, wire, commit_id, path):
1280 # note: replaced by pygit2 implementation
1280 # note: replaced by pygit2 implementation
1281 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1281 cmd = ['blame', '-l', '--root', '-r', commit_id, '--', path]
1282 # -l ==> outputs long shas (and we need all 40 characters)
1282 # -l ==> outputs long shas (and we need all 40 characters)
1283 # --root ==> doesn't put '^' character for boundaries
1283 # --root ==> doesn't put '^' character for boundaries
1284 # -r commit_id ==> blames for the given commit
1284 # -r commit_id ==> blames for the given commit
1285 output, __ = self.run_git_command(wire, cmd)
1285 output, __ = self.run_git_command(wire, cmd)
1286
1286
1287 result = []
1287 result = []
1288 for i, blame_line in enumerate(output.splitlines()[:-1]):
1288 for i, blame_line in enumerate(output.splitlines()[:-1]):
1289 line_no = i + 1
1289 line_no = i + 1
1290 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1290 blame_commit_id, line = re.split(rb' ', blame_line, 1)
1291 result.append((line_no, blame_commit_id, line))
1291 result.append((line_no, blame_commit_id, line))
1292
1292
1293 return result
1293 return result
1294
1294
1295 @reraise_safe_exceptions
1295 @reraise_safe_exceptions
1296 def node_annotate(self, wire, commit_id, path):
1296 def node_annotate(self, wire, commit_id, path):
1297
1297
1298 result_libgit = []
1298 result_libgit = []
1299 repo_init = self._factory.repo_libgit2(wire)
1299 repo_init = self._factory.repo_libgit2(wire)
1300 with repo_init as repo:
1300 with repo_init as repo:
1301 commit = repo[commit_id]
1301 commit = repo[commit_id]
1302 blame_obj = repo.blame(path, newest_commit=commit_id)
1302 blame_obj = repo.blame(path, newest_commit=commit_id)
1303 for i, line in enumerate(commit.tree[path].data.splitlines()):
1303 for i, line in enumerate(commit.tree[path].data.splitlines()):
1304 line_no = i + 1
1304 line_no = i + 1
1305 hunk = blame_obj.for_line(line_no)
1305 hunk = blame_obj.for_line(line_no)
1306 blame_commit_id = hunk.final_commit_id.hex
1306 blame_commit_id = hunk.final_commit_id.hex
1307
1307
1308 result_libgit.append((line_no, blame_commit_id, line))
1308 result_libgit.append((line_no, blame_commit_id, line))
1309
1309
1310 return result_libgit
1310 return result_libgit
1311
1311
1312 @reraise_safe_exceptions
1312 @reraise_safe_exceptions
1313 def update_server_info(self, wire):
1313 def update_server_info(self, wire):
1314 repo = self._factory.repo(wire)
1314 repo = self._factory.repo(wire)
1315 update_server_info(repo)
1315 update_server_info(repo)
1316
1316
1317 @reraise_safe_exceptions
1317 @reraise_safe_exceptions
1318 def get_all_commit_ids(self, wire):
1318 def get_all_commit_ids(self, wire):
1319
1319
1320 cache_on, context_uid, repo_id = self._cache_on(wire)
1320 cache_on, context_uid, repo_id = self._cache_on(wire)
1321 region = self._region(wire)
1321 region = self._region(wire)
1322
1322
1323 @region.conditional_cache_on_arguments(condition=cache_on)
1323 @region.conditional_cache_on_arguments(condition=cache_on)
1324 def _get_all_commit_ids(_context_uid, _repo_id):
1324 def _get_all_commit_ids(_context_uid, _repo_id):
1325
1325
1326 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1326 cmd = ['rev-list', '--reverse', '--date-order', '--branches', '--tags']
1327 try:
1327 try:
1328 output, __ = self.run_git_command(wire, cmd)
1328 output, __ = self.run_git_command(wire, cmd)
1329 return output.splitlines()
1329 return output.splitlines()
1330 except Exception:
1330 except Exception:
1331 # Can be raised for empty repositories
1331 # Can be raised for empty repositories
1332 return []
1332 return []
1333
1333
1334 @region.conditional_cache_on_arguments(condition=cache_on)
1334 @region.conditional_cache_on_arguments(condition=cache_on)
1335 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1335 def _get_all_commit_ids_pygit2(_context_uid, _repo_id):
1336 repo_init = self._factory.repo_libgit2(wire)
1336 repo_init = self._factory.repo_libgit2(wire)
1337 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1337 from pygit2 import GIT_SORT_REVERSE, GIT_SORT_TIME, GIT_BRANCH_ALL
1338 results = []
1338 results = []
1339 with repo_init as repo:
1339 with repo_init as repo:
1340 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1340 for commit in repo.walk(repo.head.target, GIT_SORT_TIME | GIT_BRANCH_ALL | GIT_SORT_REVERSE):
1341 results.append(commit.id.hex)
1341 results.append(commit.id.hex)
1342
1342
1343 return _get_all_commit_ids(context_uid, repo_id)
1343 return _get_all_commit_ids(context_uid, repo_id)
1344
1344
1345 @reraise_safe_exceptions
1345 @reraise_safe_exceptions
1346 def run_git_command(self, wire, cmd, **opts):
1346 def run_git_command(self, wire, cmd, **opts):
1347 path = wire.get('path', None)
1347 path = wire.get('path', None)
1348
1348
1349 if path and os.path.isdir(path):
1349 if path and os.path.isdir(path):
1350 opts['cwd'] = path
1350 opts['cwd'] = path
1351
1351
1352 if '_bare' in opts:
1352 if '_bare' in opts:
1353 _copts = []
1353 _copts = []
1354 del opts['_bare']
1354 del opts['_bare']
1355 else:
1355 else:
1356 _copts = ['-c', 'core.quotepath=false',]
1356 _copts = ['-c', 'core.quotepath=false',]
1357 safe_call = False
1357 safe_call = False
1358 if '_safe' in opts:
1358 if '_safe' in opts:
1359 # no exc on failure
1359 # no exc on failure
1360 del opts['_safe']
1360 del opts['_safe']
1361 safe_call = True
1361 safe_call = True
1362
1362
1363 if '_copts' in opts:
1363 if '_copts' in opts:
1364 _copts.extend(opts['_copts'] or [])
1364 _copts.extend(opts['_copts'] or [])
1365 del opts['_copts']
1365 del opts['_copts']
1366
1366
1367 gitenv = os.environ.copy()
1367 gitenv = os.environ.copy()
1368 gitenv.update(opts.pop('extra_env', {}))
1368 gitenv.update(opts.pop('extra_env', {}))
1369 # need to clean fix GIT_DIR !
1369 # need to clean fix GIT_DIR !
1370 if 'GIT_DIR' in gitenv:
1370 if 'GIT_DIR' in gitenv:
1371 del gitenv['GIT_DIR']
1371 del gitenv['GIT_DIR']
1372 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1372 gitenv['GIT_CONFIG_NOGLOBAL'] = '1'
1373 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1373 gitenv['GIT_DISCOVERY_ACROSS_FILESYSTEM'] = '1'
1374
1374
1375 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1375 cmd = [settings.GIT_EXECUTABLE] + _copts + cmd
1376 _opts = {'env': gitenv, 'shell': False}
1376 _opts = {'env': gitenv, 'shell': False}
1377
1377
1378 proc = None
1378 proc = None
1379 try:
1379 try:
1380 _opts.update(opts)
1380 _opts.update(opts)
1381 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1381 proc = subprocessio.SubprocessIOChunker(cmd, **_opts)
1382
1382
1383 return b''.join(proc), b''.join(proc.stderr)
1383 return b''.join(proc), b''.join(proc.stderr)
1384 except OSError as err:
1384 except OSError as err:
1385 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
1385 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
1386 tb_err = ("Couldn't run git command (%s).\n"
1386 tb_err = ("Couldn't run git command (%s).\n"
1387 "Original error was:%s\n"
1387 "Original error was:%s\n"
1388 "Call options:%s\n"
1388 "Call options:%s\n"
1389 % (cmd, err, _opts))
1389 % (cmd, err, _opts))
1390 log.exception(tb_err)
1390 log.exception(tb_err)
1391 if safe_call:
1391 if safe_call:
1392 return '', err
1392 return '', err
1393 else:
1393 else:
1394 raise exceptions.VcsException()(tb_err)
1394 raise exceptions.VcsException()(tb_err)
1395 finally:
1395 finally:
1396 if proc:
1396 if proc:
1397 proc.close()
1397 proc.close()
1398
1398
1399 @reraise_safe_exceptions
1399 @reraise_safe_exceptions
1400 def install_hooks(self, wire, force=False):
1400 def install_hooks(self, wire, force=False):
1401 from vcsserver.hook_utils import install_git_hooks
1401 from vcsserver.hook_utils import install_git_hooks
1402 bare = self.bare(wire)
1402 bare = self.bare(wire)
1403 path = wire['path']
1403 path = wire['path']
1404 binary_dir = settings.BINARY_DIR
1404 binary_dir = settings.BINARY_DIR
1405 if binary_dir:
1405 if binary_dir:
1406 os.path.join(binary_dir, 'python3')
1406 os.path.join(binary_dir, 'python3')
1407 return install_git_hooks(path, bare, force_create=force)
1407 return install_git_hooks(path, bare, force_create=force)
1408
1408
1409 @reraise_safe_exceptions
1409 @reraise_safe_exceptions
1410 def get_hooks_info(self, wire):
1410 def get_hooks_info(self, wire):
1411 from vcsserver.hook_utils import (
1411 from vcsserver.hook_utils import (
1412 get_git_pre_hook_version, get_git_post_hook_version)
1412 get_git_pre_hook_version, get_git_post_hook_version)
1413 bare = self.bare(wire)
1413 bare = self.bare(wire)
1414 path = wire['path']
1414 path = wire['path']
1415 return {
1415 return {
1416 'pre_version': get_git_pre_hook_version(path, bare),
1416 'pre_version': get_git_pre_hook_version(path, bare),
1417 'post_version': get_git_post_hook_version(path, bare),
1417 'post_version': get_git_post_hook_version(path, bare),
1418 }
1418 }
1419
1419
1420 @reraise_safe_exceptions
1420 @reraise_safe_exceptions
1421 def set_head_ref(self, wire, head_name):
1421 def set_head_ref(self, wire, head_name):
1422 log.debug('Setting refs/head to `%s`', head_name)
1422 log.debug('Setting refs/head to `%s`', head_name)
1423 repo_init = self._factory.repo_libgit2(wire)
1423 repo_init = self._factory.repo_libgit2(wire)
1424 with repo_init as repo:
1424 with repo_init as repo:
1425 repo.set_head(f'refs/heads/{head_name}')
1425 repo.set_head(f'refs/heads/{head_name}')
1426
1426
1427 return [head_name] + [f'set HEAD to refs/heads/{head_name}']
1427 return [head_name] + [f'set HEAD to refs/heads/{head_name}']
1428
1428
1429 @reraise_safe_exceptions
1429 @reraise_safe_exceptions
1430 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1430 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1431 archive_dir_name, commit_id, cache_config):
1431 archive_dir_name, commit_id, cache_config):
1432
1432
1433 def file_walker(_commit_id, path):
1433 def file_walker(_commit_id, path):
1434 repo_init = self._factory.repo_libgit2(wire)
1434 repo_init = self._factory.repo_libgit2(wire)
1435
1435
1436 with repo_init as repo:
1436 with repo_init as repo:
1437 commit = repo[commit_id]
1437 commit = repo[commit_id]
1438
1438
1439 if path in ['', '/']:
1439 if path in ['', '/']:
1440 tree = commit.tree
1440 tree = commit.tree
1441 else:
1441 else:
1442 tree = commit.tree[path.rstrip('/')]
1442 tree = commit.tree[path.rstrip('/')]
1443 tree_id = tree.id.hex
1443 tree_id = tree.id.hex
1444 try:
1444 try:
1445 tree = repo[tree_id]
1445 tree = repo[tree_id]
1446 except KeyError:
1446 except KeyError:
1447 raise ObjectMissing(f'No tree with id: {tree_id}')
1447 raise ObjectMissing(f'No tree with id: {tree_id}')
1448
1448
1449 index = LibGit2Index.Index()
1449 index = LibGit2Index.Index()
1450 index.read_tree(tree)
1450 index.read_tree(tree)
1451 file_iter = index
1451 file_iter = index
1452
1452
1453 for file_node in file_iter:
1453 for file_node in file_iter:
1454 file_path = file_node.path
1454 file_path = file_node.path
1455 mode = file_node.mode
1455 mode = file_node.mode
1456 is_link = stat.S_ISLNK(mode)
1456 is_link = stat.S_ISLNK(mode)
1457 if mode == pygit2.GIT_FILEMODE_COMMIT:
1457 if mode == pygit2.GIT_FILEMODE_COMMIT:
1458 log.debug('Skipping path %s as a commit node', file_path)
1458 log.debug('Skipping path %s as a commit node', file_path)
1459 continue
1459 continue
1460 yield ArchiveNode(file_path, mode, is_link, repo[file_node.hex].read_raw)
1460 yield ArchiveNode(file_path, mode, is_link, repo[file_node.hex].read_raw)
1461
1461
1462 return store_archive_in_cache(
1462 return store_archive_in_cache(
1463 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
1463 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
@@ -1,1159 +1,1159 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 import binascii
17 import binascii
18 import io
18 import io
19 import logging
19 import logging
20 import stat
20 import stat
21 import urllib.request
21 import urllib.request
22 import urllib.parse
22 import urllib.parse
23 import traceback
23 import traceback
24 import hashlib
24 import hashlib
25
25
26 from hgext import largefiles, rebase, purge
26 from hgext import largefiles, rebase, purge
27
27
28 from mercurial import commands
28 from mercurial import commands
29 from mercurial import unionrepo
29 from mercurial import unionrepo
30 from mercurial import verify
30 from mercurial import verify
31 from mercurial import repair
31 from mercurial import repair
32
32
33 import vcsserver
33 import vcsserver
34 from vcsserver import exceptions
34 from vcsserver import exceptions
35 from vcsserver.base import RepoFactory, obfuscate_qs, raise_from_original, store_archive_in_cache, ArchiveNode, BytesEnvelope, \
35 from vcsserver.base import RepoFactory, obfuscate_qs, raise_from_original, store_archive_in_cache, ArchiveNode, BytesEnvelope, \
36 BinaryEnvelope
36 BinaryEnvelope
37 from vcsserver.hgcompat import (
37 from vcsserver.hgcompat import (
38 archival, bin, clone, config as hgconfig, diffopts, hex, get_ctx,
38 archival, bin, clone, config as hgconfig, diffopts, hex, get_ctx,
39 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler,
39 hg_url as url_parser, httpbasicauthhandler, httpdigestauthhandler,
40 makepeer, instance, match, memctx, exchange, memfilectx, nullrev, hg_merge,
40 makepeer, instance, match, memctx, exchange, memfilectx, nullrev, hg_merge,
41 patch, peer, revrange, ui, hg_tag, Abort, LookupError, RepoError,
41 patch, peer, revrange, ui, hg_tag, Abort, LookupError, RepoError,
42 RepoLookupError, InterventionRequired, RequirementError,
42 RepoLookupError, InterventionRequired, RequirementError,
43 alwaysmatcher, patternmatcher, hgutil, hgext_strip)
43 alwaysmatcher, patternmatcher, hgutil, hgext_strip)
44 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes
44 from vcsserver.str_utils import ascii_bytes, ascii_str, safe_str, safe_bytes
45 from vcsserver.vcs_base import RemoteBase
45 from vcsserver.vcs_base import RemoteBase
46 from vcsserver.config import hooks as hooks_config
46 from vcsserver.config import hooks as hooks_config
47
47
48
48
49 log = logging.getLogger(__name__)
49 log = logging.getLogger(__name__)
50
50
51
51
52 def make_ui_from_config(repo_config):
52 def make_ui_from_config(repo_config):
53
53
54 class LoggingUI(ui.ui):
54 class LoggingUI(ui.ui):
55
55
56 def status(self, *msg, **opts):
56 def status(self, *msg, **opts):
57 str_msg = map(safe_str, msg)
57 str_msg = map(safe_str, msg)
58 log.info(' '.join(str_msg).rstrip('\n'))
58 log.info(' '.join(str_msg).rstrip('\n'))
59 #super(LoggingUI, self).status(*msg, **opts)
59 #super(LoggingUI, self).status(*msg, **opts)
60
60
61 def warn(self, *msg, **opts):
61 def warn(self, *msg, **opts):
62 str_msg = map(safe_str, msg)
62 str_msg = map(safe_str, msg)
63 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
63 log.warning('ui_logger:'+' '.join(str_msg).rstrip('\n'))
64 #super(LoggingUI, self).warn(*msg, **opts)
64 #super(LoggingUI, self).warn(*msg, **opts)
65
65
66 def error(self, *msg, **opts):
66 def error(self, *msg, **opts):
67 str_msg = map(safe_str, msg)
67 str_msg = map(safe_str, msg)
68 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
68 log.error('ui_logger:'+' '.join(str_msg).rstrip('\n'))
69 #super(LoggingUI, self).error(*msg, **opts)
69 #super(LoggingUI, self).error(*msg, **opts)
70
70
71 def note(self, *msg, **opts):
71 def note(self, *msg, **opts):
72 str_msg = map(safe_str, msg)
72 str_msg = map(safe_str, msg)
73 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
73 log.info('ui_logger:'+' '.join(str_msg).rstrip('\n'))
74 #super(LoggingUI, self).note(*msg, **opts)
74 #super(LoggingUI, self).note(*msg, **opts)
75
75
76 def debug(self, *msg, **opts):
76 def debug(self, *msg, **opts):
77 str_msg = map(safe_str, msg)
77 str_msg = map(safe_str, msg)
78 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
78 log.debug('ui_logger:'+' '.join(str_msg).rstrip('\n'))
79 #super(LoggingUI, self).debug(*msg, **opts)
79 #super(LoggingUI, self).debug(*msg, **opts)
80
80
81 baseui = LoggingUI()
81 baseui = LoggingUI()
82
82
83 # clean the baseui object
83 # clean the baseui object
84 baseui._ocfg = hgconfig.config()
84 baseui._ocfg = hgconfig.config()
85 baseui._ucfg = hgconfig.config()
85 baseui._ucfg = hgconfig.config()
86 baseui._tcfg = hgconfig.config()
86 baseui._tcfg = hgconfig.config()
87
87
88 for section, option, value in repo_config:
88 for section, option, value in repo_config:
89 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
89 baseui.setconfig(ascii_bytes(section), ascii_bytes(option), ascii_bytes(value))
90
90
91 # make our hgweb quiet so it doesn't print output
91 # make our hgweb quiet so it doesn't print output
92 baseui.setconfig(b'ui', b'quiet', b'true')
92 baseui.setconfig(b'ui', b'quiet', b'true')
93
93
94 baseui.setconfig(b'ui', b'paginate', b'never')
94 baseui.setconfig(b'ui', b'paginate', b'never')
95 # for better Error reporting of Mercurial
95 # for better Error reporting of Mercurial
96 baseui.setconfig(b'ui', b'message-output', b'stderr')
96 baseui.setconfig(b'ui', b'message-output', b'stderr')
97
97
98 # force mercurial to only use 1 thread, otherwise it may try to set a
98 # force mercurial to only use 1 thread, otherwise it may try to set a
99 # signal in a non-main thread, thus generating a ValueError.
99 # signal in a non-main thread, thus generating a ValueError.
100 baseui.setconfig(b'worker', b'numcpus', 1)
100 baseui.setconfig(b'worker', b'numcpus', 1)
101
101
102 # If there is no config for the largefiles extension, we explicitly disable
102 # If there is no config for the largefiles extension, we explicitly disable
103 # it here. This overrides settings from repositories hgrc file. Recent
103 # it here. This overrides settings from repositories hgrc file. Recent
104 # mercurial versions enable largefiles in hgrc on clone from largefile
104 # mercurial versions enable largefiles in hgrc on clone from largefile
105 # repo.
105 # repo.
106 if not baseui.hasconfig(b'extensions', b'largefiles'):
106 if not baseui.hasconfig(b'extensions', b'largefiles'):
107 log.debug('Explicitly disable largefiles extension for repo.')
107 log.debug('Explicitly disable largefiles extension for repo.')
108 baseui.setconfig(b'extensions', b'largefiles', b'!')
108 baseui.setconfig(b'extensions', b'largefiles', b'!')
109
109
110 return baseui
110 return baseui
111
111
112
112
113 def reraise_safe_exceptions(func):
113 def reraise_safe_exceptions(func):
114 """Decorator for converting mercurial exceptions to something neutral."""
114 """Decorator for converting mercurial exceptions to something neutral."""
115
115
116 def wrapper(*args, **kwargs):
116 def wrapper(*args, **kwargs):
117 try:
117 try:
118 return func(*args, **kwargs)
118 return func(*args, **kwargs)
119 except (Abort, InterventionRequired) as e:
119 except (Abort, InterventionRequired) as e:
120 raise_from_original(exceptions.AbortException(e), e)
120 raise_from_original(exceptions.AbortException(e), e)
121 except RepoLookupError as e:
121 except RepoLookupError as e:
122 raise_from_original(exceptions.LookupException(e), e)
122 raise_from_original(exceptions.LookupException(e), e)
123 except RequirementError as e:
123 except RequirementError as e:
124 raise_from_original(exceptions.RequirementException(e), e)
124 raise_from_original(exceptions.RequirementException(e), e)
125 except RepoError as e:
125 except RepoError as e:
126 raise_from_original(exceptions.VcsException(e), e)
126 raise_from_original(exceptions.VcsException(e), e)
127 except LookupError as e:
127 except LookupError as e:
128 raise_from_original(exceptions.LookupException(e), e)
128 raise_from_original(exceptions.LookupException(e), e)
129 except Exception as e:
129 except Exception as e:
130 if not hasattr(e, '_vcs_kind'):
130 if not hasattr(e, '_vcs_kind'):
131 log.exception("Unhandled exception in hg remote call")
131 log.exception("Unhandled exception in hg remote call")
132 raise_from_original(exceptions.UnhandledException(e), e)
132 raise_from_original(exceptions.UnhandledException(e), e)
133
133
134 raise
134 raise
135 return wrapper
135 return wrapper
136
136
137
137
138 class MercurialFactory(RepoFactory):
138 class MercurialFactory(RepoFactory):
139 repo_type = 'hg'
139 repo_type = 'hg'
140
140
141 def _create_config(self, config, hooks=True):
141 def _create_config(self, config, hooks=True):
142 if not hooks:
142 if not hooks:
143
143
144 hooks_to_clean = {
144 hooks_to_clean = {
145
145
146 hooks_config.HOOK_REPO_SIZE,
146 hooks_config.HOOK_REPO_SIZE,
147 hooks_config.HOOK_PRE_PULL,
147 hooks_config.HOOK_PRE_PULL,
148 hooks_config.HOOK_PULL,
148 hooks_config.HOOK_PULL,
149
149
150 hooks_config.HOOK_PRE_PUSH,
150 hooks_config.HOOK_PRE_PUSH,
151 # TODO: what about PRETXT, this was disabled in pre 5.0.0
151 # TODO: what about PRETXT, this was disabled in pre 5.0.0
152 hooks_config.HOOK_PRETX_PUSH,
152 hooks_config.HOOK_PRETX_PUSH,
153
153
154 }
154 }
155 new_config = []
155 new_config = []
156 for section, option, value in config:
156 for section, option, value in config:
157 if section == 'hooks' and option in hooks_to_clean:
157 if section == 'hooks' and option in hooks_to_clean:
158 continue
158 continue
159 new_config.append((section, option, value))
159 new_config.append((section, option, value))
160 config = new_config
160 config = new_config
161
161
162 baseui = make_ui_from_config(config)
162 baseui = make_ui_from_config(config)
163 return baseui
163 return baseui
164
164
165 def _create_repo(self, wire, create):
165 def _create_repo(self, wire, create):
166 baseui = self._create_config(wire["config"])
166 baseui = self._create_config(wire["config"])
167 repo = instance(baseui, safe_bytes(wire["path"]), create)
167 repo = instance(baseui, safe_bytes(wire["path"]), create)
168 log.debug('repository created: got HG object: %s', repo)
168 log.debug('repository created: got HG object: %s', repo)
169 return repo
169 return repo
170
170
171 def repo(self, wire, create=False):
171 def repo(self, wire, create=False):
172 """
172 """
173 Get a repository instance for the given path.
173 Get a repository instance for the given path.
174 """
174 """
175 return self._create_repo(wire, create)
175 return self._create_repo(wire, create)
176
176
177
177
178 def patch_ui_message_output(baseui):
178 def patch_ui_message_output(baseui):
179 baseui.setconfig(b'ui', b'quiet', b'false')
179 baseui.setconfig(b'ui', b'quiet', b'false')
180 output = io.BytesIO()
180 output = io.BytesIO()
181
181
182 def write(data, **unused_kwargs):
182 def write(data, **unused_kwargs):
183 output.write(data)
183 output.write(data)
184
184
185 baseui.status = write
185 baseui.status = write
186 baseui.write = write
186 baseui.write = write
187 baseui.warn = write
187 baseui.warn = write
188 baseui.debug = write
188 baseui.debug = write
189
189
190 return baseui, output
190 return baseui, output
191
191
192
192
193 def get_obfuscated_url(url_obj):
193 def get_obfuscated_url(url_obj):
194 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
194 url_obj.passwd = b'*****' if url_obj.passwd else url_obj.passwd
195 url_obj.query = obfuscate_qs(url_obj.query)
195 url_obj.query = obfuscate_qs(url_obj.query)
196 obfuscated_uri = str(url_obj)
196 obfuscated_uri = str(url_obj)
197 return obfuscated_uri
197 return obfuscated_uri
198
198
199
199
200 def normalize_url_for_hg(url: str):
200 def normalize_url_for_hg(url: str):
201 _proto = None
201 _proto = None
202
202
203 if '+' in url[:url.find('://')]:
203 if '+' in url[:url.find('://')]:
204 _proto = url[0:url.find('+')]
204 _proto = url[0:url.find('+')]
205 url = url[url.find('+') + 1:]
205 url = url[url.find('+') + 1:]
206 return url, _proto
206 return url, _proto
207
207
208
208
209 class HgRemote(RemoteBase):
209 class HgRemote(RemoteBase):
210
210
211 def __init__(self, factory):
211 def __init__(self, factory):
212 self._factory = factory
212 self._factory = factory
213 self._bulk_methods = {
213 self._bulk_methods = {
214 "affected_files": self.ctx_files,
214 "affected_files": self.ctx_files,
215 "author": self.ctx_user,
215 "author": self.ctx_user,
216 "branch": self.ctx_branch,
216 "branch": self.ctx_branch,
217 "children": self.ctx_children,
217 "children": self.ctx_children,
218 "date": self.ctx_date,
218 "date": self.ctx_date,
219 "message": self.ctx_description,
219 "message": self.ctx_description,
220 "parents": self.ctx_parents,
220 "parents": self.ctx_parents,
221 "status": self.ctx_status,
221 "status": self.ctx_status,
222 "obsolete": self.ctx_obsolete,
222 "obsolete": self.ctx_obsolete,
223 "phase": self.ctx_phase,
223 "phase": self.ctx_phase,
224 "hidden": self.ctx_hidden,
224 "hidden": self.ctx_hidden,
225 "_file_paths": self.ctx_list,
225 "_file_paths": self.ctx_list,
226 }
226 }
227 self._bulk_file_methods = {
227 self._bulk_file_methods = {
228 "size": self.fctx_size,
228 "size": self.fctx_size,
229 "data": self.fctx_node_data,
229 "data": self.fctx_node_data,
230 "flags": self.fctx_flags,
230 "flags": self.fctx_flags,
231 "is_binary": self.is_binary,
231 "is_binary": self.is_binary,
232 "md5": self.md5_hash,
232 "md5": self.md5_hash,
233 }
233 }
234
234
235 def _get_ctx(self, repo, ref):
235 def _get_ctx(self, repo, ref):
236 return get_ctx(repo, ref)
236 return get_ctx(repo, ref)
237
237
238 @reraise_safe_exceptions
238 @reraise_safe_exceptions
239 def discover_hg_version(self):
239 def discover_hg_version(self):
240 from mercurial import util
240 from mercurial import util
241 return safe_str(util.version())
241 return safe_str(util.version())
242
242
243 @reraise_safe_exceptions
243 @reraise_safe_exceptions
244 def is_empty(self, wire):
244 def is_empty(self, wire):
245 repo = self._factory.repo(wire)
245 repo = self._factory.repo(wire)
246
246
247 try:
247 try:
248 return len(repo) == 0
248 return len(repo) == 0
249 except Exception:
249 except Exception:
250 log.exception("failed to read object_store")
250 log.exception("failed to read object_store")
251 return False
251 return False
252
252
253 @reraise_safe_exceptions
253 @reraise_safe_exceptions
254 def bookmarks(self, wire):
254 def bookmarks(self, wire):
255 cache_on, context_uid, repo_id = self._cache_on(wire)
255 cache_on, context_uid, repo_id = self._cache_on(wire)
256 region = self._region(wire)
256 region = self._region(wire)
257
257
258 @region.conditional_cache_on_arguments(condition=cache_on)
258 @region.conditional_cache_on_arguments(condition=cache_on)
259 def _bookmarks(_context_uid, _repo_id):
259 def _bookmarks(_context_uid, _repo_id):
260 repo = self._factory.repo(wire)
260 repo = self._factory.repo(wire)
261 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo._bookmarks.items()}
261 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo._bookmarks.items()}
262
262
263 return _bookmarks(context_uid, repo_id)
263 return _bookmarks(context_uid, repo_id)
264
264
265 @reraise_safe_exceptions
265 @reraise_safe_exceptions
266 def branches(self, wire, normal, closed):
266 def branches(self, wire, normal, closed):
267 cache_on, context_uid, repo_id = self._cache_on(wire)
267 cache_on, context_uid, repo_id = self._cache_on(wire)
268 region = self._region(wire)
268 region = self._region(wire)
269
269
270 @region.conditional_cache_on_arguments(condition=cache_on)
270 @region.conditional_cache_on_arguments(condition=cache_on)
271 def _branches(_context_uid, _repo_id, _normal, _closed):
271 def _branches(_context_uid, _repo_id, _normal, _closed):
272 repo = self._factory.repo(wire)
272 repo = self._factory.repo(wire)
273 iter_branches = repo.branchmap().iterbranches()
273 iter_branches = repo.branchmap().iterbranches()
274 bt = {}
274 bt = {}
275 for branch_name, _heads, tip_node, is_closed in iter_branches:
275 for branch_name, _heads, tip_node, is_closed in iter_branches:
276 if normal and not is_closed:
276 if normal and not is_closed:
277 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
277 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
278 if closed and is_closed:
278 if closed and is_closed:
279 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
279 bt[safe_str(branch_name)] = ascii_str(hex(tip_node))
280
280
281 return bt
281 return bt
282
282
283 return _branches(context_uid, repo_id, normal, closed)
283 return _branches(context_uid, repo_id, normal, closed)
284
284
285 @reraise_safe_exceptions
285 @reraise_safe_exceptions
286 def bulk_request(self, wire, commit_id, pre_load):
286 def bulk_request(self, wire, commit_id, pre_load):
287 cache_on, context_uid, repo_id = self._cache_on(wire)
287 cache_on, context_uid, repo_id = self._cache_on(wire)
288 region = self._region(wire)
288 region = self._region(wire)
289
289
290 @region.conditional_cache_on_arguments(condition=cache_on)
290 @region.conditional_cache_on_arguments(condition=cache_on)
291 def _bulk_request(_repo_id, _commit_id, _pre_load):
291 def _bulk_request(_repo_id, _commit_id, _pre_load):
292 result = {}
292 result = {}
293 for attr in pre_load:
293 for attr in pre_load:
294 try:
294 try:
295 method = self._bulk_methods[attr]
295 method = self._bulk_methods[attr]
296 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
296 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
297 result[attr] = method(wire, commit_id)
297 result[attr] = method(wire, commit_id)
298 except KeyError as e:
298 except KeyError as e:
299 raise exceptions.VcsException(e)(
299 raise exceptions.VcsException(e)(
300 'Unknown bulk attribute: "%s"' % attr)
300 'Unknown bulk attribute: "%s"' % attr)
301 return result
301 return result
302
302
303 return _bulk_request(repo_id, commit_id, sorted(pre_load))
303 return _bulk_request(repo_id, commit_id, sorted(pre_load))
304
304
305 @reraise_safe_exceptions
305 @reraise_safe_exceptions
306 def ctx_branch(self, wire, commit_id):
306 def ctx_branch(self, wire, commit_id):
307 cache_on, context_uid, repo_id = self._cache_on(wire)
307 cache_on, context_uid, repo_id = self._cache_on(wire)
308 region = self._region(wire)
308 region = self._region(wire)
309
309
310 @region.conditional_cache_on_arguments(condition=cache_on)
310 @region.conditional_cache_on_arguments(condition=cache_on)
311 def _ctx_branch(_repo_id, _commit_id):
311 def _ctx_branch(_repo_id, _commit_id):
312 repo = self._factory.repo(wire)
312 repo = self._factory.repo(wire)
313 ctx = self._get_ctx(repo, commit_id)
313 ctx = self._get_ctx(repo, commit_id)
314 return ctx.branch()
314 return ctx.branch()
315 return _ctx_branch(repo_id, commit_id)
315 return _ctx_branch(repo_id, commit_id)
316
316
317 @reraise_safe_exceptions
317 @reraise_safe_exceptions
318 def ctx_date(self, wire, commit_id):
318 def ctx_date(self, wire, commit_id):
319 cache_on, context_uid, repo_id = self._cache_on(wire)
319 cache_on, context_uid, repo_id = self._cache_on(wire)
320 region = self._region(wire)
320 region = self._region(wire)
321
321
322 @region.conditional_cache_on_arguments(condition=cache_on)
322 @region.conditional_cache_on_arguments(condition=cache_on)
323 def _ctx_date(_repo_id, _commit_id):
323 def _ctx_date(_repo_id, _commit_id):
324 repo = self._factory.repo(wire)
324 repo = self._factory.repo(wire)
325 ctx = self._get_ctx(repo, commit_id)
325 ctx = self._get_ctx(repo, commit_id)
326 return ctx.date()
326 return ctx.date()
327 return _ctx_date(repo_id, commit_id)
327 return _ctx_date(repo_id, commit_id)
328
328
329 @reraise_safe_exceptions
329 @reraise_safe_exceptions
330 def ctx_description(self, wire, revision):
330 def ctx_description(self, wire, revision):
331 repo = self._factory.repo(wire)
331 repo = self._factory.repo(wire)
332 ctx = self._get_ctx(repo, revision)
332 ctx = self._get_ctx(repo, revision)
333 return ctx.description()
333 return ctx.description()
334
334
335 @reraise_safe_exceptions
335 @reraise_safe_exceptions
336 def ctx_files(self, wire, commit_id):
336 def ctx_files(self, wire, commit_id):
337 cache_on, context_uid, repo_id = self._cache_on(wire)
337 cache_on, context_uid, repo_id = self._cache_on(wire)
338 region = self._region(wire)
338 region = self._region(wire)
339
339
340 @region.conditional_cache_on_arguments(condition=cache_on)
340 @region.conditional_cache_on_arguments(condition=cache_on)
341 def _ctx_files(_repo_id, _commit_id):
341 def _ctx_files(_repo_id, _commit_id):
342 repo = self._factory.repo(wire)
342 repo = self._factory.repo(wire)
343 ctx = self._get_ctx(repo, commit_id)
343 ctx = self._get_ctx(repo, commit_id)
344 return ctx.files()
344 return ctx.files()
345
345
346 return _ctx_files(repo_id, commit_id)
346 return _ctx_files(repo_id, commit_id)
347
347
348 @reraise_safe_exceptions
348 @reraise_safe_exceptions
349 def ctx_list(self, path, revision):
349 def ctx_list(self, path, revision):
350 repo = self._factory.repo(path)
350 repo = self._factory.repo(path)
351 ctx = self._get_ctx(repo, revision)
351 ctx = self._get_ctx(repo, revision)
352 return list(ctx)
352 return list(ctx)
353
353
354 @reraise_safe_exceptions
354 @reraise_safe_exceptions
355 def ctx_parents(self, wire, commit_id):
355 def ctx_parents(self, wire, commit_id):
356 cache_on, context_uid, repo_id = self._cache_on(wire)
356 cache_on, context_uid, repo_id = self._cache_on(wire)
357 region = self._region(wire)
357 region = self._region(wire)
358
358
359 @region.conditional_cache_on_arguments(condition=cache_on)
359 @region.conditional_cache_on_arguments(condition=cache_on)
360 def _ctx_parents(_repo_id, _commit_id):
360 def _ctx_parents(_repo_id, _commit_id):
361 repo = self._factory.repo(wire)
361 repo = self._factory.repo(wire)
362 ctx = self._get_ctx(repo, commit_id)
362 ctx = self._get_ctx(repo, commit_id)
363 return [parent.hex() for parent in ctx.parents()
363 return [parent.hex() for parent in ctx.parents()
364 if not (parent.hidden() or parent.obsolete())]
364 if not (parent.hidden() or parent.obsolete())]
365
365
366 return _ctx_parents(repo_id, commit_id)
366 return _ctx_parents(repo_id, commit_id)
367
367
368 @reraise_safe_exceptions
368 @reraise_safe_exceptions
369 def ctx_children(self, wire, commit_id):
369 def ctx_children(self, wire, commit_id):
370 cache_on, context_uid, repo_id = self._cache_on(wire)
370 cache_on, context_uid, repo_id = self._cache_on(wire)
371 region = self._region(wire)
371 region = self._region(wire)
372
372
373 @region.conditional_cache_on_arguments(condition=cache_on)
373 @region.conditional_cache_on_arguments(condition=cache_on)
374 def _ctx_children(_repo_id, _commit_id):
374 def _ctx_children(_repo_id, _commit_id):
375 repo = self._factory.repo(wire)
375 repo = self._factory.repo(wire)
376 ctx = self._get_ctx(repo, commit_id)
376 ctx = self._get_ctx(repo, commit_id)
377 return [child.hex() for child in ctx.children()
377 return [child.hex() for child in ctx.children()
378 if not (child.hidden() or child.obsolete())]
378 if not (child.hidden() or child.obsolete())]
379
379
380 return _ctx_children(repo_id, commit_id)
380 return _ctx_children(repo_id, commit_id)
381
381
382 @reraise_safe_exceptions
382 @reraise_safe_exceptions
383 def ctx_phase(self, wire, commit_id):
383 def ctx_phase(self, wire, commit_id):
384 cache_on, context_uid, repo_id = self._cache_on(wire)
384 cache_on, context_uid, repo_id = self._cache_on(wire)
385 region = self._region(wire)
385 region = self._region(wire)
386
386
387 @region.conditional_cache_on_arguments(condition=cache_on)
387 @region.conditional_cache_on_arguments(condition=cache_on)
388 def _ctx_phase(_context_uid, _repo_id, _commit_id):
388 def _ctx_phase(_context_uid, _repo_id, _commit_id):
389 repo = self._factory.repo(wire)
389 repo = self._factory.repo(wire)
390 ctx = self._get_ctx(repo, commit_id)
390 ctx = self._get_ctx(repo, commit_id)
391 # public=0, draft=1, secret=3
391 # public=0, draft=1, secret=3
392 return ctx.phase()
392 return ctx.phase()
393 return _ctx_phase(context_uid, repo_id, commit_id)
393 return _ctx_phase(context_uid, repo_id, commit_id)
394
394
395 @reraise_safe_exceptions
395 @reraise_safe_exceptions
396 def ctx_obsolete(self, wire, commit_id):
396 def ctx_obsolete(self, wire, commit_id):
397 cache_on, context_uid, repo_id = self._cache_on(wire)
397 cache_on, context_uid, repo_id = self._cache_on(wire)
398 region = self._region(wire)
398 region = self._region(wire)
399
399
400 @region.conditional_cache_on_arguments(condition=cache_on)
400 @region.conditional_cache_on_arguments(condition=cache_on)
401 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
401 def _ctx_obsolete(_context_uid, _repo_id, _commit_id):
402 repo = self._factory.repo(wire)
402 repo = self._factory.repo(wire)
403 ctx = self._get_ctx(repo, commit_id)
403 ctx = self._get_ctx(repo, commit_id)
404 return ctx.obsolete()
404 return ctx.obsolete()
405 return _ctx_obsolete(context_uid, repo_id, commit_id)
405 return _ctx_obsolete(context_uid, repo_id, commit_id)
406
406
407 @reraise_safe_exceptions
407 @reraise_safe_exceptions
408 def ctx_hidden(self, wire, commit_id):
408 def ctx_hidden(self, wire, commit_id):
409 cache_on, context_uid, repo_id = self._cache_on(wire)
409 cache_on, context_uid, repo_id = self._cache_on(wire)
410 region = self._region(wire)
410 region = self._region(wire)
411
411
412 @region.conditional_cache_on_arguments(condition=cache_on)
412 @region.conditional_cache_on_arguments(condition=cache_on)
413 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
413 def _ctx_hidden(_context_uid, _repo_id, _commit_id):
414 repo = self._factory.repo(wire)
414 repo = self._factory.repo(wire)
415 ctx = self._get_ctx(repo, commit_id)
415 ctx = self._get_ctx(repo, commit_id)
416 return ctx.hidden()
416 return ctx.hidden()
417 return _ctx_hidden(context_uid, repo_id, commit_id)
417 return _ctx_hidden(context_uid, repo_id, commit_id)
418
418
419 @reraise_safe_exceptions
419 @reraise_safe_exceptions
420 def ctx_substate(self, wire, revision):
420 def ctx_substate(self, wire, revision):
421 repo = self._factory.repo(wire)
421 repo = self._factory.repo(wire)
422 ctx = self._get_ctx(repo, revision)
422 ctx = self._get_ctx(repo, revision)
423 return ctx.substate
423 return ctx.substate
424
424
425 @reraise_safe_exceptions
425 @reraise_safe_exceptions
426 def ctx_status(self, wire, revision):
426 def ctx_status(self, wire, revision):
427 repo = self._factory.repo(wire)
427 repo = self._factory.repo(wire)
428 ctx = self._get_ctx(repo, revision)
428 ctx = self._get_ctx(repo, revision)
429 status = repo[ctx.p1().node()].status(other=ctx.node())
429 status = repo[ctx.p1().node()].status(other=ctx.node())
430 # object of status (odd, custom named tuple in mercurial) is not
430 # object of status (odd, custom named tuple in mercurial) is not
431 # correctly serializable, we make it a list, as the underling
431 # correctly serializable, we make it a list, as the underling
432 # API expects this to be a list
432 # API expects this to be a list
433 return list(status)
433 return list(status)
434
434
435 @reraise_safe_exceptions
435 @reraise_safe_exceptions
436 def ctx_user(self, wire, revision):
436 def ctx_user(self, wire, revision):
437 repo = self._factory.repo(wire)
437 repo = self._factory.repo(wire)
438 ctx = self._get_ctx(repo, revision)
438 ctx = self._get_ctx(repo, revision)
439 return ctx.user()
439 return ctx.user()
440
440
441 @reraise_safe_exceptions
441 @reraise_safe_exceptions
442 def check_url(self, url, config):
442 def check_url(self, url, config):
443 url, _proto = normalize_url_for_hg(url)
443 url, _proto = normalize_url_for_hg(url)
444 url_obj = url_parser(safe_bytes(url))
444 url_obj = url_parser(safe_bytes(url))
445
445
446 test_uri = safe_str(url_obj.authinfo()[0])
446 test_uri = safe_str(url_obj.authinfo()[0])
447 authinfo = url_obj.authinfo()[1]
447 authinfo = url_obj.authinfo()[1]
448 obfuscated_uri = get_obfuscated_url(url_obj)
448 obfuscated_uri = get_obfuscated_url(url_obj)
449 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
449 log.info("Checking URL for remote cloning/import: %s", obfuscated_uri)
450
450
451 handlers = []
451 handlers = []
452 if authinfo:
452 if authinfo:
453 # create a password manager
453 # create a password manager
454 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
454 passmgr = urllib.request.HTTPPasswordMgrWithDefaultRealm()
455 passmgr.add_password(*authinfo)
455 passmgr.add_password(*authinfo)
456
456
457 handlers.extend((httpbasicauthhandler(passmgr),
457 handlers.extend((httpbasicauthhandler(passmgr),
458 httpdigestauthhandler(passmgr)))
458 httpdigestauthhandler(passmgr)))
459
459
460 o = urllib.request.build_opener(*handlers)
460 o = urllib.request.build_opener(*handlers)
461 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
461 o.addheaders = [('Content-Type', 'application/mercurial-0.1'),
462 ('Accept', 'application/mercurial-0.1')]
462 ('Accept', 'application/mercurial-0.1')]
463
463
464 q = {"cmd": 'between'}
464 q = {"cmd": 'between'}
465 q.update({'pairs': "{}-{}".format('0' * 40, '0' * 40)})
465 q.update({'pairs': "{}-{}".format('0' * 40, '0' * 40)})
466 qs = '?%s' % urllib.parse.urlencode(q)
466 qs = '?%s' % urllib.parse.urlencode(q)
467 cu = "{}{}".format(test_uri, qs)
467 cu = f"{test_uri}{qs}"
468 req = urllib.request.Request(cu, None, {})
468 req = urllib.request.Request(cu, None, {})
469
469
470 try:
470 try:
471 log.debug("Trying to open URL %s", obfuscated_uri)
471 log.debug("Trying to open URL %s", obfuscated_uri)
472 resp = o.open(req)
472 resp = o.open(req)
473 if resp.code != 200:
473 if resp.code != 200:
474 raise exceptions.URLError()('Return Code is not 200')
474 raise exceptions.URLError()('Return Code is not 200')
475 except Exception as e:
475 except Exception as e:
476 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
476 log.warning("URL cannot be opened: %s", obfuscated_uri, exc_info=True)
477 # means it cannot be cloned
477 # means it cannot be cloned
478 raise exceptions.URLError(e)("[{}] org_exc: {}".format(obfuscated_uri, e))
478 raise exceptions.URLError(e)(f"[{obfuscated_uri}] org_exc: {e}")
479
479
480 # now check if it's a proper hg repo, but don't do it for svn
480 # now check if it's a proper hg repo, but don't do it for svn
481 try:
481 try:
482 if _proto == 'svn':
482 if _proto == 'svn':
483 pass
483 pass
484 else:
484 else:
485 # check for pure hg repos
485 # check for pure hg repos
486 log.debug(
486 log.debug(
487 "Verifying if URL is a Mercurial repository: %s", obfuscated_uri)
487 "Verifying if URL is a Mercurial repository: %s", obfuscated_uri)
488 ui = make_ui_from_config(config)
488 ui = make_ui_from_config(config)
489 peer_checker = makepeer(ui, safe_bytes(url))
489 peer_checker = makepeer(ui, safe_bytes(url))
490 peer_checker.lookup(b'tip')
490 peer_checker.lookup(b'tip')
491 except Exception as e:
491 except Exception as e:
492 log.warning("URL is not a valid Mercurial repository: %s",
492 log.warning("URL is not a valid Mercurial repository: %s",
493 obfuscated_uri)
493 obfuscated_uri)
494 raise exceptions.URLError(e)(
494 raise exceptions.URLError(e)(
495 "url [%s] does not look like an hg repo org_exc: %s"
495 "url [%s] does not look like an hg repo org_exc: %s"
496 % (obfuscated_uri, e))
496 % (obfuscated_uri, e))
497
497
498 log.info("URL is a valid Mercurial repository: %s", obfuscated_uri)
498 log.info("URL is a valid Mercurial repository: %s", obfuscated_uri)
499 return True
499 return True
500
500
501 @reraise_safe_exceptions
501 @reraise_safe_exceptions
502 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
502 def diff(self, wire, commit_id_1, commit_id_2, file_filter, opt_git, opt_ignorews, context):
503 repo = self._factory.repo(wire)
503 repo = self._factory.repo(wire)
504
504
505 if file_filter:
505 if file_filter:
506 # unpack the file-filter
506 # unpack the file-filter
507 repo_path, node_path = file_filter
507 repo_path, node_path = file_filter
508 match_filter = match(safe_bytes(repo_path), b'', [safe_bytes(node_path)])
508 match_filter = match(safe_bytes(repo_path), b'', [safe_bytes(node_path)])
509 else:
509 else:
510 match_filter = file_filter
510 match_filter = file_filter
511 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
511 opts = diffopts(git=opt_git, ignorews=opt_ignorews, context=context, showfunc=1)
512
512
513 try:
513 try:
514 diff_iter = patch.diff(
514 diff_iter = patch.diff(
515 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts)
515 repo, node1=commit_id_1, node2=commit_id_2, match=match_filter, opts=opts)
516 return BytesEnvelope(b"".join(diff_iter))
516 return BytesEnvelope(b"".join(diff_iter))
517 except RepoLookupError as e:
517 except RepoLookupError as e:
518 raise exceptions.LookupException(e)()
518 raise exceptions.LookupException(e)()
519
519
520 @reraise_safe_exceptions
520 @reraise_safe_exceptions
521 def node_history(self, wire, revision, path, limit):
521 def node_history(self, wire, revision, path, limit):
522 cache_on, context_uid, repo_id = self._cache_on(wire)
522 cache_on, context_uid, repo_id = self._cache_on(wire)
523 region = self._region(wire)
523 region = self._region(wire)
524
524
525 @region.conditional_cache_on_arguments(condition=cache_on)
525 @region.conditional_cache_on_arguments(condition=cache_on)
526 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
526 def _node_history(_context_uid, _repo_id, _revision, _path, _limit):
527 repo = self._factory.repo(wire)
527 repo = self._factory.repo(wire)
528
528
529 ctx = self._get_ctx(repo, revision)
529 ctx = self._get_ctx(repo, revision)
530 fctx = ctx.filectx(safe_bytes(path))
530 fctx = ctx.filectx(safe_bytes(path))
531
531
532 def history_iter():
532 def history_iter():
533 limit_rev = fctx.rev()
533 limit_rev = fctx.rev()
534 for obj in reversed(list(fctx.filelog())):
534 for obj in reversed(list(fctx.filelog())):
535 obj = fctx.filectx(obj)
535 obj = fctx.filectx(obj)
536 ctx = obj.changectx()
536 ctx = obj.changectx()
537 if ctx.hidden() or ctx.obsolete():
537 if ctx.hidden() or ctx.obsolete():
538 continue
538 continue
539
539
540 if limit_rev >= obj.rev():
540 if limit_rev >= obj.rev():
541 yield obj
541 yield obj
542
542
543 history = []
543 history = []
544 for cnt, obj in enumerate(history_iter()):
544 for cnt, obj in enumerate(history_iter()):
545 if limit and cnt >= limit:
545 if limit and cnt >= limit:
546 break
546 break
547 history.append(hex(obj.node()))
547 history.append(hex(obj.node()))
548
548
549 return [x for x in history]
549 return [x for x in history]
550 return _node_history(context_uid, repo_id, revision, path, limit)
550 return _node_history(context_uid, repo_id, revision, path, limit)
551
551
552 @reraise_safe_exceptions
552 @reraise_safe_exceptions
553 def node_history_untill(self, wire, revision, path, limit):
553 def node_history_untill(self, wire, revision, path, limit):
554 cache_on, context_uid, repo_id = self._cache_on(wire)
554 cache_on, context_uid, repo_id = self._cache_on(wire)
555 region = self._region(wire)
555 region = self._region(wire)
556
556
557 @region.conditional_cache_on_arguments(condition=cache_on)
557 @region.conditional_cache_on_arguments(condition=cache_on)
558 def _node_history_until(_context_uid, _repo_id):
558 def _node_history_until(_context_uid, _repo_id):
559 repo = self._factory.repo(wire)
559 repo = self._factory.repo(wire)
560 ctx = self._get_ctx(repo, revision)
560 ctx = self._get_ctx(repo, revision)
561 fctx = ctx.filectx(safe_bytes(path))
561 fctx = ctx.filectx(safe_bytes(path))
562
562
563 file_log = list(fctx.filelog())
563 file_log = list(fctx.filelog())
564 if limit:
564 if limit:
565 # Limit to the last n items
565 # Limit to the last n items
566 file_log = file_log[-limit:]
566 file_log = file_log[-limit:]
567
567
568 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
568 return [hex(fctx.filectx(cs).node()) for cs in reversed(file_log)]
569 return _node_history_until(context_uid, repo_id, revision, path, limit)
569 return _node_history_until(context_uid, repo_id, revision, path, limit)
570
570
571 @reraise_safe_exceptions
571 @reraise_safe_exceptions
572 def bulk_file_request(self, wire, commit_id, path, pre_load):
572 def bulk_file_request(self, wire, commit_id, path, pre_load):
573 cache_on, context_uid, repo_id = self._cache_on(wire)
573 cache_on, context_uid, repo_id = self._cache_on(wire)
574 region = self._region(wire)
574 region = self._region(wire)
575
575
576 @region.conditional_cache_on_arguments(condition=cache_on)
576 @region.conditional_cache_on_arguments(condition=cache_on)
577 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
577 def _bulk_file_request(_repo_id, _commit_id, _path, _pre_load):
578 result = {}
578 result = {}
579 for attr in pre_load:
579 for attr in pre_load:
580 try:
580 try:
581 method = self._bulk_file_methods[attr]
581 method = self._bulk_file_methods[attr]
582 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
582 wire.update({'cache': False}) # disable cache for bulk calls so we don't double cache
583 result[attr] = method(wire, _commit_id, _path)
583 result[attr] = method(wire, _commit_id, _path)
584 except KeyError as e:
584 except KeyError as e:
585 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
585 raise exceptions.VcsException(e)(f'Unknown bulk attribute: "{attr}"')
586 return BinaryEnvelope(result)
586 return BinaryEnvelope(result)
587
587
588 return _bulk_file_request(repo_id, commit_id, path, sorted(pre_load))
588 return _bulk_file_request(repo_id, commit_id, path, sorted(pre_load))
589
589
590 @reraise_safe_exceptions
590 @reraise_safe_exceptions
591 def fctx_annotate(self, wire, revision, path):
591 def fctx_annotate(self, wire, revision, path):
592 repo = self._factory.repo(wire)
592 repo = self._factory.repo(wire)
593 ctx = self._get_ctx(repo, revision)
593 ctx = self._get_ctx(repo, revision)
594 fctx = ctx.filectx(safe_bytes(path))
594 fctx = ctx.filectx(safe_bytes(path))
595
595
596 result = []
596 result = []
597 for i, annotate_obj in enumerate(fctx.annotate(), 1):
597 for i, annotate_obj in enumerate(fctx.annotate(), 1):
598 ln_no = i
598 ln_no = i
599 sha = hex(annotate_obj.fctx.node())
599 sha = hex(annotate_obj.fctx.node())
600 content = annotate_obj.text
600 content = annotate_obj.text
601 result.append((ln_no, sha, content))
601 result.append((ln_no, sha, content))
602 return result
602 return result
603
603
604 @reraise_safe_exceptions
604 @reraise_safe_exceptions
605 def fctx_node_data(self, wire, revision, path):
605 def fctx_node_data(self, wire, revision, path):
606 repo = self._factory.repo(wire)
606 repo = self._factory.repo(wire)
607 ctx = self._get_ctx(repo, revision)
607 ctx = self._get_ctx(repo, revision)
608 fctx = ctx.filectx(safe_bytes(path))
608 fctx = ctx.filectx(safe_bytes(path))
609 return BytesEnvelope(fctx.data())
609 return BytesEnvelope(fctx.data())
610
610
611 @reraise_safe_exceptions
611 @reraise_safe_exceptions
612 def fctx_flags(self, wire, commit_id, path):
612 def fctx_flags(self, wire, commit_id, path):
613 cache_on, context_uid, repo_id = self._cache_on(wire)
613 cache_on, context_uid, repo_id = self._cache_on(wire)
614 region = self._region(wire)
614 region = self._region(wire)
615
615
616 @region.conditional_cache_on_arguments(condition=cache_on)
616 @region.conditional_cache_on_arguments(condition=cache_on)
617 def _fctx_flags(_repo_id, _commit_id, _path):
617 def _fctx_flags(_repo_id, _commit_id, _path):
618 repo = self._factory.repo(wire)
618 repo = self._factory.repo(wire)
619 ctx = self._get_ctx(repo, commit_id)
619 ctx = self._get_ctx(repo, commit_id)
620 fctx = ctx.filectx(safe_bytes(path))
620 fctx = ctx.filectx(safe_bytes(path))
621 return fctx.flags()
621 return fctx.flags()
622
622
623 return _fctx_flags(repo_id, commit_id, path)
623 return _fctx_flags(repo_id, commit_id, path)
624
624
625 @reraise_safe_exceptions
625 @reraise_safe_exceptions
626 def fctx_size(self, wire, commit_id, path):
626 def fctx_size(self, wire, commit_id, path):
627 cache_on, context_uid, repo_id = self._cache_on(wire)
627 cache_on, context_uid, repo_id = self._cache_on(wire)
628 region = self._region(wire)
628 region = self._region(wire)
629
629
630 @region.conditional_cache_on_arguments(condition=cache_on)
630 @region.conditional_cache_on_arguments(condition=cache_on)
631 def _fctx_size(_repo_id, _revision, _path):
631 def _fctx_size(_repo_id, _revision, _path):
632 repo = self._factory.repo(wire)
632 repo = self._factory.repo(wire)
633 ctx = self._get_ctx(repo, commit_id)
633 ctx = self._get_ctx(repo, commit_id)
634 fctx = ctx.filectx(safe_bytes(path))
634 fctx = ctx.filectx(safe_bytes(path))
635 return fctx.size()
635 return fctx.size()
636 return _fctx_size(repo_id, commit_id, path)
636 return _fctx_size(repo_id, commit_id, path)
637
637
638 @reraise_safe_exceptions
638 @reraise_safe_exceptions
639 def get_all_commit_ids(self, wire, name):
639 def get_all_commit_ids(self, wire, name):
640 cache_on, context_uid, repo_id = self._cache_on(wire)
640 cache_on, context_uid, repo_id = self._cache_on(wire)
641 region = self._region(wire)
641 region = self._region(wire)
642
642
643 @region.conditional_cache_on_arguments(condition=cache_on)
643 @region.conditional_cache_on_arguments(condition=cache_on)
644 def _get_all_commit_ids(_context_uid, _repo_id, _name):
644 def _get_all_commit_ids(_context_uid, _repo_id, _name):
645 repo = self._factory.repo(wire)
645 repo = self._factory.repo(wire)
646 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
646 revs = [ascii_str(repo[x].hex()) for x in repo.filtered(b'visible').changelog.revs()]
647 return revs
647 return revs
648 return _get_all_commit_ids(context_uid, repo_id, name)
648 return _get_all_commit_ids(context_uid, repo_id, name)
649
649
650 @reraise_safe_exceptions
650 @reraise_safe_exceptions
651 def get_config_value(self, wire, section, name, untrusted=False):
651 def get_config_value(self, wire, section, name, untrusted=False):
652 repo = self._factory.repo(wire)
652 repo = self._factory.repo(wire)
653 return repo.ui.config(ascii_bytes(section), ascii_bytes(name), untrusted=untrusted)
653 return repo.ui.config(ascii_bytes(section), ascii_bytes(name), untrusted=untrusted)
654
654
655 @reraise_safe_exceptions
655 @reraise_safe_exceptions
656 def is_large_file(self, wire, commit_id, path):
656 def is_large_file(self, wire, commit_id, path):
657 cache_on, context_uid, repo_id = self._cache_on(wire)
657 cache_on, context_uid, repo_id = self._cache_on(wire)
658 region = self._region(wire)
658 region = self._region(wire)
659
659
660 @region.conditional_cache_on_arguments(condition=cache_on)
660 @region.conditional_cache_on_arguments(condition=cache_on)
661 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
661 def _is_large_file(_context_uid, _repo_id, _commit_id, _path):
662 return largefiles.lfutil.isstandin(safe_bytes(path))
662 return largefiles.lfutil.isstandin(safe_bytes(path))
663
663
664 return _is_large_file(context_uid, repo_id, commit_id, path)
664 return _is_large_file(context_uid, repo_id, commit_id, path)
665
665
666 @reraise_safe_exceptions
666 @reraise_safe_exceptions
667 def is_binary(self, wire, revision, path):
667 def is_binary(self, wire, revision, path):
668 cache_on, context_uid, repo_id = self._cache_on(wire)
668 cache_on, context_uid, repo_id = self._cache_on(wire)
669 region = self._region(wire)
669 region = self._region(wire)
670
670
671 @region.conditional_cache_on_arguments(condition=cache_on)
671 @region.conditional_cache_on_arguments(condition=cache_on)
672 def _is_binary(_repo_id, _sha, _path):
672 def _is_binary(_repo_id, _sha, _path):
673 repo = self._factory.repo(wire)
673 repo = self._factory.repo(wire)
674 ctx = self._get_ctx(repo, revision)
674 ctx = self._get_ctx(repo, revision)
675 fctx = ctx.filectx(safe_bytes(path))
675 fctx = ctx.filectx(safe_bytes(path))
676 return fctx.isbinary()
676 return fctx.isbinary()
677
677
678 return _is_binary(repo_id, revision, path)
678 return _is_binary(repo_id, revision, path)
679
679
680 @reraise_safe_exceptions
680 @reraise_safe_exceptions
681 def md5_hash(self, wire, revision, path):
681 def md5_hash(self, wire, revision, path):
682 cache_on, context_uid, repo_id = self._cache_on(wire)
682 cache_on, context_uid, repo_id = self._cache_on(wire)
683 region = self._region(wire)
683 region = self._region(wire)
684
684
685 @region.conditional_cache_on_arguments(condition=cache_on)
685 @region.conditional_cache_on_arguments(condition=cache_on)
686 def _md5_hash(_repo_id, _sha, _path):
686 def _md5_hash(_repo_id, _sha, _path):
687 repo = self._factory.repo(wire)
687 repo = self._factory.repo(wire)
688 ctx = self._get_ctx(repo, revision)
688 ctx = self._get_ctx(repo, revision)
689 fctx = ctx.filectx(safe_bytes(path))
689 fctx = ctx.filectx(safe_bytes(path))
690 return hashlib.md5(fctx.data()).hexdigest()
690 return hashlib.md5(fctx.data()).hexdigest()
691
691
692 return _md5_hash(repo_id, revision, path)
692 return _md5_hash(repo_id, revision, path)
693
693
694 @reraise_safe_exceptions
694 @reraise_safe_exceptions
695 def in_largefiles_store(self, wire, sha):
695 def in_largefiles_store(self, wire, sha):
696 repo = self._factory.repo(wire)
696 repo = self._factory.repo(wire)
697 return largefiles.lfutil.instore(repo, sha)
697 return largefiles.lfutil.instore(repo, sha)
698
698
699 @reraise_safe_exceptions
699 @reraise_safe_exceptions
700 def in_user_cache(self, wire, sha):
700 def in_user_cache(self, wire, sha):
701 repo = self._factory.repo(wire)
701 repo = self._factory.repo(wire)
702 return largefiles.lfutil.inusercache(repo.ui, sha)
702 return largefiles.lfutil.inusercache(repo.ui, sha)
703
703
704 @reraise_safe_exceptions
704 @reraise_safe_exceptions
705 def store_path(self, wire, sha):
705 def store_path(self, wire, sha):
706 repo = self._factory.repo(wire)
706 repo = self._factory.repo(wire)
707 return largefiles.lfutil.storepath(repo, sha)
707 return largefiles.lfutil.storepath(repo, sha)
708
708
709 @reraise_safe_exceptions
709 @reraise_safe_exceptions
710 def link(self, wire, sha, path):
710 def link(self, wire, sha, path):
711 repo = self._factory.repo(wire)
711 repo = self._factory.repo(wire)
712 largefiles.lfutil.link(
712 largefiles.lfutil.link(
713 largefiles.lfutil.usercachepath(repo.ui, sha), path)
713 largefiles.lfutil.usercachepath(repo.ui, sha), path)
714
714
715 @reraise_safe_exceptions
715 @reraise_safe_exceptions
716 def localrepository(self, wire, create=False):
716 def localrepository(self, wire, create=False):
717 self._factory.repo(wire, create=create)
717 self._factory.repo(wire, create=create)
718
718
719 @reraise_safe_exceptions
719 @reraise_safe_exceptions
720 def lookup(self, wire, revision, both):
720 def lookup(self, wire, revision, both):
721 cache_on, context_uid, repo_id = self._cache_on(wire)
721 cache_on, context_uid, repo_id = self._cache_on(wire)
722 region = self._region(wire)
722 region = self._region(wire)
723
723
724 @region.conditional_cache_on_arguments(condition=cache_on)
724 @region.conditional_cache_on_arguments(condition=cache_on)
725 def _lookup(_context_uid, _repo_id, _revision, _both):
725 def _lookup(_context_uid, _repo_id, _revision, _both):
726 repo = self._factory.repo(wire)
726 repo = self._factory.repo(wire)
727 rev = _revision
727 rev = _revision
728 if isinstance(rev, int):
728 if isinstance(rev, int):
729 # NOTE(marcink):
729 # NOTE(marcink):
730 # since Mercurial doesn't support negative indexes properly
730 # since Mercurial doesn't support negative indexes properly
731 # we need to shift accordingly by one to get proper index, e.g
731 # we need to shift accordingly by one to get proper index, e.g
732 # repo[-1] => repo[-2]
732 # repo[-1] => repo[-2]
733 # repo[0] => repo[-1]
733 # repo[0] => repo[-1]
734 if rev <= 0:
734 if rev <= 0:
735 rev = rev + -1
735 rev = rev + -1
736 try:
736 try:
737 ctx = self._get_ctx(repo, rev)
737 ctx = self._get_ctx(repo, rev)
738 except (TypeError, RepoLookupError, binascii.Error) as e:
738 except (TypeError, RepoLookupError, binascii.Error) as e:
739 e._org_exc_tb = traceback.format_exc()
739 e._org_exc_tb = traceback.format_exc()
740 raise exceptions.LookupException(e)(rev)
740 raise exceptions.LookupException(e)(rev)
741 except LookupError as e:
741 except LookupError as e:
742 e._org_exc_tb = traceback.format_exc()
742 e._org_exc_tb = traceback.format_exc()
743 raise exceptions.LookupException(e)(e.name)
743 raise exceptions.LookupException(e)(e.name)
744
744
745 if not both:
745 if not both:
746 return ctx.hex()
746 return ctx.hex()
747
747
748 ctx = repo[ctx.hex()]
748 ctx = repo[ctx.hex()]
749 return ctx.hex(), ctx.rev()
749 return ctx.hex(), ctx.rev()
750
750
751 return _lookup(context_uid, repo_id, revision, both)
751 return _lookup(context_uid, repo_id, revision, both)
752
752
753 @reraise_safe_exceptions
753 @reraise_safe_exceptions
754 def sync_push(self, wire, url):
754 def sync_push(self, wire, url):
755 if not self.check_url(url, wire['config']):
755 if not self.check_url(url, wire['config']):
756 return
756 return
757
757
758 repo = self._factory.repo(wire)
758 repo = self._factory.repo(wire)
759
759
760 # Disable any prompts for this repo
760 # Disable any prompts for this repo
761 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
761 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
762
762
763 bookmarks = list(dict(repo._bookmarks).keys())
763 bookmarks = list(dict(repo._bookmarks).keys())
764 remote = peer(repo, {}, safe_bytes(url))
764 remote = peer(repo, {}, safe_bytes(url))
765 # Disable any prompts for this remote
765 # Disable any prompts for this remote
766 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
766 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
767
767
768 return exchange.push(
768 return exchange.push(
769 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
769 repo, remote, newbranch=True, bookmarks=bookmarks).cgresult
770
770
771 @reraise_safe_exceptions
771 @reraise_safe_exceptions
772 def revision(self, wire, rev):
772 def revision(self, wire, rev):
773 repo = self._factory.repo(wire)
773 repo = self._factory.repo(wire)
774 ctx = self._get_ctx(repo, rev)
774 ctx = self._get_ctx(repo, rev)
775 return ctx.rev()
775 return ctx.rev()
776
776
777 @reraise_safe_exceptions
777 @reraise_safe_exceptions
778 def rev_range(self, wire, commit_filter):
778 def rev_range(self, wire, commit_filter):
779 cache_on, context_uid, repo_id = self._cache_on(wire)
779 cache_on, context_uid, repo_id = self._cache_on(wire)
780 region = self._region(wire)
780 region = self._region(wire)
781
781
782 @region.conditional_cache_on_arguments(condition=cache_on)
782 @region.conditional_cache_on_arguments(condition=cache_on)
783 def _rev_range(_context_uid, _repo_id, _filter):
783 def _rev_range(_context_uid, _repo_id, _filter):
784 repo = self._factory.repo(wire)
784 repo = self._factory.repo(wire)
785 revisions = [
785 revisions = [
786 ascii_str(repo[rev].hex())
786 ascii_str(repo[rev].hex())
787 for rev in revrange(repo, list(map(ascii_bytes, commit_filter)))
787 for rev in revrange(repo, list(map(ascii_bytes, commit_filter)))
788 ]
788 ]
789 return revisions
789 return revisions
790
790
791 return _rev_range(context_uid, repo_id, sorted(commit_filter))
791 return _rev_range(context_uid, repo_id, sorted(commit_filter))
792
792
793 @reraise_safe_exceptions
793 @reraise_safe_exceptions
794 def rev_range_hash(self, wire, node):
794 def rev_range_hash(self, wire, node):
795 repo = self._factory.repo(wire)
795 repo = self._factory.repo(wire)
796
796
797 def get_revs(repo, rev_opt):
797 def get_revs(repo, rev_opt):
798 if rev_opt:
798 if rev_opt:
799 revs = revrange(repo, rev_opt)
799 revs = revrange(repo, rev_opt)
800 if len(revs) == 0:
800 if len(revs) == 0:
801 return (nullrev, nullrev)
801 return (nullrev, nullrev)
802 return max(revs), min(revs)
802 return max(revs), min(revs)
803 else:
803 else:
804 return len(repo) - 1, 0
804 return len(repo) - 1, 0
805
805
806 stop, start = get_revs(repo, [node + ':'])
806 stop, start = get_revs(repo, [node + ':'])
807 revs = [ascii_str(repo[r].hex()) for r in range(start, stop + 1)]
807 revs = [ascii_str(repo[r].hex()) for r in range(start, stop + 1)]
808 return revs
808 return revs
809
809
810 @reraise_safe_exceptions
810 @reraise_safe_exceptions
811 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
811 def revs_from_revspec(self, wire, rev_spec, *args, **kwargs):
812 org_path = safe_bytes(wire["path"])
812 org_path = safe_bytes(wire["path"])
813 other_path = safe_bytes(kwargs.pop('other_path', ''))
813 other_path = safe_bytes(kwargs.pop('other_path', ''))
814
814
815 # case when we want to compare two independent repositories
815 # case when we want to compare two independent repositories
816 if other_path and other_path != wire["path"]:
816 if other_path and other_path != wire["path"]:
817 baseui = self._factory._create_config(wire["config"])
817 baseui = self._factory._create_config(wire["config"])
818 repo = unionrepo.makeunionrepository(baseui, other_path, org_path)
818 repo = unionrepo.makeunionrepository(baseui, other_path, org_path)
819 else:
819 else:
820 repo = self._factory.repo(wire)
820 repo = self._factory.repo(wire)
821 return list(repo.revs(rev_spec, *args))
821 return list(repo.revs(rev_spec, *args))
822
822
823 @reraise_safe_exceptions
823 @reraise_safe_exceptions
824 def verify(self, wire,):
824 def verify(self, wire,):
825 repo = self._factory.repo(wire)
825 repo = self._factory.repo(wire)
826 baseui = self._factory._create_config(wire['config'])
826 baseui = self._factory._create_config(wire['config'])
827
827
828 baseui, output = patch_ui_message_output(baseui)
828 baseui, output = patch_ui_message_output(baseui)
829
829
830 repo.ui = baseui
830 repo.ui = baseui
831 verify.verify(repo)
831 verify.verify(repo)
832 return output.getvalue()
832 return output.getvalue()
833
833
834 @reraise_safe_exceptions
834 @reraise_safe_exceptions
835 def hg_update_cache(self, wire,):
835 def hg_update_cache(self, wire,):
836 repo = self._factory.repo(wire)
836 repo = self._factory.repo(wire)
837 baseui = self._factory._create_config(wire['config'])
837 baseui = self._factory._create_config(wire['config'])
838 baseui, output = patch_ui_message_output(baseui)
838 baseui, output = patch_ui_message_output(baseui)
839
839
840 repo.ui = baseui
840 repo.ui = baseui
841 with repo.wlock(), repo.lock():
841 with repo.wlock(), repo.lock():
842 repo.updatecaches(full=True)
842 repo.updatecaches(full=True)
843
843
844 return output.getvalue()
844 return output.getvalue()
845
845
846 @reraise_safe_exceptions
846 @reraise_safe_exceptions
847 def hg_rebuild_fn_cache(self, wire,):
847 def hg_rebuild_fn_cache(self, wire,):
848 repo = self._factory.repo(wire)
848 repo = self._factory.repo(wire)
849 baseui = self._factory._create_config(wire['config'])
849 baseui = self._factory._create_config(wire['config'])
850 baseui, output = patch_ui_message_output(baseui)
850 baseui, output = patch_ui_message_output(baseui)
851
851
852 repo.ui = baseui
852 repo.ui = baseui
853
853
854 repair.rebuildfncache(baseui, repo)
854 repair.rebuildfncache(baseui, repo)
855
855
856 return output.getvalue()
856 return output.getvalue()
857
857
858 @reraise_safe_exceptions
858 @reraise_safe_exceptions
859 def tags(self, wire):
859 def tags(self, wire):
860 cache_on, context_uid, repo_id = self._cache_on(wire)
860 cache_on, context_uid, repo_id = self._cache_on(wire)
861 region = self._region(wire)
861 region = self._region(wire)
862
862
863 @region.conditional_cache_on_arguments(condition=cache_on)
863 @region.conditional_cache_on_arguments(condition=cache_on)
864 def _tags(_context_uid, _repo_id):
864 def _tags(_context_uid, _repo_id):
865 repo = self._factory.repo(wire)
865 repo = self._factory.repo(wire)
866 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo.tags().items()}
866 return {safe_str(name): ascii_str(hex(sha)) for name, sha in repo.tags().items()}
867
867
868 return _tags(context_uid, repo_id)
868 return _tags(context_uid, repo_id)
869
869
870 @reraise_safe_exceptions
870 @reraise_safe_exceptions
871 def update(self, wire, node='', clean=False):
871 def update(self, wire, node='', clean=False):
872 repo = self._factory.repo(wire)
872 repo = self._factory.repo(wire)
873 baseui = self._factory._create_config(wire['config'])
873 baseui = self._factory._create_config(wire['config'])
874 node = safe_bytes(node)
874 node = safe_bytes(node)
875
875
876 commands.update(baseui, repo, node=node, clean=clean)
876 commands.update(baseui, repo, node=node, clean=clean)
877
877
878 @reraise_safe_exceptions
878 @reraise_safe_exceptions
879 def identify(self, wire):
879 def identify(self, wire):
880 repo = self._factory.repo(wire)
880 repo = self._factory.repo(wire)
881 baseui = self._factory._create_config(wire['config'])
881 baseui = self._factory._create_config(wire['config'])
882 output = io.BytesIO()
882 output = io.BytesIO()
883 baseui.write = output.write
883 baseui.write = output.write
884 # This is required to get a full node id
884 # This is required to get a full node id
885 baseui.debugflag = True
885 baseui.debugflag = True
886 commands.identify(baseui, repo, id=True)
886 commands.identify(baseui, repo, id=True)
887
887
888 return output.getvalue()
888 return output.getvalue()
889
889
890 @reraise_safe_exceptions
890 @reraise_safe_exceptions
891 def heads(self, wire, branch=None):
891 def heads(self, wire, branch=None):
892 repo = self._factory.repo(wire)
892 repo = self._factory.repo(wire)
893 baseui = self._factory._create_config(wire['config'])
893 baseui = self._factory._create_config(wire['config'])
894 output = io.BytesIO()
894 output = io.BytesIO()
895
895
896 def write(data, **unused_kwargs):
896 def write(data, **unused_kwargs):
897 output.write(data)
897 output.write(data)
898
898
899 baseui.write = write
899 baseui.write = write
900 if branch:
900 if branch:
901 args = [safe_bytes(branch)]
901 args = [safe_bytes(branch)]
902 else:
902 else:
903 args = []
903 args = []
904 commands.heads(baseui, repo, template=b'{node} ', *args)
904 commands.heads(baseui, repo, template=b'{node} ', *args)
905
905
906 return output.getvalue()
906 return output.getvalue()
907
907
908 @reraise_safe_exceptions
908 @reraise_safe_exceptions
909 def ancestor(self, wire, revision1, revision2):
909 def ancestor(self, wire, revision1, revision2):
910 repo = self._factory.repo(wire)
910 repo = self._factory.repo(wire)
911 changelog = repo.changelog
911 changelog = repo.changelog
912 lookup = repo.lookup
912 lookup = repo.lookup
913 a = changelog.ancestor(lookup(safe_bytes(revision1)), lookup(safe_bytes(revision2)))
913 a = changelog.ancestor(lookup(safe_bytes(revision1)), lookup(safe_bytes(revision2)))
914 return hex(a)
914 return hex(a)
915
915
916 @reraise_safe_exceptions
916 @reraise_safe_exceptions
917 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
917 def clone(self, wire, source, dest, update_after_clone=False, hooks=True):
918 baseui = self._factory._create_config(wire["config"], hooks=hooks)
918 baseui = self._factory._create_config(wire["config"], hooks=hooks)
919 clone(baseui, safe_bytes(source), safe_bytes(dest), noupdate=not update_after_clone)
919 clone(baseui, safe_bytes(source), safe_bytes(dest), noupdate=not update_after_clone)
920
920
921 @reraise_safe_exceptions
921 @reraise_safe_exceptions
922 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
922 def commitctx(self, wire, message, parents, commit_time, commit_timezone, user, files, extra, removed, updated):
923
923
924 repo = self._factory.repo(wire)
924 repo = self._factory.repo(wire)
925 baseui = self._factory._create_config(wire['config'])
925 baseui = self._factory._create_config(wire['config'])
926 publishing = baseui.configbool(b'phases', b'publish')
926 publishing = baseui.configbool(b'phases', b'publish')
927
927
928 def _filectxfn(_repo, ctx, path: bytes):
928 def _filectxfn(_repo, ctx, path: bytes):
929 """
929 """
930 Marks given path as added/changed/removed in a given _repo. This is
930 Marks given path as added/changed/removed in a given _repo. This is
931 for internal mercurial commit function.
931 for internal mercurial commit function.
932 """
932 """
933
933
934 # check if this path is removed
934 # check if this path is removed
935 if safe_str(path) in removed:
935 if safe_str(path) in removed:
936 # returning None is a way to mark node for removal
936 # returning None is a way to mark node for removal
937 return None
937 return None
938
938
939 # check if this path is added
939 # check if this path is added
940 for node in updated:
940 for node in updated:
941 if safe_bytes(node['path']) == path:
941 if safe_bytes(node['path']) == path:
942 return memfilectx(
942 return memfilectx(
943 _repo,
943 _repo,
944 changectx=ctx,
944 changectx=ctx,
945 path=safe_bytes(node['path']),
945 path=safe_bytes(node['path']),
946 data=safe_bytes(node['content']),
946 data=safe_bytes(node['content']),
947 islink=False,
947 islink=False,
948 isexec=bool(node['mode'] & stat.S_IXUSR),
948 isexec=bool(node['mode'] & stat.S_IXUSR),
949 copysource=False)
949 copysource=False)
950 abort_exc = exceptions.AbortException()
950 abort_exc = exceptions.AbortException()
951 raise abort_exc(f"Given path haven't been marked as added, changed or removed ({path})")
951 raise abort_exc(f"Given path haven't been marked as added, changed or removed ({path})")
952
952
953 if publishing:
953 if publishing:
954 new_commit_phase = b'public'
954 new_commit_phase = b'public'
955 else:
955 else:
956 new_commit_phase = b'draft'
956 new_commit_phase = b'draft'
957 with repo.ui.configoverride({(b'phases', b'new-commit'): new_commit_phase}):
957 with repo.ui.configoverride({(b'phases', b'new-commit'): new_commit_phase}):
958 kwargs = {safe_bytes(k): safe_bytes(v) for k, v in extra.items()}
958 kwargs = {safe_bytes(k): safe_bytes(v) for k, v in extra.items()}
959 commit_ctx = memctx(
959 commit_ctx = memctx(
960 repo=repo,
960 repo=repo,
961 parents=parents,
961 parents=parents,
962 text=safe_bytes(message),
962 text=safe_bytes(message),
963 files=[safe_bytes(x) for x in files],
963 files=[safe_bytes(x) for x in files],
964 filectxfn=_filectxfn,
964 filectxfn=_filectxfn,
965 user=safe_bytes(user),
965 user=safe_bytes(user),
966 date=(commit_time, commit_timezone),
966 date=(commit_time, commit_timezone),
967 extra=kwargs)
967 extra=kwargs)
968
968
969 n = repo.commitctx(commit_ctx)
969 n = repo.commitctx(commit_ctx)
970 new_id = hex(n)
970 new_id = hex(n)
971
971
972 return new_id
972 return new_id
973
973
974 @reraise_safe_exceptions
974 @reraise_safe_exceptions
975 def pull(self, wire, url, commit_ids=None):
975 def pull(self, wire, url, commit_ids=None):
976 repo = self._factory.repo(wire)
976 repo = self._factory.repo(wire)
977 # Disable any prompts for this repo
977 # Disable any prompts for this repo
978 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
978 repo.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
979
979
980 remote = peer(repo, {}, safe_bytes(url))
980 remote = peer(repo, {}, safe_bytes(url))
981 # Disable any prompts for this remote
981 # Disable any prompts for this remote
982 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
982 remote.ui.setconfig(b'ui', b'interactive', b'off', b'-y')
983
983
984 if commit_ids:
984 if commit_ids:
985 commit_ids = [bin(commit_id) for commit_id in commit_ids]
985 commit_ids = [bin(commit_id) for commit_id in commit_ids]
986
986
987 return exchange.pull(
987 return exchange.pull(
988 repo, remote, heads=commit_ids, force=None).cgresult
988 repo, remote, heads=commit_ids, force=None).cgresult
989
989
990 @reraise_safe_exceptions
990 @reraise_safe_exceptions
991 def pull_cmd(self, wire, source, bookmark='', branch='', revision='', hooks=True):
991 def pull_cmd(self, wire, source, bookmark='', branch='', revision='', hooks=True):
992 repo = self._factory.repo(wire)
992 repo = self._factory.repo(wire)
993 baseui = self._factory._create_config(wire['config'], hooks=hooks)
993 baseui = self._factory._create_config(wire['config'], hooks=hooks)
994
994
995 source = safe_bytes(source)
995 source = safe_bytes(source)
996
996
997 # Mercurial internally has a lot of logic that checks ONLY if
997 # Mercurial internally has a lot of logic that checks ONLY if
998 # option is defined, we just pass those if they are defined then
998 # option is defined, we just pass those if they are defined then
999 opts = {}
999 opts = {}
1000
1000
1001 if bookmark:
1001 if bookmark:
1002 opts['bookmark'] = [safe_bytes(x) for x in bookmark] \
1002 opts['bookmark'] = [safe_bytes(x) for x in bookmark] \
1003 if isinstance(bookmark, list) else safe_bytes(bookmark)
1003 if isinstance(bookmark, list) else safe_bytes(bookmark)
1004
1004
1005 if branch:
1005 if branch:
1006 opts['branch'] = [safe_bytes(x) for x in branch] \
1006 opts['branch'] = [safe_bytes(x) for x in branch] \
1007 if isinstance(branch, list) else safe_bytes(branch)
1007 if isinstance(branch, list) else safe_bytes(branch)
1008
1008
1009 if revision:
1009 if revision:
1010 opts['rev'] = [safe_bytes(x) for x in revision] \
1010 opts['rev'] = [safe_bytes(x) for x in revision] \
1011 if isinstance(revision, list) else safe_bytes(revision)
1011 if isinstance(revision, list) else safe_bytes(revision)
1012
1012
1013 commands.pull(baseui, repo, source, **opts)
1013 commands.pull(baseui, repo, source, **opts)
1014
1014
1015 @reraise_safe_exceptions
1015 @reraise_safe_exceptions
1016 def push(self, wire, revisions, dest_path, hooks: bool = True, push_branches: bool = False):
1016 def push(self, wire, revisions, dest_path, hooks: bool = True, push_branches: bool = False):
1017 repo = self._factory.repo(wire)
1017 repo = self._factory.repo(wire)
1018 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1018 baseui = self._factory._create_config(wire['config'], hooks=hooks)
1019
1019
1020 revisions = [safe_bytes(x) for x in revisions] \
1020 revisions = [safe_bytes(x) for x in revisions] \
1021 if isinstance(revisions, list) else safe_bytes(revisions)
1021 if isinstance(revisions, list) else safe_bytes(revisions)
1022
1022
1023 commands.push(baseui, repo, safe_bytes(dest_path),
1023 commands.push(baseui, repo, safe_bytes(dest_path),
1024 rev=revisions,
1024 rev=revisions,
1025 new_branch=push_branches)
1025 new_branch=push_branches)
1026
1026
1027 @reraise_safe_exceptions
1027 @reraise_safe_exceptions
1028 def strip(self, wire, revision, update, backup):
1028 def strip(self, wire, revision, update, backup):
1029 repo = self._factory.repo(wire)
1029 repo = self._factory.repo(wire)
1030 ctx = self._get_ctx(repo, revision)
1030 ctx = self._get_ctx(repo, revision)
1031 hgext_strip.strip(
1031 hgext_strip.strip(
1032 repo.baseui, repo, ctx.node(), update=update, backup=backup)
1032 repo.baseui, repo, ctx.node(), update=update, backup=backup)
1033
1033
1034 @reraise_safe_exceptions
1034 @reraise_safe_exceptions
1035 def get_unresolved_files(self, wire):
1035 def get_unresolved_files(self, wire):
1036 repo = self._factory.repo(wire)
1036 repo = self._factory.repo(wire)
1037
1037
1038 log.debug('Calculating unresolved files for repo: %s', repo)
1038 log.debug('Calculating unresolved files for repo: %s', repo)
1039 output = io.BytesIO()
1039 output = io.BytesIO()
1040
1040
1041 def write(data, **unused_kwargs):
1041 def write(data, **unused_kwargs):
1042 output.write(data)
1042 output.write(data)
1043
1043
1044 baseui = self._factory._create_config(wire['config'])
1044 baseui = self._factory._create_config(wire['config'])
1045 baseui.write = write
1045 baseui.write = write
1046
1046
1047 commands.resolve(baseui, repo, list=True)
1047 commands.resolve(baseui, repo, list=True)
1048 unresolved = output.getvalue().splitlines(0)
1048 unresolved = output.getvalue().splitlines(0)
1049 return unresolved
1049 return unresolved
1050
1050
1051 @reraise_safe_exceptions
1051 @reraise_safe_exceptions
1052 def merge(self, wire, revision):
1052 def merge(self, wire, revision):
1053 repo = self._factory.repo(wire)
1053 repo = self._factory.repo(wire)
1054 baseui = self._factory._create_config(wire['config'])
1054 baseui = self._factory._create_config(wire['config'])
1055 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1055 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1056
1056
1057 # In case of sub repositories are used mercurial prompts the user in
1057 # In case of sub repositories are used mercurial prompts the user in
1058 # case of merge conflicts or different sub repository sources. By
1058 # case of merge conflicts or different sub repository sources. By
1059 # setting the interactive flag to `False` mercurial doesn't prompt the
1059 # setting the interactive flag to `False` mercurial doesn't prompt the
1060 # used but instead uses a default value.
1060 # used but instead uses a default value.
1061 repo.ui.setconfig(b'ui', b'interactive', False)
1061 repo.ui.setconfig(b'ui', b'interactive', False)
1062 commands.merge(baseui, repo, rev=safe_bytes(revision))
1062 commands.merge(baseui, repo, rev=safe_bytes(revision))
1063
1063
1064 @reraise_safe_exceptions
1064 @reraise_safe_exceptions
1065 def merge_state(self, wire):
1065 def merge_state(self, wire):
1066 repo = self._factory.repo(wire)
1066 repo = self._factory.repo(wire)
1067 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1067 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1068
1068
1069 # In case of sub repositories are used mercurial prompts the user in
1069 # In case of sub repositories are used mercurial prompts the user in
1070 # case of merge conflicts or different sub repository sources. By
1070 # case of merge conflicts or different sub repository sources. By
1071 # setting the interactive flag to `False` mercurial doesn't prompt the
1071 # setting the interactive flag to `False` mercurial doesn't prompt the
1072 # used but instead uses a default value.
1072 # used but instead uses a default value.
1073 repo.ui.setconfig(b'ui', b'interactive', False)
1073 repo.ui.setconfig(b'ui', b'interactive', False)
1074 ms = hg_merge.mergestate(repo)
1074 ms = hg_merge.mergestate(repo)
1075 return [x for x in ms.unresolved()]
1075 return [x for x in ms.unresolved()]
1076
1076
1077 @reraise_safe_exceptions
1077 @reraise_safe_exceptions
1078 def commit(self, wire, message, username, close_branch=False):
1078 def commit(self, wire, message, username, close_branch=False):
1079 repo = self._factory.repo(wire)
1079 repo = self._factory.repo(wire)
1080 baseui = self._factory._create_config(wire['config'])
1080 baseui = self._factory._create_config(wire['config'])
1081 repo.ui.setconfig(b'ui', b'username', safe_bytes(username))
1081 repo.ui.setconfig(b'ui', b'username', safe_bytes(username))
1082 commands.commit(baseui, repo, message=safe_bytes(message), close_branch=close_branch)
1082 commands.commit(baseui, repo, message=safe_bytes(message), close_branch=close_branch)
1083
1083
1084 @reraise_safe_exceptions
1084 @reraise_safe_exceptions
1085 def rebase(self, wire, source='', dest='', abort=False):
1085 def rebase(self, wire, source='', dest='', abort=False):
1086 repo = self._factory.repo(wire)
1086 repo = self._factory.repo(wire)
1087 baseui = self._factory._create_config(wire['config'])
1087 baseui = self._factory._create_config(wire['config'])
1088 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1088 repo.ui.setconfig(b'ui', b'merge', b'internal:dump')
1089 # In case of sub repositories are used mercurial prompts the user in
1089 # In case of sub repositories are used mercurial prompts the user in
1090 # case of merge conflicts or different sub repository sources. By
1090 # case of merge conflicts or different sub repository sources. By
1091 # setting the interactive flag to `False` mercurial doesn't prompt the
1091 # setting the interactive flag to `False` mercurial doesn't prompt the
1092 # used but instead uses a default value.
1092 # used but instead uses a default value.
1093 repo.ui.setconfig(b'ui', b'interactive', False)
1093 repo.ui.setconfig(b'ui', b'interactive', False)
1094
1094
1095 rebase.rebase(baseui, repo, base=safe_bytes(source or ''), dest=safe_bytes(dest or ''),
1095 rebase.rebase(baseui, repo, base=safe_bytes(source or ''), dest=safe_bytes(dest or ''),
1096 abort=abort, keep=not abort)
1096 abort=abort, keep=not abort)
1097
1097
1098 @reraise_safe_exceptions
1098 @reraise_safe_exceptions
1099 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
1099 def tag(self, wire, name, revision, message, local, user, tag_time, tag_timezone):
1100 repo = self._factory.repo(wire)
1100 repo = self._factory.repo(wire)
1101 ctx = self._get_ctx(repo, revision)
1101 ctx = self._get_ctx(repo, revision)
1102 node = ctx.node()
1102 node = ctx.node()
1103
1103
1104 date = (tag_time, tag_timezone)
1104 date = (tag_time, tag_timezone)
1105 try:
1105 try:
1106 hg_tag.tag(repo, safe_bytes(name), node, safe_bytes(message), local, safe_bytes(user), date)
1106 hg_tag.tag(repo, safe_bytes(name), node, safe_bytes(message), local, safe_bytes(user), date)
1107 except Abort as e:
1107 except Abort as e:
1108 log.exception("Tag operation aborted")
1108 log.exception("Tag operation aborted")
1109 # Exception can contain unicode which we convert
1109 # Exception can contain unicode which we convert
1110 raise exceptions.AbortException(e)(repr(e))
1110 raise exceptions.AbortException(e)(repr(e))
1111
1111
1112 @reraise_safe_exceptions
1112 @reraise_safe_exceptions
1113 def bookmark(self, wire, bookmark, revision=''):
1113 def bookmark(self, wire, bookmark, revision=''):
1114 repo = self._factory.repo(wire)
1114 repo = self._factory.repo(wire)
1115 baseui = self._factory._create_config(wire['config'])
1115 baseui = self._factory._create_config(wire['config'])
1116 revision = revision or ''
1116 revision = revision or ''
1117 commands.bookmark(baseui, repo, safe_bytes(bookmark), rev=safe_bytes(revision), force=True)
1117 commands.bookmark(baseui, repo, safe_bytes(bookmark), rev=safe_bytes(revision), force=True)
1118
1118
1119 @reraise_safe_exceptions
1119 @reraise_safe_exceptions
1120 def install_hooks(self, wire, force=False):
1120 def install_hooks(self, wire, force=False):
1121 # we don't need any special hooks for Mercurial
1121 # we don't need any special hooks for Mercurial
1122 pass
1122 pass
1123
1123
1124 @reraise_safe_exceptions
1124 @reraise_safe_exceptions
1125 def get_hooks_info(self, wire):
1125 def get_hooks_info(self, wire):
1126 return {
1126 return {
1127 'pre_version': vcsserver.__version__,
1127 'pre_version': vcsserver.__version__,
1128 'post_version': vcsserver.__version__,
1128 'post_version': vcsserver.__version__,
1129 }
1129 }
1130
1130
1131 @reraise_safe_exceptions
1131 @reraise_safe_exceptions
1132 def set_head_ref(self, wire, head_name):
1132 def set_head_ref(self, wire, head_name):
1133 pass
1133 pass
1134
1134
1135 @reraise_safe_exceptions
1135 @reraise_safe_exceptions
1136 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1136 def archive_repo(self, wire, archive_name_key, kind, mtime, archive_at_path,
1137 archive_dir_name, commit_id, cache_config):
1137 archive_dir_name, commit_id, cache_config):
1138
1138
1139 def file_walker(_commit_id, path):
1139 def file_walker(_commit_id, path):
1140 repo = self._factory.repo(wire)
1140 repo = self._factory.repo(wire)
1141 ctx = repo[_commit_id]
1141 ctx = repo[_commit_id]
1142 is_root = path in ['', '/']
1142 is_root = path in ['', '/']
1143 if is_root:
1143 if is_root:
1144 matcher = alwaysmatcher(badfn=None)
1144 matcher = alwaysmatcher(badfn=None)
1145 else:
1145 else:
1146 matcher = patternmatcher('', [(b'glob', path+'/**', b'')], badfn=None)
1146 matcher = patternmatcher('', [(b'glob', path+'/**', b'')], badfn=None)
1147 file_iter = ctx.manifest().walk(matcher)
1147 file_iter = ctx.manifest().walk(matcher)
1148
1148
1149 for fn in file_iter:
1149 for fn in file_iter:
1150 file_path = fn
1150 file_path = fn
1151 flags = ctx.flags(fn)
1151 flags = ctx.flags(fn)
1152 mode = b'x' in flags and 0o755 or 0o644
1152 mode = b'x' in flags and 0o755 or 0o644
1153 is_link = b'l' in flags
1153 is_link = b'l' in flags
1154
1154
1155 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1155 yield ArchiveNode(file_path, mode, is_link, ctx[fn].data)
1156
1156
1157 return store_archive_in_cache(
1157 return store_archive_in_cache(
1158 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
1158 file_walker, archive_name_key, kind, mtime, archive_at_path, archive_dir_name, commit_id, cache_config=cache_config)
1159
1159
@@ -1,563 +1,563 b''
1 """
1 """
2 Module provides a class allowing to wrap communication over subprocess.Popen
2 Module provides a class allowing to wrap communication over subprocess.Popen
3 input, output, error streams into a meaningfull, non-blocking, concurrent
3 input, output, error streams into a meaningfull, non-blocking, concurrent
4 stream processor exposing the output data as an iterator fitting to be a
4 stream processor exposing the output data as an iterator fitting to be a
5 return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
5 return value passed by a WSGI applicaiton to a WSGI server per PEP 3333.
6
6
7 Copyright (c) 2011 Daniel Dotsenko <dotsa[at]hotmail.com>
7 Copyright (c) 2011 Daniel Dotsenko <dotsa[at]hotmail.com>
8
8
9 This file is part of git_http_backend.py Project.
9 This file is part of git_http_backend.py Project.
10
10
11 git_http_backend.py Project is free software: you can redistribute it and/or
11 git_http_backend.py Project is free software: you can redistribute it and/or
12 modify it under the terms of the GNU Lesser General Public License as
12 modify it under the terms of the GNU Lesser General Public License as
13 published by the Free Software Foundation, either version 2.1 of the License,
13 published by the Free Software Foundation, either version 2.1 of the License,
14 or (at your option) any later version.
14 or (at your option) any later version.
15
15
16 git_http_backend.py Project is distributed in the hope that it will be useful,
16 git_http_backend.py Project is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU Lesser General Public License for more details.
19 GNU Lesser General Public License for more details.
20
20
21 You should have received a copy of the GNU Lesser General Public License
21 You should have received a copy of the GNU Lesser General Public License
22 along with git_http_backend.py Project.
22 along with git_http_backend.py Project.
23 If not, see <http://www.gnu.org/licenses/>.
23 If not, see <http://www.gnu.org/licenses/>.
24 """
24 """
25 import os
25 import os
26 import collections
26 import collections
27 import logging
27 import logging
28 import subprocess
28 import subprocess
29 import threading
29 import threading
30
30
31 from vcsserver.str_utils import safe_str
31 from vcsserver.str_utils import safe_str
32
32
33 log = logging.getLogger(__name__)
33 log = logging.getLogger(__name__)
34
34
35
35
36 class StreamFeeder(threading.Thread):
36 class StreamFeeder(threading.Thread):
37 """
37 """
38 Normal writing into pipe-like is blocking once the buffer is filled.
38 Normal writing into pipe-like is blocking once the buffer is filled.
39 This thread allows a thread to seep data from a file-like into a pipe
39 This thread allows a thread to seep data from a file-like into a pipe
40 without blocking the main thread.
40 without blocking the main thread.
41 We close inpipe once the end of the source stream is reached.
41 We close inpipe once the end of the source stream is reached.
42 """
42 """
43
43
44 def __init__(self, source):
44 def __init__(self, source):
45 super(StreamFeeder, self).__init__()
45 super().__init__()
46 self.daemon = True
46 self.daemon = True
47 filelike = False
47 filelike = False
48 self.bytes = bytes()
48 self.bytes = b''
49 if type(source) in (type(''), bytes, bytearray): # string-like
49 if type(source) in (str, bytes, bytearray): # string-like
50 self.bytes = bytes(source)
50 self.bytes = bytes(source)
51 else: # can be either file pointer or file-like
51 else: # can be either file pointer or file-like
52 if isinstance(source, int): # file pointer it is
52 if isinstance(source, int): # file pointer it is
53 # converting file descriptor (int) stdin into file-like
53 # converting file descriptor (int) stdin into file-like
54 source = os.fdopen(source, 'rb', 16384)
54 source = os.fdopen(source, 'rb', 16384)
55 # let's see if source is file-like by now
55 # let's see if source is file-like by now
56 filelike = hasattr(source, 'read')
56 filelike = hasattr(source, 'read')
57 if not filelike and not self.bytes:
57 if not filelike and not self.bytes:
58 raise TypeError("StreamFeeder's source object must be a readable "
58 raise TypeError("StreamFeeder's source object must be a readable "
59 "file-like, a file descriptor, or a string-like.")
59 "file-like, a file descriptor, or a string-like.")
60 self.source = source
60 self.source = source
61 self.readiface, self.writeiface = os.pipe()
61 self.readiface, self.writeiface = os.pipe()
62
62
63 def run(self):
63 def run(self):
64 writer = self.writeiface
64 writer = self.writeiface
65 try:
65 try:
66 if self.bytes:
66 if self.bytes:
67 os.write(writer, self.bytes)
67 os.write(writer, self.bytes)
68 else:
68 else:
69 s = self.source
69 s = self.source
70
70
71 while 1:
71 while 1:
72 _bytes = s.read(4096)
72 _bytes = s.read(4096)
73 if not _bytes:
73 if not _bytes:
74 break
74 break
75 os.write(writer, _bytes)
75 os.write(writer, _bytes)
76
76
77 finally:
77 finally:
78 os.close(writer)
78 os.close(writer)
79
79
80 @property
80 @property
81 def output(self):
81 def output(self):
82 return self.readiface
82 return self.readiface
83
83
84
84
85 class InputStreamChunker(threading.Thread):
85 class InputStreamChunker(threading.Thread):
86 def __init__(self, source, target, buffer_size, chunk_size):
86 def __init__(self, source, target, buffer_size, chunk_size):
87
87
88 super(InputStreamChunker, self).__init__()
88 super().__init__()
89
89
90 self.daemon = True # die die die.
90 self.daemon = True # die die die.
91
91
92 self.source = source
92 self.source = source
93 self.target = target
93 self.target = target
94 self.chunk_count_max = int(buffer_size / chunk_size) + 1
94 self.chunk_count_max = int(buffer_size / chunk_size) + 1
95 self.chunk_size = chunk_size
95 self.chunk_size = chunk_size
96
96
97 self.data_added = threading.Event()
97 self.data_added = threading.Event()
98 self.data_added.clear()
98 self.data_added.clear()
99
99
100 self.keep_reading = threading.Event()
100 self.keep_reading = threading.Event()
101 self.keep_reading.set()
101 self.keep_reading.set()
102
102
103 self.EOF = threading.Event()
103 self.EOF = threading.Event()
104 self.EOF.clear()
104 self.EOF.clear()
105
105
106 self.go = threading.Event()
106 self.go = threading.Event()
107 self.go.set()
107 self.go.set()
108
108
109 def stop(self):
109 def stop(self):
110 self.go.clear()
110 self.go.clear()
111 self.EOF.set()
111 self.EOF.set()
112 try:
112 try:
113 # this is not proper, but is done to force the reader thread let
113 # this is not proper, but is done to force the reader thread let
114 # go of the input because, if successful, .close() will send EOF
114 # go of the input because, if successful, .close() will send EOF
115 # down the pipe.
115 # down the pipe.
116 self.source.close()
116 self.source.close()
117 except Exception:
117 except Exception:
118 pass
118 pass
119
119
120 def run(self):
120 def run(self):
121 s = self.source
121 s = self.source
122 t = self.target
122 t = self.target
123 cs = self.chunk_size
123 cs = self.chunk_size
124 chunk_count_max = self.chunk_count_max
124 chunk_count_max = self.chunk_count_max
125 keep_reading = self.keep_reading
125 keep_reading = self.keep_reading
126 da = self.data_added
126 da = self.data_added
127 go = self.go
127 go = self.go
128
128
129 try:
129 try:
130 b = s.read(cs)
130 b = s.read(cs)
131 except ValueError:
131 except ValueError:
132 b = ''
132 b = ''
133
133
134 timeout_input = 20
134 timeout_input = 20
135 while b and go.is_set():
135 while b and go.is_set():
136 if len(t) > chunk_count_max:
136 if len(t) > chunk_count_max:
137 keep_reading.clear()
137 keep_reading.clear()
138 keep_reading.wait(timeout_input)
138 keep_reading.wait(timeout_input)
139 if len(t) > chunk_count_max + timeout_input:
139 if len(t) > chunk_count_max + timeout_input:
140 log.error("Timed out while waiting for input from subprocess.")
140 log.error("Timed out while waiting for input from subprocess.")
141 os._exit(-1) # this will cause the worker to recycle itself
141 os._exit(-1) # this will cause the worker to recycle itself
142
142
143 t.append(b)
143 t.append(b)
144 da.set()
144 da.set()
145
145
146 try:
146 try:
147 b = s.read(cs)
147 b = s.read(cs)
148 except ValueError: # probably "I/O operation on closed file"
148 except ValueError: # probably "I/O operation on closed file"
149 b = ''
149 b = ''
150
150
151 self.EOF.set()
151 self.EOF.set()
152 da.set() # for cases when done but there was no input.
152 da.set() # for cases when done but there was no input.
153
153
154
154
155 class BufferedGenerator(object):
155 class BufferedGenerator(object):
156 """
156 """
157 Class behaves as a non-blocking, buffered pipe reader.
157 Class behaves as a non-blocking, buffered pipe reader.
158 Reads chunks of data (through a thread)
158 Reads chunks of data (through a thread)
159 from a blocking pipe, and attaches these to an array (Deque) of chunks.
159 from a blocking pipe, and attaches these to an array (Deque) of chunks.
160 Reading is halted in the thread when max chunks is internally buffered.
160 Reading is halted in the thread when max chunks is internally buffered.
161 The .next() may operate in blocking or non-blocking fashion by yielding
161 The .next() may operate in blocking or non-blocking fashion by yielding
162 '' if no data is ready
162 '' if no data is ready
163 to be sent or by not returning until there is some data to send
163 to be sent or by not returning until there is some data to send
164 When we get EOF from underlying source pipe we raise the marker to raise
164 When we get EOF from underlying source pipe we raise the marker to raise
165 StopIteration after the last chunk of data is yielded.
165 StopIteration after the last chunk of data is yielded.
166 """
166 """
167
167
168 def __init__(self, name, source, buffer_size=65536, chunk_size=4096,
168 def __init__(self, name, source, buffer_size=65536, chunk_size=4096,
169 starting_values=None, bottomless=False):
169 starting_values=None, bottomless=False):
170 starting_values = starting_values or []
170 starting_values = starting_values or []
171 self.name = name
171 self.name = name
172 self.buffer_size = buffer_size
172 self.buffer_size = buffer_size
173 self.chunk_size = chunk_size
173 self.chunk_size = chunk_size
174
174
175 if bottomless:
175 if bottomless:
176 maxlen = int(buffer_size / chunk_size)
176 maxlen = int(buffer_size / chunk_size)
177 else:
177 else:
178 maxlen = None
178 maxlen = None
179
179
180 self.data_queue = collections.deque(starting_values, maxlen)
180 self.data_queue = collections.deque(starting_values, maxlen)
181 self.worker = InputStreamChunker(source, self.data_queue, buffer_size, chunk_size)
181 self.worker = InputStreamChunker(source, self.data_queue, buffer_size, chunk_size)
182 if starting_values:
182 if starting_values:
183 self.worker.data_added.set()
183 self.worker.data_added.set()
184 self.worker.start()
184 self.worker.start()
185
185
186 ####################
186 ####################
187 # Generator's methods
187 # Generator's methods
188 ####################
188 ####################
189 def __str__(self):
189 def __str__(self):
190 return f'BufferedGenerator(name={self.name} chunk: {self.chunk_size} on buffer: {self.buffer_size})'
190 return f'BufferedGenerator(name={self.name} chunk: {self.chunk_size} on buffer: {self.buffer_size})'
191
191
192 def __iter__(self):
192 def __iter__(self):
193 return self
193 return self
194
194
195 def __next__(self):
195 def __next__(self):
196
196
197 while not self.length and not self.worker.EOF.is_set():
197 while not self.length and not self.worker.EOF.is_set():
198 self.worker.data_added.clear()
198 self.worker.data_added.clear()
199 self.worker.data_added.wait(0.2)
199 self.worker.data_added.wait(0.2)
200
200
201 if self.length:
201 if self.length:
202 self.worker.keep_reading.set()
202 self.worker.keep_reading.set()
203 return bytes(self.data_queue.popleft())
203 return bytes(self.data_queue.popleft())
204 elif self.worker.EOF.is_set():
204 elif self.worker.EOF.is_set():
205 raise StopIteration
205 raise StopIteration
206
206
207 def throw(self, exc_type, value=None, traceback=None):
207 def throw(self, exc_type, value=None, traceback=None):
208 if not self.worker.EOF.is_set():
208 if not self.worker.EOF.is_set():
209 raise exc_type(value)
209 raise exc_type(value)
210
210
211 def start(self):
211 def start(self):
212 self.worker.start()
212 self.worker.start()
213
213
214 def stop(self):
214 def stop(self):
215 self.worker.stop()
215 self.worker.stop()
216
216
217 def close(self):
217 def close(self):
218 try:
218 try:
219 self.worker.stop()
219 self.worker.stop()
220 self.throw(GeneratorExit)
220 self.throw(GeneratorExit)
221 except (GeneratorExit, StopIteration):
221 except (GeneratorExit, StopIteration):
222 pass
222 pass
223
223
224 ####################
224 ####################
225 # Threaded reader's infrastructure.
225 # Threaded reader's infrastructure.
226 ####################
226 ####################
227 @property
227 @property
228 def input(self):
228 def input(self):
229 return self.worker.w
229 return self.worker.w
230
230
231 @property
231 @property
232 def data_added_event(self):
232 def data_added_event(self):
233 return self.worker.data_added
233 return self.worker.data_added
234
234
235 @property
235 @property
236 def data_added(self):
236 def data_added(self):
237 return self.worker.data_added.is_set()
237 return self.worker.data_added.is_set()
238
238
239 @property
239 @property
240 def reading_paused(self):
240 def reading_paused(self):
241 return not self.worker.keep_reading.is_set()
241 return not self.worker.keep_reading.is_set()
242
242
243 @property
243 @property
244 def done_reading_event(self):
244 def done_reading_event(self):
245 """
245 """
246 Done_reding does not mean that the iterator's buffer is empty.
246 Done_reding does not mean that the iterator's buffer is empty.
247 Iterator might have done reading from underlying source, but the read
247 Iterator might have done reading from underlying source, but the read
248 chunks might still be available for serving through .next() method.
248 chunks might still be available for serving through .next() method.
249
249
250 :returns: An Event class instance.
250 :returns: An Event class instance.
251 """
251 """
252 return self.worker.EOF
252 return self.worker.EOF
253
253
254 @property
254 @property
255 def done_reading(self):
255 def done_reading(self):
256 """
256 """
257 Done_reading does not mean that the iterator's buffer is empty.
257 Done_reading does not mean that the iterator's buffer is empty.
258 Iterator might have done reading from underlying source, but the read
258 Iterator might have done reading from underlying source, but the read
259 chunks might still be available for serving through .next() method.
259 chunks might still be available for serving through .next() method.
260
260
261 :returns: An Bool value.
261 :returns: An Bool value.
262 """
262 """
263 return self.worker.EOF.is_set()
263 return self.worker.EOF.is_set()
264
264
265 @property
265 @property
266 def length(self):
266 def length(self):
267 """
267 """
268 returns int.
268 returns int.
269
269
270 This is the length of the queue of chunks, not the length of
270 This is the length of the queue of chunks, not the length of
271 the combined contents in those chunks.
271 the combined contents in those chunks.
272
272
273 __len__() cannot be meaningfully implemented because this
273 __len__() cannot be meaningfully implemented because this
274 reader is just flying through a bottomless pit content and
274 reader is just flying through a bottomless pit content and
275 can only know the length of what it already saw.
275 can only know the length of what it already saw.
276
276
277 If __len__() on WSGI server per PEP 3333 returns a value,
277 If __len__() on WSGI server per PEP 3333 returns a value,
278 the response's length will be set to that. In order not to
278 the response's length will be set to that. In order not to
279 confuse WSGI PEP3333 servers, we will not implement __len__
279 confuse WSGI PEP3333 servers, we will not implement __len__
280 at all.
280 at all.
281 """
281 """
282 return len(self.data_queue)
282 return len(self.data_queue)
283
283
284 def prepend(self, x):
284 def prepend(self, x):
285 self.data_queue.appendleft(x)
285 self.data_queue.appendleft(x)
286
286
287 def append(self, x):
287 def append(self, x):
288 self.data_queue.append(x)
288 self.data_queue.append(x)
289
289
290 def extend(self, o):
290 def extend(self, o):
291 self.data_queue.extend(o)
291 self.data_queue.extend(o)
292
292
293 def __getitem__(self, i):
293 def __getitem__(self, i):
294 return self.data_queue[i]
294 return self.data_queue[i]
295
295
296
296
297 class SubprocessIOChunker(object):
297 class SubprocessIOChunker(object):
298 """
298 """
299 Processor class wrapping handling of subprocess IO.
299 Processor class wrapping handling of subprocess IO.
300
300
301 .. important::
301 .. important::
302
302
303 Watch out for the method `__del__` on this class. If this object
303 Watch out for the method `__del__` on this class. If this object
304 is deleted, it will kill the subprocess, so avoid to
304 is deleted, it will kill the subprocess, so avoid to
305 return the `output` attribute or usage of it like in the following
305 return the `output` attribute or usage of it like in the following
306 example::
306 example::
307
307
308 # `args` expected to run a program that produces a lot of output
308 # `args` expected to run a program that produces a lot of output
309 output = ''.join(SubprocessIOChunker(
309 output = ''.join(SubprocessIOChunker(
310 args, shell=False, inputstream=inputstream, env=environ).output)
310 args, shell=False, inputstream=inputstream, env=environ).output)
311
311
312 # `output` will not contain all the data, because the __del__ method
312 # `output` will not contain all the data, because the __del__ method
313 # has already killed the subprocess in this case before all output
313 # has already killed the subprocess in this case before all output
314 # has been consumed.
314 # has been consumed.
315
315
316
316
317
317
318 In a way, this is a "communicate()" replacement with a twist.
318 In a way, this is a "communicate()" replacement with a twist.
319
319
320 - We are multithreaded. Writing in and reading out, err are all sep threads.
320 - We are multithreaded. Writing in and reading out, err are all sep threads.
321 - We support concurrent (in and out) stream processing.
321 - We support concurrent (in and out) stream processing.
322 - The output is not a stream. It's a queue of read string (bytes, not str)
322 - The output is not a stream. It's a queue of read string (bytes, not str)
323 chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
323 chunks. The object behaves as an iterable. You can "for chunk in obj:" us.
324 - We are non-blocking in more respects than communicate()
324 - We are non-blocking in more respects than communicate()
325 (reading from subprocess out pauses when internal buffer is full, but
325 (reading from subprocess out pauses when internal buffer is full, but
326 does not block the parent calling code. On the flip side, reading from
326 does not block the parent calling code. On the flip side, reading from
327 slow-yielding subprocess may block the iteration until data shows up. This
327 slow-yielding subprocess may block the iteration until data shows up. This
328 does not block the parallel inpipe reading occurring parallel thread.)
328 does not block the parallel inpipe reading occurring parallel thread.)
329
329
330 The purpose of the object is to allow us to wrap subprocess interactions into
330 The purpose of the object is to allow us to wrap subprocess interactions into
331 an iterable that can be passed to a WSGI server as the application's return
331 an iterable that can be passed to a WSGI server as the application's return
332 value. Because of stream-processing-ability, WSGI does not have to read ALL
332 value. Because of stream-processing-ability, WSGI does not have to read ALL
333 of the subprocess's output and buffer it, before handing it to WSGI server for
333 of the subprocess's output and buffer it, before handing it to WSGI server for
334 HTTP response. Instead, the class initializer reads just a bit of the stream
334 HTTP response. Instead, the class initializer reads just a bit of the stream
335 to figure out if error occurred or likely to occur and if not, just hands the
335 to figure out if error occurred or likely to occur and if not, just hands the
336 further iteration over subprocess output to the server for completion of HTTP
336 further iteration over subprocess output to the server for completion of HTTP
337 response.
337 response.
338
338
339 The real or perceived subprocess error is trapped and raised as one of
339 The real or perceived subprocess error is trapped and raised as one of
340 OSError family of exceptions
340 OSError family of exceptions
341
341
342 Example usage:
342 Example usage:
343 # try:
343 # try:
344 # answer = SubprocessIOChunker(
344 # answer = SubprocessIOChunker(
345 # cmd,
345 # cmd,
346 # input,
346 # input,
347 # buffer_size = 65536,
347 # buffer_size = 65536,
348 # chunk_size = 4096
348 # chunk_size = 4096
349 # )
349 # )
350 # except (OSError) as e:
350 # except (OSError) as e:
351 # print str(e)
351 # print str(e)
352 # raise e
352 # raise e
353 #
353 #
354 # return answer
354 # return answer
355
355
356
356
357 """
357 """
358
358
359 # TODO: johbo: This is used to make sure that the open end of the PIPE
359 # TODO: johbo: This is used to make sure that the open end of the PIPE
360 # is closed in the end. It would be way better to wrap this into an
360 # is closed in the end. It would be way better to wrap this into an
361 # object, so that it is closed automatically once it is consumed or
361 # object, so that it is closed automatically once it is consumed or
362 # something similar.
362 # something similar.
363 _close_input_fd = None
363 _close_input_fd = None
364
364
365 _closed = False
365 _closed = False
366 _stdout = None
366 _stdout = None
367 _stderr = None
367 _stderr = None
368
368
369 def __init__(self, cmd, input_stream=None, buffer_size=65536,
369 def __init__(self, cmd, input_stream=None, buffer_size=65536,
370 chunk_size=4096, starting_values=None, fail_on_stderr=True,
370 chunk_size=4096, starting_values=None, fail_on_stderr=True,
371 fail_on_return_code=True, **kwargs):
371 fail_on_return_code=True, **kwargs):
372 """
372 """
373 Initializes SubprocessIOChunker
373 Initializes SubprocessIOChunker
374
374
375 :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
375 :param cmd: A Subprocess.Popen style "cmd". Can be string or array of strings
376 :param input_stream: (Default: None) A file-like, string, or file pointer.
376 :param input_stream: (Default: None) A file-like, string, or file pointer.
377 :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
377 :param buffer_size: (Default: 65536) A size of total buffer per stream in bytes.
378 :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
378 :param chunk_size: (Default: 4096) A max size of a chunk. Actual chunk may be smaller.
379 :param starting_values: (Default: []) An array of strings to put in front of output que.
379 :param starting_values: (Default: []) An array of strings to put in front of output que.
380 :param fail_on_stderr: (Default: True) Whether to raise an exception in
380 :param fail_on_stderr: (Default: True) Whether to raise an exception in
381 case something is written to stderr.
381 case something is written to stderr.
382 :param fail_on_return_code: (Default: True) Whether to raise an
382 :param fail_on_return_code: (Default: True) Whether to raise an
383 exception if the return code is not 0.
383 exception if the return code is not 0.
384 """
384 """
385
385
386 kwargs['shell'] = kwargs.get('shell', True)
386 kwargs['shell'] = kwargs.get('shell', True)
387
387
388 starting_values = starting_values or []
388 starting_values = starting_values or []
389 if input_stream:
389 if input_stream:
390 input_streamer = StreamFeeder(input_stream)
390 input_streamer = StreamFeeder(input_stream)
391 input_streamer.start()
391 input_streamer.start()
392 input_stream = input_streamer.output
392 input_stream = input_streamer.output
393 self._close_input_fd = input_stream
393 self._close_input_fd = input_stream
394
394
395 self._fail_on_stderr = fail_on_stderr
395 self._fail_on_stderr = fail_on_stderr
396 self._fail_on_return_code = fail_on_return_code
396 self._fail_on_return_code = fail_on_return_code
397 self.cmd = cmd
397 self.cmd = cmd
398
398
399 _p = subprocess.Popen(cmd, bufsize=-1, stdin=input_stream, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
399 _p = subprocess.Popen(cmd, bufsize=-1, stdin=input_stream, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
400 **kwargs)
400 **kwargs)
401 self.process = _p
401 self.process = _p
402
402
403 bg_out = BufferedGenerator('stdout', _p.stdout, buffer_size, chunk_size, starting_values)
403 bg_out = BufferedGenerator('stdout', _p.stdout, buffer_size, chunk_size, starting_values)
404 bg_err = BufferedGenerator('stderr', _p.stderr, 10240, 1, bottomless=True)
404 bg_err = BufferedGenerator('stderr', _p.stderr, 10240, 1, bottomless=True)
405
405
406 while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
406 while not bg_out.done_reading and not bg_out.reading_paused and not bg_err.length:
407 # doing this until we reach either end of file, or end of buffer.
407 # doing this until we reach either end of file, or end of buffer.
408 bg_out.data_added_event.wait(0.2)
408 bg_out.data_added_event.wait(0.2)
409 bg_out.data_added_event.clear()
409 bg_out.data_added_event.clear()
410
410
411 # at this point it's still ambiguous if we are done reading or just full buffer.
411 # at this point it's still ambiguous if we are done reading or just full buffer.
412 # Either way, if error (returned by ended process, or implied based on
412 # Either way, if error (returned by ended process, or implied based on
413 # presence of stuff in stderr output) we error out.
413 # presence of stuff in stderr output) we error out.
414 # Else, we are happy.
414 # Else, we are happy.
415 return_code = _p.poll()
415 return_code = _p.poll()
416 ret_code_ok = return_code in [None, 0]
416 ret_code_ok = return_code in [None, 0]
417 ret_code_fail = return_code is not None and return_code != 0
417 ret_code_fail = return_code is not None and return_code != 0
418 if (
418 if (
419 (ret_code_fail and fail_on_return_code) or
419 (ret_code_fail and fail_on_return_code) or
420 (ret_code_ok and fail_on_stderr and bg_err.length)
420 (ret_code_ok and fail_on_stderr and bg_err.length)
421 ):
421 ):
422
422
423 try:
423 try:
424 _p.terminate()
424 _p.terminate()
425 except Exception:
425 except Exception:
426 pass
426 pass
427
427
428 bg_out.stop()
428 bg_out.stop()
429 out = b''.join(bg_out)
429 out = b''.join(bg_out)
430 self._stdout = out
430 self._stdout = out
431
431
432 bg_err.stop()
432 bg_err.stop()
433 err = b''.join(bg_err)
433 err = b''.join(bg_err)
434 self._stderr = err
434 self._stderr = err
435
435
436 # code from https://github.com/schacon/grack/pull/7
436 # code from https://github.com/schacon/grack/pull/7
437 if err.strip() == b'fatal: The remote end hung up unexpectedly' and out.startswith(b'0034shallow '):
437 if err.strip() == b'fatal: The remote end hung up unexpectedly' and out.startswith(b'0034shallow '):
438 bg_out = iter([out])
438 bg_out = iter([out])
439 _p = None
439 _p = None
440 elif err and fail_on_stderr:
440 elif err and fail_on_stderr:
441 text_err = err.decode()
441 text_err = err.decode()
442 raise OSError(
442 raise OSError(
443 "Subprocess exited due to an error:\n{}".format(text_err))
443 f"Subprocess exited due to an error:\n{text_err}")
444
444
445 if ret_code_fail and fail_on_return_code:
445 if ret_code_fail and fail_on_return_code:
446 text_err = err.decode()
446 text_err = err.decode()
447 if not err:
447 if not err:
448 # maybe get empty stderr, try stdout instead
448 # maybe get empty stderr, try stdout instead
449 # in many cases git reports the errors on stdout too
449 # in many cases git reports the errors on stdout too
450 text_err = out.decode()
450 text_err = out.decode()
451 raise OSError(
451 raise OSError(
452 "Subprocess exited with non 0 ret code:{}: stderr:{}".format(return_code, text_err))
452 f"Subprocess exited with non 0 ret code:{return_code}: stderr:{text_err}")
453
453
454 self.stdout = bg_out
454 self.stdout = bg_out
455 self.stderr = bg_err
455 self.stderr = bg_err
456 self.inputstream = input_stream
456 self.inputstream = input_stream
457
457
458 def __str__(self):
458 def __str__(self):
459 proc = getattr(self, 'process', 'NO_PROCESS')
459 proc = getattr(self, 'process', 'NO_PROCESS')
460 return f'SubprocessIOChunker: {proc}'
460 return f'SubprocessIOChunker: {proc}'
461
461
462 def __iter__(self):
462 def __iter__(self):
463 return self
463 return self
464
464
465 def __next__(self):
465 def __next__(self):
466 # Note: mikhail: We need to be sure that we are checking the return
466 # Note: mikhail: We need to be sure that we are checking the return
467 # code after the stdout stream is closed. Some processes, e.g. git
467 # code after the stdout stream is closed. Some processes, e.g. git
468 # are doing some magic in between closing stdout and terminating the
468 # are doing some magic in between closing stdout and terminating the
469 # process and, as a result, we are not getting return code on "slow"
469 # process and, as a result, we are not getting return code on "slow"
470 # systems.
470 # systems.
471 result = None
471 result = None
472 stop_iteration = None
472 stop_iteration = None
473 try:
473 try:
474 result = next(self.stdout)
474 result = next(self.stdout)
475 except StopIteration as e:
475 except StopIteration as e:
476 stop_iteration = e
476 stop_iteration = e
477
477
478 if self.process:
478 if self.process:
479 return_code = self.process.poll()
479 return_code = self.process.poll()
480 ret_code_fail = return_code is not None and return_code != 0
480 ret_code_fail = return_code is not None and return_code != 0
481 if ret_code_fail and self._fail_on_return_code:
481 if ret_code_fail and self._fail_on_return_code:
482 self.stop_streams()
482 self.stop_streams()
483 err = self.get_stderr()
483 err = self.get_stderr()
484 raise OSError(
484 raise OSError(
485 "Subprocess exited (exit_code:{}) due to an error during iteration:\n{}".format(return_code, err))
485 f"Subprocess exited (exit_code:{return_code}) due to an error during iteration:\n{err}")
486
486
487 if stop_iteration:
487 if stop_iteration:
488 raise stop_iteration
488 raise stop_iteration
489 return result
489 return result
490
490
491 def throw(self, exc_type, value=None, traceback=None):
491 def throw(self, exc_type, value=None, traceback=None):
492 if self.stdout.length or not self.stdout.done_reading:
492 if self.stdout.length or not self.stdout.done_reading:
493 raise exc_type(value)
493 raise exc_type(value)
494
494
495 def close(self):
495 def close(self):
496 if self._closed:
496 if self._closed:
497 return
497 return
498
498
499 try:
499 try:
500 self.process.terminate()
500 self.process.terminate()
501 except Exception:
501 except Exception:
502 pass
502 pass
503 if self._close_input_fd:
503 if self._close_input_fd:
504 os.close(self._close_input_fd)
504 os.close(self._close_input_fd)
505 try:
505 try:
506 self.stdout.close()
506 self.stdout.close()
507 except Exception:
507 except Exception:
508 pass
508 pass
509 try:
509 try:
510 self.stderr.close()
510 self.stderr.close()
511 except Exception:
511 except Exception:
512 pass
512 pass
513 try:
513 try:
514 os.close(self.inputstream)
514 os.close(self.inputstream)
515 except Exception:
515 except Exception:
516 pass
516 pass
517
517
518 self._closed = True
518 self._closed = True
519
519
520 def stop_streams(self):
520 def stop_streams(self):
521 getattr(self.stdout, 'stop', lambda: None)()
521 getattr(self.stdout, 'stop', lambda: None)()
522 getattr(self.stderr, 'stop', lambda: None)()
522 getattr(self.stderr, 'stop', lambda: None)()
523
523
524 def get_stdout(self):
524 def get_stdout(self):
525 if self._stdout:
525 if self._stdout:
526 return self._stdout
526 return self._stdout
527 else:
527 else:
528 return b''.join(self.stdout)
528 return b''.join(self.stdout)
529
529
530 def get_stderr(self):
530 def get_stderr(self):
531 if self._stderr:
531 if self._stderr:
532 return self._stderr
532 return self._stderr
533 else:
533 else:
534 return b''.join(self.stderr)
534 return b''.join(self.stderr)
535
535
536
536
537 def run_command(arguments, env=None):
537 def run_command(arguments, env=None):
538 """
538 """
539 Run the specified command and return the stdout.
539 Run the specified command and return the stdout.
540
540
541 :param arguments: sequence of program arguments (including the program name)
541 :param arguments: sequence of program arguments (including the program name)
542 :type arguments: list[str]
542 :type arguments: list[str]
543 """
543 """
544
544
545 cmd = arguments
545 cmd = arguments
546 log.debug('Running subprocessio command %s', cmd)
546 log.debug('Running subprocessio command %s', cmd)
547 proc = None
547 proc = None
548 try:
548 try:
549 _opts = {'shell': False, 'fail_on_stderr': False}
549 _opts = {'shell': False, 'fail_on_stderr': False}
550 if env:
550 if env:
551 _opts.update({'env': env})
551 _opts.update({'env': env})
552 proc = SubprocessIOChunker(cmd, **_opts)
552 proc = SubprocessIOChunker(cmd, **_opts)
553 return b''.join(proc), b''.join(proc.stderr)
553 return b''.join(proc), b''.join(proc.stderr)
554 except OSError as err:
554 except OSError as err:
555 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
555 cmd = ' '.join(map(safe_str, cmd)) # human friendly CMD
556 tb_err = ("Couldn't run subprocessio command (%s).\n"
556 tb_err = ("Couldn't run subprocessio command (%s).\n"
557 "Original error was:%s\n" % (cmd, err))
557 "Original error was:%s\n" % (cmd, err))
558 log.exception(tb_err)
558 log.exception(tb_err)
559 raise Exception(tb_err)
559 raise Exception(tb_err)
560 finally:
560 finally:
561 if proc:
561 if proc:
562 proc.close()
562 proc.close()
563
563
@@ -1,123 +1,123 b''
1 # RhodeCode VCSServer provides access to different vcs backends via network.
1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 # Copyright (C) 2014-2023 RhodeCode GmbH
2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 #
3 #
4 # This program is free software; you can redistribute it and/or modify
4 # This program is free software; you can redistribute it and/or modify
5 # it under the terms of the GNU General Public License as published by
5 # it under the terms of the GNU General Public License as published by
6 # the Free Software Foundation; either version 3 of the License, or
6 # the Free Software Foundation; either version 3 of the License, or
7 # (at your option) any later version.
7 # (at your option) any later version.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU General Public License
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software Foundation,
15 # along with this program; if not, write to the Free Software Foundation,
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 import base64
17 import base64
18 import time
18 import time
19 import logging
19 import logging
20
20
21 import msgpack
21 import msgpack
22
22
23 import vcsserver
23 import vcsserver
24 from vcsserver.str_utils import safe_str, ascii_str
24 from vcsserver.str_utils import safe_str, ascii_str
25
25
26 log = logging.getLogger(__name__)
26 log = logging.getLogger(__name__)
27
27
28
28
29 def get_access_path(environ):
29 def get_access_path(environ):
30 path = environ.get('PATH_INFO')
30 path = environ.get('PATH_INFO')
31 return path
31 return path
32
32
33
33
34 def get_user_agent(environ):
34 def get_user_agent(environ):
35 return environ.get('HTTP_USER_AGENT')
35 return environ.get('HTTP_USER_AGENT')
36
36
37
37
38 def get_call_context(request) -> dict:
38 def get_call_context(request) -> dict:
39 cc = {}
39 cc = {}
40 registry = request.registry
40 registry = request.registry
41 if hasattr(registry, 'vcs_call_context'):
41 if hasattr(registry, 'vcs_call_context'):
42 cc.update({
42 cc.update({
43 'X-RC-Method': registry.vcs_call_context.get('method'),
43 'X-RC-Method': registry.vcs_call_context.get('method'),
44 'X-RC-Repo-Name': registry.vcs_call_context.get('repo_name')
44 'X-RC-Repo-Name': registry.vcs_call_context.get('repo_name')
45 })
45 })
46
46
47 return cc
47 return cc
48
48
49
49
50 def get_headers_call_context(environ, strict=True):
50 def get_headers_call_context(environ, strict=True):
51 if 'HTTP_X_RC_VCS_STREAM_CALL_CONTEXT' in environ:
51 if 'HTTP_X_RC_VCS_STREAM_CALL_CONTEXT' in environ:
52 packed_cc = base64.b64decode(environ['HTTP_X_RC_VCS_STREAM_CALL_CONTEXT'])
52 packed_cc = base64.b64decode(environ['HTTP_X_RC_VCS_STREAM_CALL_CONTEXT'])
53 return msgpack.unpackb(packed_cc)
53 return msgpack.unpackb(packed_cc)
54 elif strict:
54 elif strict:
55 raise ValueError('Expected header HTTP_X_RC_VCS_STREAM_CALL_CONTEXT not found')
55 raise ValueError('Expected header HTTP_X_RC_VCS_STREAM_CALL_CONTEXT not found')
56
56
57
57
58 class RequestWrapperTween(object):
58 class RequestWrapperTween(object):
59 def __init__(self, handler, registry):
59 def __init__(self, handler, registry):
60 self.handler = handler
60 self.handler = handler
61 self.registry = registry
61 self.registry = registry
62
62
63 # one-time configuration code goes here
63 # one-time configuration code goes here
64
64
65 def __call__(self, request):
65 def __call__(self, request):
66 start = time.time()
66 start = time.time()
67 log.debug('Starting request time measurement')
67 log.debug('Starting request time measurement')
68 response = None
68 response = None
69
69
70 try:
70 try:
71 response = self.handler(request)
71 response = self.handler(request)
72 finally:
72 finally:
73 ua = get_user_agent(request.environ)
73 ua = get_user_agent(request.environ)
74 call_context = get_call_context(request)
74 call_context = get_call_context(request)
75 vcs_method = call_context.get('X-RC-Method', '_NO_VCS_METHOD')
75 vcs_method = call_context.get('X-RC-Method', '_NO_VCS_METHOD')
76 repo_name = call_context.get('X-RC-Repo-Name', '')
76 repo_name = call_context.get('X-RC-Repo-Name', '')
77
77
78 count = request.request_count()
78 count = request.request_count()
79 _ver_ = vcsserver.__version__
79 _ver_ = vcsserver.__version__
80 _path = safe_str(get_access_path(request.environ))
80 _path = safe_str(get_access_path(request.environ))
81
81
82 ip = '127.0.0.1'
82 ip = '127.0.0.1'
83 match_route = request.matched_route.name if request.matched_route else "NOT_FOUND"
83 match_route = request.matched_route.name if request.matched_route else "NOT_FOUND"
84 resp_code = getattr(response, 'status_code', 'UNDEFINED')
84 resp_code = getattr(response, 'status_code', 'UNDEFINED')
85
85
86 _view_path = f"{repo_name}@{_path}/{vcs_method}"
86 _view_path = f"{repo_name}@{_path}/{vcs_method}"
87
87
88 total = time.time() - start
88 total = time.time() - start
89
89
90 log.info(
90 log.info(
91 'Req[%4s] IP: %s %s Request to %s time: %.4fs [%s], VCSServer %s',
91 'Req[%4s] IP: %s %s Request to %s time: %.4fs [%s], VCSServer %s',
92 count, ip, request.environ.get('REQUEST_METHOD'),
92 count, ip, request.environ.get('REQUEST_METHOD'),
93 _view_path, total, ua, _ver_,
93 _view_path, total, ua, _ver_,
94 extra={"time": total, "ver": _ver_, "code": resp_code,
94 extra={"time": total, "ver": _ver_, "code": resp_code,
95 "path": _path, "view_name": match_route, "user_agent": ua,
95 "path": _path, "view_name": match_route, "user_agent": ua,
96 "vcs_method": vcs_method, "repo_name": repo_name}
96 "vcs_method": vcs_method, "repo_name": repo_name}
97 )
97 )
98
98
99 statsd = request.registry.statsd
99 statsd = request.registry.statsd
100 if statsd:
100 if statsd:
101 match_route = request.matched_route.name if request.matched_route else _path
101 match_route = request.matched_route.name if request.matched_route else _path
102 elapsed_time_ms = round(1000.0 * total) # use ms only
102 elapsed_time_ms = round(1000.0 * total) # use ms only
103 statsd.timing(
103 statsd.timing(
104 "vcsserver_req_timing.histogram", elapsed_time_ms,
104 "vcsserver_req_timing.histogram", elapsed_time_ms,
105 tags=[
105 tags=[
106 "view_name:{}".format(match_route),
106 f"view_name:{match_route}",
107 "code:{}".format(resp_code)
107 f"code:{resp_code}"
108 ],
108 ],
109 use_decimals=False
109 use_decimals=False
110 )
110 )
111 statsd.incr(
111 statsd.incr(
112 "vcsserver_req_total", tags=[
112 "vcsserver_req_total", tags=[
113 "view_name:{}".format(match_route),
113 f"view_name:{match_route}",
114 "code:{}".format(resp_code)
114 f"code:{resp_code}"
115 ])
115 ])
116
116
117 return response
117 return response
118
118
119
119
120 def includeme(config):
120 def includeme(config):
121 config.add_tween(
121 config.add_tween(
122 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
122 'vcsserver.tweens.request_wrapper.RequestWrapperTween',
123 )
123 )
General Comments 0
You need to be logged in to leave comments. Login now