##// END OF EJS Templates
python3: fixed various code issues...
super-admin -
r4973:5e52ba1a default
parent child Browse files
Show More
@@ -1,578 +1,578 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import itertools
21 import itertools
22 import logging
22 import logging
23 import sys
23 import sys
24 import types
24 import types
25 import fnmatch
25 import fnmatch
26
26
27 import decorator
27 import decorator
28 import venusian
28 import venusian
29 from collections import OrderedDict
29 from collections import OrderedDict
30
30
31 from pyramid.exceptions import ConfigurationError
31 from pyramid.exceptions import ConfigurationError
32 from pyramid.renderers import render
32 from pyramid.renderers import render
33 from pyramid.response import Response
33 from pyramid.response import Response
34 from pyramid.httpexceptions import HTTPNotFound
34 from pyramid.httpexceptions import HTTPNotFound
35
35
36 from rhodecode.api.exc import (
36 from rhodecode.api.exc import (
37 JSONRPCBaseError, JSONRPCError, JSONRPCForbidden, JSONRPCValidationError)
37 JSONRPCBaseError, JSONRPCError, JSONRPCForbidden, JSONRPCValidationError)
38 from rhodecode.apps._base import TemplateArgs
38 from rhodecode.apps._base import TemplateArgs
39 from rhodecode.lib.auth import AuthUser
39 from rhodecode.lib.auth import AuthUser
40 from rhodecode.lib.base import get_ip_addr, attach_context_attributes
40 from rhodecode.lib.base import get_ip_addr, attach_context_attributes
41 from rhodecode.lib.exc_tracking import store_exception
41 from rhodecode.lib.exc_tracking import store_exception
42 from rhodecode.lib.ext_json import json
42 from rhodecode.lib.ext_json import json
43 from rhodecode.lib.utils2 import safe_str
43 from rhodecode.lib.utils2 import safe_str
44 from rhodecode.lib.plugins.utils import get_plugin_settings
44 from rhodecode.lib.plugins.utils import get_plugin_settings
45 from rhodecode.model.db import User, UserApiKeys
45 from rhodecode.model.db import User, UserApiKeys
46
46
47 log = logging.getLogger(__name__)
47 log = logging.getLogger(__name__)
48
48
49 DEFAULT_RENDERER = 'jsonrpc_renderer'
49 DEFAULT_RENDERER = 'jsonrpc_renderer'
50 DEFAULT_URL = '/_admin/apiv2'
50 DEFAULT_URL = '/_admin/apiv2'
51
51
52
52
53 def find_methods(jsonrpc_methods, pattern):
53 def find_methods(jsonrpc_methods, pattern):
54 matches = OrderedDict()
54 matches = OrderedDict()
55 if not isinstance(pattern, (list, tuple)):
55 if not isinstance(pattern, (list, tuple)):
56 pattern = [pattern]
56 pattern = [pattern]
57
57
58 for single_pattern in pattern:
58 for single_pattern in pattern:
59 for method_name, method in jsonrpc_methods.items():
59 for method_name, method in jsonrpc_methods.items():
60 if fnmatch.fnmatch(method_name, single_pattern):
60 if fnmatch.fnmatch(method_name, single_pattern):
61 matches[method_name] = method
61 matches[method_name] = method
62 return matches
62 return matches
63
63
64
64
65 class ExtJsonRenderer(object):
65 class ExtJsonRenderer(object):
66 """
66 """
67 Custom renderer that mkaes use of our ext_json lib
67 Custom renderer that mkaes use of our ext_json lib
68
68
69 """
69 """
70
70
71 def __init__(self, serializer=json.dumps, **kw):
71 def __init__(self, serializer=json.dumps, **kw):
72 """ Any keyword arguments will be passed to the ``serializer``
72 """ Any keyword arguments will be passed to the ``serializer``
73 function."""
73 function."""
74 self.serializer = serializer
74 self.serializer = serializer
75 self.kw = kw
75 self.kw = kw
76
76
77 def __call__(self, info):
77 def __call__(self, info):
78 """ Returns a plain JSON-encoded string with content-type
78 """ Returns a plain JSON-encoded string with content-type
79 ``application/json``. The content-type may be overridden by
79 ``application/json``. The content-type may be overridden by
80 setting ``request.response.content_type``."""
80 setting ``request.response.content_type``."""
81
81
82 def _render(value, system):
82 def _render(value, system):
83 request = system.get('request')
83 request = system.get('request')
84 if request is not None:
84 if request is not None:
85 response = request.response
85 response = request.response
86 ct = response.content_type
86 ct = response.content_type
87 if ct == response.default_content_type:
87 if ct == response.default_content_type:
88 response.content_type = 'application/json'
88 response.content_type = 'application/json'
89
89
90 return self.serializer(value, **self.kw)
90 return self.serializer(value, **self.kw)
91
91
92 return _render
92 return _render
93
93
94
94
95 def jsonrpc_response(request, result):
95 def jsonrpc_response(request, result):
96 rpc_id = getattr(request, 'rpc_id', None)
96 rpc_id = getattr(request, 'rpc_id', None)
97 response = request.response
97 response = request.response
98
98
99 # store content_type before render is called
99 # store content_type before render is called
100 ct = response.content_type
100 ct = response.content_type
101
101
102 ret_value = ''
102 ret_value = ''
103 if rpc_id:
103 if rpc_id:
104 ret_value = {
104 ret_value = {
105 'id': rpc_id,
105 'id': rpc_id,
106 'result': result,
106 'result': result,
107 'error': None,
107 'error': None,
108 }
108 }
109
109
110 # fetch deprecation warnings, and store it inside results
110 # fetch deprecation warnings, and store it inside results
111 deprecation = getattr(request, 'rpc_deprecation', None)
111 deprecation = getattr(request, 'rpc_deprecation', None)
112 if deprecation:
112 if deprecation:
113 ret_value['DEPRECATION_WARNING'] = deprecation
113 ret_value['DEPRECATION_WARNING'] = deprecation
114
114
115 raw_body = render(DEFAULT_RENDERER, ret_value, request=request)
115 raw_body = render(DEFAULT_RENDERER, ret_value, request=request)
116 response.body = safe_str(raw_body, response.charset)
116 response.body = safe_str(raw_body, response.charset)
117
117
118 if ct == response.default_content_type:
118 if ct == response.default_content_type:
119 response.content_type = 'application/json'
119 response.content_type = 'application/json'
120
120
121 return response
121 return response
122
122
123
123
124 def jsonrpc_error(request, message, retid=None, code=None, headers=None):
124 def jsonrpc_error(request, message, retid=None, code=None, headers=None):
125 """
125 """
126 Generate a Response object with a JSON-RPC error body
126 Generate a Response object with a JSON-RPC error body
127
127
128 :param code:
128 :param code:
129 :param retid:
129 :param retid:
130 :param message:
130 :param message:
131 """
131 """
132 err_dict = {'id': retid, 'result': None, 'error': message}
132 err_dict = {'id': retid, 'result': None, 'error': message}
133 body = render(DEFAULT_RENDERER, err_dict, request=request).encode('utf-8')
133 body = render(DEFAULT_RENDERER, err_dict, request=request).encode('utf-8')
134
134
135 return Response(
135 return Response(
136 body=body,
136 body=body,
137 status=code,
137 status=code,
138 content_type='application/json',
138 content_type='application/json',
139 headerlist=headers
139 headerlist=headers
140 )
140 )
141
141
142
142
143 def exception_view(exc, request):
143 def exception_view(exc, request):
144 rpc_id = getattr(request, 'rpc_id', None)
144 rpc_id = getattr(request, 'rpc_id', None)
145
145
146 if isinstance(exc, JSONRPCError):
146 if isinstance(exc, JSONRPCError):
147 fault_message = safe_str(exc.message)
147 fault_message = safe_str(exc.message)
148 log.debug('json-rpc error rpc_id:%s "%s"', rpc_id, fault_message)
148 log.debug('json-rpc error rpc_id:%s "%s"', rpc_id, fault_message)
149 elif isinstance(exc, JSONRPCValidationError):
149 elif isinstance(exc, JSONRPCValidationError):
150 colander_exc = exc.colander_exception
150 colander_exc = exc.colander_exception
151 # TODO(marcink): think maybe of nicer way to serialize errors ?
151 # TODO(marcink): think maybe of nicer way to serialize errors ?
152 fault_message = colander_exc.asdict()
152 fault_message = colander_exc.asdict()
153 log.debug('json-rpc colander error rpc_id:%s "%s"', rpc_id, fault_message)
153 log.debug('json-rpc colander error rpc_id:%s "%s"', rpc_id, fault_message)
154 elif isinstance(exc, JSONRPCForbidden):
154 elif isinstance(exc, JSONRPCForbidden):
155 fault_message = 'Access was denied to this resource.'
155 fault_message = 'Access was denied to this resource.'
156 log.warning('json-rpc forbidden call rpc_id:%s "%s"', rpc_id, fault_message)
156 log.warning('json-rpc forbidden call rpc_id:%s "%s"', rpc_id, fault_message)
157 elif isinstance(exc, HTTPNotFound):
157 elif isinstance(exc, HTTPNotFound):
158 method = request.rpc_method
158 method = request.rpc_method
159 log.debug('json-rpc method `%s` not found in list of '
159 log.debug('json-rpc method `%s` not found in list of '
160 'api calls: %s, rpc_id:%s',
160 'api calls: %s, rpc_id:%s',
161 method, request.registry.jsonrpc_methods.keys(), rpc_id)
161 method, request.registry.jsonrpc_methods.keys(), rpc_id)
162
162
163 similar = 'none'
163 similar = 'none'
164 try:
164 try:
165 similar_paterns = ['*{}*'.format(x) for x in method.split('_')]
165 similar_paterns = ['*{}*'.format(x) for x in method.split('_')]
166 similar_found = find_methods(
166 similar_found = find_methods(
167 request.registry.jsonrpc_methods, similar_paterns)
167 request.registry.jsonrpc_methods, similar_paterns)
168 similar = ', '.join(similar_found.keys()) or similar
168 similar = ', '.join(similar_found.keys()) or similar
169 except Exception:
169 except Exception:
170 # make the whole above block safe
170 # make the whole above block safe
171 pass
171 pass
172
172
173 fault_message = "No such method: {}. Similar methods: {}".format(
173 fault_message = "No such method: {}. Similar methods: {}".format(
174 method, similar)
174 method, similar)
175 else:
175 else:
176 fault_message = 'undefined error'
176 fault_message = 'undefined error'
177 exc_info = exc.exc_info()
177 exc_info = exc.exc_info()
178 store_exception(id(exc_info), exc_info, prefix='rhodecode-api')
178 store_exception(id(exc_info), exc_info, prefix='rhodecode-api')
179
179
180 statsd = request.registry.statsd
180 statsd = request.registry.statsd
181 if statsd:
181 if statsd:
182 exc_type = "{}.{}".format(exc.__class__.__module__, exc.__class__.__name__)
182 exc_type = "{}.{}".format(exc.__class__.__module__, exc.__class__.__name__)
183 statsd.incr('rhodecode_exception_total',
183 statsd.incr('rhodecode_exception_total',
184 tags=["exc_source:api", "type:{}".format(exc_type)])
184 tags=["exc_source:api", "type:{}".format(exc_type)])
185
185
186 return jsonrpc_error(request, fault_message, rpc_id)
186 return jsonrpc_error(request, fault_message, rpc_id)
187
187
188
188
189 def request_view(request):
189 def request_view(request):
190 """
190 """
191 Main request handling method. It handles all logic to call a specific
191 Main request handling method. It handles all logic to call a specific
192 exposed method
192 exposed method
193 """
193 """
194 # cython compatible inspect
194 # cython compatible inspect
195 from rhodecode.config.patches import inspect_getargspec
195 from rhodecode.config.patches import inspect_getargspec
196 inspect = inspect_getargspec()
196 inspect = inspect_getargspec()
197
197
198 # check if we can find this session using api_key, get_by_auth_token
198 # check if we can find this session using api_key, get_by_auth_token
199 # search not expired tokens only
199 # search not expired tokens only
200 try:
200 try:
201 api_user = User.get_by_auth_token(request.rpc_api_key)
201 api_user = User.get_by_auth_token(request.rpc_api_key)
202
202
203 if api_user is None:
203 if api_user is None:
204 return jsonrpc_error(
204 return jsonrpc_error(
205 request, retid=request.rpc_id, message='Invalid API KEY')
205 request, retid=request.rpc_id, message='Invalid API KEY')
206
206
207 if not api_user.active:
207 if not api_user.active:
208 return jsonrpc_error(
208 return jsonrpc_error(
209 request, retid=request.rpc_id,
209 request, retid=request.rpc_id,
210 message='Request from this user not allowed')
210 message='Request from this user not allowed')
211
211
212 # check if we are allowed to use this IP
212 # check if we are allowed to use this IP
213 auth_u = AuthUser(
213 auth_u = AuthUser(
214 api_user.user_id, request.rpc_api_key, ip_addr=request.rpc_ip_addr)
214 api_user.user_id, request.rpc_api_key, ip_addr=request.rpc_ip_addr)
215 if not auth_u.ip_allowed:
215 if not auth_u.ip_allowed:
216 return jsonrpc_error(
216 return jsonrpc_error(
217 request, retid=request.rpc_id,
217 request, retid=request.rpc_id,
218 message='Request from IP:%s not allowed' % (
218 message='Request from IP:%s not allowed' % (
219 request.rpc_ip_addr,))
219 request.rpc_ip_addr,))
220 else:
220 else:
221 log.info('Access for IP:%s allowed', request.rpc_ip_addr)
221 log.info('Access for IP:%s allowed', request.rpc_ip_addr)
222
222
223 # register our auth-user
223 # register our auth-user
224 request.rpc_user = auth_u
224 request.rpc_user = auth_u
225 request.environ['rc_auth_user_id'] = auth_u.user_id
225 request.environ['rc_auth_user_id'] = auth_u.user_id
226
226
227 # now check if token is valid for API
227 # now check if token is valid for API
228 auth_token = request.rpc_api_key
228 auth_token = request.rpc_api_key
229 token_match = api_user.authenticate_by_token(
229 token_match = api_user.authenticate_by_token(
230 auth_token, roles=[UserApiKeys.ROLE_API])
230 auth_token, roles=[UserApiKeys.ROLE_API])
231 invalid_token = not token_match
231 invalid_token = not token_match
232
232
233 log.debug('Checking if API KEY is valid with proper role')
233 log.debug('Checking if API KEY is valid with proper role')
234 if invalid_token:
234 if invalid_token:
235 return jsonrpc_error(
235 return jsonrpc_error(
236 request, retid=request.rpc_id,
236 request, retid=request.rpc_id,
237 message='API KEY invalid or, has bad role for an API call')
237 message='API KEY invalid or, has bad role for an API call')
238
238
239 except Exception:
239 except Exception:
240 log.exception('Error on API AUTH')
240 log.exception('Error on API AUTH')
241 return jsonrpc_error(
241 return jsonrpc_error(
242 request, retid=request.rpc_id, message='Invalid API KEY')
242 request, retid=request.rpc_id, message='Invalid API KEY')
243
243
244 method = request.rpc_method
244 method = request.rpc_method
245 func = request.registry.jsonrpc_methods[method]
245 func = request.registry.jsonrpc_methods[method]
246
246
247 # now that we have a method, add request._req_params to
247 # now that we have a method, add request._req_params to
248 # self.kargs and dispatch control to WGIController
248 # self.kargs and dispatch control to WGIController
249 argspec = inspect.getargspec(func)
249 argspec = inspect.getargspec(func)
250 arglist = argspec[0]
250 arglist = argspec[0]
251 defaults = map(type, argspec[3] or [])
251 defaults = map(type, argspec[3] or [])
252 default_empty = types.NotImplementedType
252 default_empty = types.NotImplementedType
253
253
254 # kw arguments required by this method
254 # kw arguments required by this method
255 func_kwargs = dict(itertools.izip_longest(
255 func_kwargs = dict(itertools.zip_longest(
256 reversed(arglist), reversed(defaults), fillvalue=default_empty))
256 reversed(arglist), reversed(defaults), fillvalue=default_empty))
257
257
258 # This attribute will need to be first param of a method that uses
258 # This attribute will need to be first param of a method that uses
259 # api_key, which is translated to instance of user at that name
259 # api_key, which is translated to instance of user at that name
260 user_var = 'apiuser'
260 user_var = 'apiuser'
261 request_var = 'request'
261 request_var = 'request'
262
262
263 for arg in [user_var, request_var]:
263 for arg in [user_var, request_var]:
264 if arg not in arglist:
264 if arg not in arglist:
265 return jsonrpc_error(
265 return jsonrpc_error(
266 request,
266 request,
267 retid=request.rpc_id,
267 retid=request.rpc_id,
268 message='This method [%s] does not support '
268 message='This method [%s] does not support '
269 'required parameter `%s`' % (func.__name__, arg))
269 'required parameter `%s`' % (func.__name__, arg))
270
270
271 # get our arglist and check if we provided them as args
271 # get our arglist and check if we provided them as args
272 for arg, default in func_kwargs.items():
272 for arg, default in func_kwargs.items():
273 if arg in [user_var, request_var]:
273 if arg in [user_var, request_var]:
274 # user_var and request_var are pre-hardcoded parameters and we
274 # user_var and request_var are pre-hardcoded parameters and we
275 # don't need to do any translation
275 # don't need to do any translation
276 continue
276 continue
277
277
278 # skip the required param check if it's default value is
278 # skip the required param check if it's default value is
279 # NotImplementedType (default_empty)
279 # NotImplementedType (default_empty)
280 if default == default_empty and arg not in request.rpc_params:
280 if default == default_empty and arg not in request.rpc_params:
281 return jsonrpc_error(
281 return jsonrpc_error(
282 request,
282 request,
283 retid=request.rpc_id,
283 retid=request.rpc_id,
284 message=('Missing non optional `%s` arg in JSON DATA' % arg)
284 message=('Missing non optional `%s` arg in JSON DATA' % arg)
285 )
285 )
286
286
287 # sanitize extra passed arguments
287 # sanitize extra passed arguments
288 for k in request.rpc_params.keys()[:]:
288 for k in request.rpc_params.keys()[:]:
289 if k not in func_kwargs:
289 if k not in func_kwargs:
290 del request.rpc_params[k]
290 del request.rpc_params[k]
291
291
292 call_params = request.rpc_params
292 call_params = request.rpc_params
293 call_params.update({
293 call_params.update({
294 'request': request,
294 'request': request,
295 'apiuser': auth_u
295 'apiuser': auth_u
296 })
296 })
297
297
298 # register some common functions for usage
298 # register some common functions for usage
299 attach_context_attributes(TemplateArgs(), request, request.rpc_user.user_id)
299 attach_context_attributes(TemplateArgs(), request, request.rpc_user.user_id)
300
300
301 statsd = request.registry.statsd
301 statsd = request.registry.statsd
302
302
303 try:
303 try:
304 ret_value = func(**call_params)
304 ret_value = func(**call_params)
305 resp = jsonrpc_response(request, ret_value)
305 resp = jsonrpc_response(request, ret_value)
306 if statsd:
306 if statsd:
307 statsd.incr('rhodecode_api_call_success_total')
307 statsd.incr('rhodecode_api_call_success_total')
308 return resp
308 return resp
309 except JSONRPCBaseError:
309 except JSONRPCBaseError:
310 raise
310 raise
311 except Exception:
311 except Exception:
312 log.exception('Unhandled exception occurred on api call: %s', func)
312 log.exception('Unhandled exception occurred on api call: %s', func)
313 exc_info = sys.exc_info()
313 exc_info = sys.exc_info()
314 exc_id, exc_type_name = store_exception(
314 exc_id, exc_type_name = store_exception(
315 id(exc_info), exc_info, prefix='rhodecode-api')
315 id(exc_info), exc_info, prefix='rhodecode-api')
316 error_headers = [('RhodeCode-Exception-Id', str(exc_id)),
316 error_headers = [('RhodeCode-Exception-Id', str(exc_id)),
317 ('RhodeCode-Exception-Type', str(exc_type_name))]
317 ('RhodeCode-Exception-Type', str(exc_type_name))]
318 err_resp = jsonrpc_error(
318 err_resp = jsonrpc_error(
319 request, retid=request.rpc_id, message='Internal server error',
319 request, retid=request.rpc_id, message='Internal server error',
320 headers=error_headers)
320 headers=error_headers)
321 if statsd:
321 if statsd:
322 statsd.incr('rhodecode_api_call_fail_total')
322 statsd.incr('rhodecode_api_call_fail_total')
323 return err_resp
323 return err_resp
324
324
325
325
326 def setup_request(request):
326 def setup_request(request):
327 """
327 """
328 Parse a JSON-RPC request body. It's used inside the predicates method
328 Parse a JSON-RPC request body. It's used inside the predicates method
329 to validate and bootstrap requests for usage in rpc calls.
329 to validate and bootstrap requests for usage in rpc calls.
330
330
331 We need to raise JSONRPCError here if we want to return some errors back to
331 We need to raise JSONRPCError here if we want to return some errors back to
332 user.
332 user.
333 """
333 """
334
334
335 log.debug('Executing setup request: %r', request)
335 log.debug('Executing setup request: %r', request)
336 request.rpc_ip_addr = get_ip_addr(request.environ)
336 request.rpc_ip_addr = get_ip_addr(request.environ)
337 # TODO(marcink): deprecate GET at some point
337 # TODO(marcink): deprecate GET at some point
338 if request.method not in ['POST', 'GET']:
338 if request.method not in ['POST', 'GET']:
339 log.debug('unsupported request method "%s"', request.method)
339 log.debug('unsupported request method "%s"', request.method)
340 raise JSONRPCError(
340 raise JSONRPCError(
341 'unsupported request method "%s". Please use POST' % request.method)
341 'unsupported request method "%s". Please use POST' % request.method)
342
342
343 if 'CONTENT_LENGTH' not in request.environ:
343 if 'CONTENT_LENGTH' not in request.environ:
344 log.debug("No Content-Length")
344 log.debug("No Content-Length")
345 raise JSONRPCError("Empty body, No Content-Length in request")
345 raise JSONRPCError("Empty body, No Content-Length in request")
346
346
347 else:
347 else:
348 length = request.environ['CONTENT_LENGTH']
348 length = request.environ['CONTENT_LENGTH']
349 log.debug('Content-Length: %s', length)
349 log.debug('Content-Length: %s', length)
350
350
351 if length == 0:
351 if length == 0:
352 log.debug("Content-Length is 0")
352 log.debug("Content-Length is 0")
353 raise JSONRPCError("Content-Length is 0")
353 raise JSONRPCError("Content-Length is 0")
354
354
355 raw_body = request.body
355 raw_body = request.body
356 log.debug("Loading JSON body now")
356 log.debug("Loading JSON body now")
357 try:
357 try:
358 json_body = json.loads(raw_body)
358 json_body = json.loads(raw_body)
359 except ValueError as e:
359 except ValueError as e:
360 # catch JSON errors Here
360 # catch JSON errors Here
361 raise JSONRPCError("JSON parse error ERR:%s RAW:%r" % (e, raw_body))
361 raise JSONRPCError("JSON parse error ERR:%s RAW:%r" % (e, raw_body))
362
362
363 request.rpc_id = json_body.get('id')
363 request.rpc_id = json_body.get('id')
364 request.rpc_method = json_body.get('method')
364 request.rpc_method = json_body.get('method')
365
365
366 # check required base parameters
366 # check required base parameters
367 try:
367 try:
368 api_key = json_body.get('api_key')
368 api_key = json_body.get('api_key')
369 if not api_key:
369 if not api_key:
370 api_key = json_body.get('auth_token')
370 api_key = json_body.get('auth_token')
371
371
372 if not api_key:
372 if not api_key:
373 raise KeyError('api_key or auth_token')
373 raise KeyError('api_key or auth_token')
374
374
375 # TODO(marcink): support passing in token in request header
375 # TODO(marcink): support passing in token in request header
376
376
377 request.rpc_api_key = api_key
377 request.rpc_api_key = api_key
378 request.rpc_id = json_body['id']
378 request.rpc_id = json_body['id']
379 request.rpc_method = json_body['method']
379 request.rpc_method = json_body['method']
380 request.rpc_params = json_body['args'] \
380 request.rpc_params = json_body['args'] \
381 if isinstance(json_body['args'], dict) else {}
381 if isinstance(json_body['args'], dict) else {}
382
382
383 log.debug('method: %s, params: %.10240r', request.rpc_method, request.rpc_params)
383 log.debug('method: %s, params: %.10240r', request.rpc_method, request.rpc_params)
384 except KeyError as e:
384 except KeyError as e:
385 raise JSONRPCError('Incorrect JSON data. Missing %s' % e)
385 raise JSONRPCError('Incorrect JSON data. Missing %s' % e)
386
386
387 log.debug('setup complete, now handling method:%s rpcid:%s',
387 log.debug('setup complete, now handling method:%s rpcid:%s',
388 request.rpc_method, request.rpc_id, )
388 request.rpc_method, request.rpc_id, )
389
389
390
390
391 class RoutePredicate(object):
391 class RoutePredicate(object):
392 def __init__(self, val, config):
392 def __init__(self, val, config):
393 self.val = val
393 self.val = val
394
394
395 def text(self):
395 def text(self):
396 return 'jsonrpc route = %s' % self.val
396 return 'jsonrpc route = %s' % self.val
397
397
398 phash = text
398 phash = text
399
399
400 def __call__(self, info, request):
400 def __call__(self, info, request):
401 if self.val:
401 if self.val:
402 # potentially setup and bootstrap our call
402 # potentially setup and bootstrap our call
403 setup_request(request)
403 setup_request(request)
404
404
405 # Always return True so that even if it isn't a valid RPC it
405 # Always return True so that even if it isn't a valid RPC it
406 # will fall through to the underlaying handlers like notfound_view
406 # will fall through to the underlaying handlers like notfound_view
407 return True
407 return True
408
408
409
409
410 class NotFoundPredicate(object):
410 class NotFoundPredicate(object):
411 def __init__(self, val, config):
411 def __init__(self, val, config):
412 self.val = val
412 self.val = val
413 self.methods = config.registry.jsonrpc_methods
413 self.methods = config.registry.jsonrpc_methods
414
414
415 def text(self):
415 def text(self):
416 return 'jsonrpc method not found = {}.'.format(self.val)
416 return 'jsonrpc method not found = {}.'.format(self.val)
417
417
418 phash = text
418 phash = text
419
419
420 def __call__(self, info, request):
420 def __call__(self, info, request):
421 return hasattr(request, 'rpc_method')
421 return hasattr(request, 'rpc_method')
422
422
423
423
424 class MethodPredicate(object):
424 class MethodPredicate(object):
425 def __init__(self, val, config):
425 def __init__(self, val, config):
426 self.method = val
426 self.method = val
427
427
428 def text(self):
428 def text(self):
429 return 'jsonrpc method = %s' % self.method
429 return 'jsonrpc method = %s' % self.method
430
430
431 phash = text
431 phash = text
432
432
433 def __call__(self, context, request):
433 def __call__(self, context, request):
434 # we need to explicitly return False here, so pyramid doesn't try to
434 # we need to explicitly return False here, so pyramid doesn't try to
435 # execute our view directly. We need our main handler to execute things
435 # execute our view directly. We need our main handler to execute things
436 return getattr(request, 'rpc_method') == self.method
436 return getattr(request, 'rpc_method') == self.method
437
437
438
438
439 def add_jsonrpc_method(config, view, **kwargs):
439 def add_jsonrpc_method(config, view, **kwargs):
440 # pop the method name
440 # pop the method name
441 method = kwargs.pop('method', None)
441 method = kwargs.pop('method', None)
442
442
443 if method is None:
443 if method is None:
444 raise ConfigurationError(
444 raise ConfigurationError(
445 'Cannot register a JSON-RPC method without specifying the "method"')
445 'Cannot register a JSON-RPC method without specifying the "method"')
446
446
447 # we define custom predicate, to enable to detect conflicting methods,
447 # we define custom predicate, to enable to detect conflicting methods,
448 # those predicates are kind of "translation" from the decorator variables
448 # those predicates are kind of "translation" from the decorator variables
449 # to internal predicates names
449 # to internal predicates names
450
450
451 kwargs['jsonrpc_method'] = method
451 kwargs['jsonrpc_method'] = method
452
452
453 # register our view into global view store for validation
453 # register our view into global view store for validation
454 config.registry.jsonrpc_methods[method] = view
454 config.registry.jsonrpc_methods[method] = view
455
455
456 # we're using our main request_view handler, here, so each method
456 # we're using our main request_view handler, here, so each method
457 # has a unified handler for itself
457 # has a unified handler for itself
458 config.add_view(request_view, route_name='apiv2', **kwargs)
458 config.add_view(request_view, route_name='apiv2', **kwargs)
459
459
460
460
461 class jsonrpc_method(object):
461 class jsonrpc_method(object):
462 """
462 """
463 decorator that works similar to @add_view_config decorator,
463 decorator that works similar to @add_view_config decorator,
464 but tailored for our JSON RPC
464 but tailored for our JSON RPC
465 """
465 """
466
466
467 venusian = venusian # for testing injection
467 venusian = venusian # for testing injection
468
468
469 def __init__(self, method=None, **kwargs):
469 def __init__(self, method=None, **kwargs):
470 self.method = method
470 self.method = method
471 self.kwargs = kwargs
471 self.kwargs = kwargs
472
472
473 def __call__(self, wrapped):
473 def __call__(self, wrapped):
474 kwargs = self.kwargs.copy()
474 kwargs = self.kwargs.copy()
475 kwargs['method'] = self.method or wrapped.__name__
475 kwargs['method'] = self.method or wrapped.__name__
476 depth = kwargs.pop('_depth', 0)
476 depth = kwargs.pop('_depth', 0)
477
477
478 def callback(context, name, ob):
478 def callback(context, name, ob):
479 config = context.config.with_package(info.module)
479 config = context.config.with_package(info.module)
480 config.add_jsonrpc_method(view=ob, **kwargs)
480 config.add_jsonrpc_method(view=ob, **kwargs)
481
481
482 info = venusian.attach(wrapped, callback, category='pyramid',
482 info = venusian.attach(wrapped, callback, category='pyramid',
483 depth=depth + 1)
483 depth=depth + 1)
484 if info.scope == 'class':
484 if info.scope == 'class':
485 # ensure that attr is set if decorating a class method
485 # ensure that attr is set if decorating a class method
486 kwargs.setdefault('attr', wrapped.__name__)
486 kwargs.setdefault('attr', wrapped.__name__)
487
487
488 kwargs['_info'] = info.codeinfo # fbo action_method
488 kwargs['_info'] = info.codeinfo # fbo action_method
489 return wrapped
489 return wrapped
490
490
491
491
492 class jsonrpc_deprecated_method(object):
492 class jsonrpc_deprecated_method(object):
493 """
493 """
494 Marks method as deprecated, adds log.warning, and inject special key to
494 Marks method as deprecated, adds log.warning, and inject special key to
495 the request variable to mark method as deprecated.
495 the request variable to mark method as deprecated.
496 Also injects special docstring that extract_docs will catch to mark
496 Also injects special docstring that extract_docs will catch to mark
497 method as deprecated.
497 method as deprecated.
498
498
499 :param use_method: specify which method should be used instead of
499 :param use_method: specify which method should be used instead of
500 the decorated one
500 the decorated one
501
501
502 Use like::
502 Use like::
503
503
504 @jsonrpc_method()
504 @jsonrpc_method()
505 @jsonrpc_deprecated_method(use_method='new_func', deprecated_at_version='3.0.0')
505 @jsonrpc_deprecated_method(use_method='new_func', deprecated_at_version='3.0.0')
506 def old_func(request, apiuser, arg1, arg2):
506 def old_func(request, apiuser, arg1, arg2):
507 ...
507 ...
508 """
508 """
509
509
510 def __init__(self, use_method, deprecated_at_version):
510 def __init__(self, use_method, deprecated_at_version):
511 self.use_method = use_method
511 self.use_method = use_method
512 self.deprecated_at_version = deprecated_at_version
512 self.deprecated_at_version = deprecated_at_version
513 self.deprecated_msg = ''
513 self.deprecated_msg = ''
514
514
515 def __call__(self, func):
515 def __call__(self, func):
516 self.deprecated_msg = 'Please use method `{method}` instead.'.format(
516 self.deprecated_msg = 'Please use method `{method}` instead.'.format(
517 method=self.use_method)
517 method=self.use_method)
518
518
519 docstring = """\n
519 docstring = """\n
520 .. deprecated:: {version}
520 .. deprecated:: {version}
521
521
522 {deprecation_message}
522 {deprecation_message}
523
523
524 {original_docstring}
524 {original_docstring}
525 """
525 """
526 func.__doc__ = docstring.format(
526 func.__doc__ = docstring.format(
527 version=self.deprecated_at_version,
527 version=self.deprecated_at_version,
528 deprecation_message=self.deprecated_msg,
528 deprecation_message=self.deprecated_msg,
529 original_docstring=func.__doc__)
529 original_docstring=func.__doc__)
530 return decorator.decorator(self.__wrapper, func)
530 return decorator.decorator(self.__wrapper, func)
531
531
532 def __wrapper(self, func, *fargs, **fkwargs):
532 def __wrapper(self, func, *fargs, **fkwargs):
533 log.warning('DEPRECATED API CALL on function %s, please '
533 log.warning('DEPRECATED API CALL on function %s, please '
534 'use `%s` instead', func, self.use_method)
534 'use `%s` instead', func, self.use_method)
535 # alter function docstring to mark as deprecated, this is picked up
535 # alter function docstring to mark as deprecated, this is picked up
536 # via fabric file that generates API DOC.
536 # via fabric file that generates API DOC.
537 result = func(*fargs, **fkwargs)
537 result = func(*fargs, **fkwargs)
538
538
539 request = fargs[0]
539 request = fargs[0]
540 request.rpc_deprecation = 'DEPRECATED METHOD ' + self.deprecated_msg
540 request.rpc_deprecation = 'DEPRECATED METHOD ' + self.deprecated_msg
541 return result
541 return result
542
542
543
543
544 def add_api_methods(config):
544 def add_api_methods(config):
545 from rhodecode.api.views import (
545 from rhodecode.api.views import (
546 deprecated_api, gist_api, pull_request_api, repo_api, repo_group_api,
546 deprecated_api, gist_api, pull_request_api, repo_api, repo_group_api,
547 server_api, search_api, testing_api, user_api, user_group_api)
547 server_api, search_api, testing_api, user_api, user_group_api)
548
548
549 config.scan('rhodecode.api.views')
549 config.scan('rhodecode.api.views')
550
550
551
551
552 def includeme(config):
552 def includeme(config):
553 plugin_module = 'rhodecode.api'
553 plugin_module = 'rhodecode.api'
554 plugin_settings = get_plugin_settings(
554 plugin_settings = get_plugin_settings(
555 plugin_module, config.registry.settings)
555 plugin_module, config.registry.settings)
556
556
557 if not hasattr(config.registry, 'jsonrpc_methods'):
557 if not hasattr(config.registry, 'jsonrpc_methods'):
558 config.registry.jsonrpc_methods = OrderedDict()
558 config.registry.jsonrpc_methods = OrderedDict()
559
559
560 # match filter by given method only
560 # match filter by given method only
561 config.add_view_predicate('jsonrpc_method', MethodPredicate)
561 config.add_view_predicate('jsonrpc_method', MethodPredicate)
562 config.add_view_predicate('jsonrpc_method_not_found', NotFoundPredicate)
562 config.add_view_predicate('jsonrpc_method_not_found', NotFoundPredicate)
563
563
564 config.add_renderer(DEFAULT_RENDERER, ExtJsonRenderer(
564 config.add_renderer(DEFAULT_RENDERER, ExtJsonRenderer(
565 serializer=json.dumps, indent=4))
565 serializer=json.dumps, indent=4))
566 config.add_directive('add_jsonrpc_method', add_jsonrpc_method)
566 config.add_directive('add_jsonrpc_method', add_jsonrpc_method)
567
567
568 config.add_route_predicate(
568 config.add_route_predicate(
569 'jsonrpc_call', RoutePredicate)
569 'jsonrpc_call', RoutePredicate)
570
570
571 config.add_route(
571 config.add_route(
572 'apiv2', plugin_settings.get('url', DEFAULT_URL), jsonrpc_call=True)
572 'apiv2', plugin_settings.get('url', DEFAULT_URL), jsonrpc_call=True)
573
573
574 # register some exception handling view
574 # register some exception handling view
575 config.add_view(exception_view, context=JSONRPCBaseError)
575 config.add_view(exception_view, context=JSONRPCBaseError)
576 config.add_notfound_view(exception_view, jsonrpc_method_not_found=True)
576 config.add_notfound_view(exception_view, jsonrpc_method_not_found=True)
577
577
578 add_api_methods(config)
578 add_api_methods(config)
@@ -1,419 +1,419 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import itertools
22 import itertools
23 import base64
23 import base64
24
24
25 from rhodecode.api import (
25 from rhodecode.api import (
26 jsonrpc_method, JSONRPCError, JSONRPCForbidden, find_methods)
26 jsonrpc_method, JSONRPCError, JSONRPCForbidden, find_methods)
27
27
28 from rhodecode.api.utils import (
28 from rhodecode.api.utils import (
29 Optional, OAttr, has_superadmin_permission, get_user_or_error)
29 Optional, OAttr, has_superadmin_permission, get_user_or_error)
30 from rhodecode.lib.utils import repo2db_mapper
30 from rhodecode.lib.utils import repo2db_mapper
31 from rhodecode.lib import system_info
31 from rhodecode.lib import system_info
32 from rhodecode.lib import user_sessions
32 from rhodecode.lib import user_sessions
33 from rhodecode.lib import exc_tracking
33 from rhodecode.lib import exc_tracking
34 from rhodecode.lib.ext_json import json
34 from rhodecode.lib.ext_json import json
35 from rhodecode.lib.utils2 import safe_int
35 from rhodecode.lib.utils2 import safe_int
36 from rhodecode.model.db import UserIpMap
36 from rhodecode.model.db import UserIpMap
37 from rhodecode.model.scm import ScmModel
37 from rhodecode.model.scm import ScmModel
38 from rhodecode.model.settings import VcsSettingsModel
38 from rhodecode.model.settings import VcsSettingsModel
39 from rhodecode.apps.file_store import utils
39 from rhodecode.apps.file_store import utils
40 from rhodecode.apps.file_store.exceptions import FileNotAllowedException, \
40 from rhodecode.apps.file_store.exceptions import FileNotAllowedException, \
41 FileOverSizeException
41 FileOverSizeException
42
42
43 log = logging.getLogger(__name__)
43 log = logging.getLogger(__name__)
44
44
45
45
46 @jsonrpc_method()
46 @jsonrpc_method()
47 def get_server_info(request, apiuser):
47 def get_server_info(request, apiuser):
48 """
48 """
49 Returns the |RCE| server information.
49 Returns the |RCE| server information.
50
50
51 This includes the running version of |RCE| and all installed
51 This includes the running version of |RCE| and all installed
52 packages. This command takes the following options:
52 packages. This command takes the following options:
53
53
54 :param apiuser: This is filled automatically from the |authtoken|.
54 :param apiuser: This is filled automatically from the |authtoken|.
55 :type apiuser: AuthUser
55 :type apiuser: AuthUser
56
56
57 Example output:
57 Example output:
58
58
59 .. code-block:: bash
59 .. code-block:: bash
60
60
61 id : <id_given_in_input>
61 id : <id_given_in_input>
62 result : {
62 result : {
63 'modules': [<module name>,...]
63 'modules': [<module name>,...]
64 'py_version': <python version>,
64 'py_version': <python version>,
65 'platform': <platform type>,
65 'platform': <platform type>,
66 'rhodecode_version': <rhodecode version>
66 'rhodecode_version': <rhodecode version>
67 }
67 }
68 error : null
68 error : null
69 """
69 """
70
70
71 if not has_superadmin_permission(apiuser):
71 if not has_superadmin_permission(apiuser):
72 raise JSONRPCForbidden()
72 raise JSONRPCForbidden()
73
73
74 server_info = ScmModel().get_server_info(request.environ)
74 server_info = ScmModel().get_server_info(request.environ)
75 # rhodecode-index requires those
75 # rhodecode-index requires those
76
76
77 server_info['index_storage'] = server_info['search']['value']['location']
77 server_info['index_storage'] = server_info['search']['value']['location']
78 server_info['storage'] = server_info['storage']['value']['path']
78 server_info['storage'] = server_info['storage']['value']['path']
79
79
80 return server_info
80 return server_info
81
81
82
82
83 @jsonrpc_method()
83 @jsonrpc_method()
84 def get_repo_store(request, apiuser):
84 def get_repo_store(request, apiuser):
85 """
85 """
86 Returns the |RCE| repository storage information.
86 Returns the |RCE| repository storage information.
87
87
88 :param apiuser: This is filled automatically from the |authtoken|.
88 :param apiuser: This is filled automatically from the |authtoken|.
89 :type apiuser: AuthUser
89 :type apiuser: AuthUser
90
90
91 Example output:
91 Example output:
92
92
93 .. code-block:: bash
93 .. code-block:: bash
94
94
95 id : <id_given_in_input>
95 id : <id_given_in_input>
96 result : {
96 result : {
97 'modules': [<module name>,...]
97 'modules': [<module name>,...]
98 'py_version': <python version>,
98 'py_version': <python version>,
99 'platform': <platform type>,
99 'platform': <platform type>,
100 'rhodecode_version': <rhodecode version>
100 'rhodecode_version': <rhodecode version>
101 }
101 }
102 error : null
102 error : null
103 """
103 """
104
104
105 if not has_superadmin_permission(apiuser):
105 if not has_superadmin_permission(apiuser):
106 raise JSONRPCForbidden()
106 raise JSONRPCForbidden()
107
107
108 path = VcsSettingsModel().get_repos_location()
108 path = VcsSettingsModel().get_repos_location()
109 return {"path": path}
109 return {"path": path}
110
110
111
111
112 @jsonrpc_method()
112 @jsonrpc_method()
113 def get_ip(request, apiuser, userid=Optional(OAttr('apiuser'))):
113 def get_ip(request, apiuser, userid=Optional(OAttr('apiuser'))):
114 """
114 """
115 Displays the IP Address as seen from the |RCE| server.
115 Displays the IP Address as seen from the |RCE| server.
116
116
117 * This command displays the IP Address, as well as all the defined IP
117 * This command displays the IP Address, as well as all the defined IP
118 addresses for the specified user. If the ``userid`` is not set, the
118 addresses for the specified user. If the ``userid`` is not set, the
119 data returned is for the user calling the method.
119 data returned is for the user calling the method.
120
120
121 This command can only be run using an |authtoken| with admin rights to
121 This command can only be run using an |authtoken| with admin rights to
122 the specified repository.
122 the specified repository.
123
123
124 This command takes the following options:
124 This command takes the following options:
125
125
126 :param apiuser: This is filled automatically from |authtoken|.
126 :param apiuser: This is filled automatically from |authtoken|.
127 :type apiuser: AuthUser
127 :type apiuser: AuthUser
128 :param userid: Sets the userid for which associated IP Address data
128 :param userid: Sets the userid for which associated IP Address data
129 is returned.
129 is returned.
130 :type userid: Optional(str or int)
130 :type userid: Optional(str or int)
131
131
132 Example output:
132 Example output:
133
133
134 .. code-block:: bash
134 .. code-block:: bash
135
135
136 id : <id_given_in_input>
136 id : <id_given_in_input>
137 result : {
137 result : {
138 "server_ip_addr": "<ip_from_clien>",
138 "server_ip_addr": "<ip_from_clien>",
139 "user_ips": [
139 "user_ips": [
140 {
140 {
141 "ip_addr": "<ip_with_mask>",
141 "ip_addr": "<ip_with_mask>",
142 "ip_range": ["<start_ip>", "<end_ip>"],
142 "ip_range": ["<start_ip>", "<end_ip>"],
143 },
143 },
144 ...
144 ...
145 ]
145 ]
146 }
146 }
147
147
148 """
148 """
149 if not has_superadmin_permission(apiuser):
149 if not has_superadmin_permission(apiuser):
150 raise JSONRPCForbidden()
150 raise JSONRPCForbidden()
151
151
152 userid = Optional.extract(userid, evaluate_locals=locals())
152 userid = Optional.extract(userid, evaluate_locals=locals())
153 userid = getattr(userid, 'user_id', userid)
153 userid = getattr(userid, 'user_id', userid)
154
154
155 user = get_user_or_error(userid)
155 user = get_user_or_error(userid)
156 ips = UserIpMap.query().filter(UserIpMap.user == user).all()
156 ips = UserIpMap.query().filter(UserIpMap.user == user).all()
157 return {
157 return {
158 'server_ip_addr': request.rpc_ip_addr,
158 'server_ip_addr': request.rpc_ip_addr,
159 'user_ips': ips
159 'user_ips': ips
160 }
160 }
161
161
162
162
163 @jsonrpc_method()
163 @jsonrpc_method()
164 def rescan_repos(request, apiuser, remove_obsolete=Optional(False)):
164 def rescan_repos(request, apiuser, remove_obsolete=Optional(False)):
165 """
165 """
166 Triggers a rescan of the specified repositories.
166 Triggers a rescan of the specified repositories.
167
167
168 * If the ``remove_obsolete`` option is set, it also deletes repositories
168 * If the ``remove_obsolete`` option is set, it also deletes repositories
169 that are found in the database but not on the file system, so called
169 that are found in the database but not on the file system, so called
170 "clean zombies".
170 "clean zombies".
171
171
172 This command can only be run using an |authtoken| with admin rights to
172 This command can only be run using an |authtoken| with admin rights to
173 the specified repository.
173 the specified repository.
174
174
175 This command takes the following options:
175 This command takes the following options:
176
176
177 :param apiuser: This is filled automatically from the |authtoken|.
177 :param apiuser: This is filled automatically from the |authtoken|.
178 :type apiuser: AuthUser
178 :type apiuser: AuthUser
179 :param remove_obsolete: Deletes repositories from the database that
179 :param remove_obsolete: Deletes repositories from the database that
180 are not found on the filesystem.
180 are not found on the filesystem.
181 :type remove_obsolete: Optional(``True`` | ``False``)
181 :type remove_obsolete: Optional(``True`` | ``False``)
182
182
183 Example output:
183 Example output:
184
184
185 .. code-block:: bash
185 .. code-block:: bash
186
186
187 id : <id_given_in_input>
187 id : <id_given_in_input>
188 result : {
188 result : {
189 'added': [<added repository name>,...]
189 'added': [<added repository name>,...]
190 'removed': [<removed repository name>,...]
190 'removed': [<removed repository name>,...]
191 }
191 }
192 error : null
192 error : null
193
193
194 Example error output:
194 Example error output:
195
195
196 .. code-block:: bash
196 .. code-block:: bash
197
197
198 id : <id_given_in_input>
198 id : <id_given_in_input>
199 result : null
199 result : null
200 error : {
200 error : {
201 'Error occurred during rescan repositories action'
201 'Error occurred during rescan repositories action'
202 }
202 }
203
203
204 """
204 """
205 if not has_superadmin_permission(apiuser):
205 if not has_superadmin_permission(apiuser):
206 raise JSONRPCForbidden()
206 raise JSONRPCForbidden()
207
207
208 try:
208 try:
209 rm_obsolete = Optional.extract(remove_obsolete)
209 rm_obsolete = Optional.extract(remove_obsolete)
210 added, removed = repo2db_mapper(ScmModel().repo_scan(),
210 added, removed = repo2db_mapper(ScmModel().repo_scan(),
211 remove_obsolete=rm_obsolete)
211 remove_obsolete=rm_obsolete)
212 return {'added': added, 'removed': removed}
212 return {'added': added, 'removed': removed}
213 except Exception:
213 except Exception:
214 log.exception('Failed to run repo rescann')
214 log.exception('Failed to run repo rescann')
215 raise JSONRPCError(
215 raise JSONRPCError(
216 'Error occurred during rescan repositories action'
216 'Error occurred during rescan repositories action'
217 )
217 )
218
218
219
219
220 @jsonrpc_method()
220 @jsonrpc_method()
221 def cleanup_sessions(request, apiuser, older_then=Optional(60)):
221 def cleanup_sessions(request, apiuser, older_then=Optional(60)):
222 """
222 """
223 Triggers a session cleanup action.
223 Triggers a session cleanup action.
224
224
225 If the ``older_then`` option is set, only sessions that hasn't been
225 If the ``older_then`` option is set, only sessions that hasn't been
226 accessed in the given number of days will be removed.
226 accessed in the given number of days will be removed.
227
227
228 This command can only be run using an |authtoken| with admin rights to
228 This command can only be run using an |authtoken| with admin rights to
229 the specified repository.
229 the specified repository.
230
230
231 This command takes the following options:
231 This command takes the following options:
232
232
233 :param apiuser: This is filled automatically from the |authtoken|.
233 :param apiuser: This is filled automatically from the |authtoken|.
234 :type apiuser: AuthUser
234 :type apiuser: AuthUser
235 :param older_then: Deletes session that hasn't been accessed
235 :param older_then: Deletes session that hasn't been accessed
236 in given number of days.
236 in given number of days.
237 :type older_then: Optional(int)
237 :type older_then: Optional(int)
238
238
239 Example output:
239 Example output:
240
240
241 .. code-block:: bash
241 .. code-block:: bash
242
242
243 id : <id_given_in_input>
243 id : <id_given_in_input>
244 result: {
244 result: {
245 "backend": "<type of backend>",
245 "backend": "<type of backend>",
246 "sessions_removed": <number_of_removed_sessions>
246 "sessions_removed": <number_of_removed_sessions>
247 }
247 }
248 error : null
248 error : null
249
249
250 Example error output:
250 Example error output:
251
251
252 .. code-block:: bash
252 .. code-block:: bash
253
253
254 id : <id_given_in_input>
254 id : <id_given_in_input>
255 result : null
255 result : null
256 error : {
256 error : {
257 'Error occurred during session cleanup'
257 'Error occurred during session cleanup'
258 }
258 }
259
259
260 """
260 """
261 if not has_superadmin_permission(apiuser):
261 if not has_superadmin_permission(apiuser):
262 raise JSONRPCForbidden()
262 raise JSONRPCForbidden()
263
263
264 older_then = safe_int(Optional.extract(older_then)) or 60
264 older_then = safe_int(Optional.extract(older_then)) or 60
265 older_than_seconds = 60 * 60 * 24 * older_then
265 older_than_seconds = 60 * 60 * 24 * older_then
266
266
267 config = system_info.rhodecode_config().get_value()['value']['config']
267 config = system_info.rhodecode_config().get_value()['value']['config']
268 session_model = user_sessions.get_session_handler(
268 session_model = user_sessions.get_session_handler(
269 config.get('beaker.session.type', 'memory'))(config)
269 config.get('beaker.session.type', 'memory'))(config)
270
270
271 backend = session_model.SESSION_TYPE
271 backend = session_model.SESSION_TYPE
272 try:
272 try:
273 cleaned = session_model.clean_sessions(
273 cleaned = session_model.clean_sessions(
274 older_than_seconds=older_than_seconds)
274 older_than_seconds=older_than_seconds)
275 return {'sessions_removed': cleaned, 'backend': backend}
275 return {'sessions_removed': cleaned, 'backend': backend}
276 except user_sessions.CleanupCommand as msg:
276 except user_sessions.CleanupCommand as msg:
277 return {'cleanup_command': msg.message, 'backend': backend}
277 return {'cleanup_command': msg.message, 'backend': backend}
278 except Exception as e:
278 except Exception as e:
279 log.exception('Failed session cleanup')
279 log.exception('Failed session cleanup')
280 raise JSONRPCError(
280 raise JSONRPCError(
281 'Error occurred during session cleanup'
281 'Error occurred during session cleanup'
282 )
282 )
283
283
284
284
285 @jsonrpc_method()
285 @jsonrpc_method()
286 def get_method(request, apiuser, pattern=Optional('*')):
286 def get_method(request, apiuser, pattern=Optional('*')):
287 """
287 """
288 Returns list of all available API methods. By default match pattern
288 Returns list of all available API methods. By default match pattern
289 os "*" but any other pattern can be specified. eg *comment* will return
289 os "*" but any other pattern can be specified. eg *comment* will return
290 all methods with comment inside them. If just single method is matched
290 all methods with comment inside them. If just single method is matched
291 returned data will also include method specification
291 returned data will also include method specification
292
292
293 This command can only be run using an |authtoken| with admin rights to
293 This command can only be run using an |authtoken| with admin rights to
294 the specified repository.
294 the specified repository.
295
295
296 This command takes the following options:
296 This command takes the following options:
297
297
298 :param apiuser: This is filled automatically from the |authtoken|.
298 :param apiuser: This is filled automatically from the |authtoken|.
299 :type apiuser: AuthUser
299 :type apiuser: AuthUser
300 :param pattern: pattern to match method names against
300 :param pattern: pattern to match method names against
301 :type pattern: Optional("*")
301 :type pattern: Optional("*")
302
302
303 Example output:
303 Example output:
304
304
305 .. code-block:: bash
305 .. code-block:: bash
306
306
307 id : <id_given_in_input>
307 id : <id_given_in_input>
308 "result": [
308 "result": [
309 "changeset_comment",
309 "changeset_comment",
310 "comment_pull_request",
310 "comment_pull_request",
311 "comment_commit"
311 "comment_commit"
312 ]
312 ]
313 error : null
313 error : null
314
314
315 .. code-block:: bash
315 .. code-block:: bash
316
316
317 id : <id_given_in_input>
317 id : <id_given_in_input>
318 "result": [
318 "result": [
319 "comment_commit",
319 "comment_commit",
320 {
320 {
321 "apiuser": "<RequiredType>",
321 "apiuser": "<RequiredType>",
322 "comment_type": "<Optional:u'note'>",
322 "comment_type": "<Optional:u'note'>",
323 "commit_id": "<RequiredType>",
323 "commit_id": "<RequiredType>",
324 "message": "<RequiredType>",
324 "message": "<RequiredType>",
325 "repoid": "<RequiredType>",
325 "repoid": "<RequiredType>",
326 "request": "<RequiredType>",
326 "request": "<RequiredType>",
327 "resolves_comment_id": "<Optional:None>",
327 "resolves_comment_id": "<Optional:None>",
328 "status": "<Optional:None>",
328 "status": "<Optional:None>",
329 "userid": "<Optional:<OptionalAttr:apiuser>>"
329 "userid": "<Optional:<OptionalAttr:apiuser>>"
330 }
330 }
331 ]
331 ]
332 error : null
332 error : null
333 """
333 """
334 from rhodecode.config.patches import inspect_getargspec
334 from rhodecode.config.patches import inspect_getargspec
335 inspect = inspect_getargspec()
335 inspect = inspect_getargspec()
336
336
337 if not has_superadmin_permission(apiuser):
337 if not has_superadmin_permission(apiuser):
338 raise JSONRPCForbidden()
338 raise JSONRPCForbidden()
339
339
340 pattern = Optional.extract(pattern)
340 pattern = Optional.extract(pattern)
341
341
342 matches = find_methods(request.registry.jsonrpc_methods, pattern)
342 matches = find_methods(request.registry.jsonrpc_methods, pattern)
343
343
344 args_desc = []
344 args_desc = []
345 if len(matches) == 1:
345 if len(matches) == 1:
346 func = matches[matches.keys()[0]]
346 func = matches[matches.keys()[0]]
347
347
348 argspec = inspect.getargspec(func)
348 argspec = inspect.getargspec(func)
349 arglist = argspec[0]
349 arglist = argspec[0]
350 defaults = map(repr, argspec[3] or [])
350 defaults = map(repr, argspec[3] or [])
351
351
352 default_empty = '<RequiredType>'
352 default_empty = '<RequiredType>'
353
353
354 # kw arguments required by this method
354 # kw arguments required by this method
355 func_kwargs = dict(itertools.izip_longest(
355 func_kwargs = dict(itertools.zip_longest(
356 reversed(arglist), reversed(defaults), fillvalue=default_empty))
356 reversed(arglist), reversed(defaults), fillvalue=default_empty))
357 args_desc.append(func_kwargs)
357 args_desc.append(func_kwargs)
358
358
359 return matches.keys() + args_desc
359 return matches.keys() + args_desc
360
360
361
361
362 @jsonrpc_method()
362 @jsonrpc_method()
363 def store_exception(request, apiuser, exc_data_json, prefix=Optional('rhodecode')):
363 def store_exception(request, apiuser, exc_data_json, prefix=Optional('rhodecode')):
364 """
364 """
365 Stores sent exception inside the built-in exception tracker in |RCE| server.
365 Stores sent exception inside the built-in exception tracker in |RCE| server.
366
366
367 This command can only be run using an |authtoken| with admin rights to
367 This command can only be run using an |authtoken| with admin rights to
368 the specified repository.
368 the specified repository.
369
369
370 This command takes the following options:
370 This command takes the following options:
371
371
372 :param apiuser: This is filled automatically from the |authtoken|.
372 :param apiuser: This is filled automatically from the |authtoken|.
373 :type apiuser: AuthUser
373 :type apiuser: AuthUser
374
374
375 :param exc_data_json: JSON data with exception e.g
375 :param exc_data_json: JSON data with exception e.g
376 {"exc_traceback": "Value `1` is not allowed", "exc_type_name": "ValueError"}
376 {"exc_traceback": "Value `1` is not allowed", "exc_type_name": "ValueError"}
377 :type exc_data_json: JSON data
377 :type exc_data_json: JSON data
378
378
379 :param prefix: prefix for error type, e.g 'rhodecode', 'vcsserver', 'rhodecode-tools'
379 :param prefix: prefix for error type, e.g 'rhodecode', 'vcsserver', 'rhodecode-tools'
380 :type prefix: Optional("rhodecode")
380 :type prefix: Optional("rhodecode")
381
381
382 Example output:
382 Example output:
383
383
384 .. code-block:: bash
384 .. code-block:: bash
385
385
386 id : <id_given_in_input>
386 id : <id_given_in_input>
387 "result": {
387 "result": {
388 "exc_id": 139718459226384,
388 "exc_id": 139718459226384,
389 "exc_url": "http://localhost:8080/_admin/settings/exceptions/139718459226384"
389 "exc_url": "http://localhost:8080/_admin/settings/exceptions/139718459226384"
390 }
390 }
391 error : null
391 error : null
392 """
392 """
393 if not has_superadmin_permission(apiuser):
393 if not has_superadmin_permission(apiuser):
394 raise JSONRPCForbidden()
394 raise JSONRPCForbidden()
395
395
396 prefix = Optional.extract(prefix)
396 prefix = Optional.extract(prefix)
397 exc_id = exc_tracking.generate_id()
397 exc_id = exc_tracking.generate_id()
398
398
399 try:
399 try:
400 exc_data = json.loads(exc_data_json)
400 exc_data = json.loads(exc_data_json)
401 except Exception:
401 except Exception:
402 log.error('Failed to parse JSON: %r', exc_data_json)
402 log.error('Failed to parse JSON: %r', exc_data_json)
403 raise JSONRPCError('Failed to parse JSON data from exc_data_json field. '
403 raise JSONRPCError('Failed to parse JSON data from exc_data_json field. '
404 'Please make sure it contains a valid JSON.')
404 'Please make sure it contains a valid JSON.')
405
405
406 try:
406 try:
407 exc_traceback = exc_data['exc_traceback']
407 exc_traceback = exc_data['exc_traceback']
408 exc_type_name = exc_data['exc_type_name']
408 exc_type_name = exc_data['exc_type_name']
409 except KeyError as err:
409 except KeyError as err:
410 raise JSONRPCError('Missing exc_traceback, or exc_type_name '
410 raise JSONRPCError('Missing exc_traceback, or exc_type_name '
411 'in exc_data_json field. Missing: {}'.format(err))
411 'in exc_data_json field. Missing: {}'.format(err))
412
412
413 exc_tracking._store_exception(
413 exc_tracking._store_exception(
414 exc_id=exc_id, exc_traceback=exc_traceback,
414 exc_id=exc_id, exc_traceback=exc_traceback,
415 exc_type_name=exc_type_name, prefix=prefix)
415 exc_type_name=exc_type_name, prefix=prefix)
416
416
417 exc_url = request.route_url(
417 exc_url = request.route_url(
418 'admin_settings_exception_tracker_show', exception_id=exc_id)
418 'admin_settings_exception_tracker_show', exception_id=exc_id)
419 return {'exc_id': exc_id, 'exc_url': exc_url}
419 return {'exc_id': exc_id, 'exc_url': exc_url}
@@ -1,479 +1,479 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import re
21 import re
22 import logging
22 import logging
23 import formencode
23 import formencode
24 import formencode.htmlfill
24 import formencode.htmlfill
25 import datetime
25 import datetime
26 from pyramid.interfaces import IRoutesMapper
26 from pyramid.interfaces import IRoutesMapper
27
27
28 from pyramid.httpexceptions import HTTPFound
28 from pyramid.httpexceptions import HTTPFound
29 from pyramid.renderers import render
29 from pyramid.renderers import render
30 from pyramid.response import Response
30 from pyramid.response import Response
31
31
32 from rhodecode.apps._base import BaseAppView, DataGridAppView
32 from rhodecode.apps._base import BaseAppView, DataGridAppView
33 from rhodecode.apps.ssh_support import SshKeyFileChangeEvent
33 from rhodecode.apps.ssh_support import SshKeyFileChangeEvent
34 from rhodecode import events
34 from rhodecode import events
35
35
36 from rhodecode.lib import helpers as h
36 from rhodecode.lib import helpers as h
37 from rhodecode.lib.auth import (
37 from rhodecode.lib.auth import (
38 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
38 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
39 from rhodecode.lib.utils2 import aslist, safe_unicode
39 from rhodecode.lib.utils2 import aslist, safe_unicode
40 from rhodecode.model.db import (
40 from rhodecode.model.db import (
41 or_, coalesce, User, UserIpMap, UserSshKeys)
41 or_, coalesce, User, UserIpMap, UserSshKeys)
42 from rhodecode.model.forms import (
42 from rhodecode.model.forms import (
43 ApplicationPermissionsForm, ObjectPermissionsForm, UserPermissionsForm)
43 ApplicationPermissionsForm, ObjectPermissionsForm, UserPermissionsForm)
44 from rhodecode.model.meta import Session
44 from rhodecode.model.meta import Session
45 from rhodecode.model.permission import PermissionModel
45 from rhodecode.model.permission import PermissionModel
46 from rhodecode.model.settings import SettingsModel
46 from rhodecode.model.settings import SettingsModel
47
47
48
48
49 log = logging.getLogger(__name__)
49 log = logging.getLogger(__name__)
50
50
51
51
52 class AdminPermissionsView(BaseAppView, DataGridAppView):
52 class AdminPermissionsView(BaseAppView, DataGridAppView):
53 def load_default_context(self):
53 def load_default_context(self):
54 c = self._get_local_tmpl_context()
54 c = self._get_local_tmpl_context()
55 PermissionModel().set_global_permission_choices(
55 PermissionModel().set_global_permission_choices(
56 c, gettext_translator=self.request.translate)
56 c, gettext_translator=self.request.translate)
57 return c
57 return c
58
58
59 @LoginRequired()
59 @LoginRequired()
60 @HasPermissionAllDecorator('hg.admin')
60 @HasPermissionAllDecorator('hg.admin')
61 def permissions_application(self):
61 def permissions_application(self):
62 c = self.load_default_context()
62 c = self.load_default_context()
63 c.active = 'application'
63 c.active = 'application'
64
64
65 c.user = User.get_default_user(refresh=True)
65 c.user = User.get_default_user(refresh=True)
66
66
67 app_settings = c.rc_config
67 app_settings = c.rc_config
68
68
69 defaults = {
69 defaults = {
70 'anonymous': c.user.active,
70 'anonymous': c.user.active,
71 'default_register_message': app_settings.get(
71 'default_register_message': app_settings.get(
72 'rhodecode_register_message')
72 'rhodecode_register_message')
73 }
73 }
74 defaults.update(c.user.get_default_perms())
74 defaults.update(c.user.get_default_perms())
75
75
76 data = render('rhodecode:templates/admin/permissions/permissions.mako',
76 data = render('rhodecode:templates/admin/permissions/permissions.mako',
77 self._get_template_context(c), self.request)
77 self._get_template_context(c), self.request)
78 html = formencode.htmlfill.render(
78 html = formencode.htmlfill.render(
79 data,
79 data,
80 defaults=defaults,
80 defaults=defaults,
81 encoding="UTF-8",
81 encoding="UTF-8",
82 force_defaults=False
82 force_defaults=False
83 )
83 )
84 return Response(html)
84 return Response(html)
85
85
86 @LoginRequired()
86 @LoginRequired()
87 @HasPermissionAllDecorator('hg.admin')
87 @HasPermissionAllDecorator('hg.admin')
88 @CSRFRequired()
88 @CSRFRequired()
89 def permissions_application_update(self):
89 def permissions_application_update(self):
90 _ = self.request.translate
90 _ = self.request.translate
91 c = self.load_default_context()
91 c = self.load_default_context()
92 c.active = 'application'
92 c.active = 'application'
93
93
94 _form = ApplicationPermissionsForm(
94 _form = ApplicationPermissionsForm(
95 self.request.translate,
95 self.request.translate,
96 [x[0] for x in c.register_choices],
96 [x[0] for x in c.register_choices],
97 [x[0] for x in c.password_reset_choices],
97 [x[0] for x in c.password_reset_choices],
98 [x[0] for x in c.extern_activate_choices])()
98 [x[0] for x in c.extern_activate_choices])()
99
99
100 try:
100 try:
101 form_result = _form.to_python(dict(self.request.POST))
101 form_result = _form.to_python(dict(self.request.POST))
102 form_result.update({'perm_user_name': User.DEFAULT_USER})
102 form_result.update({'perm_user_name': User.DEFAULT_USER})
103 PermissionModel().update_application_permissions(form_result)
103 PermissionModel().update_application_permissions(form_result)
104
104
105 settings = [
105 settings = [
106 ('register_message', 'default_register_message'),
106 ('register_message', 'default_register_message'),
107 ]
107 ]
108 for setting, form_key in settings:
108 for setting, form_key in settings:
109 sett = SettingsModel().create_or_update_setting(
109 sett = SettingsModel().create_or_update_setting(
110 setting, form_result[form_key])
110 setting, form_result[form_key])
111 Session().add(sett)
111 Session().add(sett)
112
112
113 Session().commit()
113 Session().commit()
114 h.flash(_('Application permissions updated successfully'),
114 h.flash(_('Application permissions updated successfully'),
115 category='success')
115 category='success')
116
116
117 except formencode.Invalid as errors:
117 except formencode.Invalid as errors:
118 defaults = errors.value
118 defaults = errors.value
119
119
120 data = render(
120 data = render(
121 'rhodecode:templates/admin/permissions/permissions.mako',
121 'rhodecode:templates/admin/permissions/permissions.mako',
122 self._get_template_context(c), self.request)
122 self._get_template_context(c), self.request)
123 html = formencode.htmlfill.render(
123 html = formencode.htmlfill.render(
124 data,
124 data,
125 defaults=defaults,
125 defaults=defaults,
126 errors=errors.error_dict or {},
126 errors=errors.error_dict or {},
127 prefix_error=False,
127 prefix_error=False,
128 encoding="UTF-8",
128 encoding="UTF-8",
129 force_defaults=False
129 force_defaults=False
130 )
130 )
131 return Response(html)
131 return Response(html)
132
132
133 except Exception:
133 except Exception:
134 log.exception("Exception during update of permissions")
134 log.exception("Exception during update of permissions")
135 h.flash(_('Error occurred during update of permissions'),
135 h.flash(_('Error occurred during update of permissions'),
136 category='error')
136 category='error')
137
137
138 affected_user_ids = [User.get_default_user_id()]
138 affected_user_ids = [User.get_default_user_id()]
139 PermissionModel().trigger_permission_flush(affected_user_ids)
139 PermissionModel().trigger_permission_flush(affected_user_ids)
140
140
141 raise HTTPFound(h.route_path('admin_permissions_application'))
141 raise HTTPFound(h.route_path('admin_permissions_application'))
142
142
143 @LoginRequired()
143 @LoginRequired()
144 @HasPermissionAllDecorator('hg.admin')
144 @HasPermissionAllDecorator('hg.admin')
145 def permissions_objects(self):
145 def permissions_objects(self):
146 c = self.load_default_context()
146 c = self.load_default_context()
147 c.active = 'objects'
147 c.active = 'objects'
148
148
149 c.user = User.get_default_user(refresh=True)
149 c.user = User.get_default_user(refresh=True)
150 defaults = {}
150 defaults = {}
151 defaults.update(c.user.get_default_perms())
151 defaults.update(c.user.get_default_perms())
152
152
153 data = render(
153 data = render(
154 'rhodecode:templates/admin/permissions/permissions.mako',
154 'rhodecode:templates/admin/permissions/permissions.mako',
155 self._get_template_context(c), self.request)
155 self._get_template_context(c), self.request)
156 html = formencode.htmlfill.render(
156 html = formencode.htmlfill.render(
157 data,
157 data,
158 defaults=defaults,
158 defaults=defaults,
159 encoding="UTF-8",
159 encoding="UTF-8",
160 force_defaults=False
160 force_defaults=False
161 )
161 )
162 return Response(html)
162 return Response(html)
163
163
164 @LoginRequired()
164 @LoginRequired()
165 @HasPermissionAllDecorator('hg.admin')
165 @HasPermissionAllDecorator('hg.admin')
166 @CSRFRequired()
166 @CSRFRequired()
167 def permissions_objects_update(self):
167 def permissions_objects_update(self):
168 _ = self.request.translate
168 _ = self.request.translate
169 c = self.load_default_context()
169 c = self.load_default_context()
170 c.active = 'objects'
170 c.active = 'objects'
171
171
172 _form = ObjectPermissionsForm(
172 _form = ObjectPermissionsForm(
173 self.request.translate,
173 self.request.translate,
174 [x[0] for x in c.repo_perms_choices],
174 [x[0] for x in c.repo_perms_choices],
175 [x[0] for x in c.group_perms_choices],
175 [x[0] for x in c.group_perms_choices],
176 [x[0] for x in c.user_group_perms_choices],
176 [x[0] for x in c.user_group_perms_choices],
177 )()
177 )()
178
178
179 try:
179 try:
180 form_result = _form.to_python(dict(self.request.POST))
180 form_result = _form.to_python(dict(self.request.POST))
181 form_result.update({'perm_user_name': User.DEFAULT_USER})
181 form_result.update({'perm_user_name': User.DEFAULT_USER})
182 PermissionModel().update_object_permissions(form_result)
182 PermissionModel().update_object_permissions(form_result)
183
183
184 Session().commit()
184 Session().commit()
185 h.flash(_('Object permissions updated successfully'),
185 h.flash(_('Object permissions updated successfully'),
186 category='success')
186 category='success')
187
187
188 except formencode.Invalid as errors:
188 except formencode.Invalid as errors:
189 defaults = errors.value
189 defaults = errors.value
190
190
191 data = render(
191 data = render(
192 'rhodecode:templates/admin/permissions/permissions.mako',
192 'rhodecode:templates/admin/permissions/permissions.mako',
193 self._get_template_context(c), self.request)
193 self._get_template_context(c), self.request)
194 html = formencode.htmlfill.render(
194 html = formencode.htmlfill.render(
195 data,
195 data,
196 defaults=defaults,
196 defaults=defaults,
197 errors=errors.error_dict or {},
197 errors=errors.error_dict or {},
198 prefix_error=False,
198 prefix_error=False,
199 encoding="UTF-8",
199 encoding="UTF-8",
200 force_defaults=False
200 force_defaults=False
201 )
201 )
202 return Response(html)
202 return Response(html)
203 except Exception:
203 except Exception:
204 log.exception("Exception during update of permissions")
204 log.exception("Exception during update of permissions")
205 h.flash(_('Error occurred during update of permissions'),
205 h.flash(_('Error occurred during update of permissions'),
206 category='error')
206 category='error')
207
207
208 affected_user_ids = [User.get_default_user_id()]
208 affected_user_ids = [User.get_default_user_id()]
209 PermissionModel().trigger_permission_flush(affected_user_ids)
209 PermissionModel().trigger_permission_flush(affected_user_ids)
210
210
211 raise HTTPFound(h.route_path('admin_permissions_object'))
211 raise HTTPFound(h.route_path('admin_permissions_object'))
212
212
213 @LoginRequired()
213 @LoginRequired()
214 @HasPermissionAllDecorator('hg.admin')
214 @HasPermissionAllDecorator('hg.admin')
215 def permissions_branch(self):
215 def permissions_branch(self):
216 c = self.load_default_context()
216 c = self.load_default_context()
217 c.active = 'branch'
217 c.active = 'branch'
218
218
219 c.user = User.get_default_user(refresh=True)
219 c.user = User.get_default_user(refresh=True)
220 defaults = {}
220 defaults = {}
221 defaults.update(c.user.get_default_perms())
221 defaults.update(c.user.get_default_perms())
222
222
223 data = render(
223 data = render(
224 'rhodecode:templates/admin/permissions/permissions.mako',
224 'rhodecode:templates/admin/permissions/permissions.mako',
225 self._get_template_context(c), self.request)
225 self._get_template_context(c), self.request)
226 html = formencode.htmlfill.render(
226 html = formencode.htmlfill.render(
227 data,
227 data,
228 defaults=defaults,
228 defaults=defaults,
229 encoding="UTF-8",
229 encoding="UTF-8",
230 force_defaults=False
230 force_defaults=False
231 )
231 )
232 return Response(html)
232 return Response(html)
233
233
234 @LoginRequired()
234 @LoginRequired()
235 @HasPermissionAllDecorator('hg.admin')
235 @HasPermissionAllDecorator('hg.admin')
236 def permissions_global(self):
236 def permissions_global(self):
237 c = self.load_default_context()
237 c = self.load_default_context()
238 c.active = 'global'
238 c.active = 'global'
239
239
240 c.user = User.get_default_user(refresh=True)
240 c.user = User.get_default_user(refresh=True)
241 defaults = {}
241 defaults = {}
242 defaults.update(c.user.get_default_perms())
242 defaults.update(c.user.get_default_perms())
243
243
244 data = render(
244 data = render(
245 'rhodecode:templates/admin/permissions/permissions.mako',
245 'rhodecode:templates/admin/permissions/permissions.mako',
246 self._get_template_context(c), self.request)
246 self._get_template_context(c), self.request)
247 html = formencode.htmlfill.render(
247 html = formencode.htmlfill.render(
248 data,
248 data,
249 defaults=defaults,
249 defaults=defaults,
250 encoding="UTF-8",
250 encoding="UTF-8",
251 force_defaults=False
251 force_defaults=False
252 )
252 )
253 return Response(html)
253 return Response(html)
254
254
255 @LoginRequired()
255 @LoginRequired()
256 @HasPermissionAllDecorator('hg.admin')
256 @HasPermissionAllDecorator('hg.admin')
257 @CSRFRequired()
257 @CSRFRequired()
258 def permissions_global_update(self):
258 def permissions_global_update(self):
259 _ = self.request.translate
259 _ = self.request.translate
260 c = self.load_default_context()
260 c = self.load_default_context()
261 c.active = 'global'
261 c.active = 'global'
262
262
263 _form = UserPermissionsForm(
263 _form = UserPermissionsForm(
264 self.request.translate,
264 self.request.translate,
265 [x[0] for x in c.repo_create_choices],
265 [x[0] for x in c.repo_create_choices],
266 [x[0] for x in c.repo_create_on_write_choices],
266 [x[0] for x in c.repo_create_on_write_choices],
267 [x[0] for x in c.repo_group_create_choices],
267 [x[0] for x in c.repo_group_create_choices],
268 [x[0] for x in c.user_group_create_choices],
268 [x[0] for x in c.user_group_create_choices],
269 [x[0] for x in c.fork_choices],
269 [x[0] for x in c.fork_choices],
270 [x[0] for x in c.inherit_default_permission_choices])()
270 [x[0] for x in c.inherit_default_permission_choices])()
271
271
272 try:
272 try:
273 form_result = _form.to_python(dict(self.request.POST))
273 form_result = _form.to_python(dict(self.request.POST))
274 form_result.update({'perm_user_name': User.DEFAULT_USER})
274 form_result.update({'perm_user_name': User.DEFAULT_USER})
275 PermissionModel().update_user_permissions(form_result)
275 PermissionModel().update_user_permissions(form_result)
276
276
277 Session().commit()
277 Session().commit()
278 h.flash(_('Global permissions updated successfully'),
278 h.flash(_('Global permissions updated successfully'),
279 category='success')
279 category='success')
280
280
281 except formencode.Invalid as errors:
281 except formencode.Invalid as errors:
282 defaults = errors.value
282 defaults = errors.value
283
283
284 data = render(
284 data = render(
285 'rhodecode:templates/admin/permissions/permissions.mako',
285 'rhodecode:templates/admin/permissions/permissions.mako',
286 self._get_template_context(c), self.request)
286 self._get_template_context(c), self.request)
287 html = formencode.htmlfill.render(
287 html = formencode.htmlfill.render(
288 data,
288 data,
289 defaults=defaults,
289 defaults=defaults,
290 errors=errors.error_dict or {},
290 errors=errors.error_dict or {},
291 prefix_error=False,
291 prefix_error=False,
292 encoding="UTF-8",
292 encoding="UTF-8",
293 force_defaults=False
293 force_defaults=False
294 )
294 )
295 return Response(html)
295 return Response(html)
296 except Exception:
296 except Exception:
297 log.exception("Exception during update of permissions")
297 log.exception("Exception during update of permissions")
298 h.flash(_('Error occurred during update of permissions'),
298 h.flash(_('Error occurred during update of permissions'),
299 category='error')
299 category='error')
300
300
301 affected_user_ids = [User.get_default_user_id()]
301 affected_user_ids = [User.get_default_user_id()]
302 PermissionModel().trigger_permission_flush(affected_user_ids)
302 PermissionModel().trigger_permission_flush(affected_user_ids)
303
303
304 raise HTTPFound(h.route_path('admin_permissions_global'))
304 raise HTTPFound(h.route_path('admin_permissions_global'))
305
305
306 @LoginRequired()
306 @LoginRequired()
307 @HasPermissionAllDecorator('hg.admin')
307 @HasPermissionAllDecorator('hg.admin')
308 def permissions_ips(self):
308 def permissions_ips(self):
309 c = self.load_default_context()
309 c = self.load_default_context()
310 c.active = 'ips'
310 c.active = 'ips'
311
311
312 c.user = User.get_default_user(refresh=True)
312 c.user = User.get_default_user(refresh=True)
313 c.user_ip_map = (
313 c.user_ip_map = (
314 UserIpMap.query().filter(UserIpMap.user == c.user).all())
314 UserIpMap.query().filter(UserIpMap.user == c.user).all())
315
315
316 return self._get_template_context(c)
316 return self._get_template_context(c)
317
317
318 @LoginRequired()
318 @LoginRequired()
319 @HasPermissionAllDecorator('hg.admin')
319 @HasPermissionAllDecorator('hg.admin')
320 def permissions_overview(self):
320 def permissions_overview(self):
321 c = self.load_default_context()
321 c = self.load_default_context()
322 c.active = 'perms'
322 c.active = 'perms'
323
323
324 c.user = User.get_default_user(refresh=True)
324 c.user = User.get_default_user(refresh=True)
325 c.perm_user = c.user.AuthUser()
325 c.perm_user = c.user.AuthUser()
326 return self._get_template_context(c)
326 return self._get_template_context(c)
327
327
328 @LoginRequired()
328 @LoginRequired()
329 @HasPermissionAllDecorator('hg.admin')
329 @HasPermissionAllDecorator('hg.admin')
330 def auth_token_access(self):
330 def auth_token_access(self):
331 from rhodecode import CONFIG
331 from rhodecode import CONFIG
332
332
333 c = self.load_default_context()
333 c = self.load_default_context()
334 c.active = 'auth_token_access'
334 c.active = 'auth_token_access'
335
335
336 c.user = User.get_default_user(refresh=True)
336 c.user = User.get_default_user(refresh=True)
337 c.perm_user = c.user.AuthUser()
337 c.perm_user = c.user.AuthUser()
338
338
339 mapper = self.request.registry.queryUtility(IRoutesMapper)
339 mapper = self.request.registry.queryUtility(IRoutesMapper)
340 c.view_data = []
340 c.view_data = []
341
341
342 _argument_prog = re.compile('\{(.*?)\}|:\((.*)\)')
342 _argument_prog = re.compile(r'\{(.*?)\}|:\((.*)\)')
343 introspector = self.request.registry.introspector
343 introspector = self.request.registry.introspector
344
344
345 view_intr = {}
345 view_intr = {}
346 for view_data in introspector.get_category('views'):
346 for view_data in introspector.get_category('views'):
347 intr = view_data['introspectable']
347 intr = view_data['introspectable']
348
348
349 if 'route_name' in intr and intr['attr']:
349 if 'route_name' in intr and intr['attr']:
350 view_intr[intr['route_name']] = '{}:{}'.format(
350 view_intr[intr['route_name']] = '{}:{}'.format(
351 str(intr['derived_callable'].__name__), intr['attr']
351 str(intr['derived_callable'].__name__), intr['attr']
352 )
352 )
353
353
354 c.whitelist_key = 'api_access_controllers_whitelist'
354 c.whitelist_key = 'api_access_controllers_whitelist'
355 c.whitelist_file = CONFIG.get('__file__')
355 c.whitelist_file = CONFIG.get('__file__')
356 whitelist_views = aslist(
356 whitelist_views = aslist(
357 CONFIG.get(c.whitelist_key), sep=',')
357 CONFIG.get(c.whitelist_key), sep=',')
358
358
359 for route_info in mapper.get_routes():
359 for route_info in mapper.get_routes():
360 if not route_info.name.startswith('__'):
360 if not route_info.name.startswith('__'):
361 routepath = route_info.pattern
361 routepath = route_info.pattern
362
362
363 def replace(matchobj):
363 def replace(matchobj):
364 if matchobj.group(1):
364 if matchobj.group(1):
365 return "{%s}" % matchobj.group(1).split(':')[0]
365 return "{%s}" % matchobj.group(1).split(':')[0]
366 else:
366 else:
367 return "{%s}" % matchobj.group(2)
367 return "{%s}" % matchobj.group(2)
368
368
369 routepath = _argument_prog.sub(replace, routepath)
369 routepath = _argument_prog.sub(replace, routepath)
370
370
371 if not routepath.startswith('/'):
371 if not routepath.startswith('/'):
372 routepath = '/' + routepath
372 routepath = '/' + routepath
373
373
374 view_fqn = view_intr.get(route_info.name, 'NOT AVAILABLE')
374 view_fqn = view_intr.get(route_info.name, 'NOT AVAILABLE')
375 active = view_fqn in whitelist_views
375 active = view_fqn in whitelist_views
376 c.view_data.append((route_info.name, view_fqn, routepath, active))
376 c.view_data.append((route_info.name, view_fqn, routepath, active))
377
377
378 c.whitelist_views = whitelist_views
378 c.whitelist_views = whitelist_views
379 return self._get_template_context(c)
379 return self._get_template_context(c)
380
380
381 def ssh_enabled(self):
381 def ssh_enabled(self):
382 return self.request.registry.settings.get(
382 return self.request.registry.settings.get(
383 'ssh.generate_authorized_keyfile')
383 'ssh.generate_authorized_keyfile')
384
384
385 @LoginRequired()
385 @LoginRequired()
386 @HasPermissionAllDecorator('hg.admin')
386 @HasPermissionAllDecorator('hg.admin')
387 def ssh_keys(self):
387 def ssh_keys(self):
388 c = self.load_default_context()
388 c = self.load_default_context()
389 c.active = 'ssh_keys'
389 c.active = 'ssh_keys'
390 c.ssh_enabled = self.ssh_enabled()
390 c.ssh_enabled = self.ssh_enabled()
391 return self._get_template_context(c)
391 return self._get_template_context(c)
392
392
393 @LoginRequired()
393 @LoginRequired()
394 @HasPermissionAllDecorator('hg.admin')
394 @HasPermissionAllDecorator('hg.admin')
395 def ssh_keys_data(self):
395 def ssh_keys_data(self):
396 _ = self.request.translate
396 _ = self.request.translate
397 self.load_default_context()
397 self.load_default_context()
398 column_map = {
398 column_map = {
399 'fingerprint': 'ssh_key_fingerprint',
399 'fingerprint': 'ssh_key_fingerprint',
400 'username': User.username
400 'username': User.username
401 }
401 }
402 draw, start, limit = self._extract_chunk(self.request)
402 draw, start, limit = self._extract_chunk(self.request)
403 search_q, order_by, order_dir = self._extract_ordering(
403 search_q, order_by, order_dir = self._extract_ordering(
404 self.request, column_map=column_map)
404 self.request, column_map=column_map)
405
405
406 ssh_keys_data_total_count = UserSshKeys.query()\
406 ssh_keys_data_total_count = UserSshKeys.query()\
407 .count()
407 .count()
408
408
409 # json generate
409 # json generate
410 base_q = UserSshKeys.query().join(UserSshKeys.user)
410 base_q = UserSshKeys.query().join(UserSshKeys.user)
411
411
412 if search_q:
412 if search_q:
413 like_expression = u'%{}%'.format(safe_unicode(search_q))
413 like_expression = u'%{}%'.format(safe_unicode(search_q))
414 base_q = base_q.filter(or_(
414 base_q = base_q.filter(or_(
415 User.username.ilike(like_expression),
415 User.username.ilike(like_expression),
416 UserSshKeys.ssh_key_fingerprint.ilike(like_expression),
416 UserSshKeys.ssh_key_fingerprint.ilike(like_expression),
417 ))
417 ))
418
418
419 users_data_total_filtered_count = base_q.count()
419 users_data_total_filtered_count = base_q.count()
420
420
421 sort_col = self._get_order_col(order_by, UserSshKeys)
421 sort_col = self._get_order_col(order_by, UserSshKeys)
422 if sort_col:
422 if sort_col:
423 if order_dir == 'asc':
423 if order_dir == 'asc':
424 # handle null values properly to order by NULL last
424 # handle null values properly to order by NULL last
425 if order_by in ['created_on']:
425 if order_by in ['created_on']:
426 sort_col = coalesce(sort_col, datetime.date.max)
426 sort_col = coalesce(sort_col, datetime.date.max)
427 sort_col = sort_col.asc()
427 sort_col = sort_col.asc()
428 else:
428 else:
429 # handle null values properly to order by NULL last
429 # handle null values properly to order by NULL last
430 if order_by in ['created_on']:
430 if order_by in ['created_on']:
431 sort_col = coalesce(sort_col, datetime.date.min)
431 sort_col = coalesce(sort_col, datetime.date.min)
432 sort_col = sort_col.desc()
432 sort_col = sort_col.desc()
433
433
434 base_q = base_q.order_by(sort_col)
434 base_q = base_q.order_by(sort_col)
435 base_q = base_q.offset(start).limit(limit)
435 base_q = base_q.offset(start).limit(limit)
436
436
437 ssh_keys = base_q.all()
437 ssh_keys = base_q.all()
438
438
439 ssh_keys_data = []
439 ssh_keys_data = []
440 for ssh_key in ssh_keys:
440 for ssh_key in ssh_keys:
441 ssh_keys_data.append({
441 ssh_keys_data.append({
442 "username": h.gravatar_with_user(self.request, ssh_key.user.username),
442 "username": h.gravatar_with_user(self.request, ssh_key.user.username),
443 "fingerprint": ssh_key.ssh_key_fingerprint,
443 "fingerprint": ssh_key.ssh_key_fingerprint,
444 "description": ssh_key.description,
444 "description": ssh_key.description,
445 "created_on": h.format_date(ssh_key.created_on),
445 "created_on": h.format_date(ssh_key.created_on),
446 "accessed_on": h.format_date(ssh_key.accessed_on),
446 "accessed_on": h.format_date(ssh_key.accessed_on),
447 "action": h.link_to(
447 "action": h.link_to(
448 _('Edit'), h.route_path('edit_user_ssh_keys',
448 _('Edit'), h.route_path('edit_user_ssh_keys',
449 user_id=ssh_key.user.user_id))
449 user_id=ssh_key.user.user_id))
450 })
450 })
451
451
452 data = ({
452 data = ({
453 'draw': draw,
453 'draw': draw,
454 'data': ssh_keys_data,
454 'data': ssh_keys_data,
455 'recordsTotal': ssh_keys_data_total_count,
455 'recordsTotal': ssh_keys_data_total_count,
456 'recordsFiltered': users_data_total_filtered_count,
456 'recordsFiltered': users_data_total_filtered_count,
457 })
457 })
458
458
459 return data
459 return data
460
460
461 @LoginRequired()
461 @LoginRequired()
462 @HasPermissionAllDecorator('hg.admin')
462 @HasPermissionAllDecorator('hg.admin')
463 @CSRFRequired()
463 @CSRFRequired()
464 def ssh_keys_update(self):
464 def ssh_keys_update(self):
465 _ = self.request.translate
465 _ = self.request.translate
466 self.load_default_context()
466 self.load_default_context()
467
467
468 ssh_enabled = self.ssh_enabled()
468 ssh_enabled = self.ssh_enabled()
469 key_file = self.request.registry.settings.get(
469 key_file = self.request.registry.settings.get(
470 'ssh.authorized_keys_file_path')
470 'ssh.authorized_keys_file_path')
471 if ssh_enabled:
471 if ssh_enabled:
472 events.trigger(SshKeyFileChangeEvent(), self.request.registry)
472 events.trigger(SshKeyFileChangeEvent(), self.request.registry)
473 h.flash(_('Updated SSH keys file: {}').format(key_file),
473 h.flash(_('Updated SSH keys file: {}').format(key_file),
474 category='success')
474 category='success')
475 else:
475 else:
476 h.flash(_('SSH key support is disabled in .ini file'),
476 h.flash(_('SSH key support is disabled in .ini file'),
477 category='warning')
477 category='warning')
478
478
479 raise HTTPFound(h.route_path('admin_permissions_ssh_keys'))
479 raise HTTPFound(h.route_path('admin_permissions_ssh_keys'))
@@ -1,58 +1,57 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21 import io
22 import uuid
22 import uuid
23 from io import StringIO
24 import pathlib2
23 import pathlib2
25
24
26
25
27 def get_file_storage(settings):
26 def get_file_storage(settings):
28 from rhodecode.apps.file_store.backends.local_store import LocalFileStorage
27 from rhodecode.apps.file_store.backends.local_store import LocalFileStorage
29 from rhodecode.apps.file_store import config_keys
28 from rhodecode.apps.file_store import config_keys
30 store_path = settings.get(config_keys.store_path)
29 store_path = settings.get(config_keys.store_path)
31 return LocalFileStorage(base_path=store_path)
30 return LocalFileStorage(base_path=store_path)
32
31
33
32
34 def splitext(filename):
33 def splitext(filename):
35 ext = ''.join(pathlib2.Path(filename).suffixes)
34 ext = ''.join(pathlib2.Path(filename).suffixes)
36 return filename, ext
35 return filename, ext
37
36
38
37
39 def uid_filename(filename, randomized=True):
38 def uid_filename(filename, randomized=True):
40 """
39 """
41 Generates a randomized or stable (uuid) filename,
40 Generates a randomized or stable (uuid) filename,
42 preserving the original extension.
41 preserving the original extension.
43
42
44 :param filename: the original filename
43 :param filename: the original filename
45 :param randomized: define if filename should be stable (sha1 based) or randomized
44 :param randomized: define if filename should be stable (sha1 based) or randomized
46 """
45 """
47
46
48 _, ext = splitext(filename)
47 _, ext = splitext(filename)
49 if randomized:
48 if randomized:
50 uid = uuid.uuid4()
49 uid = uuid.uuid4()
51 else:
50 else:
52 hash_key = '{}.{}'.format(filename, 'store')
51 hash_key = '{}.{}'.format(filename, 'store')
53 uid = uuid.uuid5(uuid.NAMESPACE_URL, hash_key)
52 uid = uuid.uuid5(uuid.NAMESPACE_URL, hash_key)
54 return str(uid) + ext.lower()
53 return str(uid) + ext.lower()
55
54
56
55
57 def bytes_to_file_obj(bytes_data):
56 def bytes_to_file_obj(bytes_data):
58 return StringIO.StringIO(bytes_data)
57 return io.StringIO(bytes_data)
@@ -1,580 +1,580 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import urllib.parse
21 import urllib.parse
22
22
23 import mock
23 import mock
24 import pytest
24 import pytest
25
25
26 from rhodecode.tests import (
26 from rhodecode.tests import (
27 assert_session_flash, HG_REPO, TEST_USER_ADMIN_LOGIN,
27 assert_session_flash, HG_REPO, TEST_USER_ADMIN_LOGIN,
28 no_newline_id_generator)
28 no_newline_id_generator)
29 from rhodecode.tests.fixture import Fixture
29 from rhodecode.tests.fixture import Fixture
30 from rhodecode.lib.auth import check_password
30 from rhodecode.lib.auth import check_password
31 from rhodecode.lib import helpers as h
31 from rhodecode.lib import helpers as h
32 from rhodecode.model.auth_token import AuthTokenModel
32 from rhodecode.model.auth_token import AuthTokenModel
33 from rhodecode.model.db import User, Notification, UserApiKeys
33 from rhodecode.model.db import User, Notification, UserApiKeys
34 from rhodecode.model.meta import Session
34 from rhodecode.model.meta import Session
35
35
36 fixture = Fixture()
36 fixture = Fixture()
37
37
38 whitelist_view = ['RepoCommitsView:repo_commit_raw']
38 whitelist_view = ['RepoCommitsView:repo_commit_raw']
39
39
40
40
41 def route_path(name, params=None, **kwargs):
41 def route_path(name, params=None, **kwargs):
42 import urllib.request, urllib.parse, urllib.error
42 import urllib.request, urllib.parse, urllib.error
43 from rhodecode.apps._base import ADMIN_PREFIX
43 from rhodecode.apps._base import ADMIN_PREFIX
44
44
45 base_url = {
45 base_url = {
46 'login': ADMIN_PREFIX + '/login',
46 'login': ADMIN_PREFIX + '/login',
47 'logout': ADMIN_PREFIX + '/logout',
47 'logout': ADMIN_PREFIX + '/logout',
48 'register': ADMIN_PREFIX + '/register',
48 'register': ADMIN_PREFIX + '/register',
49 'reset_password':
49 'reset_password':
50 ADMIN_PREFIX + '/password_reset',
50 ADMIN_PREFIX + '/password_reset',
51 'reset_password_confirmation':
51 'reset_password_confirmation':
52 ADMIN_PREFIX + '/password_reset_confirmation',
52 ADMIN_PREFIX + '/password_reset_confirmation',
53
53
54 'admin_permissions_application':
54 'admin_permissions_application':
55 ADMIN_PREFIX + '/permissions/application',
55 ADMIN_PREFIX + '/permissions/application',
56 'admin_permissions_application_update':
56 'admin_permissions_application_update':
57 ADMIN_PREFIX + '/permissions/application/update',
57 ADMIN_PREFIX + '/permissions/application/update',
58
58
59 'repo_commit_raw': '/{repo_name}/raw-changeset/{commit_id}'
59 'repo_commit_raw': '/{repo_name}/raw-changeset/{commit_id}'
60
60
61 }[name].format(**kwargs)
61 }[name].format(**kwargs)
62
62
63 if params:
63 if params:
64 base_url = '{}?{}'.format(base_url, urllib.parse.urlencode(params))
64 base_url = '{}?{}'.format(base_url, urllib.parse.urlencode(params))
65 return base_url
65 return base_url
66
66
67
67
68 @pytest.mark.usefixtures('app')
68 @pytest.mark.usefixtures('app')
69 class TestLoginController(object):
69 class TestLoginController(object):
70 destroy_users = set()
70 destroy_users = set()
71
71
72 @classmethod
72 @classmethod
73 def teardown_class(cls):
73 def teardown_class(cls):
74 fixture.destroy_users(cls.destroy_users)
74 fixture.destroy_users(cls.destroy_users)
75
75
76 def teardown_method(self, method):
76 def teardown_method(self, method):
77 for n in Notification.query().all():
77 for n in Notification.query().all():
78 Session().delete(n)
78 Session().delete(n)
79
79
80 Session().commit()
80 Session().commit()
81 assert Notification.query().all() == []
81 assert Notification.query().all() == []
82
82
83 def test_index(self):
83 def test_index(self):
84 response = self.app.get(route_path('login'))
84 response = self.app.get(route_path('login'))
85 assert response.status == '200 OK'
85 assert response.status == '200 OK'
86 # Test response...
86 # Test response...
87
87
88 def test_login_admin_ok(self):
88 def test_login_admin_ok(self):
89 response = self.app.post(route_path('login'),
89 response = self.app.post(route_path('login'),
90 {'username': 'test_admin',
90 {'username': 'test_admin',
91 'password': 'test12'}, status=302)
91 'password': 'test12'}, status=302)
92 response = response.follow()
92 response = response.follow()
93 session = response.get_session_from_response()
93 session = response.get_session_from_response()
94 username = session['rhodecode_user'].get('username')
94 username = session['rhodecode_user'].get('username')
95 assert username == 'test_admin'
95 assert username == 'test_admin'
96 response.mustcontain('logout')
96 response.mustcontain('logout')
97
97
98 def test_login_regular_ok(self):
98 def test_login_regular_ok(self):
99 response = self.app.post(route_path('login'),
99 response = self.app.post(route_path('login'),
100 {'username': 'test_regular',
100 {'username': 'test_regular',
101 'password': 'test12'}, status=302)
101 'password': 'test12'}, status=302)
102
102
103 response = response.follow()
103 response = response.follow()
104 session = response.get_session_from_response()
104 session = response.get_session_from_response()
105 username = session['rhodecode_user'].get('username')
105 username = session['rhodecode_user'].get('username')
106 assert username == 'test_regular'
106 assert username == 'test_regular'
107 response.mustcontain('logout')
107 response.mustcontain('logout')
108
108
109 def test_login_regular_forbidden_when_super_admin_restriction(self):
109 def test_login_regular_forbidden_when_super_admin_restriction(self):
110 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
110 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
111 with fixture.auth_restriction(self.app._pyramid_registry,
111 with fixture.auth_restriction(self.app._pyramid_registry,
112 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SUPER_ADMIN):
112 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SUPER_ADMIN):
113 response = self.app.post(route_path('login'),
113 response = self.app.post(route_path('login'),
114 {'username': 'test_regular',
114 {'username': 'test_regular',
115 'password': 'test12'})
115 'password': 'test12'})
116
116
117 response.mustcontain('invalid user name')
117 response.mustcontain('invalid user name')
118 response.mustcontain('invalid password')
118 response.mustcontain('invalid password')
119
119
120 def test_login_regular_forbidden_when_scope_restriction(self):
120 def test_login_regular_forbidden_when_scope_restriction(self):
121 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
121 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
122 with fixture.scope_restriction(self.app._pyramid_registry,
122 with fixture.scope_restriction(self.app._pyramid_registry,
123 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SCOPE_VCS):
123 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SCOPE_VCS):
124 response = self.app.post(route_path('login'),
124 response = self.app.post(route_path('login'),
125 {'username': 'test_regular',
125 {'username': 'test_regular',
126 'password': 'test12'})
126 'password': 'test12'})
127
127
128 response.mustcontain('invalid user name')
128 response.mustcontain('invalid user name')
129 response.mustcontain('invalid password')
129 response.mustcontain('invalid password')
130
130
131 def test_login_ok_came_from(self):
131 def test_login_ok_came_from(self):
132 test_came_from = '/_admin/users?branch=stable'
132 test_came_from = '/_admin/users?branch=stable'
133 _url = '{}?came_from={}'.format(route_path('login'), test_came_from)
133 _url = '{}?came_from={}'.format(route_path('login'), test_came_from)
134 response = self.app.post(
134 response = self.app.post(
135 _url, {'username': 'test_admin', 'password': 'test12'}, status=302)
135 _url, {'username': 'test_admin', 'password': 'test12'}, status=302)
136
136
137 assert 'branch=stable' in response.location
137 assert 'branch=stable' in response.location
138 response = response.follow()
138 response = response.follow()
139
139
140 assert response.status == '200 OK'
140 assert response.status == '200 OK'
141 response.mustcontain('Users administration')
141 response.mustcontain('Users administration')
142
142
143 def test_redirect_to_login_with_get_args(self):
143 def test_redirect_to_login_with_get_args(self):
144 with fixture.anon_access(False):
144 with fixture.anon_access(False):
145 kwargs = {'branch': 'stable'}
145 kwargs = {'branch': 'stable'}
146 response = self.app.get(
146 response = self.app.get(
147 h.route_path('repo_summary', repo_name=HG_REPO, _query=kwargs),
147 h.route_path('repo_summary', repo_name=HG_REPO, _query=kwargs),
148 status=302)
148 status=302)
149
149
150 response_query = urllib.parse.urlparse.parse_qsl(response.location)
150 response_query = urllib.parse.parse_qsl(response.location)
151 assert 'branch=stable' in response_query[0][1]
151 assert 'branch=stable' in response_query[0][1]
152
152
153 def test_login_form_with_get_args(self):
153 def test_login_form_with_get_args(self):
154 _url = '{}?came_from=/_admin/users,branch=stable'.format(route_path('login'))
154 _url = '{}?came_from=/_admin/users,branch=stable'.format(route_path('login'))
155 response = self.app.get(_url)
155 response = self.app.get(_url)
156 assert 'branch%3Dstable' in response.form.action
156 assert 'branch%3Dstable' in response.form.action
157
157
158 @pytest.mark.parametrize("url_came_from", [
158 @pytest.mark.parametrize("url_came_from", [
159 'data:text/html,<script>window.alert("xss")</script>',
159 'data:text/html,<script>window.alert("xss")</script>',
160 'mailto:test@rhodecode.org',
160 'mailto:test@rhodecode.org',
161 'file:///etc/passwd',
161 'file:///etc/passwd',
162 'ftp://some.ftp.server',
162 'ftp://some.ftp.server',
163 'http://other.domain',
163 'http://other.domain',
164 '/\r\nX-Forwarded-Host: http://example.org',
164 '/\r\nX-Forwarded-Host: http://example.org',
165 ], ids=no_newline_id_generator)
165 ], ids=no_newline_id_generator)
166 def test_login_bad_came_froms(self, url_came_from):
166 def test_login_bad_came_froms(self, url_came_from):
167 _url = '{}?came_from={}'.format(route_path('login'), url_came_from)
167 _url = '{}?came_from={}'.format(route_path('login'), url_came_from)
168 response = self.app.post(
168 response = self.app.post(
169 _url,
169 _url,
170 {'username': 'test_admin', 'password': 'test12'})
170 {'username': 'test_admin', 'password': 'test12'})
171 assert response.status == '302 Found'
171 assert response.status == '302 Found'
172 response = response.follow()
172 response = response.follow()
173 assert response.status == '200 OK'
173 assert response.status == '200 OK'
174 assert response.request.path == '/'
174 assert response.request.path == '/'
175
175
176 def test_login_short_password(self):
176 def test_login_short_password(self):
177 response = self.app.post(route_path('login'),
177 response = self.app.post(route_path('login'),
178 {'username': 'test_admin',
178 {'username': 'test_admin',
179 'password': 'as'})
179 'password': 'as'})
180 assert response.status == '200 OK'
180 assert response.status == '200 OK'
181
181
182 response.mustcontain('Enter 3 characters or more')
182 response.mustcontain('Enter 3 characters or more')
183
183
184 def test_login_wrong_non_ascii_password(self, user_regular):
184 def test_login_wrong_non_ascii_password(self, user_regular):
185 response = self.app.post(
185 response = self.app.post(
186 route_path('login'),
186 route_path('login'),
187 {'username': user_regular.username,
187 {'username': user_regular.username,
188 'password': u'invalid-non-asci\xe4'.encode('utf8')})
188 'password': u'invalid-non-asci\xe4'.encode('utf8')})
189
189
190 response.mustcontain('invalid user name')
190 response.mustcontain('invalid user name')
191 response.mustcontain('invalid password')
191 response.mustcontain('invalid password')
192
192
193 def test_login_with_non_ascii_password(self, user_util):
193 def test_login_with_non_ascii_password(self, user_util):
194 password = u'valid-non-ascii\xe4'
194 password = u'valid-non-ascii\xe4'
195 user = user_util.create_user(password=password)
195 user = user_util.create_user(password=password)
196 response = self.app.post(
196 response = self.app.post(
197 route_path('login'),
197 route_path('login'),
198 {'username': user.username,
198 {'username': user.username,
199 'password': password})
199 'password': password})
200 assert response.status_code == 302
200 assert response.status_code == 302
201
201
202 def test_login_wrong_username_password(self):
202 def test_login_wrong_username_password(self):
203 response = self.app.post(route_path('login'),
203 response = self.app.post(route_path('login'),
204 {'username': 'error',
204 {'username': 'error',
205 'password': 'test12'})
205 'password': 'test12'})
206
206
207 response.mustcontain('invalid user name')
207 response.mustcontain('invalid user name')
208 response.mustcontain('invalid password')
208 response.mustcontain('invalid password')
209
209
210 def test_login_admin_ok_password_migration(self, real_crypto_backend):
210 def test_login_admin_ok_password_migration(self, real_crypto_backend):
211 from rhodecode.lib import auth
211 from rhodecode.lib import auth
212
212
213 # create new user, with sha256 password
213 # create new user, with sha256 password
214 temp_user = 'test_admin_sha256'
214 temp_user = 'test_admin_sha256'
215 user = fixture.create_user(temp_user)
215 user = fixture.create_user(temp_user)
216 user.password = auth._RhodeCodeCryptoSha256().hash_create(
216 user.password = auth._RhodeCodeCryptoSha256().hash_create(
217 b'test123')
217 b'test123')
218 Session().add(user)
218 Session().add(user)
219 Session().commit()
219 Session().commit()
220 self.destroy_users.add(temp_user)
220 self.destroy_users.add(temp_user)
221 response = self.app.post(route_path('login'),
221 response = self.app.post(route_path('login'),
222 {'username': temp_user,
222 {'username': temp_user,
223 'password': 'test123'}, status=302)
223 'password': 'test123'}, status=302)
224
224
225 response = response.follow()
225 response = response.follow()
226 session = response.get_session_from_response()
226 session = response.get_session_from_response()
227 username = session['rhodecode_user'].get('username')
227 username = session['rhodecode_user'].get('username')
228 assert username == temp_user
228 assert username == temp_user
229 response.mustcontain('logout')
229 response.mustcontain('logout')
230
230
231 # new password should be bcrypted, after log-in and transfer
231 # new password should be bcrypted, after log-in and transfer
232 user = User.get_by_username(temp_user)
232 user = User.get_by_username(temp_user)
233 assert user.password.startswith('$')
233 assert user.password.startswith('$')
234
234
235 # REGISTRATIONS
235 # REGISTRATIONS
236 def test_register(self):
236 def test_register(self):
237 response = self.app.get(route_path('register'))
237 response = self.app.get(route_path('register'))
238 response.mustcontain('Create an Account')
238 response.mustcontain('Create an Account')
239
239
240 def test_register_err_same_username(self):
240 def test_register_err_same_username(self):
241 uname = 'test_admin'
241 uname = 'test_admin'
242 response = self.app.post(
242 response = self.app.post(
243 route_path('register'),
243 route_path('register'),
244 {
244 {
245 'username': uname,
245 'username': uname,
246 'password': 'test12',
246 'password': 'test12',
247 'password_confirmation': 'test12',
247 'password_confirmation': 'test12',
248 'email': 'goodmail@domain.com',
248 'email': 'goodmail@domain.com',
249 'firstname': 'test',
249 'firstname': 'test',
250 'lastname': 'test'
250 'lastname': 'test'
251 }
251 }
252 )
252 )
253
253
254 assertr = response.assert_response()
254 assertr = response.assert_response()
255 msg = 'Username "%(username)s" already exists'
255 msg = 'Username "%(username)s" already exists'
256 msg = msg % {'username': uname}
256 msg = msg % {'username': uname}
257 assertr.element_contains('#username+.error-message', msg)
257 assertr.element_contains('#username+.error-message', msg)
258
258
259 def test_register_err_same_email(self):
259 def test_register_err_same_email(self):
260 response = self.app.post(
260 response = self.app.post(
261 route_path('register'),
261 route_path('register'),
262 {
262 {
263 'username': 'test_admin_0',
263 'username': 'test_admin_0',
264 'password': 'test12',
264 'password': 'test12',
265 'password_confirmation': 'test12',
265 'password_confirmation': 'test12',
266 'email': 'test_admin@mail.com',
266 'email': 'test_admin@mail.com',
267 'firstname': 'test',
267 'firstname': 'test',
268 'lastname': 'test'
268 'lastname': 'test'
269 }
269 }
270 )
270 )
271
271
272 assertr = response.assert_response()
272 assertr = response.assert_response()
273 msg = u'This e-mail address is already taken'
273 msg = u'This e-mail address is already taken'
274 assertr.element_contains('#email+.error-message', msg)
274 assertr.element_contains('#email+.error-message', msg)
275
275
276 def test_register_err_same_email_case_sensitive(self):
276 def test_register_err_same_email_case_sensitive(self):
277 response = self.app.post(
277 response = self.app.post(
278 route_path('register'),
278 route_path('register'),
279 {
279 {
280 'username': 'test_admin_1',
280 'username': 'test_admin_1',
281 'password': 'test12',
281 'password': 'test12',
282 'password_confirmation': 'test12',
282 'password_confirmation': 'test12',
283 'email': 'TesT_Admin@mail.COM',
283 'email': 'TesT_Admin@mail.COM',
284 'firstname': 'test',
284 'firstname': 'test',
285 'lastname': 'test'
285 'lastname': 'test'
286 }
286 }
287 )
287 )
288 assertr = response.assert_response()
288 assertr = response.assert_response()
289 msg = u'This e-mail address is already taken'
289 msg = u'This e-mail address is already taken'
290 assertr.element_contains('#email+.error-message', msg)
290 assertr.element_contains('#email+.error-message', msg)
291
291
292 def test_register_err_wrong_data(self):
292 def test_register_err_wrong_data(self):
293 response = self.app.post(
293 response = self.app.post(
294 route_path('register'),
294 route_path('register'),
295 {
295 {
296 'username': 'xs',
296 'username': 'xs',
297 'password': 'test',
297 'password': 'test',
298 'password_confirmation': 'test',
298 'password_confirmation': 'test',
299 'email': 'goodmailm',
299 'email': 'goodmailm',
300 'firstname': 'test',
300 'firstname': 'test',
301 'lastname': 'test'
301 'lastname': 'test'
302 }
302 }
303 )
303 )
304 assert response.status == '200 OK'
304 assert response.status == '200 OK'
305 response.mustcontain('An email address must contain a single @')
305 response.mustcontain('An email address must contain a single @')
306 response.mustcontain('Enter a value 6 characters long or more')
306 response.mustcontain('Enter a value 6 characters long or more')
307
307
308 def test_register_err_username(self):
308 def test_register_err_username(self):
309 response = self.app.post(
309 response = self.app.post(
310 route_path('register'),
310 route_path('register'),
311 {
311 {
312 'username': 'error user',
312 'username': 'error user',
313 'password': 'test12',
313 'password': 'test12',
314 'password_confirmation': 'test12',
314 'password_confirmation': 'test12',
315 'email': 'goodmailm',
315 'email': 'goodmailm',
316 'firstname': 'test',
316 'firstname': 'test',
317 'lastname': 'test'
317 'lastname': 'test'
318 }
318 }
319 )
319 )
320
320
321 response.mustcontain('An email address must contain a single @')
321 response.mustcontain('An email address must contain a single @')
322 response.mustcontain(
322 response.mustcontain(
323 'Username may only contain '
323 'Username may only contain '
324 'alphanumeric characters underscores, '
324 'alphanumeric characters underscores, '
325 'periods or dashes and must begin with '
325 'periods or dashes and must begin with '
326 'alphanumeric character')
326 'alphanumeric character')
327
327
328 def test_register_err_case_sensitive(self):
328 def test_register_err_case_sensitive(self):
329 usr = 'Test_Admin'
329 usr = 'Test_Admin'
330 response = self.app.post(
330 response = self.app.post(
331 route_path('register'),
331 route_path('register'),
332 {
332 {
333 'username': usr,
333 'username': usr,
334 'password': 'test12',
334 'password': 'test12',
335 'password_confirmation': 'test12',
335 'password_confirmation': 'test12',
336 'email': 'goodmailm',
336 'email': 'goodmailm',
337 'firstname': 'test',
337 'firstname': 'test',
338 'lastname': 'test'
338 'lastname': 'test'
339 }
339 }
340 )
340 )
341
341
342 assertr = response.assert_response()
342 assertr = response.assert_response()
343 msg = u'Username "%(username)s" already exists'
343 msg = u'Username "%(username)s" already exists'
344 msg = msg % {'username': usr}
344 msg = msg % {'username': usr}
345 assertr.element_contains('#username+.error-message', msg)
345 assertr.element_contains('#username+.error-message', msg)
346
346
347 def test_register_special_chars(self):
347 def test_register_special_chars(self):
348 response = self.app.post(
348 response = self.app.post(
349 route_path('register'),
349 route_path('register'),
350 {
350 {
351 'username': 'xxxaxn',
351 'username': 'xxxaxn',
352 'password': 'ąćźżąśśśś',
352 'password': 'ąćźżąśśśś',
353 'password_confirmation': 'ąćźżąśśśś',
353 'password_confirmation': 'ąćźżąśśśś',
354 'email': 'goodmailm@test.plx',
354 'email': 'goodmailm@test.plx',
355 'firstname': 'test',
355 'firstname': 'test',
356 'lastname': 'test'
356 'lastname': 'test'
357 }
357 }
358 )
358 )
359
359
360 msg = u'Invalid characters (non-ascii) in password'
360 msg = u'Invalid characters (non-ascii) in password'
361 response.mustcontain(msg)
361 response.mustcontain(msg)
362
362
363 def test_register_password_mismatch(self):
363 def test_register_password_mismatch(self):
364 response = self.app.post(
364 response = self.app.post(
365 route_path('register'),
365 route_path('register'),
366 {
366 {
367 'username': 'xs',
367 'username': 'xs',
368 'password': '123qwe',
368 'password': '123qwe',
369 'password_confirmation': 'qwe123',
369 'password_confirmation': 'qwe123',
370 'email': 'goodmailm@test.plxa',
370 'email': 'goodmailm@test.plxa',
371 'firstname': 'test',
371 'firstname': 'test',
372 'lastname': 'test'
372 'lastname': 'test'
373 }
373 }
374 )
374 )
375 msg = u'Passwords do not match'
375 msg = u'Passwords do not match'
376 response.mustcontain(msg)
376 response.mustcontain(msg)
377
377
378 def test_register_ok(self):
378 def test_register_ok(self):
379 username = 'test_regular4'
379 username = 'test_regular4'
380 password = 'qweqwe'
380 password = 'qweqwe'
381 email = 'marcin@test.com'
381 email = 'marcin@test.com'
382 name = 'testname'
382 name = 'testname'
383 lastname = 'testlastname'
383 lastname = 'testlastname'
384
384
385 # this initializes a session
385 # this initializes a session
386 response = self.app.get(route_path('register'))
386 response = self.app.get(route_path('register'))
387 response.mustcontain('Create an Account')
387 response.mustcontain('Create an Account')
388
388
389
389
390 response = self.app.post(
390 response = self.app.post(
391 route_path('register'),
391 route_path('register'),
392 {
392 {
393 'username': username,
393 'username': username,
394 'password': password,
394 'password': password,
395 'password_confirmation': password,
395 'password_confirmation': password,
396 'email': email,
396 'email': email,
397 'firstname': name,
397 'firstname': name,
398 'lastname': lastname,
398 'lastname': lastname,
399 'admin': True
399 'admin': True
400 },
400 },
401 status=302
401 status=302
402 ) # This should be overridden
402 ) # This should be overridden
403
403
404 assert_session_flash(
404 assert_session_flash(
405 response, 'You have successfully registered with RhodeCode. You can log-in now.')
405 response, 'You have successfully registered with RhodeCode. You can log-in now.')
406
406
407 ret = Session().query(User).filter(
407 ret = Session().query(User).filter(
408 User.username == 'test_regular4').one()
408 User.username == 'test_regular4').one()
409 assert ret.username == username
409 assert ret.username == username
410 assert check_password(password, ret.password)
410 assert check_password(password, ret.password)
411 assert ret.email == email
411 assert ret.email == email
412 assert ret.name == name
412 assert ret.name == name
413 assert ret.lastname == lastname
413 assert ret.lastname == lastname
414 assert ret.auth_tokens is not None
414 assert ret.auth_tokens is not None
415 assert not ret.admin
415 assert not ret.admin
416
416
417 def test_forgot_password_wrong_mail(self):
417 def test_forgot_password_wrong_mail(self):
418 bad_email = 'marcin@wrongmail.org'
418 bad_email = 'marcin@wrongmail.org'
419 # this initializes a session
419 # this initializes a session
420 self.app.get(route_path('reset_password'))
420 self.app.get(route_path('reset_password'))
421
421
422 response = self.app.post(
422 response = self.app.post(
423 route_path('reset_password'), {'email': bad_email, }
423 route_path('reset_password'), {'email': bad_email, }
424 )
424 )
425 assert_session_flash(response,
425 assert_session_flash(response,
426 'If such email exists, a password reset link was sent to it.')
426 'If such email exists, a password reset link was sent to it.')
427
427
428 def test_forgot_password(self, user_util):
428 def test_forgot_password(self, user_util):
429 # this initializes a session
429 # this initializes a session
430 self.app.get(route_path('reset_password'))
430 self.app.get(route_path('reset_password'))
431
431
432 user = user_util.create_user()
432 user = user_util.create_user()
433 user_id = user.user_id
433 user_id = user.user_id
434 email = user.email
434 email = user.email
435
435
436 response = self.app.post(route_path('reset_password'), {'email': email, })
436 response = self.app.post(route_path('reset_password'), {'email': email, })
437
437
438 assert_session_flash(response,
438 assert_session_flash(response,
439 'If such email exists, a password reset link was sent to it.')
439 'If such email exists, a password reset link was sent to it.')
440
440
441 # BAD KEY
441 # BAD KEY
442 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), 'badkey')
442 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), 'badkey')
443 response = self.app.get(confirm_url, status=302)
443 response = self.app.get(confirm_url, status=302)
444 assert response.location.endswith(route_path('reset_password'))
444 assert response.location.endswith(route_path('reset_password'))
445 assert_session_flash(response, 'Given reset token is invalid')
445 assert_session_flash(response, 'Given reset token is invalid')
446
446
447 response.follow() # cleanup flash
447 response.follow() # cleanup flash
448
448
449 # GOOD KEY
449 # GOOD KEY
450 key = UserApiKeys.query()\
450 key = UserApiKeys.query()\
451 .filter(UserApiKeys.user_id == user_id)\
451 .filter(UserApiKeys.user_id == user_id)\
452 .filter(UserApiKeys.role == UserApiKeys.ROLE_PASSWORD_RESET)\
452 .filter(UserApiKeys.role == UserApiKeys.ROLE_PASSWORD_RESET)\
453 .first()
453 .first()
454
454
455 assert key
455 assert key
456
456
457 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), key.api_key)
457 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), key.api_key)
458 response = self.app.get(confirm_url)
458 response = self.app.get(confirm_url)
459 assert response.status == '302 Found'
459 assert response.status == '302 Found'
460 assert response.location.endswith(route_path('login'))
460 assert response.location.endswith(route_path('login'))
461
461
462 assert_session_flash(
462 assert_session_flash(
463 response,
463 response,
464 'Your password reset was successful, '
464 'Your password reset was successful, '
465 'a new password has been sent to your email')
465 'a new password has been sent to your email')
466
466
467 response.follow()
467 response.follow()
468
468
469 def _get_api_whitelist(self, values=None):
469 def _get_api_whitelist(self, values=None):
470 config = {'api_access_controllers_whitelist': values or []}
470 config = {'api_access_controllers_whitelist': values or []}
471 return config
471 return config
472
472
473 @pytest.mark.parametrize("test_name, auth_token", [
473 @pytest.mark.parametrize("test_name, auth_token", [
474 ('none', None),
474 ('none', None),
475 ('empty_string', ''),
475 ('empty_string', ''),
476 ('fake_number', '123456'),
476 ('fake_number', '123456'),
477 ('proper_auth_token', None)
477 ('proper_auth_token', None)
478 ])
478 ])
479 def test_access_not_whitelisted_page_via_auth_token(
479 def test_access_not_whitelisted_page_via_auth_token(
480 self, test_name, auth_token, user_admin):
480 self, test_name, auth_token, user_admin):
481
481
482 whitelist = self._get_api_whitelist([])
482 whitelist = self._get_api_whitelist([])
483 with mock.patch.dict('rhodecode.CONFIG', whitelist):
483 with mock.patch.dict('rhodecode.CONFIG', whitelist):
484 assert [] == whitelist['api_access_controllers_whitelist']
484 assert [] == whitelist['api_access_controllers_whitelist']
485 if test_name == 'proper_auth_token':
485 if test_name == 'proper_auth_token':
486 # use builtin if api_key is None
486 # use builtin if api_key is None
487 auth_token = user_admin.api_key
487 auth_token = user_admin.api_key
488
488
489 with fixture.anon_access(False):
489 with fixture.anon_access(False):
490 self.app.get(
490 self.app.get(
491 route_path('repo_commit_raw',
491 route_path('repo_commit_raw',
492 repo_name=HG_REPO, commit_id='tip',
492 repo_name=HG_REPO, commit_id='tip',
493 params=dict(api_key=auth_token)),
493 params=dict(api_key=auth_token)),
494 status=302)
494 status=302)
495
495
496 @pytest.mark.parametrize("test_name, auth_token, code", [
496 @pytest.mark.parametrize("test_name, auth_token, code", [
497 ('none', None, 302),
497 ('none', None, 302),
498 ('empty_string', '', 302),
498 ('empty_string', '', 302),
499 ('fake_number', '123456', 302),
499 ('fake_number', '123456', 302),
500 ('proper_auth_token', None, 200)
500 ('proper_auth_token', None, 200)
501 ])
501 ])
502 def test_access_whitelisted_page_via_auth_token(
502 def test_access_whitelisted_page_via_auth_token(
503 self, test_name, auth_token, code, user_admin):
503 self, test_name, auth_token, code, user_admin):
504
504
505 whitelist = self._get_api_whitelist(whitelist_view)
505 whitelist = self._get_api_whitelist(whitelist_view)
506
506
507 with mock.patch.dict('rhodecode.CONFIG', whitelist):
507 with mock.patch.dict('rhodecode.CONFIG', whitelist):
508 assert whitelist_view == whitelist['api_access_controllers_whitelist']
508 assert whitelist_view == whitelist['api_access_controllers_whitelist']
509
509
510 if test_name == 'proper_auth_token':
510 if test_name == 'proper_auth_token':
511 auth_token = user_admin.api_key
511 auth_token = user_admin.api_key
512 assert auth_token
512 assert auth_token
513
513
514 with fixture.anon_access(False):
514 with fixture.anon_access(False):
515 self.app.get(
515 self.app.get(
516 route_path('repo_commit_raw',
516 route_path('repo_commit_raw',
517 repo_name=HG_REPO, commit_id='tip',
517 repo_name=HG_REPO, commit_id='tip',
518 params=dict(api_key=auth_token)),
518 params=dict(api_key=auth_token)),
519 status=code)
519 status=code)
520
520
521 @pytest.mark.parametrize("test_name, auth_token, code", [
521 @pytest.mark.parametrize("test_name, auth_token, code", [
522 ('proper_auth_token', None, 200),
522 ('proper_auth_token', None, 200),
523 ('wrong_auth_token', '123456', 302),
523 ('wrong_auth_token', '123456', 302),
524 ])
524 ])
525 def test_access_whitelisted_page_via_auth_token_bound_to_token(
525 def test_access_whitelisted_page_via_auth_token_bound_to_token(
526 self, test_name, auth_token, code, user_admin):
526 self, test_name, auth_token, code, user_admin):
527
527
528 expected_token = auth_token
528 expected_token = auth_token
529 if test_name == 'proper_auth_token':
529 if test_name == 'proper_auth_token':
530 auth_token = user_admin.api_key
530 auth_token = user_admin.api_key
531 expected_token = auth_token
531 expected_token = auth_token
532 assert auth_token
532 assert auth_token
533
533
534 whitelist = self._get_api_whitelist([
534 whitelist = self._get_api_whitelist([
535 'RepoCommitsView:repo_commit_raw@{}'.format(expected_token)])
535 'RepoCommitsView:repo_commit_raw@{}'.format(expected_token)])
536
536
537 with mock.patch.dict('rhodecode.CONFIG', whitelist):
537 with mock.patch.dict('rhodecode.CONFIG', whitelist):
538
538
539 with fixture.anon_access(False):
539 with fixture.anon_access(False):
540 self.app.get(
540 self.app.get(
541 route_path('repo_commit_raw',
541 route_path('repo_commit_raw',
542 repo_name=HG_REPO, commit_id='tip',
542 repo_name=HG_REPO, commit_id='tip',
543 params=dict(api_key=auth_token)),
543 params=dict(api_key=auth_token)),
544 status=code)
544 status=code)
545
545
546 def test_access_page_via_extra_auth_token(self):
546 def test_access_page_via_extra_auth_token(self):
547 whitelist = self._get_api_whitelist(whitelist_view)
547 whitelist = self._get_api_whitelist(whitelist_view)
548 with mock.patch.dict('rhodecode.CONFIG', whitelist):
548 with mock.patch.dict('rhodecode.CONFIG', whitelist):
549 assert whitelist_view == \
549 assert whitelist_view == \
550 whitelist['api_access_controllers_whitelist']
550 whitelist['api_access_controllers_whitelist']
551
551
552 new_auth_token = AuthTokenModel().create(
552 new_auth_token = AuthTokenModel().create(
553 TEST_USER_ADMIN_LOGIN, 'test')
553 TEST_USER_ADMIN_LOGIN, 'test')
554 Session().commit()
554 Session().commit()
555 with fixture.anon_access(False):
555 with fixture.anon_access(False):
556 self.app.get(
556 self.app.get(
557 route_path('repo_commit_raw',
557 route_path('repo_commit_raw',
558 repo_name=HG_REPO, commit_id='tip',
558 repo_name=HG_REPO, commit_id='tip',
559 params=dict(api_key=new_auth_token.api_key)),
559 params=dict(api_key=new_auth_token.api_key)),
560 status=200)
560 status=200)
561
561
562 def test_access_page_via_expired_auth_token(self):
562 def test_access_page_via_expired_auth_token(self):
563 whitelist = self._get_api_whitelist(whitelist_view)
563 whitelist = self._get_api_whitelist(whitelist_view)
564 with mock.patch.dict('rhodecode.CONFIG', whitelist):
564 with mock.patch.dict('rhodecode.CONFIG', whitelist):
565 assert whitelist_view == \
565 assert whitelist_view == \
566 whitelist['api_access_controllers_whitelist']
566 whitelist['api_access_controllers_whitelist']
567
567
568 new_auth_token = AuthTokenModel().create(
568 new_auth_token = AuthTokenModel().create(
569 TEST_USER_ADMIN_LOGIN, 'test')
569 TEST_USER_ADMIN_LOGIN, 'test')
570 Session().commit()
570 Session().commit()
571 # patch the api key and make it expired
571 # patch the api key and make it expired
572 new_auth_token.expires = 0
572 new_auth_token.expires = 0
573 Session().add(new_auth_token)
573 Session().add(new_auth_token)
574 Session().commit()
574 Session().commit()
575 with fixture.anon_access(False):
575 with fixture.anon_access(False):
576 self.app.get(
576 self.app.get(
577 route_path('repo_commit_raw',
577 route_path('repo_commit_raw',
578 repo_name=HG_REPO, commit_id='tip',
578 repo_name=HG_REPO, commit_id='tip',
579 params=dict(api_key=new_auth_token.api_key)),
579 params=dict(api_key=new_auth_token.api_key)),
580 status=302)
580 status=302)
@@ -1,1877 +1,1877 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import collections
22 import collections
23
23
24 import formencode
24 import formencode
25 import formencode.htmlfill
25 import formencode.htmlfill
26 import peppercorn
26 import peppercorn
27 from pyramid.httpexceptions import (
27 from pyramid.httpexceptions import (
28 HTTPFound, HTTPNotFound, HTTPForbidden, HTTPBadRequest, HTTPConflict)
28 HTTPFound, HTTPNotFound, HTTPForbidden, HTTPBadRequest, HTTPConflict)
29
29
30 from pyramid.renderers import render
30 from pyramid.renderers import render
31
31
32 from rhodecode.apps._base import RepoAppView, DataGridAppView
32 from rhodecode.apps._base import RepoAppView, DataGridAppView
33
33
34 from rhodecode.lib import helpers as h, diffs, codeblocks, channelstream
34 from rhodecode.lib import helpers as h, diffs, codeblocks, channelstream
35 from rhodecode.lib.base import vcs_operation_context
35 from rhodecode.lib.base import vcs_operation_context
36 from rhodecode.lib.diffs import load_cached_diff, cache_diff, diff_cache_exist
36 from rhodecode.lib.diffs import load_cached_diff, cache_diff, diff_cache_exist
37 from rhodecode.lib.exceptions import CommentVersionMismatch
37 from rhodecode.lib.exceptions import CommentVersionMismatch
38 from rhodecode.lib.ext_json import json
38 from rhodecode.lib.ext_json import json
39 from rhodecode.lib.auth import (
39 from rhodecode.lib.auth import (
40 LoginRequired, HasRepoPermissionAny, HasRepoPermissionAnyDecorator,
40 LoginRequired, HasRepoPermissionAny, HasRepoPermissionAnyDecorator,
41 NotAnonymous, CSRFRequired)
41 NotAnonymous, CSRFRequired)
42 from rhodecode.lib.utils2 import str2bool, safe_str, safe_unicode, safe_int, aslist, retry
42 from rhodecode.lib.utils2 import str2bool, safe_str, safe_unicode, safe_int, aslist, retry
43 from rhodecode.lib.vcs.backends.base import (
43 from rhodecode.lib.vcs.backends.base import (
44 EmptyCommit, UpdateFailureReason, unicode_to_reference)
44 EmptyCommit, UpdateFailureReason, unicode_to_reference)
45 from rhodecode.lib.vcs.exceptions import (
45 from rhodecode.lib.vcs.exceptions import (
46 CommitDoesNotExistError, RepositoryRequirementError, EmptyRepositoryError)
46 CommitDoesNotExistError, RepositoryRequirementError, EmptyRepositoryError)
47 from rhodecode.model.changeset_status import ChangesetStatusModel
47 from rhodecode.model.changeset_status import ChangesetStatusModel
48 from rhodecode.model.comment import CommentsModel
48 from rhodecode.model.comment import CommentsModel
49 from rhodecode.model.db import (
49 from rhodecode.model.db import (
50 func, false, or_, PullRequest, ChangesetComment, ChangesetStatus, Repository,
50 func, false, or_, PullRequest, ChangesetComment, ChangesetStatus, Repository,
51 PullRequestReviewers)
51 PullRequestReviewers)
52 from rhodecode.model.forms import PullRequestForm
52 from rhodecode.model.forms import PullRequestForm
53 from rhodecode.model.meta import Session
53 from rhodecode.model.meta import Session
54 from rhodecode.model.pull_request import PullRequestModel, MergeCheck
54 from rhodecode.model.pull_request import PullRequestModel, MergeCheck
55 from rhodecode.model.scm import ScmModel
55 from rhodecode.model.scm import ScmModel
56
56
57 log = logging.getLogger(__name__)
57 log = logging.getLogger(__name__)
58
58
59
59
60 class RepoPullRequestsView(RepoAppView, DataGridAppView):
60 class RepoPullRequestsView(RepoAppView, DataGridAppView):
61
61
62 def load_default_context(self):
62 def load_default_context(self):
63 c = self._get_local_tmpl_context(include_app_defaults=True)
63 c = self._get_local_tmpl_context(include_app_defaults=True)
64 c.REVIEW_STATUS_APPROVED = ChangesetStatus.STATUS_APPROVED
64 c.REVIEW_STATUS_APPROVED = ChangesetStatus.STATUS_APPROVED
65 c.REVIEW_STATUS_REJECTED = ChangesetStatus.STATUS_REJECTED
65 c.REVIEW_STATUS_REJECTED = ChangesetStatus.STATUS_REJECTED
66 # backward compat., we use for OLD PRs a plain renderer
66 # backward compat., we use for OLD PRs a plain renderer
67 c.renderer = 'plain'
67 c.renderer = 'plain'
68 return c
68 return c
69
69
70 def _get_pull_requests_list(
70 def _get_pull_requests_list(
71 self, repo_name, source, filter_type, opened_by, statuses):
71 self, repo_name, source, filter_type, opened_by, statuses):
72
72
73 draw, start, limit = self._extract_chunk(self.request)
73 draw, start, limit = self._extract_chunk(self.request)
74 search_q, order_by, order_dir = self._extract_ordering(self.request)
74 search_q, order_by, order_dir = self._extract_ordering(self.request)
75 _render = self.request.get_partial_renderer(
75 _render = self.request.get_partial_renderer(
76 'rhodecode:templates/data_table/_dt_elements.mako')
76 'rhodecode:templates/data_table/_dt_elements.mako')
77
77
78 # pagination
78 # pagination
79
79
80 if filter_type == 'awaiting_review':
80 if filter_type == 'awaiting_review':
81 pull_requests = PullRequestModel().get_awaiting_review(
81 pull_requests = PullRequestModel().get_awaiting_review(
82 repo_name,
82 repo_name,
83 search_q=search_q, statuses=statuses,
83 search_q=search_q, statuses=statuses,
84 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
84 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
85 pull_requests_total_count = PullRequestModel().count_awaiting_review(
85 pull_requests_total_count = PullRequestModel().count_awaiting_review(
86 repo_name,
86 repo_name,
87 search_q=search_q, statuses=statuses)
87 search_q=search_q, statuses=statuses)
88 elif filter_type == 'awaiting_my_review':
88 elif filter_type == 'awaiting_my_review':
89 pull_requests = PullRequestModel().get_awaiting_my_review(
89 pull_requests = PullRequestModel().get_awaiting_my_review(
90 repo_name, self._rhodecode_user.user_id,
90 repo_name, self._rhodecode_user.user_id,
91 search_q=search_q, statuses=statuses,
91 search_q=search_q, statuses=statuses,
92 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
92 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
93 pull_requests_total_count = PullRequestModel().count_awaiting_my_review(
93 pull_requests_total_count = PullRequestModel().count_awaiting_my_review(
94 repo_name, self._rhodecode_user.user_id,
94 repo_name, self._rhodecode_user.user_id,
95 search_q=search_q, statuses=statuses)
95 search_q=search_q, statuses=statuses)
96 else:
96 else:
97 pull_requests = PullRequestModel().get_all(
97 pull_requests = PullRequestModel().get_all(
98 repo_name, search_q=search_q, source=source, opened_by=opened_by,
98 repo_name, search_q=search_q, source=source, opened_by=opened_by,
99 statuses=statuses, offset=start, length=limit,
99 statuses=statuses, offset=start, length=limit,
100 order_by=order_by, order_dir=order_dir)
100 order_by=order_by, order_dir=order_dir)
101 pull_requests_total_count = PullRequestModel().count_all(
101 pull_requests_total_count = PullRequestModel().count_all(
102 repo_name, search_q=search_q, source=source, statuses=statuses,
102 repo_name, search_q=search_q, source=source, statuses=statuses,
103 opened_by=opened_by)
103 opened_by=opened_by)
104
104
105 data = []
105 data = []
106 comments_model = CommentsModel()
106 comments_model = CommentsModel()
107 for pr in pull_requests:
107 for pr in pull_requests:
108 comments_count = comments_model.get_all_comments(
108 comments_count = comments_model.get_all_comments(
109 self.db_repo.repo_id, pull_request=pr,
109 self.db_repo.repo_id, pull_request=pr,
110 include_drafts=False, count_only=True)
110 include_drafts=False, count_only=True)
111
111
112 review_statuses = pr.reviewers_statuses(user=self._rhodecode_db_user)
112 review_statuses = pr.reviewers_statuses(user=self._rhodecode_db_user)
113 my_review_status = ChangesetStatus.STATUS_NOT_REVIEWED
113 my_review_status = ChangesetStatus.STATUS_NOT_REVIEWED
114 if review_statuses and review_statuses[4]:
114 if review_statuses and review_statuses[4]:
115 _review_obj, _user, _reasons, _mandatory, statuses = review_statuses
115 _review_obj, _user, _reasons, _mandatory, statuses = review_statuses
116 my_review_status = statuses[0][1].status
116 my_review_status = statuses[0][1].status
117
117
118 data.append({
118 data.append({
119 'name': _render('pullrequest_name',
119 'name': _render('pullrequest_name',
120 pr.pull_request_id, pr.pull_request_state,
120 pr.pull_request_id, pr.pull_request_state,
121 pr.work_in_progress, pr.target_repo.repo_name,
121 pr.work_in_progress, pr.target_repo.repo_name,
122 short=True),
122 short=True),
123 'name_raw': pr.pull_request_id,
123 'name_raw': pr.pull_request_id,
124 'status': _render('pullrequest_status',
124 'status': _render('pullrequest_status',
125 pr.calculated_review_status()),
125 pr.calculated_review_status()),
126 'my_status': _render('pullrequest_status',
126 'my_status': _render('pullrequest_status',
127 my_review_status),
127 my_review_status),
128 'title': _render('pullrequest_title', pr.title, pr.description),
128 'title': _render('pullrequest_title', pr.title, pr.description),
129 'description': h.escape(pr.description),
129 'description': h.escape(pr.description),
130 'updated_on': _render('pullrequest_updated_on',
130 'updated_on': _render('pullrequest_updated_on',
131 h.datetime_to_time(pr.updated_on),
131 h.datetime_to_time(pr.updated_on),
132 pr.versions_count),
132 pr.versions_count),
133 'updated_on_raw': h.datetime_to_time(pr.updated_on),
133 'updated_on_raw': h.datetime_to_time(pr.updated_on),
134 'created_on': _render('pullrequest_updated_on',
134 'created_on': _render('pullrequest_updated_on',
135 h.datetime_to_time(pr.created_on)),
135 h.datetime_to_time(pr.created_on)),
136 'created_on_raw': h.datetime_to_time(pr.created_on),
136 'created_on_raw': h.datetime_to_time(pr.created_on),
137 'state': pr.pull_request_state,
137 'state': pr.pull_request_state,
138 'author': _render('pullrequest_author',
138 'author': _render('pullrequest_author',
139 pr.author.full_contact, ),
139 pr.author.full_contact, ),
140 'author_raw': pr.author.full_name,
140 'author_raw': pr.author.full_name,
141 'comments': _render('pullrequest_comments', comments_count),
141 'comments': _render('pullrequest_comments', comments_count),
142 'comments_raw': comments_count,
142 'comments_raw': comments_count,
143 'closed': pr.is_closed(),
143 'closed': pr.is_closed(),
144 })
144 })
145
145
146 data = ({
146 data = ({
147 'draw': draw,
147 'draw': draw,
148 'data': data,
148 'data': data,
149 'recordsTotal': pull_requests_total_count,
149 'recordsTotal': pull_requests_total_count,
150 'recordsFiltered': pull_requests_total_count,
150 'recordsFiltered': pull_requests_total_count,
151 })
151 })
152 return data
152 return data
153
153
154 @LoginRequired()
154 @LoginRequired()
155 @HasRepoPermissionAnyDecorator(
155 @HasRepoPermissionAnyDecorator(
156 'repository.read', 'repository.write', 'repository.admin')
156 'repository.read', 'repository.write', 'repository.admin')
157 def pull_request_list(self):
157 def pull_request_list(self):
158 c = self.load_default_context()
158 c = self.load_default_context()
159
159
160 req_get = self.request.GET
160 req_get = self.request.GET
161 c.source = str2bool(req_get.get('source'))
161 c.source = str2bool(req_get.get('source'))
162 c.closed = str2bool(req_get.get('closed'))
162 c.closed = str2bool(req_get.get('closed'))
163 c.my = str2bool(req_get.get('my'))
163 c.my = str2bool(req_get.get('my'))
164 c.awaiting_review = str2bool(req_get.get('awaiting_review'))
164 c.awaiting_review = str2bool(req_get.get('awaiting_review'))
165 c.awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
165 c.awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
166
166
167 c.active = 'open'
167 c.active = 'open'
168 if c.my:
168 if c.my:
169 c.active = 'my'
169 c.active = 'my'
170 if c.closed:
170 if c.closed:
171 c.active = 'closed'
171 c.active = 'closed'
172 if c.awaiting_review and not c.source:
172 if c.awaiting_review and not c.source:
173 c.active = 'awaiting'
173 c.active = 'awaiting'
174 if c.source and not c.awaiting_review:
174 if c.source and not c.awaiting_review:
175 c.active = 'source'
175 c.active = 'source'
176 if c.awaiting_my_review:
176 if c.awaiting_my_review:
177 c.active = 'awaiting_my'
177 c.active = 'awaiting_my'
178
178
179 return self._get_template_context(c)
179 return self._get_template_context(c)
180
180
181 @LoginRequired()
181 @LoginRequired()
182 @HasRepoPermissionAnyDecorator(
182 @HasRepoPermissionAnyDecorator(
183 'repository.read', 'repository.write', 'repository.admin')
183 'repository.read', 'repository.write', 'repository.admin')
184 def pull_request_list_data(self):
184 def pull_request_list_data(self):
185 self.load_default_context()
185 self.load_default_context()
186
186
187 # additional filters
187 # additional filters
188 req_get = self.request.GET
188 req_get = self.request.GET
189 source = str2bool(req_get.get('source'))
189 source = str2bool(req_get.get('source'))
190 closed = str2bool(req_get.get('closed'))
190 closed = str2bool(req_get.get('closed'))
191 my = str2bool(req_get.get('my'))
191 my = str2bool(req_get.get('my'))
192 awaiting_review = str2bool(req_get.get('awaiting_review'))
192 awaiting_review = str2bool(req_get.get('awaiting_review'))
193 awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
193 awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
194
194
195 filter_type = 'awaiting_review' if awaiting_review \
195 filter_type = 'awaiting_review' if awaiting_review \
196 else 'awaiting_my_review' if awaiting_my_review \
196 else 'awaiting_my_review' if awaiting_my_review \
197 else None
197 else None
198
198
199 opened_by = None
199 opened_by = None
200 if my:
200 if my:
201 opened_by = [self._rhodecode_user.user_id]
201 opened_by = [self._rhodecode_user.user_id]
202
202
203 statuses = [PullRequest.STATUS_NEW, PullRequest.STATUS_OPEN]
203 statuses = [PullRequest.STATUS_NEW, PullRequest.STATUS_OPEN]
204 if closed:
204 if closed:
205 statuses = [PullRequest.STATUS_CLOSED]
205 statuses = [PullRequest.STATUS_CLOSED]
206
206
207 data = self._get_pull_requests_list(
207 data = self._get_pull_requests_list(
208 repo_name=self.db_repo_name, source=source,
208 repo_name=self.db_repo_name, source=source,
209 filter_type=filter_type, opened_by=opened_by, statuses=statuses)
209 filter_type=filter_type, opened_by=opened_by, statuses=statuses)
210
210
211 return data
211 return data
212
212
213 def _is_diff_cache_enabled(self, target_repo):
213 def _is_diff_cache_enabled(self, target_repo):
214 caching_enabled = self._get_general_setting(
214 caching_enabled = self._get_general_setting(
215 target_repo, 'rhodecode_diff_cache')
215 target_repo, 'rhodecode_diff_cache')
216 log.debug('Diff caching enabled: %s', caching_enabled)
216 log.debug('Diff caching enabled: %s', caching_enabled)
217 return caching_enabled
217 return caching_enabled
218
218
219 def _get_diffset(self, source_repo_name, source_repo,
219 def _get_diffset(self, source_repo_name, source_repo,
220 ancestor_commit,
220 ancestor_commit,
221 source_ref_id, target_ref_id,
221 source_ref_id, target_ref_id,
222 target_commit, source_commit, diff_limit, file_limit,
222 target_commit, source_commit, diff_limit, file_limit,
223 fulldiff, hide_whitespace_changes, diff_context, use_ancestor=True):
223 fulldiff, hide_whitespace_changes, diff_context, use_ancestor=True):
224
224
225 target_commit_final = target_commit
225 target_commit_final = target_commit
226 source_commit_final = source_commit
226 source_commit_final = source_commit
227
227
228 if use_ancestor:
228 if use_ancestor:
229 # we might want to not use it for versions
229 # we might want to not use it for versions
230 target_ref_id = ancestor_commit.raw_id
230 target_ref_id = ancestor_commit.raw_id
231 target_commit_final = ancestor_commit
231 target_commit_final = ancestor_commit
232
232
233 vcs_diff = PullRequestModel().get_diff(
233 vcs_diff = PullRequestModel().get_diff(
234 source_repo, source_ref_id, target_ref_id,
234 source_repo, source_ref_id, target_ref_id,
235 hide_whitespace_changes, diff_context)
235 hide_whitespace_changes, diff_context)
236
236
237 diff_processor = diffs.DiffProcessor(
237 diff_processor = diffs.DiffProcessor(
238 vcs_diff, format='newdiff', diff_limit=diff_limit,
238 vcs_diff, format='newdiff', diff_limit=diff_limit,
239 file_limit=file_limit, show_full_diff=fulldiff)
239 file_limit=file_limit, show_full_diff=fulldiff)
240
240
241 _parsed = diff_processor.prepare()
241 _parsed = diff_processor.prepare()
242
242
243 diffset = codeblocks.DiffSet(
243 diffset = codeblocks.DiffSet(
244 repo_name=self.db_repo_name,
244 repo_name=self.db_repo_name,
245 source_repo_name=source_repo_name,
245 source_repo_name=source_repo_name,
246 source_node_getter=codeblocks.diffset_node_getter(target_commit_final),
246 source_node_getter=codeblocks.diffset_node_getter(target_commit_final),
247 target_node_getter=codeblocks.diffset_node_getter(source_commit_final),
247 target_node_getter=codeblocks.diffset_node_getter(source_commit_final),
248 )
248 )
249 diffset = self.path_filter.render_patchset_filtered(
249 diffset = self.path_filter.render_patchset_filtered(
250 diffset, _parsed, target_ref_id, source_ref_id)
250 diffset, _parsed, target_ref_id, source_ref_id)
251
251
252 return diffset
252 return diffset
253
253
254 def _get_range_diffset(self, source_scm, source_repo,
254 def _get_range_diffset(self, source_scm, source_repo,
255 commit1, commit2, diff_limit, file_limit,
255 commit1, commit2, diff_limit, file_limit,
256 fulldiff, hide_whitespace_changes, diff_context):
256 fulldiff, hide_whitespace_changes, diff_context):
257 vcs_diff = source_scm.get_diff(
257 vcs_diff = source_scm.get_diff(
258 commit1, commit2,
258 commit1, commit2,
259 ignore_whitespace=hide_whitespace_changes,
259 ignore_whitespace=hide_whitespace_changes,
260 context=diff_context)
260 context=diff_context)
261
261
262 diff_processor = diffs.DiffProcessor(
262 diff_processor = diffs.DiffProcessor(
263 vcs_diff, format='newdiff', diff_limit=diff_limit,
263 vcs_diff, format='newdiff', diff_limit=diff_limit,
264 file_limit=file_limit, show_full_diff=fulldiff)
264 file_limit=file_limit, show_full_diff=fulldiff)
265
265
266 _parsed = diff_processor.prepare()
266 _parsed = diff_processor.prepare()
267
267
268 diffset = codeblocks.DiffSet(
268 diffset = codeblocks.DiffSet(
269 repo_name=source_repo.repo_name,
269 repo_name=source_repo.repo_name,
270 source_node_getter=codeblocks.diffset_node_getter(commit1),
270 source_node_getter=codeblocks.diffset_node_getter(commit1),
271 target_node_getter=codeblocks.diffset_node_getter(commit2))
271 target_node_getter=codeblocks.diffset_node_getter(commit2))
272
272
273 diffset = self.path_filter.render_patchset_filtered(
273 diffset = self.path_filter.render_patchset_filtered(
274 diffset, _parsed, commit1.raw_id, commit2.raw_id)
274 diffset, _parsed, commit1.raw_id, commit2.raw_id)
275
275
276 return diffset
276 return diffset
277
277
278 def register_comments_vars(self, c, pull_request, versions, include_drafts=True):
278 def register_comments_vars(self, c, pull_request, versions, include_drafts=True):
279 comments_model = CommentsModel()
279 comments_model = CommentsModel()
280
280
281 # GENERAL COMMENTS with versions #
281 # GENERAL COMMENTS with versions #
282 q = comments_model._all_general_comments_of_pull_request(pull_request)
282 q = comments_model._all_general_comments_of_pull_request(pull_request)
283 q = q.order_by(ChangesetComment.comment_id.asc())
283 q = q.order_by(ChangesetComment.comment_id.asc())
284 if not include_drafts:
284 if not include_drafts:
285 q = q.filter(ChangesetComment.draft == false())
285 q = q.filter(ChangesetComment.draft == false())
286 general_comments = q
286 general_comments = q
287
287
288 # pick comments we want to render at current version
288 # pick comments we want to render at current version
289 c.comment_versions = comments_model.aggregate_comments(
289 c.comment_versions = comments_model.aggregate_comments(
290 general_comments, versions, c.at_version_num)
290 general_comments, versions, c.at_version_num)
291
291
292 # INLINE COMMENTS with versions #
292 # INLINE COMMENTS with versions #
293 q = comments_model._all_inline_comments_of_pull_request(pull_request)
293 q = comments_model._all_inline_comments_of_pull_request(pull_request)
294 q = q.order_by(ChangesetComment.comment_id.asc())
294 q = q.order_by(ChangesetComment.comment_id.asc())
295 if not include_drafts:
295 if not include_drafts:
296 q = q.filter(ChangesetComment.draft == false())
296 q = q.filter(ChangesetComment.draft == false())
297 inline_comments = q
297 inline_comments = q
298
298
299 c.inline_versions = comments_model.aggregate_comments(
299 c.inline_versions = comments_model.aggregate_comments(
300 inline_comments, versions, c.at_version_num, inline=True)
300 inline_comments, versions, c.at_version_num, inline=True)
301
301
302 # Comments inline+general
302 # Comments inline+general
303 if c.at_version:
303 if c.at_version:
304 c.inline_comments_flat = c.inline_versions[c.at_version_num]['display']
304 c.inline_comments_flat = c.inline_versions[c.at_version_num]['display']
305 c.comments = c.comment_versions[c.at_version_num]['display']
305 c.comments = c.comment_versions[c.at_version_num]['display']
306 else:
306 else:
307 c.inline_comments_flat = c.inline_versions[c.at_version_num]['until']
307 c.inline_comments_flat = c.inline_versions[c.at_version_num]['until']
308 c.comments = c.comment_versions[c.at_version_num]['until']
308 c.comments = c.comment_versions[c.at_version_num]['until']
309
309
310 return general_comments, inline_comments
310 return general_comments, inline_comments
311
311
312 @LoginRequired()
312 @LoginRequired()
313 @HasRepoPermissionAnyDecorator(
313 @HasRepoPermissionAnyDecorator(
314 'repository.read', 'repository.write', 'repository.admin')
314 'repository.read', 'repository.write', 'repository.admin')
315 def pull_request_show(self):
315 def pull_request_show(self):
316 _ = self.request.translate
316 _ = self.request.translate
317 c = self.load_default_context()
317 c = self.load_default_context()
318
318
319 pull_request = PullRequest.get_or_404(
319 pull_request = PullRequest.get_or_404(
320 self.request.matchdict['pull_request_id'])
320 self.request.matchdict['pull_request_id'])
321 pull_request_id = pull_request.pull_request_id
321 pull_request_id = pull_request.pull_request_id
322
322
323 c.state_progressing = pull_request.is_state_changing()
323 c.state_progressing = pull_request.is_state_changing()
324 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
324 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
325
325
326 _new_state = {
326 _new_state = {
327 'created': PullRequest.STATE_CREATED,
327 'created': PullRequest.STATE_CREATED,
328 }.get(self.request.GET.get('force_state'))
328 }.get(self.request.GET.get('force_state'))
329 can_force_state = c.is_super_admin or HasRepoPermissionAny('repository.admin')(c.repo_name)
329 can_force_state = c.is_super_admin or HasRepoPermissionAny('repository.admin')(c.repo_name)
330
330
331 if can_force_state and _new_state:
331 if can_force_state and _new_state:
332 with pull_request.set_state(PullRequest.STATE_UPDATING, final_state=_new_state):
332 with pull_request.set_state(PullRequest.STATE_UPDATING, final_state=_new_state):
333 h.flash(
333 h.flash(
334 _('Pull Request state was force changed to `{}`').format(_new_state),
334 _('Pull Request state was force changed to `{}`').format(_new_state),
335 category='success')
335 category='success')
336 Session().commit()
336 Session().commit()
337
337
338 raise HTTPFound(h.route_path(
338 raise HTTPFound(h.route_path(
339 'pullrequest_show', repo_name=self.db_repo_name,
339 'pullrequest_show', repo_name=self.db_repo_name,
340 pull_request_id=pull_request_id))
340 pull_request_id=pull_request_id))
341
341
342 version = self.request.GET.get('version')
342 version = self.request.GET.get('version')
343 from_version = self.request.GET.get('from_version') or version
343 from_version = self.request.GET.get('from_version') or version
344 merge_checks = self.request.GET.get('merge_checks')
344 merge_checks = self.request.GET.get('merge_checks')
345 c.fulldiff = str2bool(self.request.GET.get('fulldiff'))
345 c.fulldiff = str2bool(self.request.GET.get('fulldiff'))
346 force_refresh = str2bool(self.request.GET.get('force_refresh'))
346 force_refresh = str2bool(self.request.GET.get('force_refresh'))
347 c.range_diff_on = self.request.GET.get('range-diff') == "1"
347 c.range_diff_on = self.request.GET.get('range-diff') == "1"
348
348
349 # fetch global flags of ignore ws or context lines
349 # fetch global flags of ignore ws or context lines
350 diff_context = diffs.get_diff_context(self.request)
350 diff_context = diffs.get_diff_context(self.request)
351 hide_whitespace_changes = diffs.get_diff_whitespace_flag(self.request)
351 hide_whitespace_changes = diffs.get_diff_whitespace_flag(self.request)
352
352
353 (pull_request_latest,
353 (pull_request_latest,
354 pull_request_at_ver,
354 pull_request_at_ver,
355 pull_request_display_obj,
355 pull_request_display_obj,
356 at_version) = PullRequestModel().get_pr_version(
356 at_version) = PullRequestModel().get_pr_version(
357 pull_request_id, version=version)
357 pull_request_id, version=version)
358
358
359 pr_closed = pull_request_latest.is_closed()
359 pr_closed = pull_request_latest.is_closed()
360
360
361 if pr_closed and (version or from_version):
361 if pr_closed and (version or from_version):
362 # not allow to browse versions for closed PR
362 # not allow to browse versions for closed PR
363 raise HTTPFound(h.route_path(
363 raise HTTPFound(h.route_path(
364 'pullrequest_show', repo_name=self.db_repo_name,
364 'pullrequest_show', repo_name=self.db_repo_name,
365 pull_request_id=pull_request_id))
365 pull_request_id=pull_request_id))
366
366
367 versions = pull_request_display_obj.versions()
367 versions = pull_request_display_obj.versions()
368
368
369 c.commit_versions = PullRequestModel().pr_commits_versions(versions)
369 c.commit_versions = PullRequestModel().pr_commits_versions(versions)
370
370
371 # used to store per-commit range diffs
371 # used to store per-commit range diffs
372 c.changes = collections.OrderedDict()
372 c.changes = collections.OrderedDict()
373
373
374 c.at_version = at_version
374 c.at_version = at_version
375 c.at_version_num = (at_version
375 c.at_version_num = (at_version
376 if at_version and at_version != PullRequest.LATEST_VER
376 if at_version and at_version != PullRequest.LATEST_VER
377 else None)
377 else None)
378
378
379 c.at_version_index = ChangesetComment.get_index_from_version(
379 c.at_version_index = ChangesetComment.get_index_from_version(
380 c.at_version_num, versions)
380 c.at_version_num, versions)
381
381
382 (prev_pull_request_latest,
382 (prev_pull_request_latest,
383 prev_pull_request_at_ver,
383 prev_pull_request_at_ver,
384 prev_pull_request_display_obj,
384 prev_pull_request_display_obj,
385 prev_at_version) = PullRequestModel().get_pr_version(
385 prev_at_version) = PullRequestModel().get_pr_version(
386 pull_request_id, version=from_version)
386 pull_request_id, version=from_version)
387
387
388 c.from_version = prev_at_version
388 c.from_version = prev_at_version
389 c.from_version_num = (prev_at_version
389 c.from_version_num = (prev_at_version
390 if prev_at_version and prev_at_version != PullRequest.LATEST_VER
390 if prev_at_version and prev_at_version != PullRequest.LATEST_VER
391 else None)
391 else None)
392 c.from_version_index = ChangesetComment.get_index_from_version(
392 c.from_version_index = ChangesetComment.get_index_from_version(
393 c.from_version_num, versions)
393 c.from_version_num, versions)
394
394
395 # define if we're in COMPARE mode or VIEW at version mode
395 # define if we're in COMPARE mode or VIEW at version mode
396 compare = at_version != prev_at_version
396 compare = at_version != prev_at_version
397
397
398 # pull_requests repo_name we opened it against
398 # pull_requests repo_name we opened it against
399 # ie. target_repo must match
399 # ie. target_repo must match
400 if self.db_repo_name != pull_request_at_ver.target_repo.repo_name:
400 if self.db_repo_name != pull_request_at_ver.target_repo.repo_name:
401 log.warning('Mismatch between the current repo: %s, and target %s',
401 log.warning('Mismatch between the current repo: %s, and target %s',
402 self.db_repo_name, pull_request_at_ver.target_repo.repo_name)
402 self.db_repo_name, pull_request_at_ver.target_repo.repo_name)
403 raise HTTPNotFound()
403 raise HTTPNotFound()
404
404
405 c.shadow_clone_url = PullRequestModel().get_shadow_clone_url(pull_request_at_ver)
405 c.shadow_clone_url = PullRequestModel().get_shadow_clone_url(pull_request_at_ver)
406
406
407 c.pull_request = pull_request_display_obj
407 c.pull_request = pull_request_display_obj
408 c.renderer = pull_request_at_ver.description_renderer or c.renderer
408 c.renderer = pull_request_at_ver.description_renderer or c.renderer
409 c.pull_request_latest = pull_request_latest
409 c.pull_request_latest = pull_request_latest
410
410
411 # inject latest version
411 # inject latest version
412 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
412 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
413 c.versions = versions + [latest_ver]
413 c.versions = versions + [latest_ver]
414
414
415 if compare or (at_version and not at_version == PullRequest.LATEST_VER):
415 if compare or (at_version and not at_version == PullRequest.LATEST_VER):
416 c.allowed_to_change_status = False
416 c.allowed_to_change_status = False
417 c.allowed_to_update = False
417 c.allowed_to_update = False
418 c.allowed_to_merge = False
418 c.allowed_to_merge = False
419 c.allowed_to_delete = False
419 c.allowed_to_delete = False
420 c.allowed_to_comment = False
420 c.allowed_to_comment = False
421 c.allowed_to_close = False
421 c.allowed_to_close = False
422 else:
422 else:
423 can_change_status = PullRequestModel().check_user_change_status(
423 can_change_status = PullRequestModel().check_user_change_status(
424 pull_request_at_ver, self._rhodecode_user)
424 pull_request_at_ver, self._rhodecode_user)
425 c.allowed_to_change_status = can_change_status and not pr_closed
425 c.allowed_to_change_status = can_change_status and not pr_closed
426
426
427 c.allowed_to_update = PullRequestModel().check_user_update(
427 c.allowed_to_update = PullRequestModel().check_user_update(
428 pull_request_latest, self._rhodecode_user) and not pr_closed
428 pull_request_latest, self._rhodecode_user) and not pr_closed
429 c.allowed_to_merge = PullRequestModel().check_user_merge(
429 c.allowed_to_merge = PullRequestModel().check_user_merge(
430 pull_request_latest, self._rhodecode_user) and not pr_closed
430 pull_request_latest, self._rhodecode_user) and not pr_closed
431 c.allowed_to_delete = PullRequestModel().check_user_delete(
431 c.allowed_to_delete = PullRequestModel().check_user_delete(
432 pull_request_latest, self._rhodecode_user) and not pr_closed
432 pull_request_latest, self._rhodecode_user) and not pr_closed
433 c.allowed_to_comment = not pr_closed
433 c.allowed_to_comment = not pr_closed
434 c.allowed_to_close = c.allowed_to_merge and not pr_closed
434 c.allowed_to_close = c.allowed_to_merge and not pr_closed
435
435
436 c.forbid_adding_reviewers = False
436 c.forbid_adding_reviewers = False
437
437
438 if pull_request_latest.reviewer_data and \
438 if pull_request_latest.reviewer_data and \
439 'rules' in pull_request_latest.reviewer_data:
439 'rules' in pull_request_latest.reviewer_data:
440 rules = pull_request_latest.reviewer_data['rules'] or {}
440 rules = pull_request_latest.reviewer_data['rules'] or {}
441 try:
441 try:
442 c.forbid_adding_reviewers = rules.get('forbid_adding_reviewers')
442 c.forbid_adding_reviewers = rules.get('forbid_adding_reviewers')
443 except Exception:
443 except Exception:
444 pass
444 pass
445
445
446 # check merge capabilities
446 # check merge capabilities
447 _merge_check = MergeCheck.validate(
447 _merge_check = MergeCheck.validate(
448 pull_request_latest, auth_user=self._rhodecode_user,
448 pull_request_latest, auth_user=self._rhodecode_user,
449 translator=self.request.translate,
449 translator=self.request.translate,
450 force_shadow_repo_refresh=force_refresh)
450 force_shadow_repo_refresh=force_refresh)
451
451
452 c.pr_merge_errors = _merge_check.error_details
452 c.pr_merge_errors = _merge_check.error_details
453 c.pr_merge_possible = not _merge_check.failed
453 c.pr_merge_possible = not _merge_check.failed
454 c.pr_merge_message = _merge_check.merge_msg
454 c.pr_merge_message = _merge_check.merge_msg
455 c.pr_merge_source_commit = _merge_check.source_commit
455 c.pr_merge_source_commit = _merge_check.source_commit
456 c.pr_merge_target_commit = _merge_check.target_commit
456 c.pr_merge_target_commit = _merge_check.target_commit
457
457
458 c.pr_merge_info = MergeCheck.get_merge_conditions(
458 c.pr_merge_info = MergeCheck.get_merge_conditions(
459 pull_request_latest, translator=self.request.translate)
459 pull_request_latest, translator=self.request.translate)
460
460
461 c.pull_request_review_status = _merge_check.review_status
461 c.pull_request_review_status = _merge_check.review_status
462 if merge_checks:
462 if merge_checks:
463 self.request.override_renderer = \
463 self.request.override_renderer = \
464 'rhodecode:templates/pullrequests/pullrequest_merge_checks.mako'
464 'rhodecode:templates/pullrequests/pullrequest_merge_checks.mako'
465 return self._get_template_context(c)
465 return self._get_template_context(c)
466
466
467 c.reviewers_count = pull_request.reviewers_count
467 c.reviewers_count = pull_request.reviewers_count
468 c.observers_count = pull_request.observers_count
468 c.observers_count = pull_request.observers_count
469
469
470 # reviewers and statuses
470 # reviewers and statuses
471 c.pull_request_default_reviewers_data_json = json.dumps(pull_request.reviewer_data)
471 c.pull_request_default_reviewers_data_json = json.dumps(pull_request.reviewer_data)
472 c.pull_request_set_reviewers_data_json = collections.OrderedDict({'reviewers': []})
472 c.pull_request_set_reviewers_data_json = collections.OrderedDict({'reviewers': []})
473 c.pull_request_set_observers_data_json = collections.OrderedDict({'observers': []})
473 c.pull_request_set_observers_data_json = collections.OrderedDict({'observers': []})
474
474
475 for review_obj, member, reasons, mandatory, status in pull_request_at_ver.reviewers_statuses():
475 for review_obj, member, reasons, mandatory, status in pull_request_at_ver.reviewers_statuses():
476 member_reviewer = h.reviewer_as_json(
476 member_reviewer = h.reviewer_as_json(
477 member, reasons=reasons, mandatory=mandatory,
477 member, reasons=reasons, mandatory=mandatory,
478 role=review_obj.role,
478 role=review_obj.role,
479 user_group=review_obj.rule_user_group_data()
479 user_group=review_obj.rule_user_group_data()
480 )
480 )
481
481
482 current_review_status = status[0][1].status if status else ChangesetStatus.STATUS_NOT_REVIEWED
482 current_review_status = status[0][1].status if status else ChangesetStatus.STATUS_NOT_REVIEWED
483 member_reviewer['review_status'] = current_review_status
483 member_reviewer['review_status'] = current_review_status
484 member_reviewer['review_status_label'] = h.commit_status_lbl(current_review_status)
484 member_reviewer['review_status_label'] = h.commit_status_lbl(current_review_status)
485 member_reviewer['allowed_to_update'] = c.allowed_to_update
485 member_reviewer['allowed_to_update'] = c.allowed_to_update
486 c.pull_request_set_reviewers_data_json['reviewers'].append(member_reviewer)
486 c.pull_request_set_reviewers_data_json['reviewers'].append(member_reviewer)
487
487
488 c.pull_request_set_reviewers_data_json = json.dumps(c.pull_request_set_reviewers_data_json)
488 c.pull_request_set_reviewers_data_json = json.dumps(c.pull_request_set_reviewers_data_json)
489
489
490 for observer_obj, member in pull_request_at_ver.observers():
490 for observer_obj, member in pull_request_at_ver.observers():
491 member_observer = h.reviewer_as_json(
491 member_observer = h.reviewer_as_json(
492 member, reasons=[], mandatory=False,
492 member, reasons=[], mandatory=False,
493 role=observer_obj.role,
493 role=observer_obj.role,
494 user_group=observer_obj.rule_user_group_data()
494 user_group=observer_obj.rule_user_group_data()
495 )
495 )
496 member_observer['allowed_to_update'] = c.allowed_to_update
496 member_observer['allowed_to_update'] = c.allowed_to_update
497 c.pull_request_set_observers_data_json['observers'].append(member_observer)
497 c.pull_request_set_observers_data_json['observers'].append(member_observer)
498
498
499 c.pull_request_set_observers_data_json = json.dumps(c.pull_request_set_observers_data_json)
499 c.pull_request_set_observers_data_json = json.dumps(c.pull_request_set_observers_data_json)
500
500
501 general_comments, inline_comments = \
501 general_comments, inline_comments = \
502 self.register_comments_vars(c, pull_request_latest, versions)
502 self.register_comments_vars(c, pull_request_latest, versions)
503
503
504 # TODOs
504 # TODOs
505 c.unresolved_comments = CommentsModel() \
505 c.unresolved_comments = CommentsModel() \
506 .get_pull_request_unresolved_todos(pull_request_latest)
506 .get_pull_request_unresolved_todos(pull_request_latest)
507 c.resolved_comments = CommentsModel() \
507 c.resolved_comments = CommentsModel() \
508 .get_pull_request_resolved_todos(pull_request_latest)
508 .get_pull_request_resolved_todos(pull_request_latest)
509
509
510 # Drafts
510 # Drafts
511 c.draft_comments = CommentsModel().get_pull_request_drafts(
511 c.draft_comments = CommentsModel().get_pull_request_drafts(
512 self._rhodecode_db_user.user_id,
512 self._rhodecode_db_user.user_id,
513 pull_request_latest)
513 pull_request_latest)
514
514
515 # if we use version, then do not show later comments
515 # if we use version, then do not show later comments
516 # than current version
516 # than current version
517 display_inline_comments = collections.defaultdict(
517 display_inline_comments = collections.defaultdict(
518 lambda: collections.defaultdict(list))
518 lambda: collections.defaultdict(list))
519 for co in inline_comments:
519 for co in inline_comments:
520 if c.at_version_num:
520 if c.at_version_num:
521 # pick comments that are at least UPTO given version, so we
521 # pick comments that are at least UPTO given version, so we
522 # don't render comments for higher version
522 # don't render comments for higher version
523 should_render = co.pull_request_version_id and \
523 should_render = co.pull_request_version_id and \
524 co.pull_request_version_id <= c.at_version_num
524 co.pull_request_version_id <= c.at_version_num
525 else:
525 else:
526 # showing all, for 'latest'
526 # showing all, for 'latest'
527 should_render = True
527 should_render = True
528
528
529 if should_render:
529 if should_render:
530 display_inline_comments[co.f_path][co.line_no].append(co)
530 display_inline_comments[co.f_path][co.line_no].append(co)
531
531
532 # load diff data into template context, if we use compare mode then
532 # load diff data into template context, if we use compare mode then
533 # diff is calculated based on changes between versions of PR
533 # diff is calculated based on changes between versions of PR
534
534
535 source_repo = pull_request_at_ver.source_repo
535 source_repo = pull_request_at_ver.source_repo
536 source_ref_id = pull_request_at_ver.source_ref_parts.commit_id
536 source_ref_id = pull_request_at_ver.source_ref_parts.commit_id
537
537
538 target_repo = pull_request_at_ver.target_repo
538 target_repo = pull_request_at_ver.target_repo
539 target_ref_id = pull_request_at_ver.target_ref_parts.commit_id
539 target_ref_id = pull_request_at_ver.target_ref_parts.commit_id
540
540
541 if compare:
541 if compare:
542 # in compare switch the diff base to latest commit from prev version
542 # in compare switch the diff base to latest commit from prev version
543 target_ref_id = prev_pull_request_display_obj.revisions[0]
543 target_ref_id = prev_pull_request_display_obj.revisions[0]
544
544
545 # despite opening commits for bookmarks/branches/tags, we always
545 # despite opening commits for bookmarks/branches/tags, we always
546 # convert this to rev to prevent changes after bookmark or branch change
546 # convert this to rev to prevent changes after bookmark or branch change
547 c.source_ref_type = 'rev'
547 c.source_ref_type = 'rev'
548 c.source_ref = source_ref_id
548 c.source_ref = source_ref_id
549
549
550 c.target_ref_type = 'rev'
550 c.target_ref_type = 'rev'
551 c.target_ref = target_ref_id
551 c.target_ref = target_ref_id
552
552
553 c.source_repo = source_repo
553 c.source_repo = source_repo
554 c.target_repo = target_repo
554 c.target_repo = target_repo
555
555
556 c.commit_ranges = []
556 c.commit_ranges = []
557 source_commit = EmptyCommit()
557 source_commit = EmptyCommit()
558 target_commit = EmptyCommit()
558 target_commit = EmptyCommit()
559 c.missing_requirements = False
559 c.missing_requirements = False
560
560
561 source_scm = source_repo.scm_instance()
561 source_scm = source_repo.scm_instance()
562 target_scm = target_repo.scm_instance()
562 target_scm = target_repo.scm_instance()
563
563
564 shadow_scm = None
564 shadow_scm = None
565 try:
565 try:
566 shadow_scm = pull_request_latest.get_shadow_repo()
566 shadow_scm = pull_request_latest.get_shadow_repo()
567 except Exception:
567 except Exception:
568 log.debug('Failed to get shadow repo', exc_info=True)
568 log.debug('Failed to get shadow repo', exc_info=True)
569 # try first the existing source_repo, and then shadow
569 # try first the existing source_repo, and then shadow
570 # repo if we can obtain one
570 # repo if we can obtain one
571 commits_source_repo = source_scm
571 commits_source_repo = source_scm
572 if shadow_scm:
572 if shadow_scm:
573 commits_source_repo = shadow_scm
573 commits_source_repo = shadow_scm
574
574
575 c.commits_source_repo = commits_source_repo
575 c.commits_source_repo = commits_source_repo
576 c.ancestor = None # set it to None, to hide it from PR view
576 c.ancestor = None # set it to None, to hide it from PR view
577
577
578 # empty version means latest, so we keep this to prevent
578 # empty version means latest, so we keep this to prevent
579 # double caching
579 # double caching
580 version_normalized = version or PullRequest.LATEST_VER
580 version_normalized = version or PullRequest.LATEST_VER
581 from_version_normalized = from_version or PullRequest.LATEST_VER
581 from_version_normalized = from_version or PullRequest.LATEST_VER
582
582
583 cache_path = self.rhodecode_vcs_repo.get_create_shadow_cache_pr_path(target_repo)
583 cache_path = self.rhodecode_vcs_repo.get_create_shadow_cache_pr_path(target_repo)
584 cache_file_path = diff_cache_exist(
584 cache_file_path = diff_cache_exist(
585 cache_path, 'pull_request', pull_request_id, version_normalized,
585 cache_path, 'pull_request', pull_request_id, version_normalized,
586 from_version_normalized, source_ref_id, target_ref_id,
586 from_version_normalized, source_ref_id, target_ref_id,
587 hide_whitespace_changes, diff_context, c.fulldiff)
587 hide_whitespace_changes, diff_context, c.fulldiff)
588
588
589 caching_enabled = self._is_diff_cache_enabled(c.target_repo)
589 caching_enabled = self._is_diff_cache_enabled(c.target_repo)
590 force_recache = self.get_recache_flag()
590 force_recache = self.get_recache_flag()
591
591
592 cached_diff = None
592 cached_diff = None
593 if caching_enabled:
593 if caching_enabled:
594 cached_diff = load_cached_diff(cache_file_path)
594 cached_diff = load_cached_diff(cache_file_path)
595
595
596 has_proper_commit_cache = (
596 has_proper_commit_cache = (
597 cached_diff and cached_diff.get('commits')
597 cached_diff and cached_diff.get('commits')
598 and len(cached_diff.get('commits', [])) == 5
598 and len(cached_diff.get('commits', [])) == 5
599 and cached_diff.get('commits')[0]
599 and cached_diff.get('commits')[0]
600 and cached_diff.get('commits')[3])
600 and cached_diff.get('commits')[3])
601
601
602 if not force_recache and not c.range_diff_on and has_proper_commit_cache:
602 if not force_recache and not c.range_diff_on and has_proper_commit_cache:
603 diff_commit_cache = \
603 diff_commit_cache = \
604 (ancestor_commit, commit_cache, missing_requirements,
604 (ancestor_commit, commit_cache, missing_requirements,
605 source_commit, target_commit) = cached_diff['commits']
605 source_commit, target_commit) = cached_diff['commits']
606 else:
606 else:
607 # NOTE(marcink): we reach potentially unreachable errors when a PR has
607 # NOTE(marcink): we reach potentially unreachable errors when a PR has
608 # merge errors resulting in potentially hidden commits in the shadow repo.
608 # merge errors resulting in potentially hidden commits in the shadow repo.
609 maybe_unreachable = _merge_check.MERGE_CHECK in _merge_check.error_details \
609 maybe_unreachable = _merge_check.MERGE_CHECK in _merge_check.error_details \
610 and _merge_check.merge_response
610 and _merge_check.merge_response
611 maybe_unreachable = maybe_unreachable \
611 maybe_unreachable = maybe_unreachable \
612 and _merge_check.merge_response.metadata.get('unresolved_files')
612 and _merge_check.merge_response.metadata.get('unresolved_files')
613 log.debug("Using unreachable commits due to MERGE_CHECK in merge simulation")
613 log.debug("Using unreachable commits due to MERGE_CHECK in merge simulation")
614 diff_commit_cache = \
614 diff_commit_cache = \
615 (ancestor_commit, commit_cache, missing_requirements,
615 (ancestor_commit, commit_cache, missing_requirements,
616 source_commit, target_commit) = self.get_commits(
616 source_commit, target_commit) = self.get_commits(
617 commits_source_repo,
617 commits_source_repo,
618 pull_request_at_ver,
618 pull_request_at_ver,
619 source_commit,
619 source_commit,
620 source_ref_id,
620 source_ref_id,
621 source_scm,
621 source_scm,
622 target_commit,
622 target_commit,
623 target_ref_id,
623 target_ref_id,
624 target_scm,
624 target_scm,
625 maybe_unreachable=maybe_unreachable)
625 maybe_unreachable=maybe_unreachable)
626
626
627 # register our commit range
627 # register our commit range
628 for comm in commit_cache.values():
628 for comm in commit_cache.values():
629 c.commit_ranges.append(comm)
629 c.commit_ranges.append(comm)
630
630
631 c.missing_requirements = missing_requirements
631 c.missing_requirements = missing_requirements
632 c.ancestor_commit = ancestor_commit
632 c.ancestor_commit = ancestor_commit
633 c.statuses = source_repo.statuses(
633 c.statuses = source_repo.statuses(
634 [x.raw_id for x in c.commit_ranges])
634 [x.raw_id for x in c.commit_ranges])
635
635
636 # auto collapse if we have more than limit
636 # auto collapse if we have more than limit
637 collapse_limit = diffs.DiffProcessor._collapse_commits_over
637 collapse_limit = diffs.DiffProcessor._collapse_commits_over
638 c.collapse_all_commits = len(c.commit_ranges) > collapse_limit
638 c.collapse_all_commits = len(c.commit_ranges) > collapse_limit
639 c.compare_mode = compare
639 c.compare_mode = compare
640
640
641 # diff_limit is the old behavior, will cut off the whole diff
641 # diff_limit is the old behavior, will cut off the whole diff
642 # if the limit is applied otherwise will just hide the
642 # if the limit is applied otherwise will just hide the
643 # big files from the front-end
643 # big files from the front-end
644 diff_limit = c.visual.cut_off_limit_diff
644 diff_limit = c.visual.cut_off_limit_diff
645 file_limit = c.visual.cut_off_limit_file
645 file_limit = c.visual.cut_off_limit_file
646
646
647 c.missing_commits = False
647 c.missing_commits = False
648 if (c.missing_requirements
648 if (c.missing_requirements
649 or isinstance(source_commit, EmptyCommit)
649 or isinstance(source_commit, EmptyCommit)
650 or source_commit == target_commit):
650 or source_commit == target_commit):
651
651
652 c.missing_commits = True
652 c.missing_commits = True
653 else:
653 else:
654 c.inline_comments = display_inline_comments
654 c.inline_comments = display_inline_comments
655
655
656 use_ancestor = True
656 use_ancestor = True
657 if from_version_normalized != version_normalized:
657 if from_version_normalized != version_normalized:
658 use_ancestor = False
658 use_ancestor = False
659
659
660 has_proper_diff_cache = cached_diff and cached_diff.get('commits')
660 has_proper_diff_cache = cached_diff and cached_diff.get('commits')
661 if not force_recache and has_proper_diff_cache:
661 if not force_recache and has_proper_diff_cache:
662 c.diffset = cached_diff['diff']
662 c.diffset = cached_diff['diff']
663 else:
663 else:
664 try:
664 try:
665 c.diffset = self._get_diffset(
665 c.diffset = self._get_diffset(
666 c.source_repo.repo_name, commits_source_repo,
666 c.source_repo.repo_name, commits_source_repo,
667 c.ancestor_commit,
667 c.ancestor_commit,
668 source_ref_id, target_ref_id,
668 source_ref_id, target_ref_id,
669 target_commit, source_commit,
669 target_commit, source_commit,
670 diff_limit, file_limit, c.fulldiff,
670 diff_limit, file_limit, c.fulldiff,
671 hide_whitespace_changes, diff_context,
671 hide_whitespace_changes, diff_context,
672 use_ancestor=use_ancestor
672 use_ancestor=use_ancestor
673 )
673 )
674
674
675 # save cached diff
675 # save cached diff
676 if caching_enabled:
676 if caching_enabled:
677 cache_diff(cache_file_path, c.diffset, diff_commit_cache)
677 cache_diff(cache_file_path, c.diffset, diff_commit_cache)
678 except CommitDoesNotExistError:
678 except CommitDoesNotExistError:
679 log.exception('Failed to generate diffset')
679 log.exception('Failed to generate diffset')
680 c.missing_commits = True
680 c.missing_commits = True
681
681
682 if not c.missing_commits:
682 if not c.missing_commits:
683
683
684 c.limited_diff = c.diffset.limited_diff
684 c.limited_diff = c.diffset.limited_diff
685
685
686 # calculate removed files that are bound to comments
686 # calculate removed files that are bound to comments
687 comment_deleted_files = [
687 comment_deleted_files = [
688 fname for fname in display_inline_comments
688 fname for fname in display_inline_comments
689 if fname not in c.diffset.file_stats]
689 if fname not in c.diffset.file_stats]
690
690
691 c.deleted_files_comments = collections.defaultdict(dict)
691 c.deleted_files_comments = collections.defaultdict(dict)
692 for fname, per_line_comments in display_inline_comments.items():
692 for fname, per_line_comments in display_inline_comments.items():
693 if fname in comment_deleted_files:
693 if fname in comment_deleted_files:
694 c.deleted_files_comments[fname]['stats'] = 0
694 c.deleted_files_comments[fname]['stats'] = 0
695 c.deleted_files_comments[fname]['comments'] = list()
695 c.deleted_files_comments[fname]['comments'] = list()
696 for lno, comments in per_line_comments.items():
696 for lno, comments in per_line_comments.items():
697 c.deleted_files_comments[fname]['comments'].extend(comments)
697 c.deleted_files_comments[fname]['comments'].extend(comments)
698
698
699 # maybe calculate the range diff
699 # maybe calculate the range diff
700 if c.range_diff_on:
700 if c.range_diff_on:
701 # TODO(marcink): set whitespace/context
701 # TODO(marcink): set whitespace/context
702 context_lcl = 3
702 context_lcl = 3
703 ign_whitespace_lcl = False
703 ign_whitespace_lcl = False
704
704
705 for commit in c.commit_ranges:
705 for commit in c.commit_ranges:
706 commit2 = commit
706 commit2 = commit
707 commit1 = commit.first_parent
707 commit1 = commit.first_parent
708
708
709 range_diff_cache_file_path = diff_cache_exist(
709 range_diff_cache_file_path = diff_cache_exist(
710 cache_path, 'diff', commit.raw_id,
710 cache_path, 'diff', commit.raw_id,
711 ign_whitespace_lcl, context_lcl, c.fulldiff)
711 ign_whitespace_lcl, context_lcl, c.fulldiff)
712
712
713 cached_diff = None
713 cached_diff = None
714 if caching_enabled:
714 if caching_enabled:
715 cached_diff = load_cached_diff(range_diff_cache_file_path)
715 cached_diff = load_cached_diff(range_diff_cache_file_path)
716
716
717 has_proper_diff_cache = cached_diff and cached_diff.get('diff')
717 has_proper_diff_cache = cached_diff and cached_diff.get('diff')
718 if not force_recache and has_proper_diff_cache:
718 if not force_recache and has_proper_diff_cache:
719 diffset = cached_diff['diff']
719 diffset = cached_diff['diff']
720 else:
720 else:
721 diffset = self._get_range_diffset(
721 diffset = self._get_range_diffset(
722 commits_source_repo, source_repo,
722 commits_source_repo, source_repo,
723 commit1, commit2, diff_limit, file_limit,
723 commit1, commit2, diff_limit, file_limit,
724 c.fulldiff, ign_whitespace_lcl, context_lcl
724 c.fulldiff, ign_whitespace_lcl, context_lcl
725 )
725 )
726
726
727 # save cached diff
727 # save cached diff
728 if caching_enabled:
728 if caching_enabled:
729 cache_diff(range_diff_cache_file_path, diffset, None)
729 cache_diff(range_diff_cache_file_path, diffset, None)
730
730
731 c.changes[commit.raw_id] = diffset
731 c.changes[commit.raw_id] = diffset
732
732
733 # this is a hack to properly display links, when creating PR, the
733 # this is a hack to properly display links, when creating PR, the
734 # compare view and others uses different notation, and
734 # compare view and others uses different notation, and
735 # compare_commits.mako renders links based on the target_repo.
735 # compare_commits.mako renders links based on the target_repo.
736 # We need to swap that here to generate it properly on the html side
736 # We need to swap that here to generate it properly on the html side
737 c.target_repo = c.source_repo
737 c.target_repo = c.source_repo
738
738
739 c.commit_statuses = ChangesetStatus.STATUSES
739 c.commit_statuses = ChangesetStatus.STATUSES
740
740
741 c.show_version_changes = not pr_closed
741 c.show_version_changes = not pr_closed
742 if c.show_version_changes:
742 if c.show_version_changes:
743 cur_obj = pull_request_at_ver
743 cur_obj = pull_request_at_ver
744 prev_obj = prev_pull_request_at_ver
744 prev_obj = prev_pull_request_at_ver
745
745
746 old_commit_ids = prev_obj.revisions
746 old_commit_ids = prev_obj.revisions
747 new_commit_ids = cur_obj.revisions
747 new_commit_ids = cur_obj.revisions
748 commit_changes = PullRequestModel()._calculate_commit_id_changes(
748 commit_changes = PullRequestModel()._calculate_commit_id_changes(
749 old_commit_ids, new_commit_ids)
749 old_commit_ids, new_commit_ids)
750 c.commit_changes_summary = commit_changes
750 c.commit_changes_summary = commit_changes
751
751
752 # calculate the diff for commits between versions
752 # calculate the diff for commits between versions
753 c.commit_changes = []
753 c.commit_changes = []
754
754
755 def mark(cs, fw):
755 def mark(cs, fw):
756 return list(h.itertools.izip_longest([], cs, fillvalue=fw))
756 return list(h.itertools.zip_longest([], cs, fillvalue=fw))
757
757
758 for c_type, raw_id in mark(commit_changes.added, 'a') \
758 for c_type, raw_id in mark(commit_changes.added, 'a') \
759 + mark(commit_changes.removed, 'r') \
759 + mark(commit_changes.removed, 'r') \
760 + mark(commit_changes.common, 'c'):
760 + mark(commit_changes.common, 'c'):
761
761
762 if raw_id in commit_cache:
762 if raw_id in commit_cache:
763 commit = commit_cache[raw_id]
763 commit = commit_cache[raw_id]
764 else:
764 else:
765 try:
765 try:
766 commit = commits_source_repo.get_commit(raw_id)
766 commit = commits_source_repo.get_commit(raw_id)
767 except CommitDoesNotExistError:
767 except CommitDoesNotExistError:
768 # in case we fail extracting still use "dummy" commit
768 # in case we fail extracting still use "dummy" commit
769 # for display in commit diff
769 # for display in commit diff
770 commit = h.AttributeDict(
770 commit = h.AttributeDict(
771 {'raw_id': raw_id,
771 {'raw_id': raw_id,
772 'message': 'EMPTY or MISSING COMMIT'})
772 'message': 'EMPTY or MISSING COMMIT'})
773 c.commit_changes.append([c_type, commit])
773 c.commit_changes.append([c_type, commit])
774
774
775 # current user review statuses for each version
775 # current user review statuses for each version
776 c.review_versions = {}
776 c.review_versions = {}
777 is_reviewer = PullRequestModel().is_user_reviewer(
777 is_reviewer = PullRequestModel().is_user_reviewer(
778 pull_request, self._rhodecode_user)
778 pull_request, self._rhodecode_user)
779 if is_reviewer:
779 if is_reviewer:
780 for co in general_comments:
780 for co in general_comments:
781 if co.author.user_id == self._rhodecode_user.user_id:
781 if co.author.user_id == self._rhodecode_user.user_id:
782 status = co.status_change
782 status = co.status_change
783 if status:
783 if status:
784 _ver_pr = status[0].comment.pull_request_version_id
784 _ver_pr = status[0].comment.pull_request_version_id
785 c.review_versions[_ver_pr] = status[0]
785 c.review_versions[_ver_pr] = status[0]
786
786
787 return self._get_template_context(c)
787 return self._get_template_context(c)
788
788
789 def get_commits(
789 def get_commits(
790 self, commits_source_repo, pull_request_at_ver, source_commit,
790 self, commits_source_repo, pull_request_at_ver, source_commit,
791 source_ref_id, source_scm, target_commit, target_ref_id, target_scm,
791 source_ref_id, source_scm, target_commit, target_ref_id, target_scm,
792 maybe_unreachable=False):
792 maybe_unreachable=False):
793
793
794 commit_cache = collections.OrderedDict()
794 commit_cache = collections.OrderedDict()
795 missing_requirements = False
795 missing_requirements = False
796
796
797 try:
797 try:
798 pre_load = ["author", "date", "message", "branch", "parents"]
798 pre_load = ["author", "date", "message", "branch", "parents"]
799
799
800 pull_request_commits = pull_request_at_ver.revisions
800 pull_request_commits = pull_request_at_ver.revisions
801 log.debug('Loading %s commits from %s',
801 log.debug('Loading %s commits from %s',
802 len(pull_request_commits), commits_source_repo)
802 len(pull_request_commits), commits_source_repo)
803
803
804 for rev in pull_request_commits:
804 for rev in pull_request_commits:
805 comm = commits_source_repo.get_commit(commit_id=rev, pre_load=pre_load,
805 comm = commits_source_repo.get_commit(commit_id=rev, pre_load=pre_load,
806 maybe_unreachable=maybe_unreachable)
806 maybe_unreachable=maybe_unreachable)
807 commit_cache[comm.raw_id] = comm
807 commit_cache[comm.raw_id] = comm
808
808
809 # Order here matters, we first need to get target, and then
809 # Order here matters, we first need to get target, and then
810 # the source
810 # the source
811 target_commit = commits_source_repo.get_commit(
811 target_commit = commits_source_repo.get_commit(
812 commit_id=safe_str(target_ref_id))
812 commit_id=safe_str(target_ref_id))
813
813
814 source_commit = commits_source_repo.get_commit(
814 source_commit = commits_source_repo.get_commit(
815 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
815 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
816 except CommitDoesNotExistError:
816 except CommitDoesNotExistError:
817 log.warning('Failed to get commit from `{}` repo'.format(
817 log.warning('Failed to get commit from `{}` repo'.format(
818 commits_source_repo), exc_info=True)
818 commits_source_repo), exc_info=True)
819 except RepositoryRequirementError:
819 except RepositoryRequirementError:
820 log.warning('Failed to get all required data from repo', exc_info=True)
820 log.warning('Failed to get all required data from repo', exc_info=True)
821 missing_requirements = True
821 missing_requirements = True
822
822
823 pr_ancestor_id = pull_request_at_ver.common_ancestor_id
823 pr_ancestor_id = pull_request_at_ver.common_ancestor_id
824
824
825 try:
825 try:
826 ancestor_commit = source_scm.get_commit(pr_ancestor_id)
826 ancestor_commit = source_scm.get_commit(pr_ancestor_id)
827 except Exception:
827 except Exception:
828 ancestor_commit = None
828 ancestor_commit = None
829
829
830 return ancestor_commit, commit_cache, missing_requirements, source_commit, target_commit
830 return ancestor_commit, commit_cache, missing_requirements, source_commit, target_commit
831
831
832 def assure_not_empty_repo(self):
832 def assure_not_empty_repo(self):
833 _ = self.request.translate
833 _ = self.request.translate
834
834
835 try:
835 try:
836 self.db_repo.scm_instance().get_commit()
836 self.db_repo.scm_instance().get_commit()
837 except EmptyRepositoryError:
837 except EmptyRepositoryError:
838 h.flash(h.literal(_('There are no commits yet')),
838 h.flash(h.literal(_('There are no commits yet')),
839 category='warning')
839 category='warning')
840 raise HTTPFound(
840 raise HTTPFound(
841 h.route_path('repo_summary', repo_name=self.db_repo.repo_name))
841 h.route_path('repo_summary', repo_name=self.db_repo.repo_name))
842
842
843 @LoginRequired()
843 @LoginRequired()
844 @NotAnonymous()
844 @NotAnonymous()
845 @HasRepoPermissionAnyDecorator(
845 @HasRepoPermissionAnyDecorator(
846 'repository.read', 'repository.write', 'repository.admin')
846 'repository.read', 'repository.write', 'repository.admin')
847 def pull_request_new(self):
847 def pull_request_new(self):
848 _ = self.request.translate
848 _ = self.request.translate
849 c = self.load_default_context()
849 c = self.load_default_context()
850
850
851 self.assure_not_empty_repo()
851 self.assure_not_empty_repo()
852 source_repo = self.db_repo
852 source_repo = self.db_repo
853
853
854 commit_id = self.request.GET.get('commit')
854 commit_id = self.request.GET.get('commit')
855 branch_ref = self.request.GET.get('branch')
855 branch_ref = self.request.GET.get('branch')
856 bookmark_ref = self.request.GET.get('bookmark')
856 bookmark_ref = self.request.GET.get('bookmark')
857
857
858 try:
858 try:
859 source_repo_data = PullRequestModel().generate_repo_data(
859 source_repo_data = PullRequestModel().generate_repo_data(
860 source_repo, commit_id=commit_id,
860 source_repo, commit_id=commit_id,
861 branch=branch_ref, bookmark=bookmark_ref,
861 branch=branch_ref, bookmark=bookmark_ref,
862 translator=self.request.translate)
862 translator=self.request.translate)
863 except CommitDoesNotExistError as e:
863 except CommitDoesNotExistError as e:
864 log.exception(e)
864 log.exception(e)
865 h.flash(_('Commit does not exist'), 'error')
865 h.flash(_('Commit does not exist'), 'error')
866 raise HTTPFound(
866 raise HTTPFound(
867 h.route_path('pullrequest_new', repo_name=source_repo.repo_name))
867 h.route_path('pullrequest_new', repo_name=source_repo.repo_name))
868
868
869 default_target_repo = source_repo
869 default_target_repo = source_repo
870
870
871 if source_repo.parent and c.has_origin_repo_read_perm:
871 if source_repo.parent and c.has_origin_repo_read_perm:
872 parent_vcs_obj = source_repo.parent.scm_instance()
872 parent_vcs_obj = source_repo.parent.scm_instance()
873 if parent_vcs_obj and not parent_vcs_obj.is_empty():
873 if parent_vcs_obj and not parent_vcs_obj.is_empty():
874 # change default if we have a parent repo
874 # change default if we have a parent repo
875 default_target_repo = source_repo.parent
875 default_target_repo = source_repo.parent
876
876
877 target_repo_data = PullRequestModel().generate_repo_data(
877 target_repo_data = PullRequestModel().generate_repo_data(
878 default_target_repo, translator=self.request.translate)
878 default_target_repo, translator=self.request.translate)
879
879
880 selected_source_ref = source_repo_data['refs']['selected_ref']
880 selected_source_ref = source_repo_data['refs']['selected_ref']
881 title_source_ref = ''
881 title_source_ref = ''
882 if selected_source_ref:
882 if selected_source_ref:
883 title_source_ref = selected_source_ref.split(':', 2)[1]
883 title_source_ref = selected_source_ref.split(':', 2)[1]
884 c.default_title = PullRequestModel().generate_pullrequest_title(
884 c.default_title = PullRequestModel().generate_pullrequest_title(
885 source=source_repo.repo_name,
885 source=source_repo.repo_name,
886 source_ref=title_source_ref,
886 source_ref=title_source_ref,
887 target=default_target_repo.repo_name
887 target=default_target_repo.repo_name
888 )
888 )
889
889
890 c.default_repo_data = {
890 c.default_repo_data = {
891 'source_repo_name': source_repo.repo_name,
891 'source_repo_name': source_repo.repo_name,
892 'source_refs_json': json.dumps(source_repo_data),
892 'source_refs_json': json.dumps(source_repo_data),
893 'target_repo_name': default_target_repo.repo_name,
893 'target_repo_name': default_target_repo.repo_name,
894 'target_refs_json': json.dumps(target_repo_data),
894 'target_refs_json': json.dumps(target_repo_data),
895 }
895 }
896 c.default_source_ref = selected_source_ref
896 c.default_source_ref = selected_source_ref
897
897
898 return self._get_template_context(c)
898 return self._get_template_context(c)
899
899
900 @LoginRequired()
900 @LoginRequired()
901 @NotAnonymous()
901 @NotAnonymous()
902 @HasRepoPermissionAnyDecorator(
902 @HasRepoPermissionAnyDecorator(
903 'repository.read', 'repository.write', 'repository.admin')
903 'repository.read', 'repository.write', 'repository.admin')
904 def pull_request_repo_refs(self):
904 def pull_request_repo_refs(self):
905 self.load_default_context()
905 self.load_default_context()
906 target_repo_name = self.request.matchdict['target_repo_name']
906 target_repo_name = self.request.matchdict['target_repo_name']
907 repo = Repository.get_by_repo_name(target_repo_name)
907 repo = Repository.get_by_repo_name(target_repo_name)
908 if not repo:
908 if not repo:
909 raise HTTPNotFound()
909 raise HTTPNotFound()
910
910
911 target_perm = HasRepoPermissionAny(
911 target_perm = HasRepoPermissionAny(
912 'repository.read', 'repository.write', 'repository.admin')(
912 'repository.read', 'repository.write', 'repository.admin')(
913 target_repo_name)
913 target_repo_name)
914 if not target_perm:
914 if not target_perm:
915 raise HTTPNotFound()
915 raise HTTPNotFound()
916
916
917 return PullRequestModel().generate_repo_data(
917 return PullRequestModel().generate_repo_data(
918 repo, translator=self.request.translate)
918 repo, translator=self.request.translate)
919
919
920 @LoginRequired()
920 @LoginRequired()
921 @NotAnonymous()
921 @NotAnonymous()
922 @HasRepoPermissionAnyDecorator(
922 @HasRepoPermissionAnyDecorator(
923 'repository.read', 'repository.write', 'repository.admin')
923 'repository.read', 'repository.write', 'repository.admin')
924 def pullrequest_repo_targets(self):
924 def pullrequest_repo_targets(self):
925 _ = self.request.translate
925 _ = self.request.translate
926 filter_query = self.request.GET.get('query')
926 filter_query = self.request.GET.get('query')
927
927
928 # get the parents
928 # get the parents
929 parent_target_repos = []
929 parent_target_repos = []
930 if self.db_repo.parent:
930 if self.db_repo.parent:
931 parents_query = Repository.query() \
931 parents_query = Repository.query() \
932 .order_by(func.length(Repository.repo_name)) \
932 .order_by(func.length(Repository.repo_name)) \
933 .filter(Repository.fork_id == self.db_repo.parent.repo_id)
933 .filter(Repository.fork_id == self.db_repo.parent.repo_id)
934
934
935 if filter_query:
935 if filter_query:
936 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
936 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
937 parents_query = parents_query.filter(
937 parents_query = parents_query.filter(
938 Repository.repo_name.ilike(ilike_expression))
938 Repository.repo_name.ilike(ilike_expression))
939 parents = parents_query.limit(20).all()
939 parents = parents_query.limit(20).all()
940
940
941 for parent in parents:
941 for parent in parents:
942 parent_vcs_obj = parent.scm_instance()
942 parent_vcs_obj = parent.scm_instance()
943 if parent_vcs_obj and not parent_vcs_obj.is_empty():
943 if parent_vcs_obj and not parent_vcs_obj.is_empty():
944 parent_target_repos.append(parent)
944 parent_target_repos.append(parent)
945
945
946 # get other forks, and repo itself
946 # get other forks, and repo itself
947 query = Repository.query() \
947 query = Repository.query() \
948 .order_by(func.length(Repository.repo_name)) \
948 .order_by(func.length(Repository.repo_name)) \
949 .filter(
949 .filter(
950 or_(Repository.repo_id == self.db_repo.repo_id, # repo itself
950 or_(Repository.repo_id == self.db_repo.repo_id, # repo itself
951 Repository.fork_id == self.db_repo.repo_id) # forks of this repo
951 Repository.fork_id == self.db_repo.repo_id) # forks of this repo
952 ) \
952 ) \
953 .filter(~Repository.repo_id.in_([x.repo_id for x in parent_target_repos]))
953 .filter(~Repository.repo_id.in_([x.repo_id for x in parent_target_repos]))
954
954
955 if filter_query:
955 if filter_query:
956 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
956 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
957 query = query.filter(Repository.repo_name.ilike(ilike_expression))
957 query = query.filter(Repository.repo_name.ilike(ilike_expression))
958
958
959 limit = max(20 - len(parent_target_repos), 5) # not less then 5
959 limit = max(20 - len(parent_target_repos), 5) # not less then 5
960 target_repos = query.limit(limit).all()
960 target_repos = query.limit(limit).all()
961
961
962 all_target_repos = target_repos + parent_target_repos
962 all_target_repos = target_repos + parent_target_repos
963
963
964 repos = []
964 repos = []
965 # This checks permissions to the repositories
965 # This checks permissions to the repositories
966 for obj in ScmModel().get_repos(all_target_repos):
966 for obj in ScmModel().get_repos(all_target_repos):
967 repos.append({
967 repos.append({
968 'id': obj['name'],
968 'id': obj['name'],
969 'text': obj['name'],
969 'text': obj['name'],
970 'type': 'repo',
970 'type': 'repo',
971 'repo_id': obj['dbrepo']['repo_id'],
971 'repo_id': obj['dbrepo']['repo_id'],
972 'repo_type': obj['dbrepo']['repo_type'],
972 'repo_type': obj['dbrepo']['repo_type'],
973 'private': obj['dbrepo']['private'],
973 'private': obj['dbrepo']['private'],
974
974
975 })
975 })
976
976
977 data = {
977 data = {
978 'more': False,
978 'more': False,
979 'results': [{
979 'results': [{
980 'text': _('Repositories'),
980 'text': _('Repositories'),
981 'children': repos
981 'children': repos
982 }] if repos else []
982 }] if repos else []
983 }
983 }
984 return data
984 return data
985
985
986 @classmethod
986 @classmethod
987 def get_comment_ids(cls, post_data):
987 def get_comment_ids(cls, post_data):
988 return filter(lambda e: e > 0, map(safe_int, aslist(post_data.get('comments'), ',')))
988 return filter(lambda e: e > 0, map(safe_int, aslist(post_data.get('comments'), ',')))
989
989
990 @LoginRequired()
990 @LoginRequired()
991 @NotAnonymous()
991 @NotAnonymous()
992 @HasRepoPermissionAnyDecorator(
992 @HasRepoPermissionAnyDecorator(
993 'repository.read', 'repository.write', 'repository.admin')
993 'repository.read', 'repository.write', 'repository.admin')
994 def pullrequest_comments(self):
994 def pullrequest_comments(self):
995 self.load_default_context()
995 self.load_default_context()
996
996
997 pull_request = PullRequest.get_or_404(
997 pull_request = PullRequest.get_or_404(
998 self.request.matchdict['pull_request_id'])
998 self.request.matchdict['pull_request_id'])
999 pull_request_id = pull_request.pull_request_id
999 pull_request_id = pull_request.pull_request_id
1000 version = self.request.GET.get('version')
1000 version = self.request.GET.get('version')
1001
1001
1002 _render = self.request.get_partial_renderer(
1002 _render = self.request.get_partial_renderer(
1003 'rhodecode:templates/base/sidebar.mako')
1003 'rhodecode:templates/base/sidebar.mako')
1004 c = _render.get_call_context()
1004 c = _render.get_call_context()
1005
1005
1006 (pull_request_latest,
1006 (pull_request_latest,
1007 pull_request_at_ver,
1007 pull_request_at_ver,
1008 pull_request_display_obj,
1008 pull_request_display_obj,
1009 at_version) = PullRequestModel().get_pr_version(
1009 at_version) = PullRequestModel().get_pr_version(
1010 pull_request_id, version=version)
1010 pull_request_id, version=version)
1011 versions = pull_request_display_obj.versions()
1011 versions = pull_request_display_obj.versions()
1012 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1012 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1013 c.versions = versions + [latest_ver]
1013 c.versions = versions + [latest_ver]
1014
1014
1015 c.at_version = at_version
1015 c.at_version = at_version
1016 c.at_version_num = (at_version
1016 c.at_version_num = (at_version
1017 if at_version and at_version != PullRequest.LATEST_VER
1017 if at_version and at_version != PullRequest.LATEST_VER
1018 else None)
1018 else None)
1019
1019
1020 self.register_comments_vars(c, pull_request_latest, versions, include_drafts=False)
1020 self.register_comments_vars(c, pull_request_latest, versions, include_drafts=False)
1021 all_comments = c.inline_comments_flat + c.comments
1021 all_comments = c.inline_comments_flat + c.comments
1022
1022
1023 existing_ids = self.get_comment_ids(self.request.POST)
1023 existing_ids = self.get_comment_ids(self.request.POST)
1024 return _render('comments_table', all_comments, len(all_comments),
1024 return _render('comments_table', all_comments, len(all_comments),
1025 existing_ids=existing_ids)
1025 existing_ids=existing_ids)
1026
1026
1027 @LoginRequired()
1027 @LoginRequired()
1028 @NotAnonymous()
1028 @NotAnonymous()
1029 @HasRepoPermissionAnyDecorator(
1029 @HasRepoPermissionAnyDecorator(
1030 'repository.read', 'repository.write', 'repository.admin')
1030 'repository.read', 'repository.write', 'repository.admin')
1031 def pullrequest_todos(self):
1031 def pullrequest_todos(self):
1032 self.load_default_context()
1032 self.load_default_context()
1033
1033
1034 pull_request = PullRequest.get_or_404(
1034 pull_request = PullRequest.get_or_404(
1035 self.request.matchdict['pull_request_id'])
1035 self.request.matchdict['pull_request_id'])
1036 pull_request_id = pull_request.pull_request_id
1036 pull_request_id = pull_request.pull_request_id
1037 version = self.request.GET.get('version')
1037 version = self.request.GET.get('version')
1038
1038
1039 _render = self.request.get_partial_renderer(
1039 _render = self.request.get_partial_renderer(
1040 'rhodecode:templates/base/sidebar.mako')
1040 'rhodecode:templates/base/sidebar.mako')
1041 c = _render.get_call_context()
1041 c = _render.get_call_context()
1042 (pull_request_latest,
1042 (pull_request_latest,
1043 pull_request_at_ver,
1043 pull_request_at_ver,
1044 pull_request_display_obj,
1044 pull_request_display_obj,
1045 at_version) = PullRequestModel().get_pr_version(
1045 at_version) = PullRequestModel().get_pr_version(
1046 pull_request_id, version=version)
1046 pull_request_id, version=version)
1047 versions = pull_request_display_obj.versions()
1047 versions = pull_request_display_obj.versions()
1048 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1048 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1049 c.versions = versions + [latest_ver]
1049 c.versions = versions + [latest_ver]
1050
1050
1051 c.at_version = at_version
1051 c.at_version = at_version
1052 c.at_version_num = (at_version
1052 c.at_version_num = (at_version
1053 if at_version and at_version != PullRequest.LATEST_VER
1053 if at_version and at_version != PullRequest.LATEST_VER
1054 else None)
1054 else None)
1055
1055
1056 c.unresolved_comments = CommentsModel() \
1056 c.unresolved_comments = CommentsModel() \
1057 .get_pull_request_unresolved_todos(pull_request, include_drafts=False)
1057 .get_pull_request_unresolved_todos(pull_request, include_drafts=False)
1058 c.resolved_comments = CommentsModel() \
1058 c.resolved_comments = CommentsModel() \
1059 .get_pull_request_resolved_todos(pull_request, include_drafts=False)
1059 .get_pull_request_resolved_todos(pull_request, include_drafts=False)
1060
1060
1061 all_comments = c.unresolved_comments + c.resolved_comments
1061 all_comments = c.unresolved_comments + c.resolved_comments
1062 existing_ids = self.get_comment_ids(self.request.POST)
1062 existing_ids = self.get_comment_ids(self.request.POST)
1063 return _render('comments_table', all_comments, len(c.unresolved_comments),
1063 return _render('comments_table', all_comments, len(c.unresolved_comments),
1064 todo_comments=True, existing_ids=existing_ids)
1064 todo_comments=True, existing_ids=existing_ids)
1065
1065
1066 @LoginRequired()
1066 @LoginRequired()
1067 @NotAnonymous()
1067 @NotAnonymous()
1068 @HasRepoPermissionAnyDecorator(
1068 @HasRepoPermissionAnyDecorator(
1069 'repository.read', 'repository.write', 'repository.admin')
1069 'repository.read', 'repository.write', 'repository.admin')
1070 def pullrequest_drafts(self):
1070 def pullrequest_drafts(self):
1071 self.load_default_context()
1071 self.load_default_context()
1072
1072
1073 pull_request = PullRequest.get_or_404(
1073 pull_request = PullRequest.get_or_404(
1074 self.request.matchdict['pull_request_id'])
1074 self.request.matchdict['pull_request_id'])
1075 pull_request_id = pull_request.pull_request_id
1075 pull_request_id = pull_request.pull_request_id
1076 version = self.request.GET.get('version')
1076 version = self.request.GET.get('version')
1077
1077
1078 _render = self.request.get_partial_renderer(
1078 _render = self.request.get_partial_renderer(
1079 'rhodecode:templates/base/sidebar.mako')
1079 'rhodecode:templates/base/sidebar.mako')
1080 c = _render.get_call_context()
1080 c = _render.get_call_context()
1081
1081
1082 (pull_request_latest,
1082 (pull_request_latest,
1083 pull_request_at_ver,
1083 pull_request_at_ver,
1084 pull_request_display_obj,
1084 pull_request_display_obj,
1085 at_version) = PullRequestModel().get_pr_version(
1085 at_version) = PullRequestModel().get_pr_version(
1086 pull_request_id, version=version)
1086 pull_request_id, version=version)
1087 versions = pull_request_display_obj.versions()
1087 versions = pull_request_display_obj.versions()
1088 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1088 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1089 c.versions = versions + [latest_ver]
1089 c.versions = versions + [latest_ver]
1090
1090
1091 c.at_version = at_version
1091 c.at_version = at_version
1092 c.at_version_num = (at_version
1092 c.at_version_num = (at_version
1093 if at_version and at_version != PullRequest.LATEST_VER
1093 if at_version and at_version != PullRequest.LATEST_VER
1094 else None)
1094 else None)
1095
1095
1096 c.draft_comments = CommentsModel() \
1096 c.draft_comments = CommentsModel() \
1097 .get_pull_request_drafts(self._rhodecode_db_user.user_id, pull_request)
1097 .get_pull_request_drafts(self._rhodecode_db_user.user_id, pull_request)
1098
1098
1099 all_comments = c.draft_comments
1099 all_comments = c.draft_comments
1100
1100
1101 existing_ids = self.get_comment_ids(self.request.POST)
1101 existing_ids = self.get_comment_ids(self.request.POST)
1102 return _render('comments_table', all_comments, len(all_comments),
1102 return _render('comments_table', all_comments, len(all_comments),
1103 existing_ids=existing_ids, draft_comments=True)
1103 existing_ids=existing_ids, draft_comments=True)
1104
1104
1105 @LoginRequired()
1105 @LoginRequired()
1106 @NotAnonymous()
1106 @NotAnonymous()
1107 @HasRepoPermissionAnyDecorator(
1107 @HasRepoPermissionAnyDecorator(
1108 'repository.read', 'repository.write', 'repository.admin')
1108 'repository.read', 'repository.write', 'repository.admin')
1109 @CSRFRequired()
1109 @CSRFRequired()
1110 def pull_request_create(self):
1110 def pull_request_create(self):
1111 _ = self.request.translate
1111 _ = self.request.translate
1112 self.assure_not_empty_repo()
1112 self.assure_not_empty_repo()
1113 self.load_default_context()
1113 self.load_default_context()
1114
1114
1115 controls = peppercorn.parse(self.request.POST.items())
1115 controls = peppercorn.parse(self.request.POST.items())
1116
1116
1117 try:
1117 try:
1118 form = PullRequestForm(
1118 form = PullRequestForm(
1119 self.request.translate, self.db_repo.repo_id)()
1119 self.request.translate, self.db_repo.repo_id)()
1120 _form = form.to_python(controls)
1120 _form = form.to_python(controls)
1121 except formencode.Invalid as errors:
1121 except formencode.Invalid as errors:
1122 if errors.error_dict.get('revisions'):
1122 if errors.error_dict.get('revisions'):
1123 msg = 'Revisions: %s' % errors.error_dict['revisions']
1123 msg = 'Revisions: %s' % errors.error_dict['revisions']
1124 elif errors.error_dict.get('pullrequest_title'):
1124 elif errors.error_dict.get('pullrequest_title'):
1125 msg = errors.error_dict.get('pullrequest_title')
1125 msg = errors.error_dict.get('pullrequest_title')
1126 else:
1126 else:
1127 msg = _('Error creating pull request: {}').format(errors)
1127 msg = _('Error creating pull request: {}').format(errors)
1128 log.exception(msg)
1128 log.exception(msg)
1129 h.flash(msg, 'error')
1129 h.flash(msg, 'error')
1130
1130
1131 # would rather just go back to form ...
1131 # would rather just go back to form ...
1132 raise HTTPFound(
1132 raise HTTPFound(
1133 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1133 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1134
1134
1135 source_repo = _form['source_repo']
1135 source_repo = _form['source_repo']
1136 source_ref = _form['source_ref']
1136 source_ref = _form['source_ref']
1137 target_repo = _form['target_repo']
1137 target_repo = _form['target_repo']
1138 target_ref = _form['target_ref']
1138 target_ref = _form['target_ref']
1139 commit_ids = _form['revisions'][::-1]
1139 commit_ids = _form['revisions'][::-1]
1140 common_ancestor_id = _form['common_ancestor']
1140 common_ancestor_id = _form['common_ancestor']
1141
1141
1142 # find the ancestor for this pr
1142 # find the ancestor for this pr
1143 source_db_repo = Repository.get_by_repo_name(_form['source_repo'])
1143 source_db_repo = Repository.get_by_repo_name(_form['source_repo'])
1144 target_db_repo = Repository.get_by_repo_name(_form['target_repo'])
1144 target_db_repo = Repository.get_by_repo_name(_form['target_repo'])
1145
1145
1146 if not (source_db_repo or target_db_repo):
1146 if not (source_db_repo or target_db_repo):
1147 h.flash(_('source_repo or target repo not found'), category='error')
1147 h.flash(_('source_repo or target repo not found'), category='error')
1148 raise HTTPFound(
1148 raise HTTPFound(
1149 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1149 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1150
1150
1151 # re-check permissions again here
1151 # re-check permissions again here
1152 # source_repo we must have read permissions
1152 # source_repo we must have read permissions
1153
1153
1154 source_perm = HasRepoPermissionAny(
1154 source_perm = HasRepoPermissionAny(
1155 'repository.read', 'repository.write', 'repository.admin')(
1155 'repository.read', 'repository.write', 'repository.admin')(
1156 source_db_repo.repo_name)
1156 source_db_repo.repo_name)
1157 if not source_perm:
1157 if not source_perm:
1158 msg = _('Not Enough permissions to source repo `{}`.'.format(
1158 msg = _('Not Enough permissions to source repo `{}`.'.format(
1159 source_db_repo.repo_name))
1159 source_db_repo.repo_name))
1160 h.flash(msg, category='error')
1160 h.flash(msg, category='error')
1161 # copy the args back to redirect
1161 # copy the args back to redirect
1162 org_query = self.request.GET.mixed()
1162 org_query = self.request.GET.mixed()
1163 raise HTTPFound(
1163 raise HTTPFound(
1164 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1164 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1165 _query=org_query))
1165 _query=org_query))
1166
1166
1167 # target repo we must have read permissions, and also later on
1167 # target repo we must have read permissions, and also later on
1168 # we want to check branch permissions here
1168 # we want to check branch permissions here
1169 target_perm = HasRepoPermissionAny(
1169 target_perm = HasRepoPermissionAny(
1170 'repository.read', 'repository.write', 'repository.admin')(
1170 'repository.read', 'repository.write', 'repository.admin')(
1171 target_db_repo.repo_name)
1171 target_db_repo.repo_name)
1172 if not target_perm:
1172 if not target_perm:
1173 msg = _('Not Enough permissions to target repo `{}`.'.format(
1173 msg = _('Not Enough permissions to target repo `{}`.'.format(
1174 target_db_repo.repo_name))
1174 target_db_repo.repo_name))
1175 h.flash(msg, category='error')
1175 h.flash(msg, category='error')
1176 # copy the args back to redirect
1176 # copy the args back to redirect
1177 org_query = self.request.GET.mixed()
1177 org_query = self.request.GET.mixed()
1178 raise HTTPFound(
1178 raise HTTPFound(
1179 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1179 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1180 _query=org_query))
1180 _query=org_query))
1181
1181
1182 source_scm = source_db_repo.scm_instance()
1182 source_scm = source_db_repo.scm_instance()
1183 target_scm = target_db_repo.scm_instance()
1183 target_scm = target_db_repo.scm_instance()
1184
1184
1185 source_ref_obj = unicode_to_reference(source_ref)
1185 source_ref_obj = unicode_to_reference(source_ref)
1186 target_ref_obj = unicode_to_reference(target_ref)
1186 target_ref_obj = unicode_to_reference(target_ref)
1187
1187
1188 source_commit = source_scm.get_commit(source_ref_obj.commit_id)
1188 source_commit = source_scm.get_commit(source_ref_obj.commit_id)
1189 target_commit = target_scm.get_commit(target_ref_obj.commit_id)
1189 target_commit = target_scm.get_commit(target_ref_obj.commit_id)
1190
1190
1191 ancestor = source_scm.get_common_ancestor(
1191 ancestor = source_scm.get_common_ancestor(
1192 source_commit.raw_id, target_commit.raw_id, target_scm)
1192 source_commit.raw_id, target_commit.raw_id, target_scm)
1193
1193
1194 # recalculate target ref based on ancestor
1194 # recalculate target ref based on ancestor
1195 target_ref = ':'.join((target_ref_obj.type, target_ref_obj.name, ancestor))
1195 target_ref = ':'.join((target_ref_obj.type, target_ref_obj.name, ancestor))
1196
1196
1197 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1197 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1198 PullRequestModel().get_reviewer_functions()
1198 PullRequestModel().get_reviewer_functions()
1199
1199
1200 # recalculate reviewers logic, to make sure we can validate this
1200 # recalculate reviewers logic, to make sure we can validate this
1201 reviewer_rules = get_default_reviewers_data(
1201 reviewer_rules = get_default_reviewers_data(
1202 self._rhodecode_db_user,
1202 self._rhodecode_db_user,
1203 source_db_repo,
1203 source_db_repo,
1204 source_ref_obj,
1204 source_ref_obj,
1205 target_db_repo,
1205 target_db_repo,
1206 target_ref_obj,
1206 target_ref_obj,
1207 include_diff_info=False)
1207 include_diff_info=False)
1208
1208
1209 reviewers = validate_default_reviewers(_form['review_members'], reviewer_rules)
1209 reviewers = validate_default_reviewers(_form['review_members'], reviewer_rules)
1210 observers = validate_observers(_form['observer_members'], reviewer_rules)
1210 observers = validate_observers(_form['observer_members'], reviewer_rules)
1211
1211
1212 pullrequest_title = _form['pullrequest_title']
1212 pullrequest_title = _form['pullrequest_title']
1213 title_source_ref = source_ref_obj.name
1213 title_source_ref = source_ref_obj.name
1214 if not pullrequest_title:
1214 if not pullrequest_title:
1215 pullrequest_title = PullRequestModel().generate_pullrequest_title(
1215 pullrequest_title = PullRequestModel().generate_pullrequest_title(
1216 source=source_repo,
1216 source=source_repo,
1217 source_ref=title_source_ref,
1217 source_ref=title_source_ref,
1218 target=target_repo
1218 target=target_repo
1219 )
1219 )
1220
1220
1221 description = _form['pullrequest_desc']
1221 description = _form['pullrequest_desc']
1222 description_renderer = _form['description_renderer']
1222 description_renderer = _form['description_renderer']
1223
1223
1224 try:
1224 try:
1225 pull_request = PullRequestModel().create(
1225 pull_request = PullRequestModel().create(
1226 created_by=self._rhodecode_user.user_id,
1226 created_by=self._rhodecode_user.user_id,
1227 source_repo=source_repo,
1227 source_repo=source_repo,
1228 source_ref=source_ref,
1228 source_ref=source_ref,
1229 target_repo=target_repo,
1229 target_repo=target_repo,
1230 target_ref=target_ref,
1230 target_ref=target_ref,
1231 revisions=commit_ids,
1231 revisions=commit_ids,
1232 common_ancestor_id=common_ancestor_id,
1232 common_ancestor_id=common_ancestor_id,
1233 reviewers=reviewers,
1233 reviewers=reviewers,
1234 observers=observers,
1234 observers=observers,
1235 title=pullrequest_title,
1235 title=pullrequest_title,
1236 description=description,
1236 description=description,
1237 description_renderer=description_renderer,
1237 description_renderer=description_renderer,
1238 reviewer_data=reviewer_rules,
1238 reviewer_data=reviewer_rules,
1239 auth_user=self._rhodecode_user
1239 auth_user=self._rhodecode_user
1240 )
1240 )
1241 Session().commit()
1241 Session().commit()
1242
1242
1243 h.flash(_('Successfully opened new pull request'),
1243 h.flash(_('Successfully opened new pull request'),
1244 category='success')
1244 category='success')
1245 except Exception:
1245 except Exception:
1246 msg = _('Error occurred during creation of this pull request.')
1246 msg = _('Error occurred during creation of this pull request.')
1247 log.exception(msg)
1247 log.exception(msg)
1248 h.flash(msg, category='error')
1248 h.flash(msg, category='error')
1249
1249
1250 # copy the args back to redirect
1250 # copy the args back to redirect
1251 org_query = self.request.GET.mixed()
1251 org_query = self.request.GET.mixed()
1252 raise HTTPFound(
1252 raise HTTPFound(
1253 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1253 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1254 _query=org_query))
1254 _query=org_query))
1255
1255
1256 raise HTTPFound(
1256 raise HTTPFound(
1257 h.route_path('pullrequest_show', repo_name=target_repo,
1257 h.route_path('pullrequest_show', repo_name=target_repo,
1258 pull_request_id=pull_request.pull_request_id))
1258 pull_request_id=pull_request.pull_request_id))
1259
1259
1260 @LoginRequired()
1260 @LoginRequired()
1261 @NotAnonymous()
1261 @NotAnonymous()
1262 @HasRepoPermissionAnyDecorator(
1262 @HasRepoPermissionAnyDecorator(
1263 'repository.read', 'repository.write', 'repository.admin')
1263 'repository.read', 'repository.write', 'repository.admin')
1264 @CSRFRequired()
1264 @CSRFRequired()
1265 def pull_request_update(self):
1265 def pull_request_update(self):
1266 pull_request = PullRequest.get_or_404(
1266 pull_request = PullRequest.get_or_404(
1267 self.request.matchdict['pull_request_id'])
1267 self.request.matchdict['pull_request_id'])
1268 _ = self.request.translate
1268 _ = self.request.translate
1269
1269
1270 c = self.load_default_context()
1270 c = self.load_default_context()
1271 redirect_url = None
1271 redirect_url = None
1272 # we do this check as first, because we want to know ASAP in the flow that
1272 # we do this check as first, because we want to know ASAP in the flow that
1273 # pr is updating currently
1273 # pr is updating currently
1274 is_state_changing = pull_request.is_state_changing()
1274 is_state_changing = pull_request.is_state_changing()
1275
1275
1276 if pull_request.is_closed():
1276 if pull_request.is_closed():
1277 log.debug('update: forbidden because pull request is closed')
1277 log.debug('update: forbidden because pull request is closed')
1278 msg = _(u'Cannot update closed pull requests.')
1278 msg = _(u'Cannot update closed pull requests.')
1279 h.flash(msg, category='error')
1279 h.flash(msg, category='error')
1280 return {'response': True,
1280 return {'response': True,
1281 'redirect_url': redirect_url}
1281 'redirect_url': redirect_url}
1282
1282
1283 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
1283 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
1284
1284
1285 # only owner or admin can update it
1285 # only owner or admin can update it
1286 allowed_to_update = PullRequestModel().check_user_update(
1286 allowed_to_update = PullRequestModel().check_user_update(
1287 pull_request, self._rhodecode_user)
1287 pull_request, self._rhodecode_user)
1288
1288
1289 if allowed_to_update:
1289 if allowed_to_update:
1290 controls = peppercorn.parse(self.request.POST.items())
1290 controls = peppercorn.parse(self.request.POST.items())
1291 force_refresh = str2bool(self.request.POST.get('force_refresh', 'false'))
1291 force_refresh = str2bool(self.request.POST.get('force_refresh', 'false'))
1292 do_update_commits = str2bool(self.request.POST.get('update_commits', 'false'))
1292 do_update_commits = str2bool(self.request.POST.get('update_commits', 'false'))
1293
1293
1294 if 'review_members' in controls:
1294 if 'review_members' in controls:
1295 self._update_reviewers(
1295 self._update_reviewers(
1296 c,
1296 c,
1297 pull_request, controls['review_members'],
1297 pull_request, controls['review_members'],
1298 pull_request.reviewer_data,
1298 pull_request.reviewer_data,
1299 PullRequestReviewers.ROLE_REVIEWER)
1299 PullRequestReviewers.ROLE_REVIEWER)
1300 elif 'observer_members' in controls:
1300 elif 'observer_members' in controls:
1301 self._update_reviewers(
1301 self._update_reviewers(
1302 c,
1302 c,
1303 pull_request, controls['observer_members'],
1303 pull_request, controls['observer_members'],
1304 pull_request.reviewer_data,
1304 pull_request.reviewer_data,
1305 PullRequestReviewers.ROLE_OBSERVER)
1305 PullRequestReviewers.ROLE_OBSERVER)
1306 elif do_update_commits:
1306 elif do_update_commits:
1307 if is_state_changing:
1307 if is_state_changing:
1308 log.debug('commits update: forbidden because pull request is in state %s',
1308 log.debug('commits update: forbidden because pull request is in state %s',
1309 pull_request.pull_request_state)
1309 pull_request.pull_request_state)
1310 msg = _(u'Cannot update pull requests commits in state other than `{}`. '
1310 msg = _(u'Cannot update pull requests commits in state other than `{}`. '
1311 u'Current state is: `{}`').format(
1311 u'Current state is: `{}`').format(
1312 PullRequest.STATE_CREATED, pull_request.pull_request_state)
1312 PullRequest.STATE_CREATED, pull_request.pull_request_state)
1313 h.flash(msg, category='error')
1313 h.flash(msg, category='error')
1314 return {'response': True,
1314 return {'response': True,
1315 'redirect_url': redirect_url}
1315 'redirect_url': redirect_url}
1316
1316
1317 self._update_commits(c, pull_request)
1317 self._update_commits(c, pull_request)
1318 if force_refresh:
1318 if force_refresh:
1319 redirect_url = h.route_path(
1319 redirect_url = h.route_path(
1320 'pullrequest_show', repo_name=self.db_repo_name,
1320 'pullrequest_show', repo_name=self.db_repo_name,
1321 pull_request_id=pull_request.pull_request_id,
1321 pull_request_id=pull_request.pull_request_id,
1322 _query={"force_refresh": 1})
1322 _query={"force_refresh": 1})
1323 elif str2bool(self.request.POST.get('edit_pull_request', 'false')):
1323 elif str2bool(self.request.POST.get('edit_pull_request', 'false')):
1324 self._edit_pull_request(pull_request)
1324 self._edit_pull_request(pull_request)
1325 else:
1325 else:
1326 log.error('Unhandled update data.')
1326 log.error('Unhandled update data.')
1327 raise HTTPBadRequest()
1327 raise HTTPBadRequest()
1328
1328
1329 return {'response': True,
1329 return {'response': True,
1330 'redirect_url': redirect_url}
1330 'redirect_url': redirect_url}
1331 raise HTTPForbidden()
1331 raise HTTPForbidden()
1332
1332
1333 def _edit_pull_request(self, pull_request):
1333 def _edit_pull_request(self, pull_request):
1334 """
1334 """
1335 Edit title and description
1335 Edit title and description
1336 """
1336 """
1337 _ = self.request.translate
1337 _ = self.request.translate
1338
1338
1339 try:
1339 try:
1340 PullRequestModel().edit(
1340 PullRequestModel().edit(
1341 pull_request,
1341 pull_request,
1342 self.request.POST.get('title'),
1342 self.request.POST.get('title'),
1343 self.request.POST.get('description'),
1343 self.request.POST.get('description'),
1344 self.request.POST.get('description_renderer'),
1344 self.request.POST.get('description_renderer'),
1345 self._rhodecode_user)
1345 self._rhodecode_user)
1346 except ValueError:
1346 except ValueError:
1347 msg = _(u'Cannot update closed pull requests.')
1347 msg = _(u'Cannot update closed pull requests.')
1348 h.flash(msg, category='error')
1348 h.flash(msg, category='error')
1349 return
1349 return
1350 else:
1350 else:
1351 Session().commit()
1351 Session().commit()
1352
1352
1353 msg = _(u'Pull request title & description updated.')
1353 msg = _(u'Pull request title & description updated.')
1354 h.flash(msg, category='success')
1354 h.flash(msg, category='success')
1355 return
1355 return
1356
1356
1357 def _update_commits(self, c, pull_request):
1357 def _update_commits(self, c, pull_request):
1358 _ = self.request.translate
1358 _ = self.request.translate
1359 log.debug('pull-request: running update commits actions')
1359 log.debug('pull-request: running update commits actions')
1360
1360
1361 @retry(exception=Exception, n_tries=3, delay=2)
1361 @retry(exception=Exception, n_tries=3, delay=2)
1362 def commits_update():
1362 def commits_update():
1363 return PullRequestModel().update_commits(
1363 return PullRequestModel().update_commits(
1364 pull_request, self._rhodecode_db_user)
1364 pull_request, self._rhodecode_db_user)
1365
1365
1366 with pull_request.set_state(PullRequest.STATE_UPDATING):
1366 with pull_request.set_state(PullRequest.STATE_UPDATING):
1367 resp = commits_update() # retry x3
1367 resp = commits_update() # retry x3
1368
1368
1369 if resp.executed:
1369 if resp.executed:
1370
1370
1371 if resp.target_changed and resp.source_changed:
1371 if resp.target_changed and resp.source_changed:
1372 changed = 'target and source repositories'
1372 changed = 'target and source repositories'
1373 elif resp.target_changed and not resp.source_changed:
1373 elif resp.target_changed and not resp.source_changed:
1374 changed = 'target repository'
1374 changed = 'target repository'
1375 elif not resp.target_changed and resp.source_changed:
1375 elif not resp.target_changed and resp.source_changed:
1376 changed = 'source repository'
1376 changed = 'source repository'
1377 else:
1377 else:
1378 changed = 'nothing'
1378 changed = 'nothing'
1379
1379
1380 msg = _(u'Pull request updated to "{source_commit_id}" with '
1380 msg = _(u'Pull request updated to "{source_commit_id}" with '
1381 u'{count_added} added, {count_removed} removed commits. '
1381 u'{count_added} added, {count_removed} removed commits. '
1382 u'Source of changes: {change_source}.')
1382 u'Source of changes: {change_source}.')
1383 msg = msg.format(
1383 msg = msg.format(
1384 source_commit_id=pull_request.source_ref_parts.commit_id,
1384 source_commit_id=pull_request.source_ref_parts.commit_id,
1385 count_added=len(resp.changes.added),
1385 count_added=len(resp.changes.added),
1386 count_removed=len(resp.changes.removed),
1386 count_removed=len(resp.changes.removed),
1387 change_source=changed)
1387 change_source=changed)
1388 h.flash(msg, category='success')
1388 h.flash(msg, category='success')
1389 channelstream.pr_update_channelstream_push(
1389 channelstream.pr_update_channelstream_push(
1390 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1390 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1391 else:
1391 else:
1392 msg = PullRequestModel.UPDATE_STATUS_MESSAGES[resp.reason]
1392 msg = PullRequestModel.UPDATE_STATUS_MESSAGES[resp.reason]
1393 warning_reasons = [
1393 warning_reasons = [
1394 UpdateFailureReason.NO_CHANGE,
1394 UpdateFailureReason.NO_CHANGE,
1395 UpdateFailureReason.WRONG_REF_TYPE,
1395 UpdateFailureReason.WRONG_REF_TYPE,
1396 ]
1396 ]
1397 category = 'warning' if resp.reason in warning_reasons else 'error'
1397 category = 'warning' if resp.reason in warning_reasons else 'error'
1398 h.flash(msg, category=category)
1398 h.flash(msg, category=category)
1399
1399
1400 def _update_reviewers(self, c, pull_request, review_members, reviewer_rules, role):
1400 def _update_reviewers(self, c, pull_request, review_members, reviewer_rules, role):
1401 _ = self.request.translate
1401 _ = self.request.translate
1402
1402
1403 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1403 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1404 PullRequestModel().get_reviewer_functions()
1404 PullRequestModel().get_reviewer_functions()
1405
1405
1406 if role == PullRequestReviewers.ROLE_REVIEWER:
1406 if role == PullRequestReviewers.ROLE_REVIEWER:
1407 try:
1407 try:
1408 reviewers = validate_default_reviewers(review_members, reviewer_rules)
1408 reviewers = validate_default_reviewers(review_members, reviewer_rules)
1409 except ValueError as e:
1409 except ValueError as e:
1410 log.error('Reviewers Validation: {}'.format(e))
1410 log.error('Reviewers Validation: {}'.format(e))
1411 h.flash(e, category='error')
1411 h.flash(e, category='error')
1412 return
1412 return
1413
1413
1414 old_calculated_status = pull_request.calculated_review_status()
1414 old_calculated_status = pull_request.calculated_review_status()
1415 PullRequestModel().update_reviewers(
1415 PullRequestModel().update_reviewers(
1416 pull_request, reviewers, self._rhodecode_db_user)
1416 pull_request, reviewers, self._rhodecode_db_user)
1417
1417
1418 Session().commit()
1418 Session().commit()
1419
1419
1420 msg = _('Pull request reviewers updated.')
1420 msg = _('Pull request reviewers updated.')
1421 h.flash(msg, category='success')
1421 h.flash(msg, category='success')
1422 channelstream.pr_update_channelstream_push(
1422 channelstream.pr_update_channelstream_push(
1423 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1423 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1424
1424
1425 # trigger status changed if change in reviewers changes the status
1425 # trigger status changed if change in reviewers changes the status
1426 calculated_status = pull_request.calculated_review_status()
1426 calculated_status = pull_request.calculated_review_status()
1427 if old_calculated_status != calculated_status:
1427 if old_calculated_status != calculated_status:
1428 PullRequestModel().trigger_pull_request_hook(
1428 PullRequestModel().trigger_pull_request_hook(
1429 pull_request, self._rhodecode_user, 'review_status_change',
1429 pull_request, self._rhodecode_user, 'review_status_change',
1430 data={'status': calculated_status})
1430 data={'status': calculated_status})
1431
1431
1432 elif role == PullRequestReviewers.ROLE_OBSERVER:
1432 elif role == PullRequestReviewers.ROLE_OBSERVER:
1433 try:
1433 try:
1434 observers = validate_observers(review_members, reviewer_rules)
1434 observers = validate_observers(review_members, reviewer_rules)
1435 except ValueError as e:
1435 except ValueError as e:
1436 log.error('Observers Validation: {}'.format(e))
1436 log.error('Observers Validation: {}'.format(e))
1437 h.flash(e, category='error')
1437 h.flash(e, category='error')
1438 return
1438 return
1439
1439
1440 PullRequestModel().update_observers(
1440 PullRequestModel().update_observers(
1441 pull_request, observers, self._rhodecode_db_user)
1441 pull_request, observers, self._rhodecode_db_user)
1442
1442
1443 Session().commit()
1443 Session().commit()
1444 msg = _('Pull request observers updated.')
1444 msg = _('Pull request observers updated.')
1445 h.flash(msg, category='success')
1445 h.flash(msg, category='success')
1446 channelstream.pr_update_channelstream_push(
1446 channelstream.pr_update_channelstream_push(
1447 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1447 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1448
1448
1449 @LoginRequired()
1449 @LoginRequired()
1450 @NotAnonymous()
1450 @NotAnonymous()
1451 @HasRepoPermissionAnyDecorator(
1451 @HasRepoPermissionAnyDecorator(
1452 'repository.read', 'repository.write', 'repository.admin')
1452 'repository.read', 'repository.write', 'repository.admin')
1453 @CSRFRequired()
1453 @CSRFRequired()
1454 def pull_request_merge(self):
1454 def pull_request_merge(self):
1455 """
1455 """
1456 Merge will perform a server-side merge of the specified
1456 Merge will perform a server-side merge of the specified
1457 pull request, if the pull request is approved and mergeable.
1457 pull request, if the pull request is approved and mergeable.
1458 After successful merging, the pull request is automatically
1458 After successful merging, the pull request is automatically
1459 closed, with a relevant comment.
1459 closed, with a relevant comment.
1460 """
1460 """
1461 pull_request = PullRequest.get_or_404(
1461 pull_request = PullRequest.get_or_404(
1462 self.request.matchdict['pull_request_id'])
1462 self.request.matchdict['pull_request_id'])
1463 _ = self.request.translate
1463 _ = self.request.translate
1464
1464
1465 if pull_request.is_state_changing():
1465 if pull_request.is_state_changing():
1466 log.debug('show: forbidden because pull request is in state %s',
1466 log.debug('show: forbidden because pull request is in state %s',
1467 pull_request.pull_request_state)
1467 pull_request.pull_request_state)
1468 msg = _(u'Cannot merge pull requests in state other than `{}`. '
1468 msg = _(u'Cannot merge pull requests in state other than `{}`. '
1469 u'Current state is: `{}`').format(PullRequest.STATE_CREATED,
1469 u'Current state is: `{}`').format(PullRequest.STATE_CREATED,
1470 pull_request.pull_request_state)
1470 pull_request.pull_request_state)
1471 h.flash(msg, category='error')
1471 h.flash(msg, category='error')
1472 raise HTTPFound(
1472 raise HTTPFound(
1473 h.route_path('pullrequest_show',
1473 h.route_path('pullrequest_show',
1474 repo_name=pull_request.target_repo.repo_name,
1474 repo_name=pull_request.target_repo.repo_name,
1475 pull_request_id=pull_request.pull_request_id))
1475 pull_request_id=pull_request.pull_request_id))
1476
1476
1477 self.load_default_context()
1477 self.load_default_context()
1478
1478
1479 with pull_request.set_state(PullRequest.STATE_UPDATING):
1479 with pull_request.set_state(PullRequest.STATE_UPDATING):
1480 check = MergeCheck.validate(
1480 check = MergeCheck.validate(
1481 pull_request, auth_user=self._rhodecode_user,
1481 pull_request, auth_user=self._rhodecode_user,
1482 translator=self.request.translate)
1482 translator=self.request.translate)
1483 merge_possible = not check.failed
1483 merge_possible = not check.failed
1484
1484
1485 for err_type, error_msg in check.errors:
1485 for err_type, error_msg in check.errors:
1486 h.flash(error_msg, category=err_type)
1486 h.flash(error_msg, category=err_type)
1487
1487
1488 if merge_possible:
1488 if merge_possible:
1489 log.debug("Pre-conditions checked, trying to merge.")
1489 log.debug("Pre-conditions checked, trying to merge.")
1490 extras = vcs_operation_context(
1490 extras = vcs_operation_context(
1491 self.request.environ, repo_name=pull_request.target_repo.repo_name,
1491 self.request.environ, repo_name=pull_request.target_repo.repo_name,
1492 username=self._rhodecode_db_user.username, action='push',
1492 username=self._rhodecode_db_user.username, action='push',
1493 scm=pull_request.target_repo.repo_type)
1493 scm=pull_request.target_repo.repo_type)
1494 with pull_request.set_state(PullRequest.STATE_UPDATING):
1494 with pull_request.set_state(PullRequest.STATE_UPDATING):
1495 self._merge_pull_request(
1495 self._merge_pull_request(
1496 pull_request, self._rhodecode_db_user, extras)
1496 pull_request, self._rhodecode_db_user, extras)
1497 else:
1497 else:
1498 log.debug("Pre-conditions failed, NOT merging.")
1498 log.debug("Pre-conditions failed, NOT merging.")
1499
1499
1500 raise HTTPFound(
1500 raise HTTPFound(
1501 h.route_path('pullrequest_show',
1501 h.route_path('pullrequest_show',
1502 repo_name=pull_request.target_repo.repo_name,
1502 repo_name=pull_request.target_repo.repo_name,
1503 pull_request_id=pull_request.pull_request_id))
1503 pull_request_id=pull_request.pull_request_id))
1504
1504
1505 def _merge_pull_request(self, pull_request, user, extras):
1505 def _merge_pull_request(self, pull_request, user, extras):
1506 _ = self.request.translate
1506 _ = self.request.translate
1507 merge_resp = PullRequestModel().merge_repo(pull_request, user, extras=extras)
1507 merge_resp = PullRequestModel().merge_repo(pull_request, user, extras=extras)
1508
1508
1509 if merge_resp.executed:
1509 if merge_resp.executed:
1510 log.debug("The merge was successful, closing the pull request.")
1510 log.debug("The merge was successful, closing the pull request.")
1511 PullRequestModel().close_pull_request(
1511 PullRequestModel().close_pull_request(
1512 pull_request.pull_request_id, user)
1512 pull_request.pull_request_id, user)
1513 Session().commit()
1513 Session().commit()
1514 msg = _('Pull request was successfully merged and closed.')
1514 msg = _('Pull request was successfully merged and closed.')
1515 h.flash(msg, category='success')
1515 h.flash(msg, category='success')
1516 else:
1516 else:
1517 log.debug(
1517 log.debug(
1518 "The merge was not successful. Merge response: %s", merge_resp)
1518 "The merge was not successful. Merge response: %s", merge_resp)
1519 msg = merge_resp.merge_status_message
1519 msg = merge_resp.merge_status_message
1520 h.flash(msg, category='error')
1520 h.flash(msg, category='error')
1521
1521
1522 @LoginRequired()
1522 @LoginRequired()
1523 @NotAnonymous()
1523 @NotAnonymous()
1524 @HasRepoPermissionAnyDecorator(
1524 @HasRepoPermissionAnyDecorator(
1525 'repository.read', 'repository.write', 'repository.admin')
1525 'repository.read', 'repository.write', 'repository.admin')
1526 @CSRFRequired()
1526 @CSRFRequired()
1527 def pull_request_delete(self):
1527 def pull_request_delete(self):
1528 _ = self.request.translate
1528 _ = self.request.translate
1529
1529
1530 pull_request = PullRequest.get_or_404(
1530 pull_request = PullRequest.get_or_404(
1531 self.request.matchdict['pull_request_id'])
1531 self.request.matchdict['pull_request_id'])
1532 self.load_default_context()
1532 self.load_default_context()
1533
1533
1534 pr_closed = pull_request.is_closed()
1534 pr_closed = pull_request.is_closed()
1535 allowed_to_delete = PullRequestModel().check_user_delete(
1535 allowed_to_delete = PullRequestModel().check_user_delete(
1536 pull_request, self._rhodecode_user) and not pr_closed
1536 pull_request, self._rhodecode_user) and not pr_closed
1537
1537
1538 # only owner can delete it !
1538 # only owner can delete it !
1539 if allowed_to_delete:
1539 if allowed_to_delete:
1540 PullRequestModel().delete(pull_request, self._rhodecode_user)
1540 PullRequestModel().delete(pull_request, self._rhodecode_user)
1541 Session().commit()
1541 Session().commit()
1542 h.flash(_('Successfully deleted pull request'),
1542 h.flash(_('Successfully deleted pull request'),
1543 category='success')
1543 category='success')
1544 raise HTTPFound(h.route_path('pullrequest_show_all',
1544 raise HTTPFound(h.route_path('pullrequest_show_all',
1545 repo_name=self.db_repo_name))
1545 repo_name=self.db_repo_name))
1546
1546
1547 log.warning('user %s tried to delete pull request without access',
1547 log.warning('user %s tried to delete pull request without access',
1548 self._rhodecode_user)
1548 self._rhodecode_user)
1549 raise HTTPNotFound()
1549 raise HTTPNotFound()
1550
1550
1551 def _pull_request_comments_create(self, pull_request, comments):
1551 def _pull_request_comments_create(self, pull_request, comments):
1552 _ = self.request.translate
1552 _ = self.request.translate
1553 data = {}
1553 data = {}
1554 if not comments:
1554 if not comments:
1555 return
1555 return
1556 pull_request_id = pull_request.pull_request_id
1556 pull_request_id = pull_request.pull_request_id
1557
1557
1558 all_drafts = len([x for x in comments if str2bool(x['is_draft'])]) == len(comments)
1558 all_drafts = len([x for x in comments if str2bool(x['is_draft'])]) == len(comments)
1559
1559
1560 for entry in comments:
1560 for entry in comments:
1561 c = self.load_default_context()
1561 c = self.load_default_context()
1562 comment_type = entry['comment_type']
1562 comment_type = entry['comment_type']
1563 text = entry['text']
1563 text = entry['text']
1564 status = entry['status']
1564 status = entry['status']
1565 is_draft = str2bool(entry['is_draft'])
1565 is_draft = str2bool(entry['is_draft'])
1566 resolves_comment_id = entry['resolves_comment_id']
1566 resolves_comment_id = entry['resolves_comment_id']
1567 close_pull_request = entry['close_pull_request']
1567 close_pull_request = entry['close_pull_request']
1568 f_path = entry['f_path']
1568 f_path = entry['f_path']
1569 line_no = entry['line']
1569 line_no = entry['line']
1570 target_elem_id = 'file-{}'.format(h.safeid(h.safe_unicode(f_path)))
1570 target_elem_id = 'file-{}'.format(h.safeid(h.safe_unicode(f_path)))
1571
1571
1572 # the logic here should work like following, if we submit close
1572 # the logic here should work like following, if we submit close
1573 # pr comment, use `close_pull_request_with_comment` function
1573 # pr comment, use `close_pull_request_with_comment` function
1574 # else handle regular comment logic
1574 # else handle regular comment logic
1575
1575
1576 if close_pull_request:
1576 if close_pull_request:
1577 # only owner or admin or person with write permissions
1577 # only owner or admin or person with write permissions
1578 allowed_to_close = PullRequestModel().check_user_update(
1578 allowed_to_close = PullRequestModel().check_user_update(
1579 pull_request, self._rhodecode_user)
1579 pull_request, self._rhodecode_user)
1580 if not allowed_to_close:
1580 if not allowed_to_close:
1581 log.debug('comment: forbidden because not allowed to close '
1581 log.debug('comment: forbidden because not allowed to close '
1582 'pull request %s', pull_request_id)
1582 'pull request %s', pull_request_id)
1583 raise HTTPForbidden()
1583 raise HTTPForbidden()
1584
1584
1585 # This also triggers `review_status_change`
1585 # This also triggers `review_status_change`
1586 comment, status = PullRequestModel().close_pull_request_with_comment(
1586 comment, status = PullRequestModel().close_pull_request_with_comment(
1587 pull_request, self._rhodecode_user, self.db_repo, message=text,
1587 pull_request, self._rhodecode_user, self.db_repo, message=text,
1588 auth_user=self._rhodecode_user)
1588 auth_user=self._rhodecode_user)
1589 Session().flush()
1589 Session().flush()
1590 is_inline = comment.is_inline
1590 is_inline = comment.is_inline
1591
1591
1592 PullRequestModel().trigger_pull_request_hook(
1592 PullRequestModel().trigger_pull_request_hook(
1593 pull_request, self._rhodecode_user, 'comment',
1593 pull_request, self._rhodecode_user, 'comment',
1594 data={'comment': comment})
1594 data={'comment': comment})
1595
1595
1596 else:
1596 else:
1597 # regular comment case, could be inline, or one with status.
1597 # regular comment case, could be inline, or one with status.
1598 # for that one we check also permissions
1598 # for that one we check also permissions
1599 # Additionally ENSURE if somehow draft is sent we're then unable to change status
1599 # Additionally ENSURE if somehow draft is sent we're then unable to change status
1600 allowed_to_change_status = PullRequestModel().check_user_change_status(
1600 allowed_to_change_status = PullRequestModel().check_user_change_status(
1601 pull_request, self._rhodecode_user) and not is_draft
1601 pull_request, self._rhodecode_user) and not is_draft
1602
1602
1603 if status and allowed_to_change_status:
1603 if status and allowed_to_change_status:
1604 message = (_('Status change %(transition_icon)s %(status)s')
1604 message = (_('Status change %(transition_icon)s %(status)s')
1605 % {'transition_icon': '>',
1605 % {'transition_icon': '>',
1606 'status': ChangesetStatus.get_status_lbl(status)})
1606 'status': ChangesetStatus.get_status_lbl(status)})
1607 text = text or message
1607 text = text or message
1608
1608
1609 comment = CommentsModel().create(
1609 comment = CommentsModel().create(
1610 text=text,
1610 text=text,
1611 repo=self.db_repo.repo_id,
1611 repo=self.db_repo.repo_id,
1612 user=self._rhodecode_user.user_id,
1612 user=self._rhodecode_user.user_id,
1613 pull_request=pull_request,
1613 pull_request=pull_request,
1614 f_path=f_path,
1614 f_path=f_path,
1615 line_no=line_no,
1615 line_no=line_no,
1616 status_change=(ChangesetStatus.get_status_lbl(status)
1616 status_change=(ChangesetStatus.get_status_lbl(status)
1617 if status and allowed_to_change_status else None),
1617 if status and allowed_to_change_status else None),
1618 status_change_type=(status
1618 status_change_type=(status
1619 if status and allowed_to_change_status else None),
1619 if status and allowed_to_change_status else None),
1620 comment_type=comment_type,
1620 comment_type=comment_type,
1621 is_draft=is_draft,
1621 is_draft=is_draft,
1622 resolves_comment_id=resolves_comment_id,
1622 resolves_comment_id=resolves_comment_id,
1623 auth_user=self._rhodecode_user,
1623 auth_user=self._rhodecode_user,
1624 send_email=not is_draft, # skip notification for draft comments
1624 send_email=not is_draft, # skip notification for draft comments
1625 )
1625 )
1626 is_inline = comment.is_inline
1626 is_inline = comment.is_inline
1627
1627
1628 if allowed_to_change_status:
1628 if allowed_to_change_status:
1629 # calculate old status before we change it
1629 # calculate old status before we change it
1630 old_calculated_status = pull_request.calculated_review_status()
1630 old_calculated_status = pull_request.calculated_review_status()
1631
1631
1632 # get status if set !
1632 # get status if set !
1633 if status:
1633 if status:
1634 ChangesetStatusModel().set_status(
1634 ChangesetStatusModel().set_status(
1635 self.db_repo.repo_id,
1635 self.db_repo.repo_id,
1636 status,
1636 status,
1637 self._rhodecode_user.user_id,
1637 self._rhodecode_user.user_id,
1638 comment,
1638 comment,
1639 pull_request=pull_request
1639 pull_request=pull_request
1640 )
1640 )
1641
1641
1642 Session().flush()
1642 Session().flush()
1643 # this is somehow required to get access to some relationship
1643 # this is somehow required to get access to some relationship
1644 # loaded on comment
1644 # loaded on comment
1645 Session().refresh(comment)
1645 Session().refresh(comment)
1646
1646
1647 # skip notifications for drafts
1647 # skip notifications for drafts
1648 if not is_draft:
1648 if not is_draft:
1649 PullRequestModel().trigger_pull_request_hook(
1649 PullRequestModel().trigger_pull_request_hook(
1650 pull_request, self._rhodecode_user, 'comment',
1650 pull_request, self._rhodecode_user, 'comment',
1651 data={'comment': comment})
1651 data={'comment': comment})
1652
1652
1653 # we now calculate the status of pull request, and based on that
1653 # we now calculate the status of pull request, and based on that
1654 # calculation we set the commits status
1654 # calculation we set the commits status
1655 calculated_status = pull_request.calculated_review_status()
1655 calculated_status = pull_request.calculated_review_status()
1656 if old_calculated_status != calculated_status:
1656 if old_calculated_status != calculated_status:
1657 PullRequestModel().trigger_pull_request_hook(
1657 PullRequestModel().trigger_pull_request_hook(
1658 pull_request, self._rhodecode_user, 'review_status_change',
1658 pull_request, self._rhodecode_user, 'review_status_change',
1659 data={'status': calculated_status})
1659 data={'status': calculated_status})
1660
1660
1661 comment_id = comment.comment_id
1661 comment_id = comment.comment_id
1662 data[comment_id] = {
1662 data[comment_id] = {
1663 'target_id': target_elem_id
1663 'target_id': target_elem_id
1664 }
1664 }
1665 Session().flush()
1665 Session().flush()
1666
1666
1667 c.co = comment
1667 c.co = comment
1668 c.at_version_num = None
1668 c.at_version_num = None
1669 c.is_new = True
1669 c.is_new = True
1670 rendered_comment = render(
1670 rendered_comment = render(
1671 'rhodecode:templates/changeset/changeset_comment_block.mako',
1671 'rhodecode:templates/changeset/changeset_comment_block.mako',
1672 self._get_template_context(c), self.request)
1672 self._get_template_context(c), self.request)
1673
1673
1674 data[comment_id].update(comment.get_dict())
1674 data[comment_id].update(comment.get_dict())
1675 data[comment_id].update({'rendered_text': rendered_comment})
1675 data[comment_id].update({'rendered_text': rendered_comment})
1676
1676
1677 Session().commit()
1677 Session().commit()
1678
1678
1679 # skip channelstream for draft comments
1679 # skip channelstream for draft comments
1680 if not all_drafts:
1680 if not all_drafts:
1681 comment_broadcast_channel = channelstream.comment_channel(
1681 comment_broadcast_channel = channelstream.comment_channel(
1682 self.db_repo_name, pull_request_obj=pull_request)
1682 self.db_repo_name, pull_request_obj=pull_request)
1683
1683
1684 comment_data = data
1684 comment_data = data
1685 posted_comment_type = 'inline' if is_inline else 'general'
1685 posted_comment_type = 'inline' if is_inline else 'general'
1686 if len(data) == 1:
1686 if len(data) == 1:
1687 msg = _('posted {} new {} comment').format(len(data), posted_comment_type)
1687 msg = _('posted {} new {} comment').format(len(data), posted_comment_type)
1688 else:
1688 else:
1689 msg = _('posted {} new {} comments').format(len(data), posted_comment_type)
1689 msg = _('posted {} new {} comments').format(len(data), posted_comment_type)
1690
1690
1691 channelstream.comment_channelstream_push(
1691 channelstream.comment_channelstream_push(
1692 self.request, comment_broadcast_channel, self._rhodecode_user, msg,
1692 self.request, comment_broadcast_channel, self._rhodecode_user, msg,
1693 comment_data=comment_data)
1693 comment_data=comment_data)
1694
1694
1695 return data
1695 return data
1696
1696
1697 @LoginRequired()
1697 @LoginRequired()
1698 @NotAnonymous()
1698 @NotAnonymous()
1699 @HasRepoPermissionAnyDecorator(
1699 @HasRepoPermissionAnyDecorator(
1700 'repository.read', 'repository.write', 'repository.admin')
1700 'repository.read', 'repository.write', 'repository.admin')
1701 @CSRFRequired()
1701 @CSRFRequired()
1702 def pull_request_comment_create(self):
1702 def pull_request_comment_create(self):
1703 _ = self.request.translate
1703 _ = self.request.translate
1704
1704
1705 pull_request = PullRequest.get_or_404(self.request.matchdict['pull_request_id'])
1705 pull_request = PullRequest.get_or_404(self.request.matchdict['pull_request_id'])
1706
1706
1707 if pull_request.is_closed():
1707 if pull_request.is_closed():
1708 log.debug('comment: forbidden because pull request is closed')
1708 log.debug('comment: forbidden because pull request is closed')
1709 raise HTTPForbidden()
1709 raise HTTPForbidden()
1710
1710
1711 allowed_to_comment = PullRequestModel().check_user_comment(
1711 allowed_to_comment = PullRequestModel().check_user_comment(
1712 pull_request, self._rhodecode_user)
1712 pull_request, self._rhodecode_user)
1713 if not allowed_to_comment:
1713 if not allowed_to_comment:
1714 log.debug('comment: forbidden because pull request is from forbidden repo')
1714 log.debug('comment: forbidden because pull request is from forbidden repo')
1715 raise HTTPForbidden()
1715 raise HTTPForbidden()
1716
1716
1717 comment_data = {
1717 comment_data = {
1718 'comment_type': self.request.POST.get('comment_type'),
1718 'comment_type': self.request.POST.get('comment_type'),
1719 'text': self.request.POST.get('text'),
1719 'text': self.request.POST.get('text'),
1720 'status': self.request.POST.get('changeset_status', None),
1720 'status': self.request.POST.get('changeset_status', None),
1721 'is_draft': self.request.POST.get('draft'),
1721 'is_draft': self.request.POST.get('draft'),
1722 'resolves_comment_id': self.request.POST.get('resolves_comment_id', None),
1722 'resolves_comment_id': self.request.POST.get('resolves_comment_id', None),
1723 'close_pull_request': self.request.POST.get('close_pull_request'),
1723 'close_pull_request': self.request.POST.get('close_pull_request'),
1724 'f_path': self.request.POST.get('f_path'),
1724 'f_path': self.request.POST.get('f_path'),
1725 'line': self.request.POST.get('line'),
1725 'line': self.request.POST.get('line'),
1726 }
1726 }
1727 data = self._pull_request_comments_create(pull_request, [comment_data])
1727 data = self._pull_request_comments_create(pull_request, [comment_data])
1728
1728
1729 return data
1729 return data
1730
1730
1731 @LoginRequired()
1731 @LoginRequired()
1732 @NotAnonymous()
1732 @NotAnonymous()
1733 @HasRepoPermissionAnyDecorator(
1733 @HasRepoPermissionAnyDecorator(
1734 'repository.read', 'repository.write', 'repository.admin')
1734 'repository.read', 'repository.write', 'repository.admin')
1735 @CSRFRequired()
1735 @CSRFRequired()
1736 def pull_request_comment_delete(self):
1736 def pull_request_comment_delete(self):
1737 pull_request = PullRequest.get_or_404(
1737 pull_request = PullRequest.get_or_404(
1738 self.request.matchdict['pull_request_id'])
1738 self.request.matchdict['pull_request_id'])
1739
1739
1740 comment = ChangesetComment.get_or_404(
1740 comment = ChangesetComment.get_or_404(
1741 self.request.matchdict['comment_id'])
1741 self.request.matchdict['comment_id'])
1742 comment_id = comment.comment_id
1742 comment_id = comment.comment_id
1743
1743
1744 if comment.immutable:
1744 if comment.immutable:
1745 # don't allow deleting comments that are immutable
1745 # don't allow deleting comments that are immutable
1746 raise HTTPForbidden()
1746 raise HTTPForbidden()
1747
1747
1748 if pull_request.is_closed():
1748 if pull_request.is_closed():
1749 log.debug('comment: forbidden because pull request is closed')
1749 log.debug('comment: forbidden because pull request is closed')
1750 raise HTTPForbidden()
1750 raise HTTPForbidden()
1751
1751
1752 if not comment:
1752 if not comment:
1753 log.debug('Comment with id:%s not found, skipping', comment_id)
1753 log.debug('Comment with id:%s not found, skipping', comment_id)
1754 # comment already deleted in another call probably
1754 # comment already deleted in another call probably
1755 return True
1755 return True
1756
1756
1757 if comment.pull_request.is_closed():
1757 if comment.pull_request.is_closed():
1758 # don't allow deleting comments on closed pull request
1758 # don't allow deleting comments on closed pull request
1759 raise HTTPForbidden()
1759 raise HTTPForbidden()
1760
1760
1761 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1761 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1762 super_admin = h.HasPermissionAny('hg.admin')()
1762 super_admin = h.HasPermissionAny('hg.admin')()
1763 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1763 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1764 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1764 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1765 comment_repo_admin = is_repo_admin and is_repo_comment
1765 comment_repo_admin = is_repo_admin and is_repo_comment
1766
1766
1767 if comment.draft and not comment_owner:
1767 if comment.draft and not comment_owner:
1768 # We never allow to delete draft comments for other than owners
1768 # We never allow to delete draft comments for other than owners
1769 raise HTTPNotFound()
1769 raise HTTPNotFound()
1770
1770
1771 if super_admin or comment_owner or comment_repo_admin:
1771 if super_admin or comment_owner or comment_repo_admin:
1772 old_calculated_status = comment.pull_request.calculated_review_status()
1772 old_calculated_status = comment.pull_request.calculated_review_status()
1773 CommentsModel().delete(comment=comment, auth_user=self._rhodecode_user)
1773 CommentsModel().delete(comment=comment, auth_user=self._rhodecode_user)
1774 Session().commit()
1774 Session().commit()
1775 calculated_status = comment.pull_request.calculated_review_status()
1775 calculated_status = comment.pull_request.calculated_review_status()
1776 if old_calculated_status != calculated_status:
1776 if old_calculated_status != calculated_status:
1777 PullRequestModel().trigger_pull_request_hook(
1777 PullRequestModel().trigger_pull_request_hook(
1778 comment.pull_request, self._rhodecode_user, 'review_status_change',
1778 comment.pull_request, self._rhodecode_user, 'review_status_change',
1779 data={'status': calculated_status})
1779 data={'status': calculated_status})
1780 return True
1780 return True
1781 else:
1781 else:
1782 log.warning('No permissions for user %s to delete comment_id: %s',
1782 log.warning('No permissions for user %s to delete comment_id: %s',
1783 self._rhodecode_db_user, comment_id)
1783 self._rhodecode_db_user, comment_id)
1784 raise HTTPNotFound()
1784 raise HTTPNotFound()
1785
1785
1786 @LoginRequired()
1786 @LoginRequired()
1787 @NotAnonymous()
1787 @NotAnonymous()
1788 @HasRepoPermissionAnyDecorator(
1788 @HasRepoPermissionAnyDecorator(
1789 'repository.read', 'repository.write', 'repository.admin')
1789 'repository.read', 'repository.write', 'repository.admin')
1790 @CSRFRequired()
1790 @CSRFRequired()
1791 def pull_request_comment_edit(self):
1791 def pull_request_comment_edit(self):
1792 self.load_default_context()
1792 self.load_default_context()
1793
1793
1794 pull_request = PullRequest.get_or_404(
1794 pull_request = PullRequest.get_or_404(
1795 self.request.matchdict['pull_request_id']
1795 self.request.matchdict['pull_request_id']
1796 )
1796 )
1797 comment = ChangesetComment.get_or_404(
1797 comment = ChangesetComment.get_or_404(
1798 self.request.matchdict['comment_id']
1798 self.request.matchdict['comment_id']
1799 )
1799 )
1800 comment_id = comment.comment_id
1800 comment_id = comment.comment_id
1801
1801
1802 if comment.immutable:
1802 if comment.immutable:
1803 # don't allow deleting comments that are immutable
1803 # don't allow deleting comments that are immutable
1804 raise HTTPForbidden()
1804 raise HTTPForbidden()
1805
1805
1806 if pull_request.is_closed():
1806 if pull_request.is_closed():
1807 log.debug('comment: forbidden because pull request is closed')
1807 log.debug('comment: forbidden because pull request is closed')
1808 raise HTTPForbidden()
1808 raise HTTPForbidden()
1809
1809
1810 if comment.pull_request.is_closed():
1810 if comment.pull_request.is_closed():
1811 # don't allow deleting comments on closed pull request
1811 # don't allow deleting comments on closed pull request
1812 raise HTTPForbidden()
1812 raise HTTPForbidden()
1813
1813
1814 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1814 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1815 super_admin = h.HasPermissionAny('hg.admin')()
1815 super_admin = h.HasPermissionAny('hg.admin')()
1816 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1816 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1817 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1817 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1818 comment_repo_admin = is_repo_admin and is_repo_comment
1818 comment_repo_admin = is_repo_admin and is_repo_comment
1819
1819
1820 if super_admin or comment_owner or comment_repo_admin:
1820 if super_admin or comment_owner or comment_repo_admin:
1821 text = self.request.POST.get('text')
1821 text = self.request.POST.get('text')
1822 version = self.request.POST.get('version')
1822 version = self.request.POST.get('version')
1823 if text == comment.text:
1823 if text == comment.text:
1824 log.warning(
1824 log.warning(
1825 'Comment(PR): '
1825 'Comment(PR): '
1826 'Trying to create new version '
1826 'Trying to create new version '
1827 'with the same comment body {}'.format(
1827 'with the same comment body {}'.format(
1828 comment_id,
1828 comment_id,
1829 )
1829 )
1830 )
1830 )
1831 raise HTTPNotFound()
1831 raise HTTPNotFound()
1832
1832
1833 if version.isdigit():
1833 if version.isdigit():
1834 version = int(version)
1834 version = int(version)
1835 else:
1835 else:
1836 log.warning(
1836 log.warning(
1837 'Comment(PR): Wrong version type {} {} '
1837 'Comment(PR): Wrong version type {} {} '
1838 'for comment {}'.format(
1838 'for comment {}'.format(
1839 version,
1839 version,
1840 type(version),
1840 type(version),
1841 comment_id,
1841 comment_id,
1842 )
1842 )
1843 )
1843 )
1844 raise HTTPNotFound()
1844 raise HTTPNotFound()
1845
1845
1846 try:
1846 try:
1847 comment_history = CommentsModel().edit(
1847 comment_history = CommentsModel().edit(
1848 comment_id=comment_id,
1848 comment_id=comment_id,
1849 text=text,
1849 text=text,
1850 auth_user=self._rhodecode_user,
1850 auth_user=self._rhodecode_user,
1851 version=version,
1851 version=version,
1852 )
1852 )
1853 except CommentVersionMismatch:
1853 except CommentVersionMismatch:
1854 raise HTTPConflict()
1854 raise HTTPConflict()
1855
1855
1856 if not comment_history:
1856 if not comment_history:
1857 raise HTTPNotFound()
1857 raise HTTPNotFound()
1858
1858
1859 Session().commit()
1859 Session().commit()
1860 if not comment.draft:
1860 if not comment.draft:
1861 PullRequestModel().trigger_pull_request_hook(
1861 PullRequestModel().trigger_pull_request_hook(
1862 pull_request, self._rhodecode_user, 'comment_edit',
1862 pull_request, self._rhodecode_user, 'comment_edit',
1863 data={'comment': comment})
1863 data={'comment': comment})
1864
1864
1865 return {
1865 return {
1866 'comment_history_id': comment_history.comment_history_id,
1866 'comment_history_id': comment_history.comment_history_id,
1867 'comment_id': comment.comment_id,
1867 'comment_id': comment.comment_id,
1868 'comment_version': comment_history.version,
1868 'comment_version': comment_history.version,
1869 'comment_author_username': comment_history.author.username,
1869 'comment_author_username': comment_history.author.username,
1870 'comment_author_gravatar': h.gravatar_url(comment_history.author.email, 16),
1870 'comment_author_gravatar': h.gravatar_url(comment_history.author.email, 16),
1871 'comment_created_on': h.age_component(comment_history.created_on,
1871 'comment_created_on': h.age_component(comment_history.created_on,
1872 time_is_local=True),
1872 time_is_local=True),
1873 }
1873 }
1874 else:
1874 else:
1875 log.warning('No permissions for user %s to edit comment_id: %s',
1875 log.warning('No permissions for user %s to edit comment_id: %s',
1876 self._rhodecode_db_user, comment_id)
1876 self._rhodecode_db_user, comment_id)
1877 raise HTTPNotFound()
1877 raise HTTPNotFound()
@@ -1,127 +1,125 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2017-2020 RhodeCode GmbH
3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22
22
23 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
23 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
24
24
25 import formencode
25 import formencode
26
26
27 from rhodecode.apps._base import RepoAppView
27 from rhodecode.apps._base import RepoAppView
28 from rhodecode.lib import audit_logger
28 from rhodecode.lib import audit_logger
29 from rhodecode.lib import helpers as h
29 from rhodecode.lib import helpers as h
30 from rhodecode.lib.auth import (
30 from rhodecode.lib.auth import (
31 LoginRequired, HasRepoPermissionAnyDecorator, CSRFRequired)
31 LoginRequired, HasRepoPermissionAnyDecorator, CSRFRequired)
32 from rhodecode.model.forms import IssueTrackerPatternsForm
32 from rhodecode.model.forms import IssueTrackerPatternsForm
33 from rhodecode.model.meta import Session
33 from rhodecode.model.meta import Session
34 from rhodecode.model.settings import SettingsModel
34 from rhodecode.model.settings import SettingsModel
35
35
36 log = logging.getLogger(__name__)
36 log = logging.getLogger(__name__)
37
37
38
38
39 class RepoSettingsIssueTrackersView(RepoAppView):
39 class RepoSettingsIssueTrackersView(RepoAppView):
40 def load_default_context(self):
40 def load_default_context(self):
41 c = self._get_local_tmpl_context()
41 c = self._get_local_tmpl_context()
42
43
44 return c
42 return c
45
43
46 @LoginRequired()
44 @LoginRequired()
47 @HasRepoPermissionAnyDecorator('repository.admin')
45 @HasRepoPermissionAnyDecorator('repository.admin')
48 def repo_issuetracker(self):
46 def repo_issuetracker(self):
49 c = self.load_default_context()
47 c = self.load_default_context()
50 c.active = 'issuetracker'
48 c.active = 'issuetracker'
51 c.data = 'data'
49 c.data = 'data'
52
50
53 c.settings_model = self.db_repo_patterns
51 c.settings_model = self.db_repo_patterns
54 c.global_patterns = c.settings_model.get_global_settings()
52 c.global_patterns = c.settings_model.get_global_settings()
55 c.repo_patterns = c.settings_model.get_repo_settings()
53 c.repo_patterns = c.settings_model.get_repo_settings()
56
54
57 return self._get_template_context(c)
55 return self._get_template_context(c)
58
56
59 @LoginRequired()
57 @LoginRequired()
60 @HasRepoPermissionAnyDecorator('repository.admin')
58 @HasRepoPermissionAnyDecorator('repository.admin')
61 @CSRFRequired()
59 @CSRFRequired()
62 def repo_issuetracker_test(self):
60 def repo_issuetracker_test(self):
63 return h.urlify_commit_message(
61 return h.urlify_commit_message(
64 self.request.POST.get('test_text', ''),
62 self.request.POST.get('test_text', ''),
65 self.db_repo_name)
63 self.db_repo_name)
66
64
67 @LoginRequired()
65 @LoginRequired()
68 @HasRepoPermissionAnyDecorator('repository.admin')
66 @HasRepoPermissionAnyDecorator('repository.admin')
69 @CSRFRequired()
67 @CSRFRequired()
70 def repo_issuetracker_delete(self):
68 def repo_issuetracker_delete(self):
71 _ = self.request.translate
69 _ = self.request.translate
72 uid = self.request.POST.get('uid')
70 uid = self.request.POST.get('uid')
73 repo_settings = self.db_repo_patterns
71 repo_settings = self.db_repo_patterns
74 try:
72 try:
75 repo_settings.delete_entries(uid)
73 repo_settings.delete_entries(uid)
76 except Exception:
74 except Exception:
77 h.flash(_('Error occurred during deleting issue tracker entry'),
75 h.flash(_('Error occurred during deleting issue tracker entry'),
78 category='error')
76 category='error')
79 raise HTTPNotFound()
77 raise HTTPNotFound()
80
78
81 SettingsModel().invalidate_settings_cache()
79 SettingsModel().invalidate_settings_cache()
82 h.flash(_('Removed issue tracker entry.'), category='success')
80 h.flash(_('Removed issue tracker entry.'), category='success')
83
81
84 return {'deleted': uid}
82 return {'deleted': uid}
85
83
86 def _update_patterns(self, form, repo_settings):
84 def _update_patterns(self, form, repo_settings):
87 for uid in form['delete_patterns']:
85 for uid in form['delete_patterns']:
88 repo_settings.delete_entries(uid)
86 repo_settings.delete_entries(uid)
89
87
90 for pattern_data in form['patterns']:
88 for pattern_data in form['patterns']:
91 for setting_key, pattern, type_ in pattern_data:
89 for setting_key, pattern, type_ in pattern_data:
92 sett = repo_settings.create_or_update_setting(
90 sett = repo_settings.create_or_update_setting(
93 setting_key, pattern.strip(), type_)
91 setting_key, pattern.strip(), type_)
94 Session().add(sett)
92 Session().add(sett)
95
93
96 Session().commit()
94 Session().commit()
97
95
98 @LoginRequired()
96 @LoginRequired()
99 @HasRepoPermissionAnyDecorator('repository.admin')
97 @HasRepoPermissionAnyDecorator('repository.admin')
100 @CSRFRequired()
98 @CSRFRequired()
101 def repo_issuetracker_update(self):
99 def repo_issuetracker_update(self):
102 _ = self.request.translate
100 _ = self.request.translate
103 # Save inheritance
101 # Save inheritance
104 repo_settings = self.db_repo_patterns
102 repo_settings = self.db_repo_patterns
105 inherited = (
103 inherited = (
106 self.request.POST.get('inherit_global_issuetracker') == "inherited")
104 self.request.POST.get('inherit_global_issuetracker') == "inherited")
107 repo_settings.inherit_global_settings = inherited
105 repo_settings.inherit_global_settings = inherited
108 Session().commit()
106 Session().commit()
109
107
110 try:
108 try:
111 form = IssueTrackerPatternsForm(self.request.translate)().to_python(self.request.POST)
109 form = IssueTrackerPatternsForm(self.request.translate)().to_python(self.request.POST)
112 except formencode.Invalid as errors:
110 except formencode.Invalid as errors:
113 log.exception('Failed to add new pattern')
111 log.exception('Failed to add new pattern')
114 error = errors
112 error = errors
115 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
113 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
116 category='error')
114 category='error')
117 raise HTTPFound(
115 raise HTTPFound(
118 h.route_path('edit_repo_issuetracker',
116 h.route_path('edit_repo_issuetracker',
119 repo_name=self.db_repo_name))
117 repo_name=self.db_repo_name))
120
118
121 if form:
119 if form:
122 self._update_patterns(form, repo_settings)
120 self._update_patterns(form, repo_settings)
123
121
124 h.flash(_('Updated issue tracker entries'), category='success')
122 h.flash(_('Updated issue tracker entries'), category='success')
125 raise HTTPFound(
123 raise HTTPFound(
126 h.route_path('edit_repo_issuetracker', repo_name=self.db_repo_name))
124 h.route_path('edit_repo_issuetracker', repo_name=self.db_repo_name))
127
125
@@ -1,290 +1,290 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import string
22 import string
23 import time
23 import time
24
24
25 import rhodecode
25 import rhodecode
26
26
27
27
28
28
29 from rhodecode.lib.view_utils import get_format_ref_id
29 from rhodecode.lib.view_utils import get_format_ref_id
30 from rhodecode.apps._base import RepoAppView
30 from rhodecode.apps._base import RepoAppView
31 from rhodecode.config.conf import (LANGUAGES_EXTENSIONS_MAP)
31 from rhodecode.config.conf import (LANGUAGES_EXTENSIONS_MAP)
32 from rhodecode.lib import helpers as h, rc_cache
32 from rhodecode.lib import helpers as h, rc_cache
33 from rhodecode.lib.utils2 import safe_str, safe_int
33 from rhodecode.lib.utils2 import safe_str, safe_int
34 from rhodecode.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator
34 from rhodecode.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator
35 from rhodecode.lib.ext_json import json
35 from rhodecode.lib.ext_json import json
36 from rhodecode.lib.vcs.backends.base import EmptyCommit
36 from rhodecode.lib.vcs.backends.base import EmptyCommit
37 from rhodecode.lib.vcs.exceptions import (
37 from rhodecode.lib.vcs.exceptions import (
38 CommitError, EmptyRepositoryError, CommitDoesNotExistError)
38 CommitError, EmptyRepositoryError, CommitDoesNotExistError)
39 from rhodecode.model.db import Statistics, CacheKey, User
39 from rhodecode.model.db import Statistics, CacheKey, User
40 from rhodecode.model.meta import Session
40 from rhodecode.model.meta import Session
41 from rhodecode.model.scm import ScmModel
41 from rhodecode.model.scm import ScmModel
42
42
43 log = logging.getLogger(__name__)
43 log = logging.getLogger(__name__)
44
44
45
45
46 class RepoSummaryView(RepoAppView):
46 class RepoSummaryView(RepoAppView):
47
47
48 def load_default_context(self):
48 def load_default_context(self):
49 c = self._get_local_tmpl_context(include_app_defaults=True)
49 c = self._get_local_tmpl_context(include_app_defaults=True)
50 c.rhodecode_repo = None
50 c.rhodecode_repo = None
51 if not c.repository_requirements_missing:
51 if not c.repository_requirements_missing:
52 c.rhodecode_repo = self.rhodecode_vcs_repo
52 c.rhodecode_repo = self.rhodecode_vcs_repo
53 return c
53 return c
54
54
55 def _load_commits_context(self, c):
55 def _load_commits_context(self, c):
56 p = safe_int(self.request.GET.get('page'), 1)
56 p = safe_int(self.request.GET.get('page'), 1)
57 size = safe_int(self.request.GET.get('size'), 10)
57 size = safe_int(self.request.GET.get('size'), 10)
58
58
59 def url_generator(page_num):
59 def url_generator(page_num):
60 query_params = {
60 query_params = {
61 'page': page_num,
61 'page': page_num,
62 'size': size
62 'size': size
63 }
63 }
64 return h.route_path(
64 return h.route_path(
65 'repo_summary_commits',
65 'repo_summary_commits',
66 repo_name=c.rhodecode_db_repo.repo_name, _query=query_params)
66 repo_name=c.rhodecode_db_repo.repo_name, _query=query_params)
67
67
68 pre_load = self.get_commit_preload_attrs()
68 pre_load = self.get_commit_preload_attrs()
69
69
70 try:
70 try:
71 collection = self.rhodecode_vcs_repo.get_commits(
71 collection = self.rhodecode_vcs_repo.get_commits(
72 pre_load=pre_load, translate_tags=False)
72 pre_load=pre_load, translate_tags=False)
73 except EmptyRepositoryError:
73 except EmptyRepositoryError:
74 collection = self.rhodecode_vcs_repo
74 collection = self.rhodecode_vcs_repo
75
75
76 c.repo_commits = h.RepoPage(
76 c.repo_commits = h.RepoPage(
77 collection, page=p, items_per_page=size, url_maker=url_generator)
77 collection, page=p, items_per_page=size, url_maker=url_generator)
78 page_ids = [x.raw_id for x in c.repo_commits]
78 page_ids = [x.raw_id for x in c.repo_commits]
79 c.comments = self.db_repo.get_comments(page_ids)
79 c.comments = self.db_repo.get_comments(page_ids)
80 c.statuses = self.db_repo.statuses(page_ids)
80 c.statuses = self.db_repo.statuses(page_ids)
81
81
82 def _prepare_and_set_clone_url(self, c):
82 def _prepare_and_set_clone_url(self, c):
83 username = ''
83 username = ''
84 if self._rhodecode_user.username != User.DEFAULT_USER:
84 if self._rhodecode_user.username != User.DEFAULT_USER:
85 username = safe_str(self._rhodecode_user.username)
85 username = safe_str(self._rhodecode_user.username)
86
86
87 _def_clone_uri = c.clone_uri_tmpl
87 _def_clone_uri = c.clone_uri_tmpl
88 _def_clone_uri_id = c.clone_uri_id_tmpl
88 _def_clone_uri_id = c.clone_uri_id_tmpl
89 _def_clone_uri_ssh = c.clone_uri_ssh_tmpl
89 _def_clone_uri_ssh = c.clone_uri_ssh_tmpl
90
90
91 c.clone_repo_url = self.db_repo.clone_url(
91 c.clone_repo_url = self.db_repo.clone_url(
92 user=username, uri_tmpl=_def_clone_uri)
92 user=username, uri_tmpl=_def_clone_uri)
93 c.clone_repo_url_id = self.db_repo.clone_url(
93 c.clone_repo_url_id = self.db_repo.clone_url(
94 user=username, uri_tmpl=_def_clone_uri_id)
94 user=username, uri_tmpl=_def_clone_uri_id)
95 c.clone_repo_url_ssh = self.db_repo.clone_url(
95 c.clone_repo_url_ssh = self.db_repo.clone_url(
96 uri_tmpl=_def_clone_uri_ssh, ssh=True)
96 uri_tmpl=_def_clone_uri_ssh, ssh=True)
97
97
98 @LoginRequired()
98 @LoginRequired()
99 @HasRepoPermissionAnyDecorator(
99 @HasRepoPermissionAnyDecorator(
100 'repository.read', 'repository.write', 'repository.admin')
100 'repository.read', 'repository.write', 'repository.admin')
101 def summary_commits(self):
101 def summary_commits(self):
102 c = self.load_default_context()
102 c = self.load_default_context()
103 self._prepare_and_set_clone_url(c)
103 self._prepare_and_set_clone_url(c)
104 self._load_commits_context(c)
104 self._load_commits_context(c)
105 return self._get_template_context(c)
105 return self._get_template_context(c)
106
106
107 @LoginRequired()
107 @LoginRequired()
108 @HasRepoPermissionAnyDecorator(
108 @HasRepoPermissionAnyDecorator(
109 'repository.read', 'repository.write', 'repository.admin')
109 'repository.read', 'repository.write', 'repository.admin')
110 def summary(self):
110 def summary(self):
111 c = self.load_default_context()
111 c = self.load_default_context()
112
112
113 # Prepare the clone URL
113 # Prepare the clone URL
114 self._prepare_and_set_clone_url(c)
114 self._prepare_and_set_clone_url(c)
115
115
116 # If enabled, get statistics data
116 # If enabled, get statistics data
117 c.show_stats = bool(self.db_repo.enable_statistics)
117 c.show_stats = bool(self.db_repo.enable_statistics)
118
118
119 stats = Session().query(Statistics) \
119 stats = Session().query(Statistics) \
120 .filter(Statistics.repository == self.db_repo) \
120 .filter(Statistics.repository == self.db_repo) \
121 .scalar()
121 .scalar()
122
122
123 c.stats_percentage = 0
123 c.stats_percentage = 0
124
124
125 if stats and stats.languages:
125 if stats and stats.languages:
126 c.no_data = False is self.db_repo.enable_statistics
126 c.no_data = False is self.db_repo.enable_statistics
127 lang_stats_d = json.loads(stats.languages)
127 lang_stats_d = json.loads(stats.languages)
128
128
129 # Sort first by decreasing count and second by the file extension,
129 # Sort first by decreasing count and second by the file extension,
130 # so we have a consistent output.
130 # so we have a consistent output.
131 lang_stats_items = sorted(lang_stats_d.items(),
131 lang_stats_items = sorted(lang_stats_d.items(),
132 key=lambda k: (-k[1], k[0]))[:10]
132 key=lambda k: (-k[1], k[0]))[:10]
133 lang_stats = [(x, {"count": y,
133 lang_stats = [(x, {"count": y,
134 "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
134 "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
135 for x, y in lang_stats_items]
135 for x, y in lang_stats_items]
136
136
137 c.trending_languages = json.dumps(lang_stats)
137 c.trending_languages = json.dumps(lang_stats)
138 else:
138 else:
139 c.no_data = True
139 c.no_data = True
140 c.trending_languages = json.dumps({})
140 c.trending_languages = json.dumps({})
141
141
142 scm_model = ScmModel()
142 scm_model = ScmModel()
143 c.enable_downloads = self.db_repo.enable_downloads
143 c.enable_downloads = self.db_repo.enable_downloads
144 c.repository_followers = scm_model.get_followers(self.db_repo)
144 c.repository_followers = scm_model.get_followers(self.db_repo)
145 c.repository_forks = scm_model.get_forks(self.db_repo)
145 c.repository_forks = scm_model.get_forks(self.db_repo)
146
146
147 # first interaction with the VCS instance after here...
147 # first interaction with the VCS instance after here...
148 if c.repository_requirements_missing:
148 if c.repository_requirements_missing:
149 self.request.override_renderer = \
149 self.request.override_renderer = \
150 'rhodecode:templates/summary/missing_requirements.mako'
150 'rhodecode:templates/summary/missing_requirements.mako'
151 return self._get_template_context(c)
151 return self._get_template_context(c)
152
152
153 c.readme_data, c.readme_file = \
153 c.readme_data, c.readme_file = \
154 self._get_readme_data(self.db_repo, c.visual.default_renderer)
154 self._get_readme_data(self.db_repo, c.visual.default_renderer)
155
155
156 # loads the summary commits template context
156 # loads the summary commits template context
157 self._load_commits_context(c)
157 self._load_commits_context(c)
158
158
159 return self._get_template_context(c)
159 return self._get_template_context(c)
160
160
161 @LoginRequired()
161 @LoginRequired()
162 @HasRepoPermissionAnyDecorator(
162 @HasRepoPermissionAnyDecorator(
163 'repository.read', 'repository.write', 'repository.admin')
163 'repository.read', 'repository.write', 'repository.admin')
164 def repo_stats(self):
164 def repo_stats(self):
165 show_stats = bool(self.db_repo.enable_statistics)
165 show_stats = bool(self.db_repo.enable_statistics)
166 repo_id = self.db_repo.repo_id
166 repo_id = self.db_repo.repo_id
167
167
168 landing_commit = self.db_repo.get_landing_commit()
168 landing_commit = self.db_repo.get_landing_commit()
169 if isinstance(landing_commit, EmptyCommit):
169 if isinstance(landing_commit, EmptyCommit):
170 return {'size': 0, 'code_stats': {}}
170 return {'size': 0, 'code_stats': {}}
171
171
172 cache_seconds = safe_int(rhodecode.CONFIG.get('rc_cache.cache_repo.expiration_time'))
172 cache_seconds = safe_int(rhodecode.CONFIG.get('rc_cache.cache_repo.expiration_time'))
173 cache_on = cache_seconds > 0
173 cache_on = cache_seconds > 0
174
174
175 log.debug(
175 log.debug(
176 'Computing REPO STATS for repo_id %s commit_id `%s` '
176 'Computing REPO STATS for repo_id %s commit_id `%s` '
177 'with caching: %s[TTL: %ss]' % (
177 'with caching: %s[TTL: %ss]' % (
178 repo_id, landing_commit, cache_on, cache_seconds or 0))
178 repo_id, landing_commit, cache_on, cache_seconds or 0))
179
179
180 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
180 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
181 region = rc_cache.get_or_create_region('cache_repo', cache_namespace_uid)
181 region = rc_cache.get_or_create_region('cache_repo', cache_namespace_uid)
182
182
183 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
183 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
184 condition=cache_on)
184 condition=cache_on)
185 def compute_stats(repo_id, commit_id, _show_stats):
185 def compute_stats(repo_id, commit_id, _show_stats):
186 code_stats = {}
186 code_stats = {}
187 size = 0
187 size = 0
188 try:
188 try:
189 commit = self.db_repo.get_commit(commit_id)
189 commit = self.db_repo.get_commit(commit_id)
190
190
191 for node in commit.get_filenodes_generator():
191 for node in commit.get_filenodes_generator():
192 size += node.size
192 size += node.size
193 if not _show_stats:
193 if not _show_stats:
194 continue
194 continue
195 ext = string.lower(node.extension)
195 ext = node.extension.lower()
196 ext_info = LANGUAGES_EXTENSIONS_MAP.get(ext)
196 ext_info = LANGUAGES_EXTENSIONS_MAP.get(ext)
197 if ext_info:
197 if ext_info:
198 if ext in code_stats:
198 if ext in code_stats:
199 code_stats[ext]['count'] += 1
199 code_stats[ext]['count'] += 1
200 else:
200 else:
201 code_stats[ext] = {"count": 1, "desc": ext_info}
201 code_stats[ext] = {"count": 1, "desc": ext_info}
202 except (EmptyRepositoryError, CommitDoesNotExistError):
202 except (EmptyRepositoryError, CommitDoesNotExistError):
203 pass
203 pass
204 return {'size': h.format_byte_size_binary(size),
204 return {'size': h.format_byte_size_binary(size),
205 'code_stats': code_stats}
205 'code_stats': code_stats}
206
206
207 stats = compute_stats(self.db_repo.repo_id, landing_commit.raw_id, show_stats)
207 stats = compute_stats(self.db_repo.repo_id, landing_commit.raw_id, show_stats)
208 return stats
208 return stats
209
209
210 @LoginRequired()
210 @LoginRequired()
211 @HasRepoPermissionAnyDecorator(
211 @HasRepoPermissionAnyDecorator(
212 'repository.read', 'repository.write', 'repository.admin')
212 'repository.read', 'repository.write', 'repository.admin')
213 def repo_refs_data(self):
213 def repo_refs_data(self):
214 _ = self.request.translate
214 _ = self.request.translate
215 self.load_default_context()
215 self.load_default_context()
216
216
217 repo = self.rhodecode_vcs_repo
217 repo = self.rhodecode_vcs_repo
218 refs_to_create = [
218 refs_to_create = [
219 (_("Branch"), repo.branches, 'branch'),
219 (_("Branch"), repo.branches, 'branch'),
220 (_("Tag"), repo.tags, 'tag'),
220 (_("Tag"), repo.tags, 'tag'),
221 (_("Bookmark"), repo.bookmarks, 'book'),
221 (_("Bookmark"), repo.bookmarks, 'book'),
222 ]
222 ]
223 res = self._create_reference_data(repo, self.db_repo_name, refs_to_create)
223 res = self._create_reference_data(repo, self.db_repo_name, refs_to_create)
224 data = {
224 data = {
225 'more': False,
225 'more': False,
226 'results': res
226 'results': res
227 }
227 }
228 return data
228 return data
229
229
230 @LoginRequired()
230 @LoginRequired()
231 @HasRepoPermissionAnyDecorator(
231 @HasRepoPermissionAnyDecorator(
232 'repository.read', 'repository.write', 'repository.admin')
232 'repository.read', 'repository.write', 'repository.admin')
233 def repo_refs_changelog_data(self):
233 def repo_refs_changelog_data(self):
234 _ = self.request.translate
234 _ = self.request.translate
235 self.load_default_context()
235 self.load_default_context()
236
236
237 repo = self.rhodecode_vcs_repo
237 repo = self.rhodecode_vcs_repo
238
238
239 refs_to_create = [
239 refs_to_create = [
240 (_("Branches"), repo.branches, 'branch'),
240 (_("Branches"), repo.branches, 'branch'),
241 (_("Closed branches"), repo.branches_closed, 'branch_closed'),
241 (_("Closed branches"), repo.branches_closed, 'branch_closed'),
242 # TODO: enable when vcs can handle bookmarks filters
242 # TODO: enable when vcs can handle bookmarks filters
243 # (_("Bookmarks"), repo.bookmarks, "book"),
243 # (_("Bookmarks"), repo.bookmarks, "book"),
244 ]
244 ]
245 res = self._create_reference_data(
245 res = self._create_reference_data(
246 repo, self.db_repo_name, refs_to_create)
246 repo, self.db_repo_name, refs_to_create)
247 data = {
247 data = {
248 'more': False,
248 'more': False,
249 'results': res
249 'results': res
250 }
250 }
251 return data
251 return data
252
252
253 def _create_reference_data(self, repo, full_repo_name, refs_to_create):
253 def _create_reference_data(self, repo, full_repo_name, refs_to_create):
254 format_ref_id = get_format_ref_id(repo)
254 format_ref_id = get_format_ref_id(repo)
255
255
256 result = []
256 result = []
257 for title, refs, ref_type in refs_to_create:
257 for title, refs, ref_type in refs_to_create:
258 if refs:
258 if refs:
259 result.append({
259 result.append({
260 'text': title,
260 'text': title,
261 'children': self._create_reference_items(
261 'children': self._create_reference_items(
262 repo, full_repo_name, refs, ref_type,
262 repo, full_repo_name, refs, ref_type,
263 format_ref_id),
263 format_ref_id),
264 })
264 })
265 return result
265 return result
266
266
267 def _create_reference_items(self, repo, full_repo_name, refs, ref_type, format_ref_id):
267 def _create_reference_items(self, repo, full_repo_name, refs, ref_type, format_ref_id):
268 result = []
268 result = []
269 is_svn = h.is_svn(repo)
269 is_svn = h.is_svn(repo)
270 for ref_name, raw_id in refs.items():
270 for ref_name, raw_id in refs.items():
271 files_url = self._create_files_url(
271 files_url = self._create_files_url(
272 repo, full_repo_name, ref_name, raw_id, is_svn)
272 repo, full_repo_name, ref_name, raw_id, is_svn)
273 result.append({
273 result.append({
274 'text': ref_name,
274 'text': ref_name,
275 'id': format_ref_id(ref_name, raw_id),
275 'id': format_ref_id(ref_name, raw_id),
276 'raw_id': raw_id,
276 'raw_id': raw_id,
277 'type': ref_type,
277 'type': ref_type,
278 'files_url': files_url,
278 'files_url': files_url,
279 'idx': 0,
279 'idx': 0,
280 })
280 })
281 return result
281 return result
282
282
283 def _create_files_url(self, repo, full_repo_name, ref_name, raw_id, is_svn):
283 def _create_files_url(self, repo, full_repo_name, ref_name, raw_id, is_svn):
284 use_commit_id = '/' in ref_name or is_svn
284 use_commit_id = '/' in ref_name or is_svn
285 return h.route_path(
285 return h.route_path(
286 'repo_files',
286 'repo_files',
287 repo_name=full_repo_name,
287 repo_name=full_repo_name,
288 f_path=ref_name if is_svn else '',
288 f_path=ref_name if is_svn else '',
289 commit_id=raw_id if use_commit_id else ref_name,
289 commit_id=raw_id if use_commit_id else ref_name,
290 _query=dict(at=ref_name))
290 _query=dict(at=ref_name))
@@ -1,172 +1,172 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2012-2020 RhodeCode GmbH
3 # Copyright (C) 2012-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 RhodeCode authentication library for PAM
22 RhodeCode authentication library for PAM
23 """
23 """
24
24
25 import colander
25 import colander
26 import grp
26 import grp
27 import logging
27 import logging
28 import pam
28 import pam
29 import pwd
29 import pwd
30 import re
30 import re
31 import socket
31 import socket
32
32
33 from rhodecode.translation import _
33 from rhodecode.translation import _
34 from rhodecode.authentication.base import (
34 from rhodecode.authentication.base import (
35 RhodeCodeExternalAuthPlugin, hybrid_property)
35 RhodeCodeExternalAuthPlugin, hybrid_property)
36 from rhodecode.authentication.schema import AuthnPluginSettingsSchemaBase
36 from rhodecode.authentication.schema import AuthnPluginSettingsSchemaBase
37 from rhodecode.authentication.routes import AuthnPluginResourceBase
37 from rhodecode.authentication.routes import AuthnPluginResourceBase
38 from rhodecode.lib.colander_utils import strip_whitespace
38 from rhodecode.lib.colander_utils import strip_whitespace
39
39
40 log = logging.getLogger(__name__)
40 log = logging.getLogger(__name__)
41
41
42
42
43 def plugin_factory(plugin_id, *args, **kwargs):
43 def plugin_factory(plugin_id, *args, **kwargs):
44 """
44 """
45 Factory function that is called during plugin discovery.
45 Factory function that is called during plugin discovery.
46 It returns the plugin instance.
46 It returns the plugin instance.
47 """
47 """
48 plugin = RhodeCodeAuthPlugin(plugin_id)
48 plugin = RhodeCodeAuthPlugin(plugin_id)
49 return plugin
49 return plugin
50
50
51
51
52 class PamAuthnResource(AuthnPluginResourceBase):
52 class PamAuthnResource(AuthnPluginResourceBase):
53 pass
53 pass
54
54
55
55
56 class PamSettingsSchema(AuthnPluginSettingsSchemaBase):
56 class PamSettingsSchema(AuthnPluginSettingsSchemaBase):
57 service = colander.SchemaNode(
57 service = colander.SchemaNode(
58 colander.String(),
58 colander.String(),
59 default='login',
59 default='login',
60 description=_('PAM service name to use for authentication.'),
60 description=_('PAM service name to use for authentication.'),
61 preparer=strip_whitespace,
61 preparer=strip_whitespace,
62 title=_('PAM service name'),
62 title=_('PAM service name'),
63 widget='string')
63 widget='string')
64 gecos = colander.SchemaNode(
64 gecos = colander.SchemaNode(
65 colander.String(),
65 colander.String(),
66 default='(?P<last_name>.+),\s*(?P<first_name>\w+)',
66 default=r'(?P<last_name>.+),\s*(?P<first_name>\w+)',
67 description=_('Regular expression for extracting user name/email etc. '
67 description=_('Regular expression for extracting user name/email etc. '
68 'from Unix userinfo.'),
68 'from Unix userinfo.'),
69 preparer=strip_whitespace,
69 preparer=strip_whitespace,
70 title=_('Gecos Regex'),
70 title=_('Gecos Regex'),
71 widget='string')
71 widget='string')
72
72
73
73
74 class RhodeCodeAuthPlugin(RhodeCodeExternalAuthPlugin):
74 class RhodeCodeAuthPlugin(RhodeCodeExternalAuthPlugin):
75 uid = 'pam'
75 uid = 'pam'
76 # PAM authentication can be slow. Repository operations involve a lot of
76 # PAM authentication can be slow. Repository operations involve a lot of
77 # auth calls. Little caching helps speedup push/pull operations significantly
77 # auth calls. Little caching helps speedup push/pull operations significantly
78 AUTH_CACHE_TTL = 4
78 AUTH_CACHE_TTL = 4
79
79
80 def includeme(self, config):
80 def includeme(self, config):
81 config.add_authn_plugin(self)
81 config.add_authn_plugin(self)
82 config.add_authn_resource(self.get_id(), PamAuthnResource(self))
82 config.add_authn_resource(self.get_id(), PamAuthnResource(self))
83 config.add_view(
83 config.add_view(
84 'rhodecode.authentication.views.AuthnPluginViewBase',
84 'rhodecode.authentication.views.AuthnPluginViewBase',
85 attr='settings_get',
85 attr='settings_get',
86 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
86 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
87 request_method='GET',
87 request_method='GET',
88 route_name='auth_home',
88 route_name='auth_home',
89 context=PamAuthnResource)
89 context=PamAuthnResource)
90 config.add_view(
90 config.add_view(
91 'rhodecode.authentication.views.AuthnPluginViewBase',
91 'rhodecode.authentication.views.AuthnPluginViewBase',
92 attr='settings_post',
92 attr='settings_post',
93 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
93 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
94 request_method='POST',
94 request_method='POST',
95 route_name='auth_home',
95 route_name='auth_home',
96 context=PamAuthnResource)
96 context=PamAuthnResource)
97
97
98 def get_display_name(self, load_from_settings=False):
98 def get_display_name(self, load_from_settings=False):
99 return _('PAM')
99 return _('PAM')
100
100
101 @classmethod
101 @classmethod
102 def docs(cls):
102 def docs(cls):
103 return "https://docs.rhodecode.com/RhodeCode-Enterprise/auth/auth-pam.html"
103 return "https://docs.rhodecode.com/RhodeCode-Enterprise/auth/auth-pam.html"
104
104
105 @hybrid_property
105 @hybrid_property
106 def name(self):
106 def name(self):
107 return u"pam"
107 return u"pam"
108
108
109 def get_settings_schema(self):
109 def get_settings_schema(self):
110 return PamSettingsSchema()
110 return PamSettingsSchema()
111
111
112 def use_fake_password(self):
112 def use_fake_password(self):
113 return True
113 return True
114
114
115 def auth(self, userobj, username, password, settings, **kwargs):
115 def auth(self, userobj, username, password, settings, **kwargs):
116 if not username or not password:
116 if not username or not password:
117 log.debug('Empty username or password skipping...')
117 log.debug('Empty username or password skipping...')
118 return None
118 return None
119 _pam = pam.pam()
119 _pam = pam.pam()
120 auth_result = _pam.authenticate(username, password, settings["service"])
120 auth_result = _pam.authenticate(username, password, settings["service"])
121
121
122 if not auth_result:
122 if not auth_result:
123 log.error("PAM was unable to authenticate user: %s", username)
123 log.error("PAM was unable to authenticate user: %s", username)
124 return None
124 return None
125
125
126 log.debug('Got PAM response %s', auth_result)
126 log.debug('Got PAM response %s', auth_result)
127
127
128 # old attrs fetched from RhodeCode database
128 # old attrs fetched from RhodeCode database
129 default_email = "%s@%s" % (username, socket.gethostname())
129 default_email = "%s@%s" % (username, socket.gethostname())
130 admin = getattr(userobj, 'admin', False)
130 admin = getattr(userobj, 'admin', False)
131 active = getattr(userobj, 'active', True)
131 active = getattr(userobj, 'active', True)
132 email = getattr(userobj, 'email', '') or default_email
132 email = getattr(userobj, 'email', '') or default_email
133 username = getattr(userobj, 'username', username)
133 username = getattr(userobj, 'username', username)
134 firstname = getattr(userobj, 'firstname', '')
134 firstname = getattr(userobj, 'firstname', '')
135 lastname = getattr(userobj, 'lastname', '')
135 lastname = getattr(userobj, 'lastname', '')
136 extern_type = getattr(userobj, 'extern_type', '')
136 extern_type = getattr(userobj, 'extern_type', '')
137
137
138 user_attrs = {
138 user_attrs = {
139 'username': username,
139 'username': username,
140 'firstname': firstname,
140 'firstname': firstname,
141 'lastname': lastname,
141 'lastname': lastname,
142 'groups': [g.gr_name for g in grp.getgrall()
142 'groups': [g.gr_name for g in grp.getgrall()
143 if username in g.gr_mem],
143 if username in g.gr_mem],
144 'user_group_sync': True,
144 'user_group_sync': True,
145 'email': email,
145 'email': email,
146 'admin': admin,
146 'admin': admin,
147 'active': active,
147 'active': active,
148 'active_from_extern': None,
148 'active_from_extern': None,
149 'extern_name': username,
149 'extern_name': username,
150 'extern_type': extern_type,
150 'extern_type': extern_type,
151 }
151 }
152
152
153 try:
153 try:
154 user_data = pwd.getpwnam(username)
154 user_data = pwd.getpwnam(username)
155 regex = settings["gecos"]
155 regex = settings["gecos"]
156 match = re.search(regex, user_data.pw_gecos)
156 match = re.search(regex, user_data.pw_gecos)
157 if match:
157 if match:
158 user_attrs["firstname"] = match.group('first_name')
158 user_attrs["firstname"] = match.group('first_name')
159 user_attrs["lastname"] = match.group('last_name')
159 user_attrs["lastname"] = match.group('last_name')
160 except Exception:
160 except Exception:
161 log.warning("Cannot extract additional info for PAM user")
161 log.warning("Cannot extract additional info for PAM user")
162 pass
162 pass
163
163
164 log.debug("pamuser: %s", user_attrs)
164 log.debug("pamuser: %s", user_attrs)
165 log.info('user `%s` authenticated correctly', user_attrs['username'],
165 log.info('user `%s` authenticated correctly', user_attrs['username'],
166 extra={"action": "user_auth_ok", "auth_module": "auth_pam", "username": user_attrs["username"]})
166 extra={"action": "user_auth_ok", "auth_module": "auth_pam", "username": user_attrs["username"]})
167 return user_attrs
167 return user_attrs
168
168
169
169
170 def includeme(config):
170 def includeme(config):
171 plugin_id = 'egg:rhodecode-enterprise-ce#{}'.format(RhodeCodeAuthPlugin.uid)
171 plugin_id = 'egg:rhodecode-enterprise-ce#{}'.format(RhodeCodeAuthPlugin.uid)
172 plugin_factory(plugin_id).includeme(config)
172 plugin_factory(plugin_id).includeme(config)
@@ -1,89 +1,90 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import logging
22 import logging
23 import rhodecode
23 import rhodecode
24 import collections
24
25
25 from rhodecode.config import utils
26 from rhodecode.config import utils
26
27
27 from rhodecode.lib.utils import load_rcextensions
28 from rhodecode.lib.utils import load_rcextensions
28 from rhodecode.lib.utils2 import str2bool
29 from rhodecode.lib.utils2 import str2bool
29 from rhodecode.lib.vcs import connect_vcs
30 from rhodecode.lib.vcs import connect_vcs
30
31
31 log = logging.getLogger(__name__)
32 log = logging.getLogger(__name__)
32
33
33
34
34 def load_pyramid_environment(global_config, settings):
35 def load_pyramid_environment(global_config, settings):
35 # Some parts of the code expect a merge of global and app settings.
36 # Some parts of the code expect a merge of global and app settings.
36 settings_merged = global_config.copy()
37 settings_merged = global_config.copy()
37 settings_merged.update(settings)
38 settings_merged.update(settings)
38
39
39 # TODO(marcink): probably not required anymore
40 # TODO(marcink): probably not required anymore
40 # configure channelstream,
41 # configure channelstream,
41 settings_merged['channelstream_config'] = {
42 settings_merged['channelstream_config'] = {
42 'enabled': str2bool(settings_merged.get('channelstream.enabled', False)),
43 'enabled': str2bool(settings_merged.get('channelstream.enabled', False)),
43 'server': settings_merged.get('channelstream.server'),
44 'server': settings_merged.get('channelstream.server'),
44 'secret': settings_merged.get('channelstream.secret')
45 'secret': settings_merged.get('channelstream.secret')
45 }
46 }
46
47
47 # If this is a test run we prepare the test environment like
48 # If this is a test run we prepare the test environment like
48 # creating a test database, test search index and test repositories.
49 # creating a test database, test search index and test repositories.
49 # This has to be done before the database connection is initialized.
50 # This has to be done before the database connection is initialized.
50 if settings['is_test']:
51 if settings['is_test']:
51 rhodecode.is_test = True
52 rhodecode.is_test = True
52 rhodecode.disable_error_handler = True
53 rhodecode.disable_error_handler = True
53 from rhodecode import authentication
54 from rhodecode import authentication
54 authentication.plugin_default_auth_ttl = 0
55 authentication.plugin_default_auth_ttl = 0
55
56
56 utils.initialize_test_environment(settings_merged)
57 utils.initialize_test_environment(settings_merged)
57
58
58 # Initialize the database connection.
59 # Initialize the database connection.
59 utils.initialize_database(settings_merged)
60 utils.initialize_database(settings_merged)
60
61
61 load_rcextensions(root_path=settings_merged['here'])
62 load_rcextensions(root_path=settings_merged['here'])
62
63
63 # Limit backends to `vcs.backends` from configuration, and preserve the order
64 # Limit backends to `vcs.backends` from configuration, and preserve the order
64 for alias in rhodecode.BACKENDS.keys():
65 for alias in rhodecode.BACKENDS.keys():
65 if alias not in settings['vcs.backends']:
66 if alias not in settings['vcs.backends']:
66 del rhodecode.BACKENDS[alias]
67 del rhodecode.BACKENDS[alias]
67
68
68 _sorted_backend = sorted(rhodecode.BACKENDS.items(),
69 _sorted_backend = sorted(rhodecode.BACKENDS.items(),
69 key=lambda item: settings['vcs.backends'].index(item[0]))
70 key=lambda item: settings['vcs.backends'].index(item[0]))
70 rhodecode.BACKENDS = rhodecode.OrderedDict(_sorted_backend)
71 rhodecode.BACKENDS = collections.OrderedDict(_sorted_backend)
71
72
72 log.info('Enabled VCS backends: %s', rhodecode.BACKENDS.keys())
73 log.info('Enabled VCS backends: %s', rhodecode.BACKENDS.keys())
73
74
74 # initialize vcs client and optionally run the server if enabled
75 # initialize vcs client and optionally run the server if enabled
75 vcs_server_uri = settings['vcs.server']
76 vcs_server_uri = settings['vcs.server']
76 vcs_server_enabled = settings['vcs.server.enable']
77 vcs_server_enabled = settings['vcs.server.enable']
77
78
78 utils.configure_vcs(settings)
79 utils.configure_vcs(settings)
79
80
80 # Store the settings to make them available to other modules.
81 # Store the settings to make them available to other modules.
81
82
82 rhodecode.PYRAMID_SETTINGS = settings_merged
83 rhodecode.PYRAMID_SETTINGS = settings_merged
83 rhodecode.CONFIG = settings_merged
84 rhodecode.CONFIG = settings_merged
84 rhodecode.CONFIG['default_user_id'] = utils.get_default_user_id()
85 rhodecode.CONFIG['default_user_id'] = utils.get_default_user_id()
85
86
86 if vcs_server_enabled:
87 if vcs_server_enabled:
87 connect_vcs(vcs_server_uri, utils.get_vcs_server_protocol(settings))
88 connect_vcs(vcs_server_uri, utils.get_vcs_server_protocol(settings))
88 else:
89 else:
89 log.warning('vcs-server not enabled, vcs connection unavailable')
90 log.warning('vcs-server not enabled, vcs connection unavailable')
@@ -1,282 +1,280 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Adapters
3 Adapters
4 --------
4 --------
5
5
6 .. contents::
6 .. contents::
7 :backlinks: none
7 :backlinks: none
8
8
9 The :func:`authomatic.login` function needs access to functionality like
9 The :func:`authomatic.login` function needs access to functionality like
10 getting the **URL** of the handler where it is being called, getting the
10 getting the **URL** of the handler where it is being called, getting the
11 **request params** and **cookies** and **writing the body**, **headers**
11 **request params** and **cookies** and **writing the body**, **headers**
12 and **status** to the response.
12 and **status** to the response.
13
13
14 Since implementation of these features varies across Python web frameworks,
14 Since implementation of these features varies across Python web frameworks,
15 the Authomatic library uses **adapters** to unify these differences into a
15 the Authomatic library uses **adapters** to unify these differences into a
16 single interface.
16 single interface.
17
17
18 Available Adapters
18 Available Adapters
19 ^^^^^^^^^^^^^^^^^^
19 ^^^^^^^^^^^^^^^^^^
20
20
21 If you are missing an adapter for the framework of your choice, please
21 If you are missing an adapter for the framework of your choice, please
22 open an `enhancement issue <https://github.com/authomatic/authomatic/issues>`_
22 open an `enhancement issue <https://github.com/authomatic/authomatic/issues>`_
23 or consider a contribution to this module by
23 or consider a contribution to this module by
24 :ref:`implementing <implement_adapters>` one by yourself.
24 :ref:`implementing <implement_adapters>` one by yourself.
25 Its very easy and shouldn't take you more than a few minutes.
25 Its very easy and shouldn't take you more than a few minutes.
26
26
27 .. autoclass:: DjangoAdapter
27 .. autoclass:: DjangoAdapter
28 :members:
28 :members:
29
29
30 .. autoclass:: Webapp2Adapter
30 .. autoclass:: Webapp2Adapter
31 :members:
31 :members:
32
32
33 .. autoclass:: WebObAdapter
33 .. autoclass:: WebObAdapter
34 :members:
34 :members:
35
35
36 .. autoclass:: WerkzeugAdapter
36 .. autoclass:: WerkzeugAdapter
37 :members:
37 :members:
38
38
39 .. _implement_adapters:
39 .. _implement_adapters:
40
40
41 Implementing an Adapter
41 Implementing an Adapter
42 ^^^^^^^^^^^^^^^^^^^^^^^
42 ^^^^^^^^^^^^^^^^^^^^^^^
43
43
44 Implementing an adapter for a Python web framework is pretty easy.
44 Implementing an adapter for a Python web framework is pretty easy.
45
45
46 Do it by subclassing the :class:`.BaseAdapter` abstract class.
46 Do it by subclassing the :class:`.BaseAdapter` abstract class.
47 There are only **six** members that you need to implement.
47 There are only **six** members that you need to implement.
48
48
49 Moreover if your framework is based on the |webob|_ or |werkzeug|_ package
49 Moreover if your framework is based on the |webob|_ or |werkzeug|_ package
50 you can subclass the :class:`.WebObAdapter` or :class:`.WerkzeugAdapter`
50 you can subclass the :class:`.WebObAdapter` or :class:`.WerkzeugAdapter`
51 respectively.
51 respectively.
52
52
53 .. autoclass:: BaseAdapter
53 .. autoclass:: BaseAdapter
54 :members:
54 :members:
55
55
56 """
56 """
57
57
58 import abc
58 import abc
59 from authomatic.core import Response
59 from authomatic.core import Response
60
60
61
61
62 class BaseAdapter(object):
62 class BaseAdapter(object, metaclass=abc.ABCMeta):
63 """
63 """
64 Base class for platform adapters.
64 Base class for platform adapters.
65
65
66 Defines common interface for WSGI framework specific functionality.
66 Defines common interface for WSGI framework specific functionality.
67
67
68 """
68 """
69
69
70 __metaclass__ = abc.ABCMeta
71
72 @abc.abstractproperty
70 @abc.abstractproperty
73 def params(self):
71 def params(self):
74 """
72 """
75 Must return a :class:`dict` of all request parameters of any HTTP
73 Must return a :class:`dict` of all request parameters of any HTTP
76 method.
74 method.
77
75
78 :returns:
76 :returns:
79 :class:`dict`
77 :class:`dict`
80
78
81 """
79 """
82
80
83 @abc.abstractproperty
81 @abc.abstractproperty
84 def url(self):
82 def url(self):
85 """
83 """
86 Must return the url of the actual request including path but without
84 Must return the url of the actual request including path but without
87 query and fragment.
85 query and fragment.
88
86
89 :returns:
87 :returns:
90 :class:`str`
88 :class:`str`
91
89
92 """
90 """
93
91
94 @abc.abstractproperty
92 @abc.abstractproperty
95 def cookies(self):
93 def cookies(self):
96 """
94 """
97 Must return cookies as a :class:`dict`.
95 Must return cookies as a :class:`dict`.
98
96
99 :returns:
97 :returns:
100 :class:`dict`
98 :class:`dict`
101
99
102 """
100 """
103
101
104 @abc.abstractmethod
102 @abc.abstractmethod
105 def write(self, value):
103 def write(self, value):
106 """
104 """
107 Must write specified value to response.
105 Must write specified value to response.
108
106
109 :param str value:
107 :param str value:
110 String to be written to response.
108 String to be written to response.
111
109
112 """
110 """
113
111
114 @abc.abstractmethod
112 @abc.abstractmethod
115 def set_header(self, key, value):
113 def set_header(self, key, value):
116 """
114 """
117 Must set response headers to ``Key: value``.
115 Must set response headers to ``Key: value``.
118
116
119 :param str key:
117 :param str key:
120 Header name.
118 Header name.
121
119
122 :param str value:
120 :param str value:
123 Header value.
121 Header value.
124
122
125 """
123 """
126
124
127 @abc.abstractmethod
125 @abc.abstractmethod
128 def set_status(self, status):
126 def set_status(self, status):
129 """
127 """
130 Must set the response status e.g. ``'302 Found'``.
128 Must set the response status e.g. ``'302 Found'``.
131
129
132 :param str status:
130 :param str status:
133 The HTTP response status.
131 The HTTP response status.
134
132
135 """
133 """
136
134
137
135
138 class DjangoAdapter(BaseAdapter):
136 class DjangoAdapter(BaseAdapter):
139 """
137 """
140 Adapter for the |django|_ framework.
138 Adapter for the |django|_ framework.
141 """
139 """
142
140
143 def __init__(self, request, response):
141 def __init__(self, request, response):
144 """
142 """
145 :param request:
143 :param request:
146 An instance of the :class:`django.http.HttpRequest` class.
144 An instance of the :class:`django.http.HttpRequest` class.
147
145
148 :param response:
146 :param response:
149 An instance of the :class:`django.http.HttpResponse` class.
147 An instance of the :class:`django.http.HttpResponse` class.
150 """
148 """
151 self.request = request
149 self.request = request
152 self.response = response
150 self.response = response
153
151
154 @property
152 @property
155 def params(self):
153 def params(self):
156 params = {}
154 params = {}
157 params.update(self.request.GET.dict())
155 params.update(self.request.GET.dict())
158 params.update(self.request.POST.dict())
156 params.update(self.request.POST.dict())
159 return params
157 return params
160
158
161 @property
159 @property
162 def url(self):
160 def url(self):
163 return self.request.build_absolute_uri(self.request.path)
161 return self.request.build_absolute_uri(self.request.path)
164
162
165 @property
163 @property
166 def cookies(self):
164 def cookies(self):
167 return dict(self.request.COOKIES)
165 return dict(self.request.COOKIES)
168
166
169 def write(self, value):
167 def write(self, value):
170 self.response.write(value)
168 self.response.write(value)
171
169
172 def set_header(self, key, value):
170 def set_header(self, key, value):
173 self.response[key] = value
171 self.response[key] = value
174
172
175 def set_status(self, status):
173 def set_status(self, status):
176 status_code, reason = status.split(' ', 1)
174 status_code, reason = status.split(' ', 1)
177 self.response.status_code = int(status_code)
175 self.response.status_code = int(status_code)
178
176
179
177
180 class WebObAdapter(BaseAdapter):
178 class WebObAdapter(BaseAdapter):
181 """
179 """
182 Adapter for the |webob|_ package.
180 Adapter for the |webob|_ package.
183 """
181 """
184
182
185 def __init__(self, request, response):
183 def __init__(self, request, response):
186 """
184 """
187 :param request:
185 :param request:
188 A |webob|_ :class:`Request` instance.
186 A |webob|_ :class:`Request` instance.
189
187
190 :param response:
188 :param response:
191 A |webob|_ :class:`Response` instance.
189 A |webob|_ :class:`Response` instance.
192 """
190 """
193 self.request = request
191 self.request = request
194 self.response = response
192 self.response = response
195
193
196 # =========================================================================
194 # =========================================================================
197 # Request
195 # Request
198 # =========================================================================
196 # =========================================================================
199
197
200 @property
198 @property
201 def url(self):
199 def url(self):
202 return self.request.path_url
200 return self.request.path_url
203
201
204 @property
202 @property
205 def params(self):
203 def params(self):
206 return dict(self.request.params)
204 return dict(self.request.params)
207
205
208 @property
206 @property
209 def cookies(self):
207 def cookies(self):
210 return dict(self.request.cookies)
208 return dict(self.request.cookies)
211
209
212 # =========================================================================
210 # =========================================================================
213 # Response
211 # Response
214 # =========================================================================
212 # =========================================================================
215
213
216 def write(self, value):
214 def write(self, value):
217 self.response.write(value)
215 self.response.write(value)
218
216
219 def set_header(self, key, value):
217 def set_header(self, key, value):
220 self.response.headers[key] = str(value)
218 self.response.headers[key] = str(value)
221
219
222 def set_status(self, status):
220 def set_status(self, status):
223 self.response.status = status
221 self.response.status = status
224
222
225
223
226 class Webapp2Adapter(WebObAdapter):
224 class Webapp2Adapter(WebObAdapter):
227 """
225 """
228 Adapter for the |webapp2|_ framework.
226 Adapter for the |webapp2|_ framework.
229
227
230 Inherits from the :class:`.WebObAdapter`.
228 Inherits from the :class:`.WebObAdapter`.
231
229
232 """
230 """
233
231
234 def __init__(self, handler):
232 def __init__(self, handler):
235 """
233 """
236 :param handler:
234 :param handler:
237 A :class:`webapp2.RequestHandler` instance.
235 A :class:`webapp2.RequestHandler` instance.
238 """
236 """
239 self.request = handler.request
237 self.request = handler.request
240 self.response = handler.response
238 self.response = handler.response
241
239
242
240
243 class WerkzeugAdapter(BaseAdapter):
241 class WerkzeugAdapter(BaseAdapter):
244 """
242 """
245 Adapter for |flask|_ and other |werkzeug|_ based frameworks.
243 Adapter for |flask|_ and other |werkzeug|_ based frameworks.
246
244
247 Thanks to `Mark Steve Samson <http://marksteve.com>`_.
245 Thanks to `Mark Steve Samson <http://marksteve.com>`_.
248
246
249 """
247 """
250
248
251 @property
249 @property
252 def params(self):
250 def params(self):
253 return self.request.args
251 return self.request.args
254
252
255 @property
253 @property
256 def url(self):
254 def url(self):
257 return self.request.base_url
255 return self.request.base_url
258
256
259 @property
257 @property
260 def cookies(self):
258 def cookies(self):
261 return self.request.cookies
259 return self.request.cookies
262
260
263 def __init__(self, request, response):
261 def __init__(self, request, response):
264 """
262 """
265 :param request:
263 :param request:
266 Instance of the :class:`werkzeug.wrappers.Request` class.
264 Instance of the :class:`werkzeug.wrappers.Request` class.
267
265
268 :param response:
266 :param response:
269 Instance of the :class:`werkzeug.wrappers.Response` class.
267 Instance of the :class:`werkzeug.wrappers.Response` class.
270 """
268 """
271
269
272 self.request = request
270 self.request = request
273 self.response = response
271 self.response = response
274
272
275 def write(self, value):
273 def write(self, value):
276 self.response.data = self.response.data + value
274 self.response.data = self.response.data + value
277
275
278 def set_header(self, key, value):
276 def set_header(self, key, value):
279 self.response.headers[key] = value
277 self.response.headers[key] = value
280
278
281 def set_status(self, status):
279 def set_status(self, status):
282 self.response.status = status
280 self.response.status = status
@@ -1,305 +1,305 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2017-2020 RhodeCode GmbH
3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import datetime
22 import datetime
23
23
24 from rhodecode.lib.jsonalchemy import JsonRaw
24 from rhodecode.lib.jsonalchemy import JsonRaw
25 from rhodecode.model import meta
25 from rhodecode.model import meta
26 from rhodecode.model.db import User, UserLog, Repository
26 from rhodecode.model.db import User, UserLog, Repository
27
27
28
28
29 log = logging.getLogger(__name__)
29 log = logging.getLogger(__name__)
30
30
31 # action as key, and expected action_data as value
31 # action as key, and expected action_data as value
32 ACTIONS_V1 = {
32 ACTIONS_V1 = {
33 'user.login.success': {'user_agent': ''},
33 'user.login.success': {'user_agent': ''},
34 'user.login.failure': {'user_agent': ''},
34 'user.login.failure': {'user_agent': ''},
35 'user.logout': {'user_agent': ''},
35 'user.logout': {'user_agent': ''},
36 'user.register': {},
36 'user.register': {},
37 'user.password.reset_request': {},
37 'user.password.reset_request': {},
38 'user.push': {'user_agent': '', 'commit_ids': []},
38 'user.push': {'user_agent': '', 'commit_ids': []},
39 'user.pull': {'user_agent': ''},
39 'user.pull': {'user_agent': ''},
40
40
41 'user.create': {'data': {}},
41 'user.create': {'data': {}},
42 'user.delete': {'old_data': {}},
42 'user.delete': {'old_data': {}},
43 'user.edit': {'old_data': {}},
43 'user.edit': {'old_data': {}},
44 'user.edit.permissions': {},
44 'user.edit.permissions': {},
45 'user.edit.ip.add': {'ip': {}, 'user': {}},
45 'user.edit.ip.add': {'ip': {}, 'user': {}},
46 'user.edit.ip.delete': {'ip': {}, 'user': {}},
46 'user.edit.ip.delete': {'ip': {}, 'user': {}},
47 'user.edit.token.add': {'token': {}, 'user': {}},
47 'user.edit.token.add': {'token': {}, 'user': {}},
48 'user.edit.token.delete': {'token': {}, 'user': {}},
48 'user.edit.token.delete': {'token': {}, 'user': {}},
49 'user.edit.email.add': {'email': ''},
49 'user.edit.email.add': {'email': ''},
50 'user.edit.email.delete': {'email': ''},
50 'user.edit.email.delete': {'email': ''},
51 'user.edit.ssh_key.add': {'token': {}, 'user': {}},
51 'user.edit.ssh_key.add': {'token': {}, 'user': {}},
52 'user.edit.ssh_key.delete': {'token': {}, 'user': {}},
52 'user.edit.ssh_key.delete': {'token': {}, 'user': {}},
53 'user.edit.password_reset.enabled': {},
53 'user.edit.password_reset.enabled': {},
54 'user.edit.password_reset.disabled': {},
54 'user.edit.password_reset.disabled': {},
55
55
56 'user_group.create': {'data': {}},
56 'user_group.create': {'data': {}},
57 'user_group.delete': {'old_data': {}},
57 'user_group.delete': {'old_data': {}},
58 'user_group.edit': {'old_data': {}},
58 'user_group.edit': {'old_data': {}},
59 'user_group.edit.permissions': {},
59 'user_group.edit.permissions': {},
60 'user_group.edit.member.add': {'user': {}},
60 'user_group.edit.member.add': {'user': {}},
61 'user_group.edit.member.delete': {'user': {}},
61 'user_group.edit.member.delete': {'user': {}},
62
62
63 'repo.create': {'data': {}},
63 'repo.create': {'data': {}},
64 'repo.fork': {'data': {}},
64 'repo.fork': {'data': {}},
65 'repo.edit': {'old_data': {}},
65 'repo.edit': {'old_data': {}},
66 'repo.edit.permissions': {},
66 'repo.edit.permissions': {},
67 'repo.edit.permissions.branch': {},
67 'repo.edit.permissions.branch': {},
68 'repo.archive': {'old_data': {}},
68 'repo.archive': {'old_data': {}},
69 'repo.delete': {'old_data': {}},
69 'repo.delete': {'old_data': {}},
70
70
71 'repo.archive.download': {'user_agent': '', 'archive_name': '',
71 'repo.archive.download': {'user_agent': '', 'archive_name': '',
72 'archive_spec': '', 'archive_cached': ''},
72 'archive_spec': '', 'archive_cached': ''},
73
73
74 'repo.permissions.branch_rule.create': {},
74 'repo.permissions.branch_rule.create': {},
75 'repo.permissions.branch_rule.edit': {},
75 'repo.permissions.branch_rule.edit': {},
76 'repo.permissions.branch_rule.delete': {},
76 'repo.permissions.branch_rule.delete': {},
77
77
78 'repo.pull_request.create': '',
78 'repo.pull_request.create': '',
79 'repo.pull_request.edit': '',
79 'repo.pull_request.edit': '',
80 'repo.pull_request.delete': '',
80 'repo.pull_request.delete': '',
81 'repo.pull_request.close': '',
81 'repo.pull_request.close': '',
82 'repo.pull_request.merge': '',
82 'repo.pull_request.merge': '',
83 'repo.pull_request.vote': '',
83 'repo.pull_request.vote': '',
84 'repo.pull_request.comment.create': '',
84 'repo.pull_request.comment.create': '',
85 'repo.pull_request.comment.edit': '',
85 'repo.pull_request.comment.edit': '',
86 'repo.pull_request.comment.delete': '',
86 'repo.pull_request.comment.delete': '',
87
87
88 'repo.pull_request.reviewer.add': '',
88 'repo.pull_request.reviewer.add': '',
89 'repo.pull_request.reviewer.delete': '',
89 'repo.pull_request.reviewer.delete': '',
90
90
91 'repo.pull_request.observer.add': '',
91 'repo.pull_request.observer.add': '',
92 'repo.pull_request.observer.delete': '',
92 'repo.pull_request.observer.delete': '',
93
93
94 'repo.commit.strip': {'commit_id': ''},
94 'repo.commit.strip': {'commit_id': ''},
95 'repo.commit.comment.create': {'data': {}},
95 'repo.commit.comment.create': {'data': {}},
96 'repo.commit.comment.delete': {'data': {}},
96 'repo.commit.comment.delete': {'data': {}},
97 'repo.commit.comment.edit': {'data': {}},
97 'repo.commit.comment.edit': {'data': {}},
98 'repo.commit.vote': '',
98 'repo.commit.vote': '',
99
99
100 'repo.artifact.add': '',
100 'repo.artifact.add': '',
101 'repo.artifact.delete': '',
101 'repo.artifact.delete': '',
102
102
103 'repo_group.create': {'data': {}},
103 'repo_group.create': {'data': {}},
104 'repo_group.edit': {'old_data': {}},
104 'repo_group.edit': {'old_data': {}},
105 'repo_group.edit.permissions': {},
105 'repo_group.edit.permissions': {},
106 'repo_group.delete': {'old_data': {}},
106 'repo_group.delete': {'old_data': {}},
107 }
107 }
108
108
109 ACTIONS = ACTIONS_V1
109 ACTIONS = ACTIONS_V1
110
110
111 SOURCE_WEB = 'source_web'
111 SOURCE_WEB = 'source_web'
112 SOURCE_API = 'source_api'
112 SOURCE_API = 'source_api'
113
113
114
114
115 class UserWrap(object):
115 class UserWrap(object):
116 """
116 """
117 Fake object used to imitate AuthUser
117 Fake object used to imitate AuthUser
118 """
118 """
119
119
120 def __init__(self, user_id=None, username=None, ip_addr=None):
120 def __init__(self, user_id=None, username=None, ip_addr=None):
121 self.user_id = user_id
121 self.user_id = user_id
122 self.username = username
122 self.username = username
123 self.ip_addr = ip_addr
123 self.ip_addr = ip_addr
124
124
125
125
126 class RepoWrap(object):
126 class RepoWrap(object):
127 """
127 """
128 Fake object used to imitate RepoObject that audit logger requires
128 Fake object used to imitate RepoObject that audit logger requires
129 """
129 """
130
130
131 def __init__(self, repo_id=None, repo_name=None):
131 def __init__(self, repo_id=None, repo_name=None):
132 self.repo_id = repo_id
132 self.repo_id = repo_id
133 self.repo_name = repo_name
133 self.repo_name = repo_name
134
134
135
135
136 def _store_log(action_name, action_data, user_id, username, user_data,
136 def _store_log(action_name, action_data, user_id, username, user_data,
137 ip_address, repository_id, repository_name):
137 ip_address, repository_id, repository_name):
138 user_log = UserLog()
138 user_log = UserLog()
139 user_log.version = UserLog.VERSION_2
139 user_log.version = UserLog.VERSION_2
140
140
141 user_log.action = action_name
141 user_log.action = action_name
142 user_log.action_data = action_data or JsonRaw(u'{}')
142 user_log.action_data = action_data or JsonRaw('{}')
143
143
144 user_log.user_ip = ip_address
144 user_log.user_ip = ip_address
145
145
146 user_log.user_id = user_id
146 user_log.user_id = user_id
147 user_log.username = username
147 user_log.username = username
148 user_log.user_data = user_data or JsonRaw(u'{}')
148 user_log.user_data = user_data or JsonRaw('{}')
149
149
150 user_log.repository_id = repository_id
150 user_log.repository_id = repository_id
151 user_log.repository_name = repository_name
151 user_log.repository_name = repository_name
152
152
153 user_log.action_date = datetime.datetime.now()
153 user_log.action_date = datetime.datetime.now()
154
154
155 return user_log
155 return user_log
156
156
157
157
158 def store_web(*args, **kwargs):
158 def store_web(*args, **kwargs):
159 action_data = {}
159 action_data = {}
160 org_action_data = kwargs.pop('action_data', {})
160 org_action_data = kwargs.pop('action_data', {})
161 action_data.update(org_action_data)
161 action_data.update(org_action_data)
162 action_data['source'] = SOURCE_WEB
162 action_data['source'] = SOURCE_WEB
163 kwargs['action_data'] = action_data
163 kwargs['action_data'] = action_data
164
164
165 return store(*args, **kwargs)
165 return store(*args, **kwargs)
166
166
167
167
168 def store_api(*args, **kwargs):
168 def store_api(*args, **kwargs):
169 action_data = {}
169 action_data = {}
170 org_action_data = kwargs.pop('action_data', {})
170 org_action_data = kwargs.pop('action_data', {})
171 action_data.update(org_action_data)
171 action_data.update(org_action_data)
172 action_data['source'] = SOURCE_API
172 action_data['source'] = SOURCE_API
173 kwargs['action_data'] = action_data
173 kwargs['action_data'] = action_data
174
174
175 return store(*args, **kwargs)
175 return store(*args, **kwargs)
176
176
177
177
178 def store(action, user, action_data=None, user_data=None, ip_addr=None,
178 def store(action, user, action_data=None, user_data=None, ip_addr=None,
179 repo=None, sa_session=None, commit=False):
179 repo=None, sa_session=None, commit=False):
180 """
180 """
181 Audit logger for various actions made by users, typically this
181 Audit logger for various actions made by users, typically this
182 results in a call such::
182 results in a call such::
183
183
184 from rhodecode.lib import audit_logger
184 from rhodecode.lib import audit_logger
185
185
186 audit_logger.store(
186 audit_logger.store(
187 'repo.edit', user=self._rhodecode_user)
187 'repo.edit', user=self._rhodecode_user)
188 audit_logger.store(
188 audit_logger.store(
189 'repo.delete', action_data={'data': repo_data},
189 'repo.delete', action_data={'data': repo_data},
190 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'))
190 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'))
191
191
192 # repo action
192 # repo action
193 audit_logger.store(
193 audit_logger.store(
194 'repo.delete',
194 'repo.delete',
195 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'),
195 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'),
196 repo=audit_logger.RepoWrap(repo_name='some-repo'))
196 repo=audit_logger.RepoWrap(repo_name='some-repo'))
197
197
198 # repo action, when we know and have the repository object already
198 # repo action, when we know and have the repository object already
199 audit_logger.store(
199 audit_logger.store(
200 'repo.delete', action_data={'source': audit_logger.SOURCE_WEB, },
200 'repo.delete', action_data={'source': audit_logger.SOURCE_WEB, },
201 user=self._rhodecode_user,
201 user=self._rhodecode_user,
202 repo=repo_object)
202 repo=repo_object)
203
203
204 # alternative wrapper to the above
204 # alternative wrapper to the above
205 audit_logger.store_web(
205 audit_logger.store_web(
206 'repo.delete', action_data={},
206 'repo.delete', action_data={},
207 user=self._rhodecode_user,
207 user=self._rhodecode_user,
208 repo=repo_object)
208 repo=repo_object)
209
209
210 # without an user ?
210 # without an user ?
211 audit_logger.store(
211 audit_logger.store(
212 'user.login.failure',
212 'user.login.failure',
213 user=audit_logger.UserWrap(
213 user=audit_logger.UserWrap(
214 username=self.request.params.get('username'),
214 username=self.request.params.get('username'),
215 ip_addr=self.request.remote_addr))
215 ip_addr=self.request.remote_addr))
216
216
217 """
217 """
218 from rhodecode.lib.utils2 import safe_unicode
218 from rhodecode.lib.utils2 import safe_unicode
219 from rhodecode.lib.auth import AuthUser
219 from rhodecode.lib.auth import AuthUser
220
220
221 action_spec = ACTIONS.get(action, None)
221 action_spec = ACTIONS.get(action, None)
222 if action_spec is None:
222 if action_spec is None:
223 raise ValueError('Action `{}` is not supported'.format(action))
223 raise ValueError('Action `{}` is not supported'.format(action))
224
224
225 if not sa_session:
225 if not sa_session:
226 sa_session = meta.Session()
226 sa_session = meta.Session()
227
227
228 try:
228 try:
229 username = getattr(user, 'username', None)
229 username = getattr(user, 'username', None)
230 if not username:
230 if not username:
231 pass
231 pass
232
232
233 user_id = getattr(user, 'user_id', None)
233 user_id = getattr(user, 'user_id', None)
234 if not user_id:
234 if not user_id:
235 # maybe we have username ? Try to figure user_id from username
235 # maybe we have username ? Try to figure user_id from username
236 if username:
236 if username:
237 user_id = getattr(
237 user_id = getattr(
238 User.get_by_username(username), 'user_id', None)
238 User.get_by_username(username), 'user_id', None)
239
239
240 ip_addr = ip_addr or getattr(user, 'ip_addr', None)
240 ip_addr = ip_addr or getattr(user, 'ip_addr', None)
241 if not ip_addr:
241 if not ip_addr:
242 pass
242 pass
243
243
244 if not user_data:
244 if not user_data:
245 # try to get this from the auth user
245 # try to get this from the auth user
246 if isinstance(user, AuthUser):
246 if isinstance(user, AuthUser):
247 user_data = {
247 user_data = {
248 'username': user.username,
248 'username': user.username,
249 'email': user.email,
249 'email': user.email,
250 }
250 }
251
251
252 repository_name = getattr(repo, 'repo_name', None)
252 repository_name = getattr(repo, 'repo_name', None)
253 repository_id = getattr(repo, 'repo_id', None)
253 repository_id = getattr(repo, 'repo_id', None)
254 if not repository_id:
254 if not repository_id:
255 # maybe we have repo_name ? Try to figure repo_id from repo_name
255 # maybe we have repo_name ? Try to figure repo_id from repo_name
256 if repository_name:
256 if repository_name:
257 repository_id = getattr(
257 repository_id = getattr(
258 Repository.get_by_repo_name(repository_name), 'repo_id', None)
258 Repository.get_by_repo_name(repository_name), 'repo_id', None)
259
259
260 action_name = safe_unicode(action)
260 action_name = safe_unicode(action)
261 ip_address = safe_unicode(ip_addr)
261 ip_address = safe_unicode(ip_addr)
262
262
263 with sa_session.no_autoflush:
263 with sa_session.no_autoflush:
264
264
265 user_log = _store_log(
265 user_log = _store_log(
266 action_name=action_name,
266 action_name=action_name,
267 action_data=action_data or {},
267 action_data=action_data or {},
268 user_id=user_id,
268 user_id=user_id,
269 username=username,
269 username=username,
270 user_data=user_data or {},
270 user_data=user_data or {},
271 ip_address=ip_address,
271 ip_address=ip_address,
272 repository_id=repository_id,
272 repository_id=repository_id,
273 repository_name=repository_name
273 repository_name=repository_name
274 )
274 )
275
275
276 sa_session.add(user_log)
276 sa_session.add(user_log)
277 if commit:
277 if commit:
278 sa_session.commit()
278 sa_session.commit()
279 entry_id = user_log.entry_id or ''
279 entry_id = user_log.entry_id or ''
280
280
281 update_user_last_activity(sa_session, user_id)
281 update_user_last_activity(sa_session, user_id)
282
282
283 if commit:
283 if commit:
284 sa_session.commit()
284 sa_session.commit()
285
285
286 log.info('AUDIT[%s]: Logging action: `%s` by user:id:%s[%s] ip:%s',
286 log.info('AUDIT[%s]: Logging action: `%s` by user:id:%s[%s] ip:%s',
287 entry_id, action_name, user_id, username, ip_address,
287 entry_id, action_name, user_id, username, ip_address,
288 extra={"entry_id": entry_id, "action": action_name,
288 extra={"entry_id": entry_id, "action": action_name,
289 "user_id": user_id, "ip": ip_address})
289 "user_id": user_id, "ip": ip_address})
290
290
291 except Exception:
291 except Exception:
292 log.exception('AUDIT: failed to store audit log')
292 log.exception('AUDIT: failed to store audit log')
293
293
294
294
295 def update_user_last_activity(sa_session, user_id):
295 def update_user_last_activity(sa_session, user_id):
296 _last_activity = datetime.datetime.now()
296 _last_activity = datetime.datetime.now()
297 try:
297 try:
298 sa_session.query(User).filter(User.user_id == user_id).update(
298 sa_session.query(User).filter(User.user_id == user_id).update(
299 {"last_activity": _last_activity})
299 {"last_activity": _last_activity})
300 log.debug(
300 log.debug(
301 'updated user `%s` last activity to:%s', user_id, _last_activity)
301 'updated user `%s` last activity to:%s', user_id, _last_activity)
302 except Exception:
302 except Exception:
303 log.exception("Failed last activity update for user_id: %s", user_id)
303 log.exception("Failed last activity update for user_id: %s", user_id)
304 sa_session.rollback()
304 sa_session.rollback()
305
305
@@ -1,371 +1,371 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import hashlib
22 import hashlib
23 import itsdangerous
23 import itsdangerous
24 import logging
24 import logging
25 import requests
25 import requests
26 import datetime
26 import datetime
27
27
28 from dogpile.util.readwrite_lock import ReadWriteMutex
28 from dogpile.util.readwrite_lock import ReadWriteMutex
29 from pyramid.threadlocal import get_current_registry
29 from pyramid.threadlocal import get_current_registry
30
30
31 import rhodecode.lib.helpers as h
31 import rhodecode.lib.helpers as h
32 from rhodecode.lib.auth import HasRepoPermissionAny
32 from rhodecode.lib.auth import HasRepoPermissionAny
33 from rhodecode.lib.ext_json import json
33 from rhodecode.lib.ext_json import json
34 from rhodecode.model.db import User
34 from rhodecode.model.db import User
35
35
36 log = logging.getLogger(__name__)
36 log = logging.getLogger(__name__)
37
37
38 LOCK = ReadWriteMutex()
38 LOCK = ReadWriteMutex()
39
39
40 USER_STATE_PUBLIC_KEYS = [
40 USER_STATE_PUBLIC_KEYS = [
41 'id', 'username', 'first_name', 'last_name',
41 'id', 'username', 'first_name', 'last_name',
42 'icon_link', 'display_name', 'display_link']
42 'icon_link', 'display_name', 'display_link']
43
43
44
44
45 class ChannelstreamException(Exception):
45 class ChannelstreamException(Exception):
46 pass
46 pass
47
47
48
48
49 class ChannelstreamConnectionException(ChannelstreamException):
49 class ChannelstreamConnectionException(ChannelstreamException):
50 pass
50 pass
51
51
52
52
53 class ChannelstreamPermissionException(ChannelstreamException):
53 class ChannelstreamPermissionException(ChannelstreamException):
54 pass
54 pass
55
55
56
56
57 def get_channelstream_server_url(config, endpoint):
57 def get_channelstream_server_url(config, endpoint):
58 return 'http://{}{}'.format(config['server'], endpoint)
58 return 'http://{}{}'.format(config['server'], endpoint)
59
59
60
60
61 def channelstream_request(config, payload, endpoint, raise_exc=True):
61 def channelstream_request(config, payload, endpoint, raise_exc=True):
62 signer = itsdangerous.TimestampSigner(config['secret'])
62 signer = itsdangerous.TimestampSigner(config['secret'])
63 sig_for_server = signer.sign(endpoint)
63 sig_for_server = signer.sign(endpoint)
64 secret_headers = {'x-channelstream-secret': sig_for_server,
64 secret_headers = {'x-channelstream-secret': sig_for_server,
65 'x-channelstream-endpoint': endpoint,
65 'x-channelstream-endpoint': endpoint,
66 'Content-Type': 'application/json'}
66 'Content-Type': 'application/json'}
67 req_url = get_channelstream_server_url(config, endpoint)
67 req_url = get_channelstream_server_url(config, endpoint)
68
68
69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
70 response = None
70 response = None
71 try:
71 try:
72 response = requests.post(req_url, data=json.dumps(payload),
72 response = requests.post(req_url, data=json.dumps(payload),
73 headers=secret_headers).json()
73 headers=secret_headers).json()
74 except requests.ConnectionError:
74 except requests.ConnectionError:
75 log.exception('ConnectionError occurred for endpoint %s', req_url)
75 log.exception('ConnectionError occurred for endpoint %s', req_url)
76 if raise_exc:
76 if raise_exc:
77 raise ChannelstreamConnectionException(req_url)
77 raise ChannelstreamConnectionException(req_url)
78 except Exception:
78 except Exception:
79 log.exception('Exception related to Channelstream happened')
79 log.exception('Exception related to Channelstream happened')
80 if raise_exc:
80 if raise_exc:
81 raise ChannelstreamConnectionException()
81 raise ChannelstreamConnectionException()
82 log.debug('Got channelstream response: %s', response)
82 log.debug('Got channelstream response: %s', response)
83 return response
83 return response
84
84
85
85
86 def get_user_data(user_id):
86 def get_user_data(user_id):
87 user = User.get(user_id)
87 user = User.get(user_id)
88 return {
88 return {
89 'id': user.user_id,
89 'id': user.user_id,
90 'username': user.username,
90 'username': user.username,
91 'first_name': user.first_name,
91 'first_name': user.first_name,
92 'last_name': user.last_name,
92 'last_name': user.last_name,
93 'icon_link': h.gravatar_url(user.email, 60),
93 'icon_link': h.gravatar_url(user.email, 60),
94 'display_name': h.person(user, 'username_or_name_or_email'),
94 'display_name': h.person(user, 'username_or_name_or_email'),
95 'display_link': h.link_to_user(user),
95 'display_link': h.link_to_user(user),
96 'notifications': user.user_data.get('notification_status', True)
96 'notifications': user.user_data.get('notification_status', True)
97 }
97 }
98
98
99
99
100 def broadcast_validator(channel_name):
100 def broadcast_validator(channel_name):
101 """ checks if user can access the broadcast channel """
101 """ checks if user can access the broadcast channel """
102 if channel_name == 'broadcast':
102 if channel_name == 'broadcast':
103 return True
103 return True
104
104
105
105
106 def repo_validator(channel_name):
106 def repo_validator(channel_name):
107 """ checks if user can access the broadcast channel """
107 """ checks if user can access the broadcast channel """
108 channel_prefix = '/repo$'
108 channel_prefix = '/repo$'
109 if channel_name.startswith(channel_prefix):
109 if channel_name.startswith(channel_prefix):
110 elements = channel_name[len(channel_prefix):].split('$')
110 elements = channel_name[len(channel_prefix):].split('$')
111 repo_name = elements[0]
111 repo_name = elements[0]
112 can_access = HasRepoPermissionAny(
112 can_access = HasRepoPermissionAny(
113 'repository.read',
113 'repository.read',
114 'repository.write',
114 'repository.write',
115 'repository.admin')(repo_name)
115 'repository.admin')(repo_name)
116 log.debug(
116 log.debug(
117 'permission check for %s channel resulted in %s',
117 'permission check for %s channel resulted in %s',
118 repo_name, can_access)
118 repo_name, can_access)
119 if can_access:
119 if can_access:
120 return True
120 return True
121 return False
121 return False
122
122
123
123
124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
125 valid_channels = []
125 valid_channels = []
126
126
127 validators = [broadcast_validator, repo_validator]
127 validators = [broadcast_validator, repo_validator]
128 if plugin_validators:
128 if plugin_validators:
129 validators.extend(plugin_validators)
129 validators.extend(plugin_validators)
130 for channel_name in channels:
130 for channel_name in channels:
131 is_valid = False
131 is_valid = False
132 for validator in validators:
132 for validator in validators:
133 if validator(channel_name):
133 if validator(channel_name):
134 is_valid = True
134 is_valid = True
135 break
135 break
136 if is_valid:
136 if is_valid:
137 valid_channels.append(channel_name)
137 valid_channels.append(channel_name)
138 else:
138 else:
139 if should_raise:
139 if should_raise:
140 raise ChannelstreamPermissionException()
140 raise ChannelstreamPermissionException()
141 return valid_channels
141 return valid_channels
142
142
143
143
144 def get_channels_info(self, channels):
144 def get_channels_info(self, channels):
145 payload = {'channels': channels}
145 payload = {'channels': channels}
146 # gather persistence info
146 # gather persistence info
147 return channelstream_request(self._config(), payload, '/info')
147 return channelstream_request(self._config(), payload, '/info')
148
148
149
149
150 def parse_channels_info(info_result, include_channel_info=None):
150 def parse_channels_info(info_result, include_channel_info=None):
151 """
151 """
152 Returns data that contains only secure information that can be
152 Returns data that contains only secure information that can be
153 presented to clients
153 presented to clients
154 """
154 """
155 include_channel_info = include_channel_info or []
155 include_channel_info = include_channel_info or []
156
156
157 user_state_dict = {}
157 user_state_dict = {}
158 for userinfo in info_result['users']:
158 for userinfo in info_result['users']:
159 user_state_dict[userinfo['user']] = {
159 user_state_dict[userinfo['user']] = {
160 k: v for k, v in userinfo['state'].items()
160 k: v for k, v in userinfo['state'].items()
161 if k in USER_STATE_PUBLIC_KEYS
161 if k in USER_STATE_PUBLIC_KEYS
162 }
162 }
163
163
164 channels_info = {}
164 channels_info = {}
165
165
166 for c_name, c_info in info_result['channels'].items():
166 for c_name, c_info in info_result['channels'].items():
167 if c_name not in include_channel_info:
167 if c_name not in include_channel_info:
168 continue
168 continue
169 connected_list = []
169 connected_list = []
170 for username in c_info['users']:
170 for username in c_info['users']:
171 connected_list.append({
171 connected_list.append({
172 'user': username,
172 'user': username,
173 'state': user_state_dict[username]
173 'state': user_state_dict[username]
174 })
174 })
175 channels_info[c_name] = {'users': connected_list,
175 channels_info[c_name] = {'users': connected_list,
176 'history': c_info['history']}
176 'history': c_info['history']}
177
177
178 return channels_info
178 return channels_info
179
179
180
180
181 def log_filepath(history_location, channel_name):
181 def log_filepath(history_location, channel_name):
182 hasher = hashlib.sha256()
182 hasher = hashlib.sha256()
183 hasher.update(channel_name.encode('utf8'))
183 hasher.update(channel_name.encode('utf8'))
184 filename = '{}.log'.format(hasher.hexdigest())
184 filename = '{}.log'.format(hasher.hexdigest())
185 filepath = os.path.join(history_location, filename)
185 filepath = os.path.join(history_location, filename)
186 return filepath
186 return filepath
187
187
188
188
189 def read_history(history_location, channel_name):
189 def read_history(history_location, channel_name):
190 filepath = log_filepath(history_location, channel_name)
190 filepath = log_filepath(history_location, channel_name)
191 if not os.path.exists(filepath):
191 if not os.path.exists(filepath):
192 return []
192 return []
193 history_lines_limit = -100
193 history_lines_limit = -100
194 history = []
194 history = []
195 with open(filepath, 'rb') as f:
195 with open(filepath, 'rb') as f:
196 for line in f.readlines()[history_lines_limit:]:
196 for line in f.readlines()[history_lines_limit:]:
197 try:
197 try:
198 history.append(json.loads(line))
198 history.append(json.loads(line))
199 except Exception:
199 except Exception:
200 log.exception('Failed to load history')
200 log.exception('Failed to load history')
201 return history
201 return history
202
202
203
203
204 def update_history_from_logs(config, channels, payload):
204 def update_history_from_logs(config, channels, payload):
205 history_location = config.get('history.location')
205 history_location = config.get('history.location')
206 for channel in channels:
206 for channel in channels:
207 history = read_history(history_location, channel)
207 history = read_history(history_location, channel)
208 payload['channels_info'][channel]['history'] = history
208 payload['channels_info'][channel]['history'] = history
209
209
210
210
211 def write_history(config, message):
211 def write_history(config, message):
212 """ writes a message to a base64encoded filename """
212 """ writes a message to a base64encoded filename """
213 history_location = config.get('history.location')
213 history_location = config.get('history.location')
214 if not os.path.exists(history_location):
214 if not os.path.exists(history_location):
215 return
215 return
216 try:
216 try:
217 LOCK.acquire_write_lock()
217 LOCK.acquire_write_lock()
218 filepath = log_filepath(history_location, message['channel'])
218 filepath = log_filepath(history_location, message['channel'])
219 json_message = json.dumps(message)
219 json_message = json.dumps(message)
220 with open(filepath, 'ab') as f:
220 with open(filepath, 'ab') as f:
221 f.write(json_message)
221 f.write(json_message)
222 f.write('\n')
222 f.write('\n')
223 finally:
223 finally:
224 LOCK.release_write_lock()
224 LOCK.release_write_lock()
225
225
226
226
227 def get_connection_validators(registry):
227 def get_connection_validators(registry):
228 validators = []
228 validators = []
229 for k, config in registry.rhodecode_plugins.items():
229 for k, config in registry.rhodecode_plugins.items():
230 validator = config.get('channelstream', {}).get('connect_validator')
230 validator = config.get('channelstream', {}).get('connect_validator')
231 if validator:
231 if validator:
232 validators.append(validator)
232 validators.append(validator)
233 return validators
233 return validators
234
234
235
235
236 def get_channelstream_config(registry=None):
236 def get_channelstream_config(registry=None):
237 if not registry:
237 if not registry:
238 registry = get_current_registry()
238 registry = get_current_registry()
239
239
240 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
240 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
241 channelstream_config = rhodecode_plugins.get('channelstream', {})
241 channelstream_config = rhodecode_plugins.get('channelstream', {})
242 return channelstream_config
242 return channelstream_config
243
243
244
244
245 def post_message(channel, message, username, registry=None):
245 def post_message(channel, message, username, registry=None):
246 channelstream_config = get_channelstream_config(registry)
246 channelstream_config = get_channelstream_config(registry)
247 if not channelstream_config.get('enabled'):
247 if not channelstream_config.get('enabled'):
248 return
248 return
249
249
250 message_obj = message
250 message_obj = message
251 if isinstance(message, str):
251 if isinstance(message, str):
252 message_obj = {
252 message_obj = {
253 'message': message,
253 'message': message,
254 'level': 'success',
254 'level': 'success',
255 'topic': '/notifications'
255 'topic': '/notifications'
256 }
256 }
257
257
258 log.debug('Channelstream: sending notification to channel %s', channel)
258 log.debug('Channelstream: sending notification to channel %s', channel)
259 payload = {
259 payload = {
260 'type': 'message',
260 'type': 'message',
261 'timestamp': datetime.datetime.utcnow(),
261 'timestamp': datetime.datetime.utcnow(),
262 'user': 'system',
262 'user': 'system',
263 'exclude_users': [username],
263 'exclude_users': [username],
264 'channel': channel,
264 'channel': channel,
265 'message': message_obj
265 'message': message_obj
266 }
266 }
267
267
268 try:
268 try:
269 return channelstream_request(
269 return channelstream_request(
270 channelstream_config, [payload], '/message',
270 channelstream_config, [payload], '/message',
271 raise_exc=False)
271 raise_exc=False)
272 except ChannelstreamException:
272 except ChannelstreamException:
273 log.exception('Failed to send channelstream data')
273 log.exception('Failed to send channelstream data')
274 raise
274 raise
275
275
276
276
277 def _reload_link(label):
277 def _reload_link(label):
278 return (
278 return (
279 '<a onclick="window.location.reload()">'
279 '<a onclick="window.location.reload()">'
280 '<strong>{}</strong>'
280 '<strong>{}</strong>'
281 '</a>'.format(label)
281 '</a>'.format(label)
282 )
282 )
283
283
284
284
285 def pr_channel(pull_request):
285 def pr_channel(pull_request):
286 repo_name = pull_request.target_repo.repo_name
286 repo_name = pull_request.target_repo.repo_name
287 pull_request_id = pull_request.pull_request_id
287 pull_request_id = pull_request.pull_request_id
288 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
288 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
289 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
289 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
290 return channel
290 return channel
291
291
292
292
293 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
293 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
294 channel = None
294 channel = None
295 if commit_obj:
295 if commit_obj:
296 channel = u'/repo${}$/commit/{}'.format(
296 channel = '/repo${}$/commit/{}'.format(
297 repo_name, commit_obj.raw_id
297 repo_name, commit_obj.raw_id
298 )
298 )
299 elif pull_request_obj:
299 elif pull_request_obj:
300 channel = u'/repo${}$/pr/{}'.format(
300 channel = '/repo${}$/pr/{}'.format(
301 repo_name, pull_request_obj.pull_request_id
301 repo_name, pull_request_obj.pull_request_id
302 )
302 )
303 log.debug('Getting comment channelstream broadcast channel: %s', channel)
303 log.debug('Getting comment channelstream broadcast channel: %s', channel)
304
304
305 return channel
305 return channel
306
306
307
307
308 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
308 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
309 """
309 """
310 Channel push on pull request update
310 Channel push on pull request update
311 """
311 """
312 if not pr_broadcast_channel:
312 if not pr_broadcast_channel:
313 return
313 return
314
314
315 _ = request.translate
315 _ = request.translate
316
316
317 message = '{} {}'.format(
317 message = '{} {}'.format(
318 msg,
318 msg,
319 _reload_link(_(' Reload page to load changes')))
319 _reload_link(_(' Reload page to load changes')))
320
320
321 message_obj = {
321 message_obj = {
322 'message': message,
322 'message': message,
323 'level': 'success',
323 'level': 'success',
324 'topic': '/notifications'
324 'topic': '/notifications'
325 }
325 }
326
326
327 post_message(
327 post_message(
328 pr_broadcast_channel, message_obj, user.username,
328 pr_broadcast_channel, message_obj, user.username,
329 registry=request.registry)
329 registry=request.registry)
330
330
331
331
332 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
332 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
333 """
333 """
334 Channelstream push on comment action, on commit, or pull-request
334 Channelstream push on comment action, on commit, or pull-request
335 """
335 """
336 if not comment_broadcast_channel:
336 if not comment_broadcast_channel:
337 return
337 return
338
338
339 _ = request.translate
339 _ = request.translate
340
340
341 comment_data = kwargs.pop('comment_data', {})
341 comment_data = kwargs.pop('comment_data', {})
342 user_data = kwargs.pop('user_data', {})
342 user_data = kwargs.pop('user_data', {})
343 comment_id = comment_data.keys()[0] if comment_data else ''
343 comment_id = comment_data.keys()[0] if comment_data else ''
344
344
345 message = '<strong>{}</strong> {} #{}'.format(
345 message = '<strong>{}</strong> {} #{}'.format(
346 user.username,
346 user.username,
347 msg,
347 msg,
348 comment_id,
348 comment_id,
349 )
349 )
350
350
351 message_obj = {
351 message_obj = {
352 'message': message,
352 'message': message,
353 'level': 'success',
353 'level': 'success',
354 'topic': '/notifications'
354 'topic': '/notifications'
355 }
355 }
356
356
357 post_message(
357 post_message(
358 comment_broadcast_channel, message_obj, user.username,
358 comment_broadcast_channel, message_obj, user.username,
359 registry=request.registry)
359 registry=request.registry)
360
360
361 message_obj = {
361 message_obj = {
362 'message': None,
362 'message': None,
363 'user': user.username,
363 'user': user.username,
364 'comment_id': comment_id,
364 'comment_id': comment_id,
365 'comment_data': comment_data,
365 'comment_data': comment_data,
366 'user_data': user_data,
366 'user_data': user_data,
367 'topic': '/comment'
367 'topic': '/comment'
368 }
368 }
369 post_message(
369 post_message(
370 comment_broadcast_channel, message_obj, user.username,
370 comment_broadcast_channel, message_obj, user.username,
371 registry=request.registry)
371 registry=request.registry)
@@ -1,797 +1,797 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import difflib
22 import difflib
23 from itertools import groupby
23 from itertools import groupby
24
24
25 from pygments import lex
25 from pygments import lex
26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
27 from pygments.lexers.special import TextLexer, Token
27 from pygments.lexers.special import TextLexer, Token
28 from pygments.lexers import get_lexer_by_name
28 from pygments.lexers import get_lexer_by_name
29
29
30 from rhodecode.lib.helpers import (
30 from rhodecode.lib.helpers import (
31 get_lexer_for_filenode, html_escape, get_custom_lexer)
31 get_lexer_for_filenode, html_escape, get_custom_lexer)
32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
33 from rhodecode.lib.vcs.nodes import FileNode
33 from rhodecode.lib.vcs.nodes import FileNode
34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
35 from rhodecode.lib.diff_match_patch import diff_match_patch
35 from rhodecode.lib.diff_match_patch import diff_match_patch
36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37
37
38
38
39 plain_text_lexer = get_lexer_by_name(
39 plain_text_lexer = get_lexer_by_name(
40 'text', stripall=False, stripnl=False, ensurenl=False)
40 'text', stripall=False, stripnl=False, ensurenl=False)
41
41
42
42
43 log = logging.getLogger(__name__)
43 log = logging.getLogger(__name__)
44
44
45
45
46 def filenode_as_lines_tokens(filenode, lexer=None):
46 def filenode_as_lines_tokens(filenode, lexer=None):
47 org_lexer = lexer
47 org_lexer = lexer
48 lexer = lexer or get_lexer_for_filenode(filenode)
48 lexer = lexer or get_lexer_for_filenode(filenode)
49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
50 lexer, filenode, org_lexer)
50 lexer, filenode, org_lexer)
51 content = filenode.content
51 content = filenode.content
52 tokens = tokenize_string(content, lexer)
52 tokens = tokenize_string(content, lexer)
53 lines = split_token_stream(tokens, content)
53 lines = split_token_stream(tokens, content)
54 rv = list(lines)
54 rv = list(lines)
55 return rv
55 return rv
56
56
57
57
58 def tokenize_string(content, lexer):
58 def tokenize_string(content, lexer):
59 """
59 """
60 Use pygments to tokenize some content based on a lexer
60 Use pygments to tokenize some content based on a lexer
61 ensuring all original new lines and whitespace is preserved
61 ensuring all original new lines and whitespace is preserved
62 """
62 """
63
63
64 lexer.stripall = False
64 lexer.stripall = False
65 lexer.stripnl = False
65 lexer.stripnl = False
66 lexer.ensurenl = False
66 lexer.ensurenl = False
67
67
68 if isinstance(lexer, TextLexer):
68 if isinstance(lexer, TextLexer):
69 lexed = [(Token.Text, content)]
69 lexed = [(Token.Text, content)]
70 else:
70 else:
71 lexed = lex(content, lexer)
71 lexed = lex(content, lexer)
72
72
73 for token_type, token_text in lexed:
73 for token_type, token_text in lexed:
74 yield pygment_token_class(token_type), token_text
74 yield pygment_token_class(token_type), token_text
75
75
76
76
77 def split_token_stream(tokens, content):
77 def split_token_stream(tokens, content):
78 """
78 """
79 Take a list of (TokenType, text) tuples and split them by a string
79 Take a list of (TokenType, text) tuples and split them by a string
80
80
81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
82 [(TEXT, 'some'), (TEXT, 'text'),
82 [(TEXT, 'some'), (TEXT, 'text'),
83 (TEXT, 'more'), (TEXT, 'text')]
83 (TEXT, 'more'), (TEXT, 'text')]
84 """
84 """
85
85
86 token_buffer = []
86 token_buffer = []
87 for token_class, token_text in tokens:
87 for token_class, token_text in tokens:
88 parts = token_text.split('\n')
88 parts = token_text.split('\n')
89 for part in parts[:-1]:
89 for part in parts[:-1]:
90 token_buffer.append((token_class, part))
90 token_buffer.append((token_class, part))
91 yield token_buffer
91 yield token_buffer
92 token_buffer = []
92 token_buffer = []
93
93
94 token_buffer.append((token_class, parts[-1]))
94 token_buffer.append((token_class, parts[-1]))
95
95
96 if token_buffer:
96 if token_buffer:
97 yield token_buffer
97 yield token_buffer
98 elif content:
98 elif content:
99 # this is a special case, we have the content, but tokenization didn't produce
99 # this is a special case, we have the content, but tokenization didn't produce
100 # any results. THis can happen if know file extensions like .css have some bogus
100 # any results. THis can happen if know file extensions like .css have some bogus
101 # unicode content without any newline characters
101 # unicode content without any newline characters
102 yield [(pygment_token_class(Token.Text), content)]
102 yield [(pygment_token_class(Token.Text), content)]
103
103
104
104
105 def filenode_as_annotated_lines_tokens(filenode):
105 def filenode_as_annotated_lines_tokens(filenode):
106 """
106 """
107 Take a file node and return a list of annotations => lines, if no annotation
107 Take a file node and return a list of annotations => lines, if no annotation
108 is found, it will be None.
108 is found, it will be None.
109
109
110 eg:
110 eg:
111
111
112 [
112 [
113 (annotation1, [
113 (annotation1, [
114 (1, line1_tokens_list),
114 (1, line1_tokens_list),
115 (2, line2_tokens_list),
115 (2, line2_tokens_list),
116 ]),
116 ]),
117 (annotation2, [
117 (annotation2, [
118 (3, line1_tokens_list),
118 (3, line1_tokens_list),
119 ]),
119 ]),
120 (None, [
120 (None, [
121 (4, line1_tokens_list),
121 (4, line1_tokens_list),
122 ]),
122 ]),
123 (annotation1, [
123 (annotation1, [
124 (5, line1_tokens_list),
124 (5, line1_tokens_list),
125 (6, line2_tokens_list),
125 (6, line2_tokens_list),
126 ])
126 ])
127 ]
127 ]
128 """
128 """
129
129
130 commit_cache = {} # cache commit_getter lookups
130 commit_cache = {} # cache commit_getter lookups
131
131
132 def _get_annotation(commit_id, commit_getter):
132 def _get_annotation(commit_id, commit_getter):
133 if commit_id not in commit_cache:
133 if commit_id not in commit_cache:
134 commit_cache[commit_id] = commit_getter()
134 commit_cache[commit_id] = commit_getter()
135 return commit_cache[commit_id]
135 return commit_cache[commit_id]
136
136
137 annotation_lookup = {
137 annotation_lookup = {
138 line_no: _get_annotation(commit_id, commit_getter)
138 line_no: _get_annotation(commit_id, commit_getter)
139 for line_no, commit_id, commit_getter, line_content
139 for line_no, commit_id, commit_getter, line_content
140 in filenode.annotate
140 in filenode.annotate
141 }
141 }
142
142
143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
144 for line_no, tokens
144 for line_no, tokens
145 in enumerate(filenode_as_lines_tokens(filenode), 1))
145 in enumerate(filenode_as_lines_tokens(filenode), 1))
146
146
147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
148
148
149 for annotation, group in grouped_annotations_lines:
149 for annotation, group in grouped_annotations_lines:
150 yield (
150 yield (
151 annotation, [(line_no, tokens)
151 annotation, [(line_no, tokens)
152 for (_, line_no, tokens) in group]
152 for (_, line_no, tokens) in group]
153 )
153 )
154
154
155
155
156 def render_tokenstream(tokenstream):
156 def render_tokenstream(tokenstream):
157 result = []
157 result = []
158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
159
159
160 if token_class:
160 if token_class:
161 result.append(u'<span class="%s">' % token_class)
161 result.append('<span class="%s">' % token_class)
162 else:
162 else:
163 result.append(u'<span>')
163 result.append('<span>')
164
164
165 for op_tag, token_text in token_ops_texts:
165 for op_tag, token_text in token_ops_texts:
166
166
167 if op_tag:
167 if op_tag:
168 result.append(u'<%s>' % op_tag)
168 result.append('<%s>' % op_tag)
169
169
170 # NOTE(marcink): in some cases of mixed encodings, we might run into
170 # NOTE(marcink): in some cases of mixed encodings, we might run into
171 # troubles in the html_escape, in this case we say unicode force on token_text
171 # troubles in the html_escape, in this case we say unicode force on token_text
172 # that would ensure "correct" data even with the cost of rendered
172 # that would ensure "correct" data even with the cost of rendered
173 try:
173 try:
174 escaped_text = html_escape(token_text)
174 escaped_text = html_escape(token_text)
175 except TypeError:
175 except TypeError:
176 escaped_text = html_escape(safe_unicode(token_text))
176 escaped_text = html_escape(safe_unicode(token_text))
177
177
178 # TODO: dan: investigate showing hidden characters like space/nl/tab
178 # TODO: dan: investigate showing hidden characters like space/nl/tab
179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
182
182
183 result.append(escaped_text)
183 result.append(escaped_text)
184
184
185 if op_tag:
185 if op_tag:
186 result.append(u'</%s>' % op_tag)
186 result.append('</%s>' % op_tag)
187
187
188 result.append(u'</span>')
188 result.append('</span>')
189
189
190 html = ''.join(result)
190 html = ''.join(result)
191 return html
191 return html
192
192
193
193
194 def rollup_tokenstream(tokenstream):
194 def rollup_tokenstream(tokenstream):
195 """
195 """
196 Group a token stream of the format:
196 Group a token stream of the format:
197
197
198 ('class', 'op', 'text')
198 ('class', 'op', 'text')
199 or
199 or
200 ('class', 'text')
200 ('class', 'text')
201
201
202 into
202 into
203
203
204 [('class1',
204 [('class1',
205 [('op1', 'text'),
205 [('op1', 'text'),
206 ('op2', 'text')]),
206 ('op2', 'text')]),
207 ('class2',
207 ('class2',
208 [('op3', 'text')])]
208 [('op3', 'text')])]
209
209
210 This is used to get the minimal tags necessary when
210 This is used to get the minimal tags necessary when
211 rendering to html eg for a token stream ie.
211 rendering to html eg for a token stream ie.
212
212
213 <span class="A"><ins>he</ins>llo</span>
213 <span class="A"><ins>he</ins>llo</span>
214 vs
214 vs
215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
216
216
217 If a 2 tuple is passed in, the output op will be an empty string.
217 If a 2 tuple is passed in, the output op will be an empty string.
218
218
219 eg:
219 eg:
220
220
221 >>> rollup_tokenstream([('classA', '', 'h'),
221 >>> rollup_tokenstream([('classA', '', 'h'),
222 ('classA', 'del', 'ell'),
222 ('classA', 'del', 'ell'),
223 ('classA', '', 'o'),
223 ('classA', '', 'o'),
224 ('classB', '', ' '),
224 ('classB', '', ' '),
225 ('classA', '', 'the'),
225 ('classA', '', 'the'),
226 ('classA', '', 're'),
226 ('classA', '', 're'),
227 ])
227 ])
228
228
229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
230 ('classB', [('', ' ')],
230 ('classB', [('', ' ')],
231 ('classA', [('', 'there')]]
231 ('classA', [('', 'there')]]
232
232
233 """
233 """
234 if tokenstream and len(tokenstream[0]) == 2:
234 if tokenstream and len(tokenstream[0]) == 2:
235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
236
236
237 result = []
237 result = []
238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
239 ops = []
239 ops = []
240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
241 text_buffer = []
241 text_buffer = []
242 for t_class, t_op, t_text in token_text_list:
242 for t_class, t_op, t_text in token_text_list:
243 text_buffer.append(t_text)
243 text_buffer.append(t_text)
244 ops.append((token_op, ''.join(text_buffer)))
244 ops.append((token_op, ''.join(text_buffer)))
245 result.append((token_class, ops))
245 result.append((token_class, ops))
246 return result
246 return result
247
247
248
248
249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
250 """
250 """
251 Converts a list of (token_class, token_text) tuples to a list of
251 Converts a list of (token_class, token_text) tuples to a list of
252 (token_class, token_op, token_text) tuples where token_op is one of
252 (token_class, token_op, token_text) tuples where token_op is one of
253 ('ins', 'del', '')
253 ('ins', 'del', '')
254
254
255 :param old_tokens: list of (token_class, token_text) tuples of old line
255 :param old_tokens: list of (token_class, token_text) tuples of old line
256 :param new_tokens: list of (token_class, token_text) tuples of new line
256 :param new_tokens: list of (token_class, token_text) tuples of new line
257 :param use_diff_match_patch: boolean, will use google's diff match patch
257 :param use_diff_match_patch: boolean, will use google's diff match patch
258 library which has options to 'smooth' out the character by character
258 library which has options to 'smooth' out the character by character
259 differences making nicer ins/del blocks
259 differences making nicer ins/del blocks
260 """
260 """
261
261
262 old_tokens_result = []
262 old_tokens_result = []
263 new_tokens_result = []
263 new_tokens_result = []
264
264
265 similarity = difflib.SequenceMatcher(None,
265 similarity = difflib.SequenceMatcher(None,
266 ''.join(token_text for token_class, token_text in old_tokens),
266 ''.join(token_text for token_class, token_text in old_tokens),
267 ''.join(token_text for token_class, token_text in new_tokens)
267 ''.join(token_text for token_class, token_text in new_tokens)
268 ).ratio()
268 ).ratio()
269
269
270 if similarity < 0.6: # return, the blocks are too different
270 if similarity < 0.6: # return, the blocks are too different
271 for token_class, token_text in old_tokens:
271 for token_class, token_text in old_tokens:
272 old_tokens_result.append((token_class, '', token_text))
272 old_tokens_result.append((token_class, '', token_text))
273 for token_class, token_text in new_tokens:
273 for token_class, token_text in new_tokens:
274 new_tokens_result.append((token_class, '', token_text))
274 new_tokens_result.append((token_class, '', token_text))
275 return old_tokens_result, new_tokens_result, similarity
275 return old_tokens_result, new_tokens_result, similarity
276
276
277 token_sequence_matcher = difflib.SequenceMatcher(None,
277 token_sequence_matcher = difflib.SequenceMatcher(None,
278 [x[1] for x in old_tokens],
278 [x[1] for x in old_tokens],
279 [x[1] for x in new_tokens])
279 [x[1] for x in new_tokens])
280
280
281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
282 # check the differences by token block types first to give a more
282 # check the differences by token block types first to give a more
283 # nicer "block" level replacement vs character diffs
283 # nicer "block" level replacement vs character diffs
284
284
285 if tag == 'equal':
285 if tag == 'equal':
286 for token_class, token_text in old_tokens[o1:o2]:
286 for token_class, token_text in old_tokens[o1:o2]:
287 old_tokens_result.append((token_class, '', token_text))
287 old_tokens_result.append((token_class, '', token_text))
288 for token_class, token_text in new_tokens[n1:n2]:
288 for token_class, token_text in new_tokens[n1:n2]:
289 new_tokens_result.append((token_class, '', token_text))
289 new_tokens_result.append((token_class, '', token_text))
290 elif tag == 'delete':
290 elif tag == 'delete':
291 for token_class, token_text in old_tokens[o1:o2]:
291 for token_class, token_text in old_tokens[o1:o2]:
292 old_tokens_result.append((token_class, 'del', token_text))
292 old_tokens_result.append((token_class, 'del', token_text))
293 elif tag == 'insert':
293 elif tag == 'insert':
294 for token_class, token_text in new_tokens[n1:n2]:
294 for token_class, token_text in new_tokens[n1:n2]:
295 new_tokens_result.append((token_class, 'ins', token_text))
295 new_tokens_result.append((token_class, 'ins', token_text))
296 elif tag == 'replace':
296 elif tag == 'replace':
297 # if same type token blocks must be replaced, do a diff on the
297 # if same type token blocks must be replaced, do a diff on the
298 # characters in the token blocks to show individual changes
298 # characters in the token blocks to show individual changes
299
299
300 old_char_tokens = []
300 old_char_tokens = []
301 new_char_tokens = []
301 new_char_tokens = []
302 for token_class, token_text in old_tokens[o1:o2]:
302 for token_class, token_text in old_tokens[o1:o2]:
303 for char in token_text:
303 for char in token_text:
304 old_char_tokens.append((token_class, char))
304 old_char_tokens.append((token_class, char))
305
305
306 for token_class, token_text in new_tokens[n1:n2]:
306 for token_class, token_text in new_tokens[n1:n2]:
307 for char in token_text:
307 for char in token_text:
308 new_char_tokens.append((token_class, char))
308 new_char_tokens.append((token_class, char))
309
309
310 old_string = ''.join([token_text for
310 old_string = ''.join([token_text for
311 token_class, token_text in old_char_tokens])
311 token_class, token_text in old_char_tokens])
312 new_string = ''.join([token_text for
312 new_string = ''.join([token_text for
313 token_class, token_text in new_char_tokens])
313 token_class, token_text in new_char_tokens])
314
314
315 char_sequence = difflib.SequenceMatcher(
315 char_sequence = difflib.SequenceMatcher(
316 None, old_string, new_string)
316 None, old_string, new_string)
317 copcodes = char_sequence.get_opcodes()
317 copcodes = char_sequence.get_opcodes()
318 obuffer, nbuffer = [], []
318 obuffer, nbuffer = [], []
319
319
320 if use_diff_match_patch:
320 if use_diff_match_patch:
321 dmp = diff_match_patch()
321 dmp = diff_match_patch()
322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
323 reps = dmp.diff_main(old_string, new_string)
323 reps = dmp.diff_main(old_string, new_string)
324 dmp.diff_cleanupEfficiency(reps)
324 dmp.diff_cleanupEfficiency(reps)
325
325
326 a, b = 0, 0
326 a, b = 0, 0
327 for op, rep in reps:
327 for op, rep in reps:
328 l = len(rep)
328 l = len(rep)
329 if op == 0:
329 if op == 0:
330 for i, c in enumerate(rep):
330 for i, c in enumerate(rep):
331 obuffer.append((old_char_tokens[a+i][0], '', c))
331 obuffer.append((old_char_tokens[a+i][0], '', c))
332 nbuffer.append((new_char_tokens[b+i][0], '', c))
332 nbuffer.append((new_char_tokens[b+i][0], '', c))
333 a += l
333 a += l
334 b += l
334 b += l
335 elif op == -1:
335 elif op == -1:
336 for i, c in enumerate(rep):
336 for i, c in enumerate(rep):
337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
338 a += l
338 a += l
339 elif op == 1:
339 elif op == 1:
340 for i, c in enumerate(rep):
340 for i, c in enumerate(rep):
341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
342 b += l
342 b += l
343 else:
343 else:
344 for ctag, co1, co2, cn1, cn2 in copcodes:
344 for ctag, co1, co2, cn1, cn2 in copcodes:
345 if ctag == 'equal':
345 if ctag == 'equal':
346 for token_class, token_text in old_char_tokens[co1:co2]:
346 for token_class, token_text in old_char_tokens[co1:co2]:
347 obuffer.append((token_class, '', token_text))
347 obuffer.append((token_class, '', token_text))
348 for token_class, token_text in new_char_tokens[cn1:cn2]:
348 for token_class, token_text in new_char_tokens[cn1:cn2]:
349 nbuffer.append((token_class, '', token_text))
349 nbuffer.append((token_class, '', token_text))
350 elif ctag == 'delete':
350 elif ctag == 'delete':
351 for token_class, token_text in old_char_tokens[co1:co2]:
351 for token_class, token_text in old_char_tokens[co1:co2]:
352 obuffer.append((token_class, 'del', token_text))
352 obuffer.append((token_class, 'del', token_text))
353 elif ctag == 'insert':
353 elif ctag == 'insert':
354 for token_class, token_text in new_char_tokens[cn1:cn2]:
354 for token_class, token_text in new_char_tokens[cn1:cn2]:
355 nbuffer.append((token_class, 'ins', token_text))
355 nbuffer.append((token_class, 'ins', token_text))
356 elif ctag == 'replace':
356 elif ctag == 'replace':
357 for token_class, token_text in old_char_tokens[co1:co2]:
357 for token_class, token_text in old_char_tokens[co1:co2]:
358 obuffer.append((token_class, 'del', token_text))
358 obuffer.append((token_class, 'del', token_text))
359 for token_class, token_text in new_char_tokens[cn1:cn2]:
359 for token_class, token_text in new_char_tokens[cn1:cn2]:
360 nbuffer.append((token_class, 'ins', token_text))
360 nbuffer.append((token_class, 'ins', token_text))
361
361
362 old_tokens_result.extend(obuffer)
362 old_tokens_result.extend(obuffer)
363 new_tokens_result.extend(nbuffer)
363 new_tokens_result.extend(nbuffer)
364
364
365 return old_tokens_result, new_tokens_result, similarity
365 return old_tokens_result, new_tokens_result, similarity
366
366
367
367
368 def diffset_node_getter(commit):
368 def diffset_node_getter(commit):
369 def get_node(fname):
369 def get_node(fname):
370 try:
370 try:
371 return commit.get_node(fname)
371 return commit.get_node(fname)
372 except NodeDoesNotExistError:
372 except NodeDoesNotExistError:
373 return None
373 return None
374
374
375 return get_node
375 return get_node
376
376
377
377
378 class DiffSet(object):
378 class DiffSet(object):
379 """
379 """
380 An object for parsing the diff result from diffs.DiffProcessor and
380 An object for parsing the diff result from diffs.DiffProcessor and
381 adding highlighting, side by side/unified renderings and line diffs
381 adding highlighting, side by side/unified renderings and line diffs
382 """
382 """
383
383
384 HL_REAL = 'REAL' # highlights using original file, slow
384 HL_REAL = 'REAL' # highlights using original file, slow
385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
386 # in the case of multiline code
386 # in the case of multiline code
387 HL_NONE = 'NONE' # no highlighting, fastest
387 HL_NONE = 'NONE' # no highlighting, fastest
388
388
389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
390 source_repo_name=None,
390 source_repo_name=None,
391 source_node_getter=lambda filename: None,
391 source_node_getter=lambda filename: None,
392 target_repo_name=None,
392 target_repo_name=None,
393 target_node_getter=lambda filename: None,
393 target_node_getter=lambda filename: None,
394 source_nodes=None, target_nodes=None,
394 source_nodes=None, target_nodes=None,
395 # files over this size will use fast highlighting
395 # files over this size will use fast highlighting
396 max_file_size_limit=150 * 1024,
396 max_file_size_limit=150 * 1024,
397 ):
397 ):
398
398
399 self.highlight_mode = highlight_mode
399 self.highlight_mode = highlight_mode
400 self.highlighted_filenodes = {
400 self.highlighted_filenodes = {
401 'before': {},
401 'before': {},
402 'after': {}
402 'after': {}
403 }
403 }
404 self.source_node_getter = source_node_getter
404 self.source_node_getter = source_node_getter
405 self.target_node_getter = target_node_getter
405 self.target_node_getter = target_node_getter
406 self.source_nodes = source_nodes or {}
406 self.source_nodes = source_nodes or {}
407 self.target_nodes = target_nodes or {}
407 self.target_nodes = target_nodes or {}
408 self.repo_name = repo_name
408 self.repo_name = repo_name
409 self.target_repo_name = target_repo_name or repo_name
409 self.target_repo_name = target_repo_name or repo_name
410 self.source_repo_name = source_repo_name or repo_name
410 self.source_repo_name = source_repo_name or repo_name
411 self.max_file_size_limit = max_file_size_limit
411 self.max_file_size_limit = max_file_size_limit
412
412
413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
414 diffset = AttributeDict(dict(
414 diffset = AttributeDict(dict(
415 lines_added=0,
415 lines_added=0,
416 lines_deleted=0,
416 lines_deleted=0,
417 changed_files=0,
417 changed_files=0,
418 files=[],
418 files=[],
419 file_stats={},
419 file_stats={},
420 limited_diff=isinstance(patchset, LimitedDiffContainer),
420 limited_diff=isinstance(patchset, LimitedDiffContainer),
421 repo_name=self.repo_name,
421 repo_name=self.repo_name,
422 target_repo_name=self.target_repo_name,
422 target_repo_name=self.target_repo_name,
423 source_repo_name=self.source_repo_name,
423 source_repo_name=self.source_repo_name,
424 source_ref=source_ref,
424 source_ref=source_ref,
425 target_ref=target_ref,
425 target_ref=target_ref,
426 ))
426 ))
427 for patch in patchset:
427 for patch in patchset:
428 diffset.file_stats[patch['filename']] = patch['stats']
428 diffset.file_stats[patch['filename']] = patch['stats']
429 filediff = self.render_patch(patch)
429 filediff = self.render_patch(patch)
430 filediff.diffset = StrictAttributeDict(dict(
430 filediff.diffset = StrictAttributeDict(dict(
431 source_ref=diffset.source_ref,
431 source_ref=diffset.source_ref,
432 target_ref=diffset.target_ref,
432 target_ref=diffset.target_ref,
433 repo_name=diffset.repo_name,
433 repo_name=diffset.repo_name,
434 source_repo_name=diffset.source_repo_name,
434 source_repo_name=diffset.source_repo_name,
435 target_repo_name=diffset.target_repo_name,
435 target_repo_name=diffset.target_repo_name,
436 ))
436 ))
437 diffset.files.append(filediff)
437 diffset.files.append(filediff)
438 diffset.changed_files += 1
438 diffset.changed_files += 1
439 if not patch['stats']['binary']:
439 if not patch['stats']['binary']:
440 diffset.lines_added += patch['stats']['added']
440 diffset.lines_added += patch['stats']['added']
441 diffset.lines_deleted += patch['stats']['deleted']
441 diffset.lines_deleted += patch['stats']['deleted']
442
442
443 return diffset
443 return diffset
444
444
445 _lexer_cache = {}
445 _lexer_cache = {}
446
446
447 def _get_lexer_for_filename(self, filename, filenode=None):
447 def _get_lexer_for_filename(self, filename, filenode=None):
448 # cached because we might need to call it twice for source/target
448 # cached because we might need to call it twice for source/target
449 if filename not in self._lexer_cache:
449 if filename not in self._lexer_cache:
450 if filenode:
450 if filenode:
451 lexer = filenode.lexer
451 lexer = filenode.lexer
452 extension = filenode.extension
452 extension = filenode.extension
453 else:
453 else:
454 lexer = FileNode.get_lexer(filename=filename)
454 lexer = FileNode.get_lexer(filename=filename)
455 extension = filename.split('.')[-1]
455 extension = filename.split('.')[-1]
456
456
457 lexer = get_custom_lexer(extension) or lexer
457 lexer = get_custom_lexer(extension) or lexer
458 self._lexer_cache[filename] = lexer
458 self._lexer_cache[filename] = lexer
459 return self._lexer_cache[filename]
459 return self._lexer_cache[filename]
460
460
461 def render_patch(self, patch):
461 def render_patch(self, patch):
462 log.debug('rendering diff for %r', patch['filename'])
462 log.debug('rendering diff for %r', patch['filename'])
463
463
464 source_filename = patch['original_filename']
464 source_filename = patch['original_filename']
465 target_filename = patch['filename']
465 target_filename = patch['filename']
466
466
467 source_lexer = plain_text_lexer
467 source_lexer = plain_text_lexer
468 target_lexer = plain_text_lexer
468 target_lexer = plain_text_lexer
469
469
470 if not patch['stats']['binary']:
470 if not patch['stats']['binary']:
471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
472 hl_mode = node_hl_mode or self.highlight_mode
472 hl_mode = node_hl_mode or self.highlight_mode
473
473
474 if hl_mode == self.HL_REAL:
474 if hl_mode == self.HL_REAL:
475 if (source_filename and patch['operation'] in ('D', 'M')
475 if (source_filename and patch['operation'] in ('D', 'M')
476 and source_filename not in self.source_nodes):
476 and source_filename not in self.source_nodes):
477 self.source_nodes[source_filename] = (
477 self.source_nodes[source_filename] = (
478 self.source_node_getter(source_filename))
478 self.source_node_getter(source_filename))
479
479
480 if (target_filename and patch['operation'] in ('A', 'M')
480 if (target_filename and patch['operation'] in ('A', 'M')
481 and target_filename not in self.target_nodes):
481 and target_filename not in self.target_nodes):
482 self.target_nodes[target_filename] = (
482 self.target_nodes[target_filename] = (
483 self.target_node_getter(target_filename))
483 self.target_node_getter(target_filename))
484
484
485 elif hl_mode == self.HL_FAST:
485 elif hl_mode == self.HL_FAST:
486 source_lexer = self._get_lexer_for_filename(source_filename)
486 source_lexer = self._get_lexer_for_filename(source_filename)
487 target_lexer = self._get_lexer_for_filename(target_filename)
487 target_lexer = self._get_lexer_for_filename(target_filename)
488
488
489 source_file = self.source_nodes.get(source_filename, source_filename)
489 source_file = self.source_nodes.get(source_filename, source_filename)
490 target_file = self.target_nodes.get(target_filename, target_filename)
490 target_file = self.target_nodes.get(target_filename, target_filename)
491 raw_id_uid = ''
491 raw_id_uid = ''
492 if self.source_nodes.get(source_filename):
492 if self.source_nodes.get(source_filename):
493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
494
494
495 if not raw_id_uid and self.target_nodes.get(target_filename):
495 if not raw_id_uid and self.target_nodes.get(target_filename):
496 # in case this is a new file we only have it in target
496 # in case this is a new file we only have it in target
497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
498
498
499 source_filenode, target_filenode = None, None
499 source_filenode, target_filenode = None, None
500
500
501 # TODO: dan: FileNode.lexer works on the content of the file - which
501 # TODO: dan: FileNode.lexer works on the content of the file - which
502 # can be slow - issue #4289 explains a lexer clean up - which once
502 # can be slow - issue #4289 explains a lexer clean up - which once
503 # done can allow caching a lexer for a filenode to avoid the file lookup
503 # done can allow caching a lexer for a filenode to avoid the file lookup
504 if isinstance(source_file, FileNode):
504 if isinstance(source_file, FileNode):
505 source_filenode = source_file
505 source_filenode = source_file
506 #source_lexer = source_file.lexer
506 #source_lexer = source_file.lexer
507 source_lexer = self._get_lexer_for_filename(source_filename)
507 source_lexer = self._get_lexer_for_filename(source_filename)
508 source_file.lexer = source_lexer
508 source_file.lexer = source_lexer
509
509
510 if isinstance(target_file, FileNode):
510 if isinstance(target_file, FileNode):
511 target_filenode = target_file
511 target_filenode = target_file
512 #target_lexer = target_file.lexer
512 #target_lexer = target_file.lexer
513 target_lexer = self._get_lexer_for_filename(target_filename)
513 target_lexer = self._get_lexer_for_filename(target_filename)
514 target_file.lexer = target_lexer
514 target_file.lexer = target_lexer
515
515
516 source_file_path, target_file_path = None, None
516 source_file_path, target_file_path = None, None
517
517
518 if source_filename != '/dev/null':
518 if source_filename != '/dev/null':
519 source_file_path = source_filename
519 source_file_path = source_filename
520 if target_filename != '/dev/null':
520 if target_filename != '/dev/null':
521 target_file_path = target_filename
521 target_file_path = target_filename
522
522
523 source_file_type = source_lexer.name
523 source_file_type = source_lexer.name
524 target_file_type = target_lexer.name
524 target_file_type = target_lexer.name
525
525
526 filediff = AttributeDict({
526 filediff = AttributeDict({
527 'source_file_path': source_file_path,
527 'source_file_path': source_file_path,
528 'target_file_path': target_file_path,
528 'target_file_path': target_file_path,
529 'source_filenode': source_filenode,
529 'source_filenode': source_filenode,
530 'target_filenode': target_filenode,
530 'target_filenode': target_filenode,
531 'source_file_type': target_file_type,
531 'source_file_type': target_file_type,
532 'target_file_type': source_file_type,
532 'target_file_type': source_file_type,
533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
534 'operation': patch['operation'],
534 'operation': patch['operation'],
535 'source_mode': patch['stats']['old_mode'],
535 'source_mode': patch['stats']['old_mode'],
536 'target_mode': patch['stats']['new_mode'],
536 'target_mode': patch['stats']['new_mode'],
537 'limited_diff': patch['is_limited_diff'],
537 'limited_diff': patch['is_limited_diff'],
538 'hunks': [],
538 'hunks': [],
539 'hunk_ops': None,
539 'hunk_ops': None,
540 'diffset': self,
540 'diffset': self,
541 'raw_id': raw_id_uid,
541 'raw_id': raw_id_uid,
542 })
542 })
543
543
544 file_chunks = patch['chunks'][1:]
544 file_chunks = patch['chunks'][1:]
545 for i, hunk in enumerate(file_chunks, 1):
545 for i, hunk in enumerate(file_chunks, 1):
546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
547 hunkbit.source_file_path = source_file_path
547 hunkbit.source_file_path = source_file_path
548 hunkbit.target_file_path = target_file_path
548 hunkbit.target_file_path = target_file_path
549 hunkbit.index = i
549 hunkbit.index = i
550 filediff.hunks.append(hunkbit)
550 filediff.hunks.append(hunkbit)
551
551
552 # Simulate hunk on OPS type line which doesn't really contain any diff
552 # Simulate hunk on OPS type line which doesn't really contain any diff
553 # this allows commenting on those
553 # this allows commenting on those
554 if not file_chunks:
554 if not file_chunks:
555 actions = []
555 actions = []
556 for op_id, op_text in filediff.patch['stats']['ops'].items():
556 for op_id, op_text in filediff.patch['stats']['ops'].items():
557 if op_id == DEL_FILENODE:
557 if op_id == DEL_FILENODE:
558 actions.append(u'file was removed')
558 actions.append('file was removed')
559 elif op_id == BIN_FILENODE:
559 elif op_id == BIN_FILENODE:
560 actions.append(u'binary diff hidden')
560 actions.append('binary diff hidden')
561 else:
561 else:
562 actions.append(safe_unicode(op_text))
562 actions.append(safe_unicode(op_text))
563 action_line = u'NO CONTENT: ' + \
563 action_line = 'NO CONTENT: ' + \
564 u', '.join(actions) or u'UNDEFINED_ACTION'
564 ', '.join(actions) or 'UNDEFINED_ACTION'
565
565
566 hunk_ops = {'source_length': 0, 'source_start': 0,
566 hunk_ops = {'source_length': 0, 'source_start': 0,
567 'lines': [
567 'lines': [
568 {'new_lineno': 0, 'old_lineno': 1,
568 {'new_lineno': 0, 'old_lineno': 1,
569 'action': 'unmod-no-hl', 'line': action_line}
569 'action': 'unmod-no-hl', 'line': action_line}
570 ],
570 ],
571 'section_header': u'', 'target_start': 1, 'target_length': 1}
571 'section_header': '', 'target_start': 1, 'target_length': 1}
572
572
573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
574 hunkbit.source_file_path = source_file_path
574 hunkbit.source_file_path = source_file_path
575 hunkbit.target_file_path = target_file_path
575 hunkbit.target_file_path = target_file_path
576 filediff.hunk_ops = hunkbit
576 filediff.hunk_ops = hunkbit
577 return filediff
577 return filediff
578
578
579 def parse_hunk(self, hunk, source_file, target_file):
579 def parse_hunk(self, hunk, source_file, target_file):
580 result = AttributeDict(dict(
580 result = AttributeDict(dict(
581 source_start=hunk['source_start'],
581 source_start=hunk['source_start'],
582 source_length=hunk['source_length'],
582 source_length=hunk['source_length'],
583 target_start=hunk['target_start'],
583 target_start=hunk['target_start'],
584 target_length=hunk['target_length'],
584 target_length=hunk['target_length'],
585 section_header=hunk['section_header'],
585 section_header=hunk['section_header'],
586 lines=[],
586 lines=[],
587 ))
587 ))
588 before, after = [], []
588 before, after = [], []
589
589
590 for line in hunk['lines']:
590 for line in hunk['lines']:
591 if line['action'] in ['unmod', 'unmod-no-hl']:
591 if line['action'] in ['unmod', 'unmod-no-hl']:
592 no_hl = line['action'] == 'unmod-no-hl'
592 no_hl = line['action'] == 'unmod-no-hl'
593 result.lines.extend(
593 result.lines.extend(
594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
595 after.append(line)
595 after.append(line)
596 before.append(line)
596 before.append(line)
597 elif line['action'] == 'add':
597 elif line['action'] == 'add':
598 after.append(line)
598 after.append(line)
599 elif line['action'] == 'del':
599 elif line['action'] == 'del':
600 before.append(line)
600 before.append(line)
601 elif line['action'] == 'old-no-nl':
601 elif line['action'] == 'old-no-nl':
602 before.append(line)
602 before.append(line)
603 elif line['action'] == 'new-no-nl':
603 elif line['action'] == 'new-no-nl':
604 after.append(line)
604 after.append(line)
605
605
606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
608 result.lines.extend(
608 result.lines.extend(
609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
610 # NOTE(marcink): we must keep list() call here so we can cache the result...
610 # NOTE(marcink): we must keep list() call here so we can cache the result...
611 result.unified = list(self.as_unified(result.lines))
611 result.unified = list(self.as_unified(result.lines))
612 result.sideside = result.lines
612 result.sideside = result.lines
613
613
614 return result
614 return result
615
615
616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
617 no_hl=False):
617 no_hl=False):
618 # TODO: dan: investigate doing the diff comparison and fast highlighting
618 # TODO: dan: investigate doing the diff comparison and fast highlighting
619 # on the entire before and after buffered block lines rather than by
619 # on the entire before and after buffered block lines rather than by
620 # line, this means we can get better 'fast' highlighting if the context
620 # line, this means we can get better 'fast' highlighting if the context
621 # allows it - eg.
621 # allows it - eg.
622 # line 4: """
622 # line 4: """
623 # line 5: this gets highlighted as a string
623 # line 5: this gets highlighted as a string
624 # line 6: """
624 # line 6: """
625
625
626 lines = []
626 lines = []
627
627
628 before_newline = AttributeDict()
628 before_newline = AttributeDict()
629 after_newline = AttributeDict()
629 after_newline = AttributeDict()
630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
631 before_newline_line = before_lines.pop(-1)
631 before_newline_line = before_lines.pop(-1)
632 before_newline.content = '\n {}'.format(
632 before_newline.content = '\n {}'.format(
633 render_tokenstream(
633 render_tokenstream(
634 [(x[0], '', x[1])
634 [(x[0], '', x[1])
635 for x in [('nonl', before_newline_line['line'])]]))
635 for x in [('nonl', before_newline_line['line'])]]))
636
636
637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
638 after_newline_line = after_lines.pop(-1)
638 after_newline_line = after_lines.pop(-1)
639 after_newline.content = '\n {}'.format(
639 after_newline.content = '\n {}'.format(
640 render_tokenstream(
640 render_tokenstream(
641 [(x[0], '', x[1])
641 [(x[0], '', x[1])
642 for x in [('nonl', after_newline_line['line'])]]))
642 for x in [('nonl', after_newline_line['line'])]]))
643
643
644 while before_lines or after_lines:
644 while before_lines or after_lines:
645 before, after = None, None
645 before, after = None, None
646 before_tokens, after_tokens = None, None
646 before_tokens, after_tokens = None, None
647
647
648 if before_lines:
648 if before_lines:
649 before = before_lines.pop(0)
649 before = before_lines.pop(0)
650 if after_lines:
650 if after_lines:
651 after = after_lines.pop(0)
651 after = after_lines.pop(0)
652
652
653 original = AttributeDict()
653 original = AttributeDict()
654 modified = AttributeDict()
654 modified = AttributeDict()
655
655
656 if before:
656 if before:
657 if before['action'] == 'old-no-nl':
657 if before['action'] == 'old-no-nl':
658 before_tokens = [('nonl', before['line'])]
658 before_tokens = [('nonl', before['line'])]
659 else:
659 else:
660 before_tokens = self.get_line_tokens(
660 before_tokens = self.get_line_tokens(
661 line_text=before['line'], line_number=before['old_lineno'],
661 line_text=before['line'], line_number=before['old_lineno'],
662 input_file=source_file, no_hl=no_hl, source='before')
662 input_file=source_file, no_hl=no_hl, source='before')
663 original.lineno = before['old_lineno']
663 original.lineno = before['old_lineno']
664 original.content = before['line']
664 original.content = before['line']
665 original.action = self.action_to_op(before['action'])
665 original.action = self.action_to_op(before['action'])
666
666
667 original.get_comment_args = (
667 original.get_comment_args = (
668 source_file, 'o', before['old_lineno'])
668 source_file, 'o', before['old_lineno'])
669
669
670 if after:
670 if after:
671 if after['action'] == 'new-no-nl':
671 if after['action'] == 'new-no-nl':
672 after_tokens = [('nonl', after['line'])]
672 after_tokens = [('nonl', after['line'])]
673 else:
673 else:
674 after_tokens = self.get_line_tokens(
674 after_tokens = self.get_line_tokens(
675 line_text=after['line'], line_number=after['new_lineno'],
675 line_text=after['line'], line_number=after['new_lineno'],
676 input_file=target_file, no_hl=no_hl, source='after')
676 input_file=target_file, no_hl=no_hl, source='after')
677 modified.lineno = after['new_lineno']
677 modified.lineno = after['new_lineno']
678 modified.content = after['line']
678 modified.content = after['line']
679 modified.action = self.action_to_op(after['action'])
679 modified.action = self.action_to_op(after['action'])
680
680
681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
682
682
683 # diff the lines
683 # diff the lines
684 if before_tokens and after_tokens:
684 if before_tokens and after_tokens:
685 o_tokens, m_tokens, similarity = tokens_diff(
685 o_tokens, m_tokens, similarity = tokens_diff(
686 before_tokens, after_tokens)
686 before_tokens, after_tokens)
687 original.content = render_tokenstream(o_tokens)
687 original.content = render_tokenstream(o_tokens)
688 modified.content = render_tokenstream(m_tokens)
688 modified.content = render_tokenstream(m_tokens)
689 elif before_tokens:
689 elif before_tokens:
690 original.content = render_tokenstream(
690 original.content = render_tokenstream(
691 [(x[0], '', x[1]) for x in before_tokens])
691 [(x[0], '', x[1]) for x in before_tokens])
692 elif after_tokens:
692 elif after_tokens:
693 modified.content = render_tokenstream(
693 modified.content = render_tokenstream(
694 [(x[0], '', x[1]) for x in after_tokens])
694 [(x[0], '', x[1]) for x in after_tokens])
695
695
696 if not before_lines and before_newline:
696 if not before_lines and before_newline:
697 original.content += before_newline.content
697 original.content += before_newline.content
698 before_newline = None
698 before_newline = None
699 if not after_lines and after_newline:
699 if not after_lines and after_newline:
700 modified.content += after_newline.content
700 modified.content += after_newline.content
701 after_newline = None
701 after_newline = None
702
702
703 lines.append(AttributeDict({
703 lines.append(AttributeDict({
704 'original': original,
704 'original': original,
705 'modified': modified,
705 'modified': modified,
706 }))
706 }))
707
707
708 return lines
708 return lines
709
709
710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
711 filenode = None
711 filenode = None
712 filename = None
712 filename = None
713
713
714 if isinstance(input_file, str):
714 if isinstance(input_file, str):
715 filename = input_file
715 filename = input_file
716 elif isinstance(input_file, FileNode):
716 elif isinstance(input_file, FileNode):
717 filenode = input_file
717 filenode = input_file
718 filename = input_file.unicode_path
718 filename = input_file.unicode_path
719
719
720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
721 if hl_mode == self.HL_REAL and filenode:
721 if hl_mode == self.HL_REAL and filenode:
722 lexer = self._get_lexer_for_filename(filename)
722 lexer = self._get_lexer_for_filename(filename)
723 file_size_allowed = input_file.size < self.max_file_size_limit
723 file_size_allowed = input_file.size < self.max_file_size_limit
724 if line_number and file_size_allowed:
724 if line_number and file_size_allowed:
725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
726
726
727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
728 lexer = self._get_lexer_for_filename(filename)
728 lexer = self._get_lexer_for_filename(filename)
729 return list(tokenize_string(line_text, lexer))
729 return list(tokenize_string(line_text, lexer))
730
730
731 return list(tokenize_string(line_text, plain_text_lexer))
731 return list(tokenize_string(line_text, plain_text_lexer))
732
732
733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
734
734
735 def tokenize(_filenode):
735 def tokenize(_filenode):
736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
737
737
738 if filenode not in self.highlighted_filenodes[source]:
738 if filenode not in self.highlighted_filenodes[source]:
739 tokenize(filenode)
739 tokenize(filenode)
740
740
741 try:
741 try:
742 return self.highlighted_filenodes[source][filenode][line_number - 1]
742 return self.highlighted_filenodes[source][filenode][line_number - 1]
743 except Exception:
743 except Exception:
744 log.exception('diff rendering error')
744 log.exception('diff rendering error')
745 return [('', u'L{}: rhodecode diff rendering error'.format(line_number))]
745 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
746
746
747 def action_to_op(self, action):
747 def action_to_op(self, action):
748 return {
748 return {
749 'add': '+',
749 'add': '+',
750 'del': '-',
750 'del': '-',
751 'unmod': ' ',
751 'unmod': ' ',
752 'unmod-no-hl': ' ',
752 'unmod-no-hl': ' ',
753 'old-no-nl': ' ',
753 'old-no-nl': ' ',
754 'new-no-nl': ' ',
754 'new-no-nl': ' ',
755 }.get(action, action)
755 }.get(action, action)
756
756
757 def as_unified(self, lines):
757 def as_unified(self, lines):
758 """
758 """
759 Return a generator that yields the lines of a diff in unified order
759 Return a generator that yields the lines of a diff in unified order
760 """
760 """
761 def generator():
761 def generator():
762 buf = []
762 buf = []
763 for line in lines:
763 for line in lines:
764
764
765 if buf and not line.original or line.original.action == ' ':
765 if buf and not line.original or line.original.action == ' ':
766 for b in buf:
766 for b in buf:
767 yield b
767 yield b
768 buf = []
768 buf = []
769
769
770 if line.original:
770 if line.original:
771 if line.original.action == ' ':
771 if line.original.action == ' ':
772 yield (line.original.lineno, line.modified.lineno,
772 yield (line.original.lineno, line.modified.lineno,
773 line.original.action, line.original.content,
773 line.original.action, line.original.content,
774 line.original.get_comment_args)
774 line.original.get_comment_args)
775 continue
775 continue
776
776
777 if line.original.action == '-':
777 if line.original.action == '-':
778 yield (line.original.lineno, None,
778 yield (line.original.lineno, None,
779 line.original.action, line.original.content,
779 line.original.action, line.original.content,
780 line.original.get_comment_args)
780 line.original.get_comment_args)
781
781
782 if line.modified.action == '+':
782 if line.modified.action == '+':
783 buf.append((
783 buf.append((
784 None, line.modified.lineno,
784 None, line.modified.lineno,
785 line.modified.action, line.modified.content,
785 line.modified.action, line.modified.content,
786 line.modified.get_comment_args))
786 line.modified.get_comment_args))
787 continue
787 continue
788
788
789 if line.modified:
789 if line.modified:
790 yield (None, line.modified.lineno,
790 yield (None, line.modified.lineno,
791 line.modified.action, line.modified.content,
791 line.modified.action, line.modified.content,
792 line.modified.get_comment_args)
792 line.modified.get_comment_args)
793
793
794 for b in buf:
794 for b in buf:
795 yield b
795 yield b
796
796
797 return generator()
797 return generator()
@@ -1,680 +1,680 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 Database creation, and setup module for RhodeCode Enterprise. Used for creation
22 Database creation, and setup module for RhodeCode Enterprise. Used for creation
23 of database as well as for migration operations
23 of database as well as for migration operations
24 """
24 """
25
25
26 import os
26 import os
27 import sys
27 import sys
28 import time
28 import time
29 import uuid
29 import uuid
30 import logging
30 import logging
31 import getpass
31 import getpass
32 from os.path import dirname as dn, join as jn
32 from os.path import dirname as dn, join as jn
33
33
34 from sqlalchemy.engine import create_engine
34 from sqlalchemy.engine import create_engine
35
35
36 from rhodecode import __dbversion__
36 from rhodecode import __dbversion__
37 from rhodecode.model import init_model
37 from rhodecode.model import init_model
38 from rhodecode.model.user import UserModel
38 from rhodecode.model.user import UserModel
39 from rhodecode.model.db import (
39 from rhodecode.model.db import (
40 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
40 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
41 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
41 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
42 from rhodecode.model.meta import Session, Base
42 from rhodecode.model.meta import Session, Base
43 from rhodecode.model.permission import PermissionModel
43 from rhodecode.model.permission import PermissionModel
44 from rhodecode.model.repo import RepoModel
44 from rhodecode.model.repo import RepoModel
45 from rhodecode.model.repo_group import RepoGroupModel
45 from rhodecode.model.repo_group import RepoGroupModel
46 from rhodecode.model.settings import SettingsModel
46 from rhodecode.model.settings import SettingsModel
47
47
48
48
49 log = logging.getLogger(__name__)
49 log = logging.getLogger(__name__)
50
50
51
51
52 def notify(msg):
52 def notify(msg):
53 """
53 """
54 Notification for migrations messages
54 Notification for migrations messages
55 """
55 """
56 ml = len(msg) + (4 * 2)
56 ml = len(msg) + (4 * 2)
57 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
57 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
58
58
59
59
60 class DbManage(object):
60 class DbManage(object):
61
61
62 def __init__(self, log_sql, dbconf, root, tests=False,
62 def __init__(self, log_sql, dbconf, root, tests=False,
63 SESSION=None, cli_args=None):
63 SESSION=None, cli_args=None):
64 self.dbname = dbconf.split('/')[-1]
64 self.dbname = dbconf.split('/')[-1]
65 self.tests = tests
65 self.tests = tests
66 self.root = root
66 self.root = root
67 self.dburi = dbconf
67 self.dburi = dbconf
68 self.log_sql = log_sql
68 self.log_sql = log_sql
69 self.cli_args = cli_args or {}
69 self.cli_args = cli_args or {}
70 self.init_db(SESSION=SESSION)
70 self.init_db(SESSION=SESSION)
71 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
71 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
72
72
73 def db_exists(self):
73 def db_exists(self):
74 if not self.sa:
74 if not self.sa:
75 self.init_db()
75 self.init_db()
76 try:
76 try:
77 self.sa.query(RhodeCodeUi)\
77 self.sa.query(RhodeCodeUi)\
78 .filter(RhodeCodeUi.ui_key == '/')\
78 .filter(RhodeCodeUi.ui_key == '/')\
79 .scalar()
79 .scalar()
80 return True
80 return True
81 except Exception:
81 except Exception:
82 return False
82 return False
83 finally:
83 finally:
84 self.sa.rollback()
84 self.sa.rollback()
85
85
86 def get_ask_ok_func(self, param):
86 def get_ask_ok_func(self, param):
87 if param not in [None]:
87 if param not in [None]:
88 # return a function lambda that has a default set to param
88 # return a function lambda that has a default set to param
89 return lambda *args, **kwargs: param
89 return lambda *args, **kwargs: param
90 else:
90 else:
91 from rhodecode.lib.utils import ask_ok
91 from rhodecode.lib.utils import ask_ok
92 return ask_ok
92 return ask_ok
93
93
94 def init_db(self, SESSION=None):
94 def init_db(self, SESSION=None):
95 if SESSION:
95 if SESSION:
96 self.sa = SESSION
96 self.sa = SESSION
97 else:
97 else:
98 # init new sessions
98 # init new sessions
99 engine = create_engine(self.dburi, echo=self.log_sql)
99 engine = create_engine(self.dburi, echo=self.log_sql)
100 init_model(engine)
100 init_model(engine)
101 self.sa = Session()
101 self.sa = Session()
102
102
103 def create_tables(self, override=False):
103 def create_tables(self, override=False):
104 """
104 """
105 Create a auth database
105 Create a auth database
106 """
106 """
107
107
108 log.info("Existing database with the same name is going to be destroyed.")
108 log.info("Existing database with the same name is going to be destroyed.")
109 log.info("Setup command will run DROP ALL command on that database.")
109 log.info("Setup command will run DROP ALL command on that database.")
110 if self.tests:
110 if self.tests:
111 destroy = True
111 destroy = True
112 else:
112 else:
113 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
113 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
114 if not destroy:
114 if not destroy:
115 log.info('db tables bootstrap: Nothing done.')
115 log.info('db tables bootstrap: Nothing done.')
116 sys.exit(0)
116 sys.exit(0)
117 if destroy:
117 if destroy:
118 Base.metadata.drop_all()
118 Base.metadata.drop_all()
119
119
120 checkfirst = not override
120 checkfirst = not override
121 Base.metadata.create_all(checkfirst=checkfirst)
121 Base.metadata.create_all(checkfirst=checkfirst)
122 log.info('Created tables for %s', self.dbname)
122 log.info('Created tables for %s', self.dbname)
123
123
124 def set_db_version(self):
124 def set_db_version(self):
125 ver = DbMigrateVersion()
125 ver = DbMigrateVersion()
126 ver.version = __dbversion__
126 ver.version = __dbversion__
127 ver.repository_id = 'rhodecode_db_migrations'
127 ver.repository_id = 'rhodecode_db_migrations'
128 ver.repository_path = 'versions'
128 ver.repository_path = 'versions'
129 self.sa.add(ver)
129 self.sa.add(ver)
130 log.info('db version set to: %s', __dbversion__)
130 log.info('db version set to: %s', __dbversion__)
131
131
132 def run_post_migration_tasks(self):
132 def run_post_migration_tasks(self):
133 """
133 """
134 Run various tasks before actually doing migrations
134 Run various tasks before actually doing migrations
135 """
135 """
136 # delete cache keys on each upgrade
136 # delete cache keys on each upgrade
137 total = CacheKey.query().count()
137 total = CacheKey.query().count()
138 log.info("Deleting (%s) cache keys now...", total)
138 log.info("Deleting (%s) cache keys now...", total)
139 CacheKey.delete_all_cache()
139 CacheKey.delete_all_cache()
140
140
141 def upgrade(self, version=None):
141 def upgrade(self, version=None):
142 """
142 """
143 Upgrades given database schema to given revision following
143 Upgrades given database schema to given revision following
144 all needed steps, to perform the upgrade
144 all needed steps, to perform the upgrade
145
145
146 """
146 """
147
147
148 from rhodecode.lib.dbmigrate.migrate.versioning import api
148 from rhodecode.lib.dbmigrate.migrate.versioning import api
149 from rhodecode.lib.dbmigrate.migrate.exceptions import \
149 from rhodecode.lib.dbmigrate.migrate.exceptions import \
150 DatabaseNotControlledError
150 DatabaseNotControlledError
151
151
152 if 'sqlite' in self.dburi:
152 if 'sqlite' in self.dburi:
153 print(
153 print(
154 '********************** WARNING **********************\n'
154 '********************** WARNING **********************\n'
155 'Make sure your version of sqlite is at least 3.7.X. \n'
155 'Make sure your version of sqlite is at least 3.7.X. \n'
156 'Earlier versions are known to fail on some migrations\n'
156 'Earlier versions are known to fail on some migrations\n'
157 '*****************************************************\n')
157 '*****************************************************\n')
158
158
159 upgrade = self.ask_ok(
159 upgrade = self.ask_ok(
160 'You are about to perform a database upgrade. Make '
160 'You are about to perform a database upgrade. Make '
161 'sure you have backed up your database. '
161 'sure you have backed up your database. '
162 'Continue ? [y/n]')
162 'Continue ? [y/n]')
163 if not upgrade:
163 if not upgrade:
164 log.info('No upgrade performed')
164 log.info('No upgrade performed')
165 sys.exit(0)
165 sys.exit(0)
166
166
167 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
167 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
168 'rhodecode/lib/dbmigrate')
168 'rhodecode/lib/dbmigrate')
169 db_uri = self.dburi
169 db_uri = self.dburi
170
170
171 if version:
171 if version:
172 DbMigrateVersion.set_version(version)
172 DbMigrateVersion.set_version(version)
173
173
174 try:
174 try:
175 curr_version = api.db_version(db_uri, repository_path)
175 curr_version = api.db_version(db_uri, repository_path)
176 msg = ('Found current database db_uri under version '
176 msg = ('Found current database db_uri under version '
177 'control with version {}'.format(curr_version))
177 'control with version {}'.format(curr_version))
178
178
179 except (RuntimeError, DatabaseNotControlledError):
179 except (RuntimeError, DatabaseNotControlledError):
180 curr_version = 1
180 curr_version = 1
181 msg = ('Current database is not under version control. Setting '
181 msg = ('Current database is not under version control. Setting '
182 'as version %s' % curr_version)
182 'as version %s' % curr_version)
183 api.version_control(db_uri, repository_path, curr_version)
183 api.version_control(db_uri, repository_path, curr_version)
184
184
185 notify(msg)
185 notify(msg)
186
186
187
187
188 if curr_version == __dbversion__:
188 if curr_version == __dbversion__:
189 log.info('This database is already at the newest version')
189 log.info('This database is already at the newest version')
190 sys.exit(0)
190 sys.exit(0)
191
191
192 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
192 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
193 notify('attempting to upgrade database from '
193 notify('attempting to upgrade database from '
194 'version %s to version %s' % (curr_version, __dbversion__))
194 'version %s to version %s' % (curr_version, __dbversion__))
195
195
196 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
196 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
197 _step = None
197 _step = None
198 for step in upgrade_steps:
198 for step in upgrade_steps:
199 notify('performing upgrade step %s' % step)
199 notify('performing upgrade step %s' % step)
200 time.sleep(0.5)
200 time.sleep(0.5)
201
201
202 api.upgrade(db_uri, repository_path, step)
202 api.upgrade(db_uri, repository_path, step)
203 self.sa.rollback()
203 self.sa.rollback()
204 notify('schema upgrade for step %s completed' % (step,))
204 notify('schema upgrade for step %s completed' % (step,))
205
205
206 _step = step
206 _step = step
207
207
208 self.run_post_migration_tasks()
208 self.run_post_migration_tasks()
209 notify('upgrade to version %s successful' % _step)
209 notify('upgrade to version %s successful' % _step)
210
210
211 def fix_repo_paths(self):
211 def fix_repo_paths(self):
212 """
212 """
213 Fixes an old RhodeCode version path into new one without a '*'
213 Fixes an old RhodeCode version path into new one without a '*'
214 """
214 """
215
215
216 paths = self.sa.query(RhodeCodeUi)\
216 paths = self.sa.query(RhodeCodeUi)\
217 .filter(RhodeCodeUi.ui_key == '/')\
217 .filter(RhodeCodeUi.ui_key == '/')\
218 .scalar()
218 .scalar()
219
219
220 paths.ui_value = paths.ui_value.replace('*', '')
220 paths.ui_value = paths.ui_value.replace('*', '')
221
221
222 try:
222 try:
223 self.sa.add(paths)
223 self.sa.add(paths)
224 self.sa.commit()
224 self.sa.commit()
225 except Exception:
225 except Exception:
226 self.sa.rollback()
226 self.sa.rollback()
227 raise
227 raise
228
228
229 def fix_default_user(self):
229 def fix_default_user(self):
230 """
230 """
231 Fixes an old default user with some 'nicer' default values,
231 Fixes an old default user with some 'nicer' default values,
232 used mostly for anonymous access
232 used mostly for anonymous access
233 """
233 """
234 def_user = self.sa.query(User)\
234 def_user = self.sa.query(User)\
235 .filter(User.username == User.DEFAULT_USER)\
235 .filter(User.username == User.DEFAULT_USER)\
236 .one()
236 .one()
237
237
238 def_user.name = 'Anonymous'
238 def_user.name = 'Anonymous'
239 def_user.lastname = 'User'
239 def_user.lastname = 'User'
240 def_user.email = User.DEFAULT_USER_EMAIL
240 def_user.email = User.DEFAULT_USER_EMAIL
241
241
242 try:
242 try:
243 self.sa.add(def_user)
243 self.sa.add(def_user)
244 self.sa.commit()
244 self.sa.commit()
245 except Exception:
245 except Exception:
246 self.sa.rollback()
246 self.sa.rollback()
247 raise
247 raise
248
248
249 def fix_settings(self):
249 def fix_settings(self):
250 """
250 """
251 Fixes rhodecode settings and adds ga_code key for google analytics
251 Fixes rhodecode settings and adds ga_code key for google analytics
252 """
252 """
253
253
254 hgsettings3 = RhodeCodeSetting('ga_code', '')
254 hgsettings3 = RhodeCodeSetting('ga_code', '')
255
255
256 try:
256 try:
257 self.sa.add(hgsettings3)
257 self.sa.add(hgsettings3)
258 self.sa.commit()
258 self.sa.commit()
259 except Exception:
259 except Exception:
260 self.sa.rollback()
260 self.sa.rollback()
261 raise
261 raise
262
262
263 def create_admin_and_prompt(self):
263 def create_admin_and_prompt(self):
264
264
265 # defaults
265 # defaults
266 defaults = self.cli_args
266 defaults = self.cli_args
267 username = defaults.get('username')
267 username = defaults.get('username')
268 password = defaults.get('password')
268 password = defaults.get('password')
269 email = defaults.get('email')
269 email = defaults.get('email')
270
270
271 if username is None:
271 if username is None:
272 username = input('Specify admin username:')
272 username = eval(input('Specify admin username:'))
273 if password is None:
273 if password is None:
274 password = self._get_admin_password()
274 password = self._get_admin_password()
275 if not password:
275 if not password:
276 # second try
276 # second try
277 password = self._get_admin_password()
277 password = self._get_admin_password()
278 if not password:
278 if not password:
279 sys.exit()
279 sys.exit()
280 if email is None:
280 if email is None:
281 email = input('Specify admin email:')
281 email = eval(input('Specify admin email:'))
282 api_key = self.cli_args.get('api_key')
282 api_key = self.cli_args.get('api_key')
283 self.create_user(username, password, email, True,
283 self.create_user(username, password, email, True,
284 strict_creation_check=False,
284 strict_creation_check=False,
285 api_key=api_key)
285 api_key=api_key)
286
286
287 def _get_admin_password(self):
287 def _get_admin_password(self):
288 password = getpass.getpass('Specify admin password '
288 password = getpass.getpass('Specify admin password '
289 '(min 6 chars):')
289 '(min 6 chars):')
290 confirm = getpass.getpass('Confirm password:')
290 confirm = getpass.getpass('Confirm password:')
291
291
292 if password != confirm:
292 if password != confirm:
293 log.error('passwords mismatch')
293 log.error('passwords mismatch')
294 return False
294 return False
295 if len(password) < 6:
295 if len(password) < 6:
296 log.error('password is too short - use at least 6 characters')
296 log.error('password is too short - use at least 6 characters')
297 return False
297 return False
298
298
299 return password
299 return password
300
300
301 def create_test_admin_and_users(self):
301 def create_test_admin_and_users(self):
302 log.info('creating admin and regular test users')
302 log.info('creating admin and regular test users')
303 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
303 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
304 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
304 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
305 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
305 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
306 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
306 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
307 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
307 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
308
308
309 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
309 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
310 TEST_USER_ADMIN_EMAIL, True, api_key=True)
310 TEST_USER_ADMIN_EMAIL, True, api_key=True)
311
311
312 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
312 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
313 TEST_USER_REGULAR_EMAIL, False, api_key=True)
313 TEST_USER_REGULAR_EMAIL, False, api_key=True)
314
314
315 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
315 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
316 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
316 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
317
317
318 def create_ui_settings(self, repo_store_path):
318 def create_ui_settings(self, repo_store_path):
319 """
319 """
320 Creates ui settings, fills out hooks
320 Creates ui settings, fills out hooks
321 and disables dotencode
321 and disables dotencode
322 """
322 """
323 settings_model = SettingsModel(sa=self.sa)
323 settings_model = SettingsModel(sa=self.sa)
324 from rhodecode.lib.vcs.backends.hg import largefiles_store
324 from rhodecode.lib.vcs.backends.hg import largefiles_store
325 from rhodecode.lib.vcs.backends.git import lfs_store
325 from rhodecode.lib.vcs.backends.git import lfs_store
326
326
327 # Build HOOKS
327 # Build HOOKS
328 hooks = [
328 hooks = [
329 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
329 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
330
330
331 # HG
331 # HG
332 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
332 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
333 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
333 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
334 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
334 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
335 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
335 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
336 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
336 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
337 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
337 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
338
338
339 ]
339 ]
340
340
341 for key, value in hooks:
341 for key, value in hooks:
342 hook_obj = settings_model.get_ui_by_key(key)
342 hook_obj = settings_model.get_ui_by_key(key)
343 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
343 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
344 hooks2.ui_section = 'hooks'
344 hooks2.ui_section = 'hooks'
345 hooks2.ui_key = key
345 hooks2.ui_key = key
346 hooks2.ui_value = value
346 hooks2.ui_value = value
347 self.sa.add(hooks2)
347 self.sa.add(hooks2)
348
348
349 # enable largefiles
349 # enable largefiles
350 largefiles = RhodeCodeUi()
350 largefiles = RhodeCodeUi()
351 largefiles.ui_section = 'extensions'
351 largefiles.ui_section = 'extensions'
352 largefiles.ui_key = 'largefiles'
352 largefiles.ui_key = 'largefiles'
353 largefiles.ui_value = ''
353 largefiles.ui_value = ''
354 self.sa.add(largefiles)
354 self.sa.add(largefiles)
355
355
356 # set default largefiles cache dir, defaults to
356 # set default largefiles cache dir, defaults to
357 # /repo_store_location/.cache/largefiles
357 # /repo_store_location/.cache/largefiles
358 largefiles = RhodeCodeUi()
358 largefiles = RhodeCodeUi()
359 largefiles.ui_section = 'largefiles'
359 largefiles.ui_section = 'largefiles'
360 largefiles.ui_key = 'usercache'
360 largefiles.ui_key = 'usercache'
361 largefiles.ui_value = largefiles_store(repo_store_path)
361 largefiles.ui_value = largefiles_store(repo_store_path)
362
362
363 self.sa.add(largefiles)
363 self.sa.add(largefiles)
364
364
365 # set default lfs cache dir, defaults to
365 # set default lfs cache dir, defaults to
366 # /repo_store_location/.cache/lfs_store
366 # /repo_store_location/.cache/lfs_store
367 lfsstore = RhodeCodeUi()
367 lfsstore = RhodeCodeUi()
368 lfsstore.ui_section = 'vcs_git_lfs'
368 lfsstore.ui_section = 'vcs_git_lfs'
369 lfsstore.ui_key = 'store_location'
369 lfsstore.ui_key = 'store_location'
370 lfsstore.ui_value = lfs_store(repo_store_path)
370 lfsstore.ui_value = lfs_store(repo_store_path)
371
371
372 self.sa.add(lfsstore)
372 self.sa.add(lfsstore)
373
373
374 # enable hgsubversion disabled by default
374 # enable hgsubversion disabled by default
375 hgsubversion = RhodeCodeUi()
375 hgsubversion = RhodeCodeUi()
376 hgsubversion.ui_section = 'extensions'
376 hgsubversion.ui_section = 'extensions'
377 hgsubversion.ui_key = 'hgsubversion'
377 hgsubversion.ui_key = 'hgsubversion'
378 hgsubversion.ui_value = ''
378 hgsubversion.ui_value = ''
379 hgsubversion.ui_active = False
379 hgsubversion.ui_active = False
380 self.sa.add(hgsubversion)
380 self.sa.add(hgsubversion)
381
381
382 # enable hgevolve disabled by default
382 # enable hgevolve disabled by default
383 hgevolve = RhodeCodeUi()
383 hgevolve = RhodeCodeUi()
384 hgevolve.ui_section = 'extensions'
384 hgevolve.ui_section = 'extensions'
385 hgevolve.ui_key = 'evolve'
385 hgevolve.ui_key = 'evolve'
386 hgevolve.ui_value = ''
386 hgevolve.ui_value = ''
387 hgevolve.ui_active = False
387 hgevolve.ui_active = False
388 self.sa.add(hgevolve)
388 self.sa.add(hgevolve)
389
389
390 hgevolve = RhodeCodeUi()
390 hgevolve = RhodeCodeUi()
391 hgevolve.ui_section = 'experimental'
391 hgevolve.ui_section = 'experimental'
392 hgevolve.ui_key = 'evolution'
392 hgevolve.ui_key = 'evolution'
393 hgevolve.ui_value = ''
393 hgevolve.ui_value = ''
394 hgevolve.ui_active = False
394 hgevolve.ui_active = False
395 self.sa.add(hgevolve)
395 self.sa.add(hgevolve)
396
396
397 hgevolve = RhodeCodeUi()
397 hgevolve = RhodeCodeUi()
398 hgevolve.ui_section = 'experimental'
398 hgevolve.ui_section = 'experimental'
399 hgevolve.ui_key = 'evolution.exchange'
399 hgevolve.ui_key = 'evolution.exchange'
400 hgevolve.ui_value = ''
400 hgevolve.ui_value = ''
401 hgevolve.ui_active = False
401 hgevolve.ui_active = False
402 self.sa.add(hgevolve)
402 self.sa.add(hgevolve)
403
403
404 hgevolve = RhodeCodeUi()
404 hgevolve = RhodeCodeUi()
405 hgevolve.ui_section = 'extensions'
405 hgevolve.ui_section = 'extensions'
406 hgevolve.ui_key = 'topic'
406 hgevolve.ui_key = 'topic'
407 hgevolve.ui_value = ''
407 hgevolve.ui_value = ''
408 hgevolve.ui_active = False
408 hgevolve.ui_active = False
409 self.sa.add(hgevolve)
409 self.sa.add(hgevolve)
410
410
411 # enable hggit disabled by default
411 # enable hggit disabled by default
412 hggit = RhodeCodeUi()
412 hggit = RhodeCodeUi()
413 hggit.ui_section = 'extensions'
413 hggit.ui_section = 'extensions'
414 hggit.ui_key = 'hggit'
414 hggit.ui_key = 'hggit'
415 hggit.ui_value = ''
415 hggit.ui_value = ''
416 hggit.ui_active = False
416 hggit.ui_active = False
417 self.sa.add(hggit)
417 self.sa.add(hggit)
418
418
419 # set svn branch defaults
419 # set svn branch defaults
420 branches = ["/branches/*", "/trunk"]
420 branches = ["/branches/*", "/trunk"]
421 tags = ["/tags/*"]
421 tags = ["/tags/*"]
422
422
423 for branch in branches:
423 for branch in branches:
424 settings_model.create_ui_section_value(
424 settings_model.create_ui_section_value(
425 RhodeCodeUi.SVN_BRANCH_ID, branch)
425 RhodeCodeUi.SVN_BRANCH_ID, branch)
426
426
427 for tag in tags:
427 for tag in tags:
428 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
428 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
429
429
430 def create_auth_plugin_options(self, skip_existing=False):
430 def create_auth_plugin_options(self, skip_existing=False):
431 """
431 """
432 Create default auth plugin settings, and make it active
432 Create default auth plugin settings, and make it active
433
433
434 :param skip_existing:
434 :param skip_existing:
435 """
435 """
436 defaults = [
436 defaults = [
437 ('auth_plugins',
437 ('auth_plugins',
438 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
438 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
439 'list'),
439 'list'),
440
440
441 ('auth_authtoken_enabled',
441 ('auth_authtoken_enabled',
442 'True',
442 'True',
443 'bool'),
443 'bool'),
444
444
445 ('auth_rhodecode_enabled',
445 ('auth_rhodecode_enabled',
446 'True',
446 'True',
447 'bool'),
447 'bool'),
448 ]
448 ]
449 for k, v, t in defaults:
449 for k, v, t in defaults:
450 if (skip_existing and
450 if (skip_existing and
451 SettingsModel().get_setting_by_name(k) is not None):
451 SettingsModel().get_setting_by_name(k) is not None):
452 log.debug('Skipping option %s', k)
452 log.debug('Skipping option %s', k)
453 continue
453 continue
454 setting = RhodeCodeSetting(k, v, t)
454 setting = RhodeCodeSetting(k, v, t)
455 self.sa.add(setting)
455 self.sa.add(setting)
456
456
457 def create_default_options(self, skip_existing=False):
457 def create_default_options(self, skip_existing=False):
458 """Creates default settings"""
458 """Creates default settings"""
459
459
460 for k, v, t in [
460 for k, v, t in [
461 ('default_repo_enable_locking', False, 'bool'),
461 ('default_repo_enable_locking', False, 'bool'),
462 ('default_repo_enable_downloads', False, 'bool'),
462 ('default_repo_enable_downloads', False, 'bool'),
463 ('default_repo_enable_statistics', False, 'bool'),
463 ('default_repo_enable_statistics', False, 'bool'),
464 ('default_repo_private', False, 'bool'),
464 ('default_repo_private', False, 'bool'),
465 ('default_repo_type', 'hg', 'unicode')]:
465 ('default_repo_type', 'hg', 'unicode')]:
466
466
467 if (skip_existing and
467 if (skip_existing and
468 SettingsModel().get_setting_by_name(k) is not None):
468 SettingsModel().get_setting_by_name(k) is not None):
469 log.debug('Skipping option %s', k)
469 log.debug('Skipping option %s', k)
470 continue
470 continue
471 setting = RhodeCodeSetting(k, v, t)
471 setting = RhodeCodeSetting(k, v, t)
472 self.sa.add(setting)
472 self.sa.add(setting)
473
473
474 def fixup_groups(self):
474 def fixup_groups(self):
475 def_usr = User.get_default_user()
475 def_usr = User.get_default_user()
476 for g in RepoGroup.query().all():
476 for g in RepoGroup.query().all():
477 g.group_name = g.get_new_name(g.name)
477 g.group_name = g.get_new_name(g.name)
478 self.sa.add(g)
478 self.sa.add(g)
479 # get default perm
479 # get default perm
480 default = UserRepoGroupToPerm.query()\
480 default = UserRepoGroupToPerm.query()\
481 .filter(UserRepoGroupToPerm.group == g)\
481 .filter(UserRepoGroupToPerm.group == g)\
482 .filter(UserRepoGroupToPerm.user == def_usr)\
482 .filter(UserRepoGroupToPerm.user == def_usr)\
483 .scalar()
483 .scalar()
484
484
485 if default is None:
485 if default is None:
486 log.debug('missing default permission for group %s adding', g)
486 log.debug('missing default permission for group %s adding', g)
487 perm_obj = RepoGroupModel()._create_default_perms(g)
487 perm_obj = RepoGroupModel()._create_default_perms(g)
488 self.sa.add(perm_obj)
488 self.sa.add(perm_obj)
489
489
490 def reset_permissions(self, username):
490 def reset_permissions(self, username):
491 """
491 """
492 Resets permissions to default state, useful when old systems had
492 Resets permissions to default state, useful when old systems had
493 bad permissions, we must clean them up
493 bad permissions, we must clean them up
494
494
495 :param username:
495 :param username:
496 """
496 """
497 default_user = User.get_by_username(username)
497 default_user = User.get_by_username(username)
498 if not default_user:
498 if not default_user:
499 return
499 return
500
500
501 u2p = UserToPerm.query()\
501 u2p = UserToPerm.query()\
502 .filter(UserToPerm.user == default_user).all()
502 .filter(UserToPerm.user == default_user).all()
503 fixed = False
503 fixed = False
504 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
504 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
505 for p in u2p:
505 for p in u2p:
506 Session().delete(p)
506 Session().delete(p)
507 fixed = True
507 fixed = True
508 self.populate_default_permissions()
508 self.populate_default_permissions()
509 return fixed
509 return fixed
510
510
511 def config_prompt(self, test_repo_path='', retries=3):
511 def config_prompt(self, test_repo_path='', retries=3):
512 defaults = self.cli_args
512 defaults = self.cli_args
513 _path = defaults.get('repos_location')
513 _path = defaults.get('repos_location')
514 if retries == 3:
514 if retries == 3:
515 log.info('Setting up repositories config')
515 log.info('Setting up repositories config')
516
516
517 if _path is not None:
517 if _path is not None:
518 path = _path
518 path = _path
519 elif not self.tests and not test_repo_path:
519 elif not self.tests and not test_repo_path:
520 path = input(
520 path = eval(input(
521 'Enter a valid absolute path to store repositories. '
521 'Enter a valid absolute path to store repositories. '
522 'All repositories in that path will be added automatically:'
522 'All repositories in that path will be added automatically:'
523 )
523 ))
524 else:
524 else:
525 path = test_repo_path
525 path = test_repo_path
526 path_ok = True
526 path_ok = True
527
527
528 # check proper dir
528 # check proper dir
529 if not os.path.isdir(path):
529 if not os.path.isdir(path):
530 path_ok = False
530 path_ok = False
531 log.error('Given path %s is not a valid directory', path)
531 log.error('Given path %s is not a valid directory', path)
532
532
533 elif not os.path.isabs(path):
533 elif not os.path.isabs(path):
534 path_ok = False
534 path_ok = False
535 log.error('Given path %s is not an absolute path', path)
535 log.error('Given path %s is not an absolute path', path)
536
536
537 # check if path is at least readable.
537 # check if path is at least readable.
538 if not os.access(path, os.R_OK):
538 if not os.access(path, os.R_OK):
539 path_ok = False
539 path_ok = False
540 log.error('Given path %s is not readable', path)
540 log.error('Given path %s is not readable', path)
541
541
542 # check write access, warn user about non writeable paths
542 # check write access, warn user about non writeable paths
543 elif not os.access(path, os.W_OK) and path_ok:
543 elif not os.access(path, os.W_OK) and path_ok:
544 log.warning('No write permission to given path %s', path)
544 log.warning('No write permission to given path %s', path)
545
545
546 q = ('Given path %s is not writeable, do you want to '
546 q = ('Given path %s is not writeable, do you want to '
547 'continue with read only mode ? [y/n]' % (path,))
547 'continue with read only mode ? [y/n]' % (path,))
548 if not self.ask_ok(q):
548 if not self.ask_ok(q):
549 log.error('Canceled by user')
549 log.error('Canceled by user')
550 sys.exit(-1)
550 sys.exit(-1)
551
551
552 if retries == 0:
552 if retries == 0:
553 sys.exit('max retries reached')
553 sys.exit('max retries reached')
554 if not path_ok:
554 if not path_ok:
555 retries -= 1
555 retries -= 1
556 return self.config_prompt(test_repo_path, retries)
556 return self.config_prompt(test_repo_path, retries)
557
557
558 real_path = os.path.normpath(os.path.realpath(path))
558 real_path = os.path.normpath(os.path.realpath(path))
559
559
560 if real_path != os.path.normpath(path):
560 if real_path != os.path.normpath(path):
561 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
561 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
562 'given path as %s ? [y/n]') % (real_path,)
562 'given path as %s ? [y/n]') % (real_path,)
563 if not self.ask_ok(q):
563 if not self.ask_ok(q):
564 log.error('Canceled by user')
564 log.error('Canceled by user')
565 sys.exit(-1)
565 sys.exit(-1)
566
566
567 return real_path
567 return real_path
568
568
569 def create_settings(self, path):
569 def create_settings(self, path):
570
570
571 self.create_ui_settings(path)
571 self.create_ui_settings(path)
572
572
573 ui_config = [
573 ui_config = [
574 ('web', 'push_ssl', 'False'),
574 ('web', 'push_ssl', 'False'),
575 ('web', 'allow_archive', 'gz zip bz2'),
575 ('web', 'allow_archive', 'gz zip bz2'),
576 ('web', 'allow_push', '*'),
576 ('web', 'allow_push', '*'),
577 ('web', 'baseurl', '/'),
577 ('web', 'baseurl', '/'),
578 ('paths', '/', path),
578 ('paths', '/', path),
579 ('phases', 'publish', 'True')
579 ('phases', 'publish', 'True')
580 ]
580 ]
581 for section, key, value in ui_config:
581 for section, key, value in ui_config:
582 ui_conf = RhodeCodeUi()
582 ui_conf = RhodeCodeUi()
583 setattr(ui_conf, 'ui_section', section)
583 setattr(ui_conf, 'ui_section', section)
584 setattr(ui_conf, 'ui_key', key)
584 setattr(ui_conf, 'ui_key', key)
585 setattr(ui_conf, 'ui_value', value)
585 setattr(ui_conf, 'ui_value', value)
586 self.sa.add(ui_conf)
586 self.sa.add(ui_conf)
587
587
588 # rhodecode app settings
588 # rhodecode app settings
589 settings = [
589 settings = [
590 ('realm', 'RhodeCode', 'unicode'),
590 ('realm', 'RhodeCode', 'unicode'),
591 ('title', '', 'unicode'),
591 ('title', '', 'unicode'),
592 ('pre_code', '', 'unicode'),
592 ('pre_code', '', 'unicode'),
593 ('post_code', '', 'unicode'),
593 ('post_code', '', 'unicode'),
594
594
595 # Visual
595 # Visual
596 ('show_public_icon', True, 'bool'),
596 ('show_public_icon', True, 'bool'),
597 ('show_private_icon', True, 'bool'),
597 ('show_private_icon', True, 'bool'),
598 ('stylify_metatags', True, 'bool'),
598 ('stylify_metatags', True, 'bool'),
599 ('dashboard_items', 100, 'int'),
599 ('dashboard_items', 100, 'int'),
600 ('admin_grid_items', 25, 'int'),
600 ('admin_grid_items', 25, 'int'),
601
601
602 ('markup_renderer', 'markdown', 'unicode'),
602 ('markup_renderer', 'markdown', 'unicode'),
603
603
604 ('repository_fields', True, 'bool'),
604 ('repository_fields', True, 'bool'),
605 ('show_version', True, 'bool'),
605 ('show_version', True, 'bool'),
606 ('show_revision_number', True, 'bool'),
606 ('show_revision_number', True, 'bool'),
607 ('show_sha_length', 12, 'int'),
607 ('show_sha_length', 12, 'int'),
608
608
609 ('use_gravatar', False, 'bool'),
609 ('use_gravatar', False, 'bool'),
610 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
610 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
611
611
612 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
612 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
613 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
613 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
614 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
614 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
615 ('support_url', '', 'unicode'),
615 ('support_url', '', 'unicode'),
616 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
616 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
617
617
618 # VCS Settings
618 # VCS Settings
619 ('pr_merge_enabled', True, 'bool'),
619 ('pr_merge_enabled', True, 'bool'),
620 ('use_outdated_comments', True, 'bool'),
620 ('use_outdated_comments', True, 'bool'),
621 ('diff_cache', True, 'bool'),
621 ('diff_cache', True, 'bool'),
622 ]
622 ]
623
623
624 for key, val, type_ in settings:
624 for key, val, type_ in settings:
625 sett = RhodeCodeSetting(key, val, type_)
625 sett = RhodeCodeSetting(key, val, type_)
626 self.sa.add(sett)
626 self.sa.add(sett)
627
627
628 self.create_auth_plugin_options()
628 self.create_auth_plugin_options()
629 self.create_default_options()
629 self.create_default_options()
630
630
631 log.info('created ui config')
631 log.info('created ui config')
632
632
633 def create_user(self, username, password, email='', admin=False,
633 def create_user(self, username, password, email='', admin=False,
634 strict_creation_check=True, api_key=None):
634 strict_creation_check=True, api_key=None):
635 log.info('creating user `%s`', username)
635 log.info('creating user `%s`', username)
636 user = UserModel().create_or_update(
636 user = UserModel().create_or_update(
637 username, password, email, firstname=u'RhodeCode', lastname=u'Admin',
637 username, password, email, firstname='RhodeCode', lastname='Admin',
638 active=True, admin=admin, extern_type="rhodecode",
638 active=True, admin=admin, extern_type="rhodecode",
639 strict_creation_check=strict_creation_check)
639 strict_creation_check=strict_creation_check)
640
640
641 if api_key:
641 if api_key:
642 log.info('setting a new default auth token for user `%s`', username)
642 log.info('setting a new default auth token for user `%s`', username)
643 UserModel().add_auth_token(
643 UserModel().add_auth_token(
644 user=user, lifetime_minutes=-1,
644 user=user, lifetime_minutes=-1,
645 role=UserModel.auth_token_role.ROLE_ALL,
645 role=UserModel.auth_token_role.ROLE_ALL,
646 description=u'BUILTIN TOKEN')
646 description='BUILTIN TOKEN')
647
647
648 def create_default_user(self):
648 def create_default_user(self):
649 log.info('creating default user')
649 log.info('creating default user')
650 # create default user for handling default permissions.
650 # create default user for handling default permissions.
651 user = UserModel().create_or_update(username=User.DEFAULT_USER,
651 user = UserModel().create_or_update(username=User.DEFAULT_USER,
652 password=str(uuid.uuid1())[:20],
652 password=str(uuid.uuid1())[:20],
653 email=User.DEFAULT_USER_EMAIL,
653 email=User.DEFAULT_USER_EMAIL,
654 firstname=u'Anonymous',
654 firstname='Anonymous',
655 lastname=u'User',
655 lastname='User',
656 strict_creation_check=False)
656 strict_creation_check=False)
657 # based on configuration options activate/de-activate this user which
657 # based on configuration options activate/de-activate this user which
658 # controlls anonymous access
658 # controls anonymous access
659 if self.cli_args.get('public_access') is False:
659 if self.cli_args.get('public_access') is False:
660 log.info('Public access disabled')
660 log.info('Public access disabled')
661 user.active = False
661 user.active = False
662 Session().add(user)
662 Session().add(user)
663 Session().commit()
663 Session().commit()
664
664
665 def create_permissions(self):
665 def create_permissions(self):
666 """
666 """
667 Creates all permissions defined in the system
667 Creates all permissions defined in the system
668 """
668 """
669 # module.(access|create|change|delete)_[name]
669 # module.(access|create|change|delete)_[name]
670 # module.(none|read|write|admin)
670 # module.(none|read|write|admin)
671 log.info('creating permissions')
671 log.info('creating permissions')
672 PermissionModel(self.sa).create_permissions()
672 PermissionModel(self.sa).create_permissions()
673
673
674 def populate_default_permissions(self):
674 def populate_default_permissions(self):
675 """
675 """
676 Populate default permissions. It will create only the default
676 Populate default permissions. It will create only the default
677 permissions that are missing, and not alter already defined ones
677 permissions that are missing, and not alter already defined ones
678 """
678 """
679 log.info('creating default user permissions')
679 log.info('creating default user permissions')
680 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
680 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
@@ -1,2031 +1,2031 b''
1 """Diff Match and Patch
1 """Diff Match and Patch
2
2
3 Copyright 2006 Google Inc.
3 Copyright 2006 Google Inc.
4 http://code.google.com/p/google-diff-match-patch/
4 http://code.google.com/p/google-diff-match-patch/
5
5
6 Licensed under the Apache License, Version 2.0 (the "License");
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
8 You may obtain a copy of the License at
9
9
10 http://www.apache.org/licenses/LICENSE-2.0
10 http://www.apache.org/licenses/LICENSE-2.0
11
11
12 Unless required by applicable law or agreed to in writing, software
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
15 See the License for the specific language governing permissions and
16 limitations under the License.
16 limitations under the License.
17 """
17 """
18
18
19 """Functions for diff, match and patch.
19 """Functions for diff, match and patch.
20
20
21 Computes the difference between two texts to create a patch.
21 Computes the difference between two texts to create a patch.
22 Applies the patch onto another text, allowing for errors.
22 Applies the patch onto another text, allowing for errors.
23 """
23 """
24
24
25 __author__ = "fraser@google.com (Neil Fraser)"
25 __author__ = "fraser@google.com (Neil Fraser)"
26
26
27 import math
27 import math
28 import re
28 import re
29 import sys
29 import sys
30 import time
30 import time
31 import urllib.request, urllib.parse, urllib.error
31 import urllib.request, urllib.parse, urllib.error
32
32
33
33
34 class diff_match_patch:
34 class diff_match_patch:
35 """Class containing the diff, match and patch methods.
35 """Class containing the diff, match and patch methods.
36
36
37 Also contains the behaviour settings.
37 Also contains the behaviour settings.
38 """
38 """
39
39
40 def __init__(self):
40 def __init__(self):
41 """Inits a diff_match_patch object with default settings.
41 """Inits a diff_match_patch object with default settings.
42 Redefine these in your program to override the defaults.
42 Redefine these in your program to override the defaults.
43 """
43 """
44
44
45 # Number of seconds to map a diff before giving up (0 for infinity).
45 # Number of seconds to map a diff before giving up (0 for infinity).
46 self.Diff_Timeout = 1.0
46 self.Diff_Timeout = 1.0
47 # Cost of an empty edit operation in terms of edit characters.
47 # Cost of an empty edit operation in terms of edit characters.
48 self.Diff_EditCost = 4
48 self.Diff_EditCost = 4
49 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
49 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
50 self.Match_Threshold = 0.5
50 self.Match_Threshold = 0.5
51 # How far to search for a match (0 = exact location, 1000+ = broad match).
51 # How far to search for a match (0 = exact location, 1000+ = broad match).
52 # A match this many characters away from the expected location will add
52 # A match this many characters away from the expected location will add
53 # 1.0 to the score (0.0 is a perfect match).
53 # 1.0 to the score (0.0 is a perfect match).
54 self.Match_Distance = 1000
54 self.Match_Distance = 1000
55 # When deleting a large block of text (over ~64 characters), how close do
55 # When deleting a large block of text (over ~64 characters), how close do
56 # the contents have to be to match the expected contents. (0.0 = perfection,
56 # the contents have to be to match the expected contents. (0.0 = perfection,
57 # 1.0 = very loose). Note that Match_Threshold controls how closely the
57 # 1.0 = very loose). Note that Match_Threshold controls how closely the
58 # end points of a delete need to match.
58 # end points of a delete need to match.
59 self.Patch_DeleteThreshold = 0.5
59 self.Patch_DeleteThreshold = 0.5
60 # Chunk size for context length.
60 # Chunk size for context length.
61 self.Patch_Margin = 4
61 self.Patch_Margin = 4
62
62
63 # The number of bits in an int.
63 # The number of bits in an int.
64 # Python has no maximum, thus to disable patch splitting set to 0.
64 # Python has no maximum, thus to disable patch splitting set to 0.
65 # However to avoid long patches in certain pathological cases, use 32.
65 # However to avoid long patches in certain pathological cases, use 32.
66 # Multiple short patches (using native ints) are much faster than long ones.
66 # Multiple short patches (using native ints) are much faster than long ones.
67 self.Match_MaxBits = 32
67 self.Match_MaxBits = 32
68
68
69 # DIFF FUNCTIONS
69 # DIFF FUNCTIONS
70
70
71 # The data structure representing a diff is an array of tuples:
71 # The data structure representing a diff is an array of tuples:
72 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
72 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
73 # which means: delete "Hello", add "Goodbye" and keep " world."
73 # which means: delete "Hello", add "Goodbye" and keep " world."
74 DIFF_DELETE = -1
74 DIFF_DELETE = -1
75 DIFF_INSERT = 1
75 DIFF_INSERT = 1
76 DIFF_EQUAL = 0
76 DIFF_EQUAL = 0
77
77
78 def diff_main(self, text1, text2, checklines=True, deadline=None):
78 def diff_main(self, text1, text2, checklines=True, deadline=None):
79 """Find the differences between two texts. Simplifies the problem by
79 """Find the differences between two texts. Simplifies the problem by
80 stripping any common prefix or suffix off the texts before diffing.
80 stripping any common prefix or suffix off the texts before diffing.
81
81
82 Args:
82 Args:
83 text1: Old string to be diffed.
83 text1: Old string to be diffed.
84 text2: New string to be diffed.
84 text2: New string to be diffed.
85 checklines: Optional speedup flag. If present and false, then don't run
85 checklines: Optional speedup flag. If present and false, then don't run
86 a line-level diff first to identify the changed areas.
86 a line-level diff first to identify the changed areas.
87 Defaults to true, which does a faster, slightly less optimal diff.
87 Defaults to true, which does a faster, slightly less optimal diff.
88 deadline: Optional time when the diff should be complete by. Used
88 deadline: Optional time when the diff should be complete by. Used
89 internally for recursive calls. Users should set DiffTimeout instead.
89 internally for recursive calls. Users should set DiffTimeout instead.
90
90
91 Returns:
91 Returns:
92 Array of changes.
92 Array of changes.
93 """
93 """
94 # Set a deadline by which time the diff must be complete.
94 # Set a deadline by which time the diff must be complete.
95 if deadline is None:
95 if deadline is None:
96 # Unlike in most languages, Python counts time in seconds.
96 # Unlike in most languages, Python counts time in seconds.
97 if self.Diff_Timeout <= 0:
97 if self.Diff_Timeout <= 0:
98 deadline = sys.maxsize
98 deadline = sys.maxsize
99 else:
99 else:
100 deadline = time.time() + self.Diff_Timeout
100 deadline = time.time() + self.Diff_Timeout
101
101
102 # Check for null inputs.
102 # Check for null inputs.
103 if text1 is None or text2 is None:
103 if text1 is None or text2 is None:
104 raise ValueError("Null inputs. (diff_main)")
104 raise ValueError("Null inputs. (diff_main)")
105
105
106 # Check for equality (speedup).
106 # Check for equality (speedup).
107 if text1 == text2:
107 if text1 == text2:
108 if text1:
108 if text1:
109 return [(self.DIFF_EQUAL, text1)]
109 return [(self.DIFF_EQUAL, text1)]
110 return []
110 return []
111
111
112 # Trim off common prefix (speedup).
112 # Trim off common prefix (speedup).
113 commonlength = self.diff_commonPrefix(text1, text2)
113 commonlength = self.diff_commonPrefix(text1, text2)
114 commonprefix = text1[:commonlength]
114 commonprefix = text1[:commonlength]
115 text1 = text1[commonlength:]
115 text1 = text1[commonlength:]
116 text2 = text2[commonlength:]
116 text2 = text2[commonlength:]
117
117
118 # Trim off common suffix (speedup).
118 # Trim off common suffix (speedup).
119 commonlength = self.diff_commonSuffix(text1, text2)
119 commonlength = self.diff_commonSuffix(text1, text2)
120 if commonlength == 0:
120 if commonlength == 0:
121 commonsuffix = ""
121 commonsuffix = ""
122 else:
122 else:
123 commonsuffix = text1[-commonlength:]
123 commonsuffix = text1[-commonlength:]
124 text1 = text1[:-commonlength]
124 text1 = text1[:-commonlength]
125 text2 = text2[:-commonlength]
125 text2 = text2[:-commonlength]
126
126
127 # Compute the diff on the middle block.
127 # Compute the diff on the middle block.
128 diffs = self.diff_compute(text1, text2, checklines, deadline)
128 diffs = self.diff_compute(text1, text2, checklines, deadline)
129
129
130 # Restore the prefix and suffix.
130 # Restore the prefix and suffix.
131 if commonprefix:
131 if commonprefix:
132 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
132 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
133 if commonsuffix:
133 if commonsuffix:
134 diffs.append((self.DIFF_EQUAL, commonsuffix))
134 diffs.append((self.DIFF_EQUAL, commonsuffix))
135 self.diff_cleanupMerge(diffs)
135 self.diff_cleanupMerge(diffs)
136 return diffs
136 return diffs
137
137
138 def diff_compute(self, text1, text2, checklines, deadline):
138 def diff_compute(self, text1, text2, checklines, deadline):
139 """Find the differences between two texts. Assumes that the texts do not
139 """Find the differences between two texts. Assumes that the texts do not
140 have any common prefix or suffix.
140 have any common prefix or suffix.
141
141
142 Args:
142 Args:
143 text1: Old string to be diffed.
143 text1: Old string to be diffed.
144 text2: New string to be diffed.
144 text2: New string to be diffed.
145 checklines: Speedup flag. If false, then don't run a line-level diff
145 checklines: Speedup flag. If false, then don't run a line-level diff
146 first to identify the changed areas.
146 first to identify the changed areas.
147 If true, then run a faster, slightly less optimal diff.
147 If true, then run a faster, slightly less optimal diff.
148 deadline: Time when the diff should be complete by.
148 deadline: Time when the diff should be complete by.
149
149
150 Returns:
150 Returns:
151 Array of changes.
151 Array of changes.
152 """
152 """
153 if not text1:
153 if not text1:
154 # Just add some text (speedup).
154 # Just add some text (speedup).
155 return [(self.DIFF_INSERT, text2)]
155 return [(self.DIFF_INSERT, text2)]
156
156
157 if not text2:
157 if not text2:
158 # Just delete some text (speedup).
158 # Just delete some text (speedup).
159 return [(self.DIFF_DELETE, text1)]
159 return [(self.DIFF_DELETE, text1)]
160
160
161 if len(text1) > len(text2):
161 if len(text1) > len(text2):
162 (longtext, shorttext) = (text1, text2)
162 (longtext, shorttext) = (text1, text2)
163 else:
163 else:
164 (shorttext, longtext) = (text1, text2)
164 (shorttext, longtext) = (text1, text2)
165 i = longtext.find(shorttext)
165 i = longtext.find(shorttext)
166 if i != -1:
166 if i != -1:
167 # Shorter text is inside the longer text (speedup).
167 # Shorter text is inside the longer text (speedup).
168 diffs = [
168 diffs = [
169 (self.DIFF_INSERT, longtext[:i]),
169 (self.DIFF_INSERT, longtext[:i]),
170 (self.DIFF_EQUAL, shorttext),
170 (self.DIFF_EQUAL, shorttext),
171 (self.DIFF_INSERT, longtext[i + len(shorttext) :]),
171 (self.DIFF_INSERT, longtext[i + len(shorttext) :]),
172 ]
172 ]
173 # Swap insertions for deletions if diff is reversed.
173 # Swap insertions for deletions if diff is reversed.
174 if len(text1) > len(text2):
174 if len(text1) > len(text2):
175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
177 return diffs
177 return diffs
178
178
179 if len(shorttext) == 1:
179 if len(shorttext) == 1:
180 # Single character string.
180 # Single character string.
181 # After the previous speedup, the character can't be an equality.
181 # After the previous speedup, the character can't be an equality.
182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
183
183
184 # Check to see if the problem can be split in two.
184 # Check to see if the problem can be split in two.
185 hm = self.diff_halfMatch(text1, text2)
185 hm = self.diff_halfMatch(text1, text2)
186 if hm:
186 if hm:
187 # A half-match was found, sort out the return data.
187 # A half-match was found, sort out the return data.
188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
189 # Send both pairs off for separate processing.
189 # Send both pairs off for separate processing.
190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
192 # Merge the results.
192 # Merge the results.
193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
194
194
195 if checklines and len(text1) > 100 and len(text2) > 100:
195 if checklines and len(text1) > 100 and len(text2) > 100:
196 return self.diff_lineMode(text1, text2, deadline)
196 return self.diff_lineMode(text1, text2, deadline)
197
197
198 return self.diff_bisect(text1, text2, deadline)
198 return self.diff_bisect(text1, text2, deadline)
199
199
200 def diff_lineMode(self, text1, text2, deadline):
200 def diff_lineMode(self, text1, text2, deadline):
201 """Do a quick line-level diff on both strings, then rediff the parts for
201 """Do a quick line-level diff on both strings, then rediff the parts for
202 greater accuracy.
202 greater accuracy.
203 This speedup can produce non-minimal diffs.
203 This speedup can produce non-minimal diffs.
204
204
205 Args:
205 Args:
206 text1: Old string to be diffed.
206 text1: Old string to be diffed.
207 text2: New string to be diffed.
207 text2: New string to be diffed.
208 deadline: Time when the diff should be complete by.
208 deadline: Time when the diff should be complete by.
209
209
210 Returns:
210 Returns:
211 Array of changes.
211 Array of changes.
212 """
212 """
213
213
214 # Scan the text on a line-by-line basis first.
214 # Scan the text on a line-by-line basis first.
215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
216
216
217 diffs = self.diff_main(text1, text2, False, deadline)
217 diffs = self.diff_main(text1, text2, False, deadline)
218
218
219 # Convert the diff back to original text.
219 # Convert the diff back to original text.
220 self.diff_charsToLines(diffs, linearray)
220 self.diff_charsToLines(diffs, linearray)
221 # Eliminate freak matches (e.g. blank lines)
221 # Eliminate freak matches (e.g. blank lines)
222 self.diff_cleanupSemantic(diffs)
222 self.diff_cleanupSemantic(diffs)
223
223
224 # Rediff any replacement blocks, this time character-by-character.
224 # Rediff any replacement blocks, this time character-by-character.
225 # Add a dummy entry at the end.
225 # Add a dummy entry at the end.
226 diffs.append((self.DIFF_EQUAL, ""))
226 diffs.append((self.DIFF_EQUAL, ""))
227 pointer = 0
227 pointer = 0
228 count_delete = 0
228 count_delete = 0
229 count_insert = 0
229 count_insert = 0
230 text_delete = ""
230 text_delete = ""
231 text_insert = ""
231 text_insert = ""
232 while pointer < len(diffs):
232 while pointer < len(diffs):
233 if diffs[pointer][0] == self.DIFF_INSERT:
233 if diffs[pointer][0] == self.DIFF_INSERT:
234 count_insert += 1
234 count_insert += 1
235 text_insert += diffs[pointer][1]
235 text_insert += diffs[pointer][1]
236 elif diffs[pointer][0] == self.DIFF_DELETE:
236 elif diffs[pointer][0] == self.DIFF_DELETE:
237 count_delete += 1
237 count_delete += 1
238 text_delete += diffs[pointer][1]
238 text_delete += diffs[pointer][1]
239 elif diffs[pointer][0] == self.DIFF_EQUAL:
239 elif diffs[pointer][0] == self.DIFF_EQUAL:
240 # Upon reaching an equality, check for prior redundancies.
240 # Upon reaching an equality, check for prior redundancies.
241 if count_delete >= 1 and count_insert >= 1:
241 if count_delete >= 1 and count_insert >= 1:
242 # Delete the offending records and add the merged ones.
242 # Delete the offending records and add the merged ones.
243 a = self.diff_main(text_delete, text_insert, False, deadline)
243 a = self.diff_main(text_delete, text_insert, False, deadline)
244 diffs[pointer - count_delete - count_insert : pointer] = a
244 diffs[pointer - count_delete - count_insert : pointer] = a
245 pointer = pointer - count_delete - count_insert + len(a)
245 pointer = pointer - count_delete - count_insert + len(a)
246 count_insert = 0
246 count_insert = 0
247 count_delete = 0
247 count_delete = 0
248 text_delete = ""
248 text_delete = ""
249 text_insert = ""
249 text_insert = ""
250
250
251 pointer += 1
251 pointer += 1
252
252
253 diffs.pop() # Remove the dummy entry at the end.
253 diffs.pop() # Remove the dummy entry at the end.
254
254
255 return diffs
255 return diffs
256
256
257 def diff_bisect(self, text1, text2, deadline):
257 def diff_bisect(self, text1, text2, deadline):
258 """Find the 'middle snake' of a diff, split the problem in two
258 """Find the 'middle snake' of a diff, split the problem in two
259 and return the recursively constructed diff.
259 and return the recursively constructed diff.
260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
261
261
262 Args:
262 Args:
263 text1: Old string to be diffed.
263 text1: Old string to be diffed.
264 text2: New string to be diffed.
264 text2: New string to be diffed.
265 deadline: Time at which to bail if not yet complete.
265 deadline: Time at which to bail if not yet complete.
266
266
267 Returns:
267 Returns:
268 Array of diff tuples.
268 Array of diff tuples.
269 """
269 """
270
270
271 # Cache the text lengths to prevent multiple calls.
271 # Cache the text lengths to prevent multiple calls.
272 text1_length = len(text1)
272 text1_length = len(text1)
273 text2_length = len(text2)
273 text2_length = len(text2)
274 max_d = (text1_length + text2_length + 1) // 2
274 max_d = (text1_length + text2_length + 1) // 2
275 v_offset = max_d
275 v_offset = max_d
276 v_length = 2 * max_d
276 v_length = 2 * max_d
277 v1 = [-1] * v_length
277 v1 = [-1] * v_length
278 v1[v_offset + 1] = 0
278 v1[v_offset + 1] = 0
279 v2 = v1[:]
279 v2 = v1[:]
280 delta = text1_length - text2_length
280 delta = text1_length - text2_length
281 # If the total number of characters is odd, then the front path will
281 # If the total number of characters is odd, then the front path will
282 # collide with the reverse path.
282 # collide with the reverse path.
283 front = delta % 2 != 0
283 front = delta % 2 != 0
284 # Offsets for start and end of k loop.
284 # Offsets for start and end of k loop.
285 # Prevents mapping of space beyond the grid.
285 # Prevents mapping of space beyond the grid.
286 k1start = 0
286 k1start = 0
287 k1end = 0
287 k1end = 0
288 k2start = 0
288 k2start = 0
289 k2end = 0
289 k2end = 0
290 for d in range(max_d):
290 for d in range(max_d):
291 # Bail out if deadline is reached.
291 # Bail out if deadline is reached.
292 if time.time() > deadline:
292 if time.time() > deadline:
293 break
293 break
294
294
295 # Walk the front path one step.
295 # Walk the front path one step.
296 for k1 in range(-d + k1start, d + 1 - k1end, 2):
296 for k1 in range(-d + k1start, d + 1 - k1end, 2):
297 k1_offset = v_offset + k1
297 k1_offset = v_offset + k1
298 if k1 == -d or (k1 != d and v1[k1_offset - 1] < v1[k1_offset + 1]):
298 if k1 == -d or (k1 != d and v1[k1_offset - 1] < v1[k1_offset + 1]):
299 x1 = v1[k1_offset + 1]
299 x1 = v1[k1_offset + 1]
300 else:
300 else:
301 x1 = v1[k1_offset - 1] + 1
301 x1 = v1[k1_offset - 1] + 1
302 y1 = x1 - k1
302 y1 = x1 - k1
303 while (
303 while (
304 x1 < text1_length and y1 < text2_length and text1[x1] == text2[y1]
304 x1 < text1_length and y1 < text2_length and text1[x1] == text2[y1]
305 ):
305 ):
306 x1 += 1
306 x1 += 1
307 y1 += 1
307 y1 += 1
308 v1[k1_offset] = x1
308 v1[k1_offset] = x1
309 if x1 > text1_length:
309 if x1 > text1_length:
310 # Ran off the right of the graph.
310 # Ran off the right of the graph.
311 k1end += 2
311 k1end += 2
312 elif y1 > text2_length:
312 elif y1 > text2_length:
313 # Ran off the bottom of the graph.
313 # Ran off the bottom of the graph.
314 k1start += 2
314 k1start += 2
315 elif front:
315 elif front:
316 k2_offset = v_offset + delta - k1
316 k2_offset = v_offset + delta - k1
317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
318 # Mirror x2 onto top-left coordinate system.
318 # Mirror x2 onto top-left coordinate system.
319 x2 = text1_length - v2[k2_offset]
319 x2 = text1_length - v2[k2_offset]
320 if x1 >= x2:
320 if x1 >= x2:
321 # Overlap detected.
321 # Overlap detected.
322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
323
323
324 # Walk the reverse path one step.
324 # Walk the reverse path one step.
325 for k2 in range(-d + k2start, d + 1 - k2end, 2):
325 for k2 in range(-d + k2start, d + 1 - k2end, 2):
326 k2_offset = v_offset + k2
326 k2_offset = v_offset + k2
327 if k2 == -d or (k2 != d and v2[k2_offset - 1] < v2[k2_offset + 1]):
327 if k2 == -d or (k2 != d and v2[k2_offset - 1] < v2[k2_offset + 1]):
328 x2 = v2[k2_offset + 1]
328 x2 = v2[k2_offset + 1]
329 else:
329 else:
330 x2 = v2[k2_offset - 1] + 1
330 x2 = v2[k2_offset - 1] + 1
331 y2 = x2 - k2
331 y2 = x2 - k2
332 while (
332 while (
333 x2 < text1_length
333 x2 < text1_length
334 and y2 < text2_length
334 and y2 < text2_length
335 and text1[-x2 - 1] == text2[-y2 - 1]
335 and text1[-x2 - 1] == text2[-y2 - 1]
336 ):
336 ):
337 x2 += 1
337 x2 += 1
338 y2 += 1
338 y2 += 1
339 v2[k2_offset] = x2
339 v2[k2_offset] = x2
340 if x2 > text1_length:
340 if x2 > text1_length:
341 # Ran off the left of the graph.
341 # Ran off the left of the graph.
342 k2end += 2
342 k2end += 2
343 elif y2 > text2_length:
343 elif y2 > text2_length:
344 # Ran off the top of the graph.
344 # Ran off the top of the graph.
345 k2start += 2
345 k2start += 2
346 elif not front:
346 elif not front:
347 k1_offset = v_offset + delta - k2
347 k1_offset = v_offset + delta - k2
348 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
348 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
349 x1 = v1[k1_offset]
349 x1 = v1[k1_offset]
350 y1 = v_offset + x1 - k1_offset
350 y1 = v_offset + x1 - k1_offset
351 # Mirror x2 onto top-left coordinate system.
351 # Mirror x2 onto top-left coordinate system.
352 x2 = text1_length - x2
352 x2 = text1_length - x2
353 if x1 >= x2:
353 if x1 >= x2:
354 # Overlap detected.
354 # Overlap detected.
355 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
355 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
356
356
357 # Diff took too long and hit the deadline or
357 # Diff took too long and hit the deadline or
358 # number of diffs equals number of characters, no commonality at all.
358 # number of diffs equals number of characters, no commonality at all.
359 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
359 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
360
360
361 def diff_bisectSplit(self, text1, text2, x, y, deadline):
361 def diff_bisectSplit(self, text1, text2, x, y, deadline):
362 """Given the location of the 'middle snake', split the diff in two parts
362 """Given the location of the 'middle snake', split the diff in two parts
363 and recurse.
363 and recurse.
364
364
365 Args:
365 Args:
366 text1: Old string to be diffed.
366 text1: Old string to be diffed.
367 text2: New string to be diffed.
367 text2: New string to be diffed.
368 x: Index of split point in text1.
368 x: Index of split point in text1.
369 y: Index of split point in text2.
369 y: Index of split point in text2.
370 deadline: Time at which to bail if not yet complete.
370 deadline: Time at which to bail if not yet complete.
371
371
372 Returns:
372 Returns:
373 Array of diff tuples.
373 Array of diff tuples.
374 """
374 """
375 text1a = text1[:x]
375 text1a = text1[:x]
376 text2a = text2[:y]
376 text2a = text2[:y]
377 text1b = text1[x:]
377 text1b = text1[x:]
378 text2b = text2[y:]
378 text2b = text2[y:]
379
379
380 # Compute both diffs serially.
380 # Compute both diffs serially.
381 diffs = self.diff_main(text1a, text2a, False, deadline)
381 diffs = self.diff_main(text1a, text2a, False, deadline)
382 diffsb = self.diff_main(text1b, text2b, False, deadline)
382 diffsb = self.diff_main(text1b, text2b, False, deadline)
383
383
384 return diffs + diffsb
384 return diffs + diffsb
385
385
386 def diff_linesToChars(self, text1, text2):
386 def diff_linesToChars(self, text1, text2):
387 """Split two texts into an array of strings. Reduce the texts to a string
387 """Split two texts into an array of strings. Reduce the texts to a string
388 of hashes where each Unicode character represents one line.
388 of hashes where each Unicode character represents one line.
389
389
390 Args:
390 Args:
391 text1: First string.
391 text1: First string.
392 text2: Second string.
392 text2: Second string.
393
393
394 Returns:
394 Returns:
395 Three element tuple, containing the encoded text1, the encoded text2 and
395 Three element tuple, containing the encoded text1, the encoded text2 and
396 the array of unique strings. The zeroth element of the array of unique
396 the array of unique strings. The zeroth element of the array of unique
397 strings is intentionally blank.
397 strings is intentionally blank.
398 """
398 """
399 lineArray = [] # e.g. lineArray[4] == "Hello\n"
399 lineArray = [] # e.g. lineArray[4] == "Hello\n"
400 lineHash = {} # e.g. lineHash["Hello\n"] == 4
400 lineHash = {} # e.g. lineHash["Hello\n"] == 4
401
401
402 # "\x00" is a valid character, but various debuggers don't like it.
402 # "\x00" is a valid character, but various debuggers don't like it.
403 # So we'll insert a junk entry to avoid generating a null character.
403 # So we'll insert a junk entry to avoid generating a null character.
404 lineArray.append("")
404 lineArray.append("")
405
405
406 def diff_linesToCharsMunge(text):
406 def diff_linesToCharsMunge(text):
407 """Split a text into an array of strings. Reduce the texts to a string
407 """Split a text into an array of strings. Reduce the texts to a string
408 of hashes where each Unicode character represents one line.
408 of hashes where each Unicode character represents one line.
409 Modifies linearray and linehash through being a closure.
409 Modifies linearray and linehash through being a closure.
410
410
411 Args:
411 Args:
412 text: String to encode.
412 text: String to encode.
413
413
414 Returns:
414 Returns:
415 Encoded string.
415 Encoded string.
416 """
416 """
417 chars = []
417 chars = []
418 # Walk the text, pulling out a substring for each line.
418 # Walk the text, pulling out a substring for each line.
419 # text.split('\n') would would temporarily double our memory footprint.
419 # text.split('\n') would would temporarily double our memory footprint.
420 # Modifying text would create many large strings to garbage collect.
420 # Modifying text would create many large strings to garbage collect.
421 lineStart = 0
421 lineStart = 0
422 lineEnd = -1
422 lineEnd = -1
423 while lineEnd < len(text) - 1:
423 while lineEnd < len(text) - 1:
424 lineEnd = text.find("\n", lineStart)
424 lineEnd = text.find("\n", lineStart)
425 if lineEnd == -1:
425 if lineEnd == -1:
426 lineEnd = len(text) - 1
426 lineEnd = len(text) - 1
427 line = text[lineStart : lineEnd + 1]
427 line = text[lineStart : lineEnd + 1]
428 lineStart = lineEnd + 1
428 lineStart = lineEnd + 1
429
429
430 if line in lineHash:
430 if line in lineHash:
431 chars.append(chr(lineHash[line]))
431 chars.append(chr(lineHash[line]))
432 else:
432 else:
433 lineArray.append(line)
433 lineArray.append(line)
434 lineHash[line] = len(lineArray) - 1
434 lineHash[line] = len(lineArray) - 1
435 chars.append(chr(len(lineArray) - 1))
435 chars.append(chr(len(lineArray) - 1))
436 return "".join(chars)
436 return "".join(chars)
437
437
438 chars1 = diff_linesToCharsMunge(text1)
438 chars1 = diff_linesToCharsMunge(text1)
439 chars2 = diff_linesToCharsMunge(text2)
439 chars2 = diff_linesToCharsMunge(text2)
440 return (chars1, chars2, lineArray)
440 return (chars1, chars2, lineArray)
441
441
442 def diff_charsToLines(self, diffs, lineArray):
442 def diff_charsToLines(self, diffs, lineArray):
443 """Rehydrate the text in a diff from a string of line hashes to real lines
443 """Rehydrate the text in a diff from a string of line hashes to real lines
444 of text.
444 of text.
445
445
446 Args:
446 Args:
447 diffs: Array of diff tuples.
447 diffs: Array of diff tuples.
448 lineArray: Array of unique strings.
448 lineArray: Array of unique strings.
449 """
449 """
450 for x in range(len(diffs)):
450 for x in range(len(diffs)):
451 text = []
451 text = []
452 for char in diffs[x][1]:
452 for char in diffs[x][1]:
453 text.append(lineArray[ord(char)])
453 text.append(lineArray[ord(char)])
454 diffs[x] = (diffs[x][0], "".join(text))
454 diffs[x] = (diffs[x][0], "".join(text))
455
455
456 def diff_commonPrefix(self, text1, text2):
456 def diff_commonPrefix(self, text1, text2):
457 """Determine the common prefix of two strings.
457 """Determine the common prefix of two strings.
458
458
459 Args:
459 Args:
460 text1: First string.
460 text1: First string.
461 text2: Second string.
461 text2: Second string.
462
462
463 Returns:
463 Returns:
464 The number of characters common to the start of each string.
464 The number of characters common to the start of each string.
465 """
465 """
466 # Quick check for common null cases.
466 # Quick check for common null cases.
467 if not text1 or not text2 or text1[0] != text2[0]:
467 if not text1 or not text2 or text1[0] != text2[0]:
468 return 0
468 return 0
469 # Binary search.
469 # Binary search.
470 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
470 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
471 pointermin = 0
471 pointermin = 0
472 pointermax = min(len(text1), len(text2))
472 pointermax = min(len(text1), len(text2))
473 pointermid = pointermax
473 pointermid = pointermax
474 pointerstart = 0
474 pointerstart = 0
475 while pointermin < pointermid:
475 while pointermin < pointermid:
476 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
476 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
477 pointermin = pointermid
477 pointermin = pointermid
478 pointerstart = pointermin
478 pointerstart = pointermin
479 else:
479 else:
480 pointermax = pointermid
480 pointermax = pointermid
481 pointermid = (pointermax - pointermin) // 2 + pointermin
481 pointermid = (pointermax - pointermin) // 2 + pointermin
482 return pointermid
482 return pointermid
483
483
484 def diff_commonSuffix(self, text1, text2):
484 def diff_commonSuffix(self, text1, text2):
485 """Determine the common suffix of two strings.
485 """Determine the common suffix of two strings.
486
486
487 Args:
487 Args:
488 text1: First string.
488 text1: First string.
489 text2: Second string.
489 text2: Second string.
490
490
491 Returns:
491 Returns:
492 The number of characters common to the end of each string.
492 The number of characters common to the end of each string.
493 """
493 """
494 # Quick check for common null cases.
494 # Quick check for common null cases.
495 if not text1 or not text2 or text1[-1] != text2[-1]:
495 if not text1 or not text2 or text1[-1] != text2[-1]:
496 return 0
496 return 0
497 # Binary search.
497 # Binary search.
498 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
498 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
499 pointermin = 0
499 pointermin = 0
500 pointermax = min(len(text1), len(text2))
500 pointermax = min(len(text1), len(text2))
501 pointermid = pointermax
501 pointermid = pointermax
502 pointerend = 0
502 pointerend = 0
503 while pointermin < pointermid:
503 while pointermin < pointermid:
504 if (
504 if (
505 text1[-pointermid : len(text1) - pointerend]
505 text1[-pointermid : len(text1) - pointerend]
506 == text2[-pointermid : len(text2) - pointerend]
506 == text2[-pointermid : len(text2) - pointerend]
507 ):
507 ):
508 pointermin = pointermid
508 pointermin = pointermid
509 pointerend = pointermin
509 pointerend = pointermin
510 else:
510 else:
511 pointermax = pointermid
511 pointermax = pointermid
512 pointermid = (pointermax - pointermin) // 2 + pointermin
512 pointermid = (pointermax - pointermin) // 2 + pointermin
513 return pointermid
513 return pointermid
514
514
515 def diff_commonOverlap(self, text1, text2):
515 def diff_commonOverlap(self, text1, text2):
516 """Determine if the suffix of one string is the prefix of another.
516 """Determine if the suffix of one string is the prefix of another.
517
517
518 Args:
518 Args:
519 text1 First string.
519 text1 First string.
520 text2 Second string.
520 text2 Second string.
521
521
522 Returns:
522 Returns:
523 The number of characters common to the end of the first
523 The number of characters common to the end of the first
524 string and the start of the second string.
524 string and the start of the second string.
525 """
525 """
526 # Cache the text lengths to prevent multiple calls.
526 # Cache the text lengths to prevent multiple calls.
527 text1_length = len(text1)
527 text1_length = len(text1)
528 text2_length = len(text2)
528 text2_length = len(text2)
529 # Eliminate the null case.
529 # Eliminate the null case.
530 if text1_length == 0 or text2_length == 0:
530 if text1_length == 0 or text2_length == 0:
531 return 0
531 return 0
532 # Truncate the longer string.
532 # Truncate the longer string.
533 if text1_length > text2_length:
533 if text1_length > text2_length:
534 text1 = text1[-text2_length:]
534 text1 = text1[-text2_length:]
535 elif text1_length < text2_length:
535 elif text1_length < text2_length:
536 text2 = text2[:text1_length]
536 text2 = text2[:text1_length]
537 text_length = min(text1_length, text2_length)
537 text_length = min(text1_length, text2_length)
538 # Quick check for the worst case.
538 # Quick check for the worst case.
539 if text1 == text2:
539 if text1 == text2:
540 return text_length
540 return text_length
541
541
542 # Start by looking for a single character match
542 # Start by looking for a single character match
543 # and increase length until no match is found.
543 # and increase length until no match is found.
544 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
544 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
545 best = 0
545 best = 0
546 length = 1
546 length = 1
547 while True:
547 while True:
548 pattern = text1[-length:]
548 pattern = text1[-length:]
549 found = text2.find(pattern)
549 found = text2.find(pattern)
550 if found == -1:
550 if found == -1:
551 return best
551 return best
552 length += found
552 length += found
553 if found == 0 or text1[-length:] == text2[:length]:
553 if found == 0 or text1[-length:] == text2[:length]:
554 best = length
554 best = length
555 length += 1
555 length += 1
556
556
557 def diff_halfMatch(self, text1, text2):
557 def diff_halfMatch(self, text1, text2):
558 """Do the two texts share a substring which is at least half the length of
558 """Do the two texts share a substring which is at least half the length of
559 the longer text?
559 the longer text?
560 This speedup can produce non-minimal diffs.
560 This speedup can produce non-minimal diffs.
561
561
562 Args:
562 Args:
563 text1: First string.
563 text1: First string.
564 text2: Second string.
564 text2: Second string.
565
565
566 Returns:
566 Returns:
567 Five element Array, containing the prefix of text1, the suffix of text1,
567 Five element Array, containing the prefix of text1, the suffix of text1,
568 the prefix of text2, the suffix of text2 and the common middle. Or None
568 the prefix of text2, the suffix of text2 and the common middle. Or None
569 if there was no match.
569 if there was no match.
570 """
570 """
571 if self.Diff_Timeout <= 0:
571 if self.Diff_Timeout <= 0:
572 # Don't risk returning a non-optimal diff if we have unlimited time.
572 # Don't risk returning a non-optimal diff if we have unlimited time.
573 return None
573 return None
574 if len(text1) > len(text2):
574 if len(text1) > len(text2):
575 (longtext, shorttext) = (text1, text2)
575 (longtext, shorttext) = (text1, text2)
576 else:
576 else:
577 (shorttext, longtext) = (text1, text2)
577 (shorttext, longtext) = (text1, text2)
578 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
578 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
579 return None # Pointless.
579 return None # Pointless.
580
580
581 def diff_halfMatchI(longtext, shorttext, i):
581 def diff_halfMatchI(longtext, shorttext, i):
582 """Does a substring of shorttext exist within longtext such that the
582 """Does a substring of shorttext exist within longtext such that the
583 substring is at least half the length of longtext?
583 substring is at least half the length of longtext?
584 Closure, but does not reference any external variables.
584 Closure, but does not reference any external variables.
585
585
586 Args:
586 Args:
587 longtext: Longer string.
587 longtext: Longer string.
588 shorttext: Shorter string.
588 shorttext: Shorter string.
589 i: Start index of quarter length substring within longtext.
589 i: Start index of quarter length substring within longtext.
590
590
591 Returns:
591 Returns:
592 Five element Array, containing the prefix of longtext, the suffix of
592 Five element Array, containing the prefix of longtext, the suffix of
593 longtext, the prefix of shorttext, the suffix of shorttext and the
593 longtext, the prefix of shorttext, the suffix of shorttext and the
594 common middle. Or None if there was no match.
594 common middle. Or None if there was no match.
595 """
595 """
596 seed = longtext[i : i + len(longtext) // 4]
596 seed = longtext[i : i + len(longtext) // 4]
597 best_common = ""
597 best_common = ""
598 j = shorttext.find(seed)
598 j = shorttext.find(seed)
599 while j != -1:
599 while j != -1:
600 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
600 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
601 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
601 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
602 if len(best_common) < suffixLength + prefixLength:
602 if len(best_common) < suffixLength + prefixLength:
603 best_common = (
603 best_common = (
604 shorttext[j - suffixLength : j]
604 shorttext[j - suffixLength : j]
605 + shorttext[j : j + prefixLength]
605 + shorttext[j : j + prefixLength]
606 )
606 )
607 best_longtext_a = longtext[: i - suffixLength]
607 best_longtext_a = longtext[: i - suffixLength]
608 best_longtext_b = longtext[i + prefixLength :]
608 best_longtext_b = longtext[i + prefixLength :]
609 best_shorttext_a = shorttext[: j - suffixLength]
609 best_shorttext_a = shorttext[: j - suffixLength]
610 best_shorttext_b = shorttext[j + prefixLength :]
610 best_shorttext_b = shorttext[j + prefixLength :]
611 j = shorttext.find(seed, j + 1)
611 j = shorttext.find(seed, j + 1)
612
612
613 if len(best_common) * 2 >= len(longtext):
613 if len(best_common) * 2 >= len(longtext):
614 return (
614 return (
615 best_longtext_a,
615 best_longtext_a,
616 best_longtext_b,
616 best_longtext_b,
617 best_shorttext_a,
617 best_shorttext_a,
618 best_shorttext_b,
618 best_shorttext_b,
619 best_common,
619 best_common,
620 )
620 )
621 else:
621 else:
622 return None
622 return None
623
623
624 # First check if the second quarter is the seed for a half-match.
624 # First check if the second quarter is the seed for a half-match.
625 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
625 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
626 # Check again based on the third quarter.
626 # Check again based on the third quarter.
627 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
627 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
628 if not hm1 and not hm2:
628 if not hm1 and not hm2:
629 return None
629 return None
630 elif not hm2:
630 elif not hm2:
631 hm = hm1
631 hm = hm1
632 elif not hm1:
632 elif not hm1:
633 hm = hm2
633 hm = hm2
634 else:
634 else:
635 # Both matched. Select the longest.
635 # Both matched. Select the longest.
636 if len(hm1[4]) > len(hm2[4]):
636 if len(hm1[4]) > len(hm2[4]):
637 hm = hm1
637 hm = hm1
638 else:
638 else:
639 hm = hm2
639 hm = hm2
640
640
641 # A half-match was found, sort out the return data.
641 # A half-match was found, sort out the return data.
642 if len(text1) > len(text2):
642 if len(text1) > len(text2):
643 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
643 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
644 else:
644 else:
645 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
645 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
646 return (text1_a, text1_b, text2_a, text2_b, mid_common)
646 return (text1_a, text1_b, text2_a, text2_b, mid_common)
647
647
648 def diff_cleanupSemantic(self, diffs):
648 def diff_cleanupSemantic(self, diffs):
649 """Reduce the number of edits by eliminating semantically trivial
649 """Reduce the number of edits by eliminating semantically trivial
650 equalities.
650 equalities.
651
651
652 Args:
652 Args:
653 diffs: Array of diff tuples.
653 diffs: Array of diff tuples.
654 """
654 """
655 changes = False
655 changes = False
656 equalities = [] # Stack of indices where equalities are found.
656 equalities = [] # Stack of indices where equalities are found.
657 lastequality = None # Always equal to diffs[equalities[-1]][1]
657 lastequality = None # Always equal to diffs[equalities[-1]][1]
658 pointer = 0 # Index of current position.
658 pointer = 0 # Index of current position.
659 # Number of chars that changed prior to the equality.
659 # Number of chars that changed prior to the equality.
660 length_insertions1, length_deletions1 = 0, 0
660 length_insertions1, length_deletions1 = 0, 0
661 # Number of chars that changed after the equality.
661 # Number of chars that changed after the equality.
662 length_insertions2, length_deletions2 = 0, 0
662 length_insertions2, length_deletions2 = 0, 0
663 while pointer < len(diffs):
663 while pointer < len(diffs):
664 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
664 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
665 equalities.append(pointer)
665 equalities.append(pointer)
666 length_insertions1, length_insertions2 = length_insertions2, 0
666 length_insertions1, length_insertions2 = length_insertions2, 0
667 length_deletions1, length_deletions2 = length_deletions2, 0
667 length_deletions1, length_deletions2 = length_deletions2, 0
668 lastequality = diffs[pointer][1]
668 lastequality = diffs[pointer][1]
669 else: # An insertion or deletion.
669 else: # An insertion or deletion.
670 if diffs[pointer][0] == self.DIFF_INSERT:
670 if diffs[pointer][0] == self.DIFF_INSERT:
671 length_insertions2 += len(diffs[pointer][1])
671 length_insertions2 += len(diffs[pointer][1])
672 else:
672 else:
673 length_deletions2 += len(diffs[pointer][1])
673 length_deletions2 += len(diffs[pointer][1])
674 # Eliminate an equality that is smaller or equal to the edits on both
674 # Eliminate an equality that is smaller or equal to the edits on both
675 # sides of it.
675 # sides of it.
676 if (
676 if (
677 lastequality
677 lastequality
678 and (
678 and (
679 len(lastequality) <= max(length_insertions1, length_deletions1)
679 len(lastequality) <= max(length_insertions1, length_deletions1)
680 )
680 )
681 and (
681 and (
682 len(lastequality) <= max(length_insertions2, length_deletions2)
682 len(lastequality) <= max(length_insertions2, length_deletions2)
683 )
683 )
684 ):
684 ):
685 # Duplicate record.
685 # Duplicate record.
686 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
686 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
687 # Change second copy to insert.
687 # Change second copy to insert.
688 diffs[equalities[-1] + 1] = (
688 diffs[equalities[-1] + 1] = (
689 self.DIFF_INSERT,
689 self.DIFF_INSERT,
690 diffs[equalities[-1] + 1][1],
690 diffs[equalities[-1] + 1][1],
691 )
691 )
692 # Throw away the equality we just deleted.
692 # Throw away the equality we just deleted.
693 equalities.pop()
693 equalities.pop()
694 # Throw away the previous equality (it needs to be reevaluated).
694 # Throw away the previous equality (it needs to be reevaluated).
695 if len(equalities):
695 if len(equalities):
696 equalities.pop()
696 equalities.pop()
697 if len(equalities):
697 if len(equalities):
698 pointer = equalities[-1]
698 pointer = equalities[-1]
699 else:
699 else:
700 pointer = -1
700 pointer = -1
701 # Reset the counters.
701 # Reset the counters.
702 length_insertions1, length_deletions1 = 0, 0
702 length_insertions1, length_deletions1 = 0, 0
703 length_insertions2, length_deletions2 = 0, 0
703 length_insertions2, length_deletions2 = 0, 0
704 lastequality = None
704 lastequality = None
705 changes = True
705 changes = True
706 pointer += 1
706 pointer += 1
707
707
708 # Normalize the diff.
708 # Normalize the diff.
709 if changes:
709 if changes:
710 self.diff_cleanupMerge(diffs)
710 self.diff_cleanupMerge(diffs)
711 self.diff_cleanupSemanticLossless(diffs)
711 self.diff_cleanupSemanticLossless(diffs)
712
712
713 # Find any overlaps between deletions and insertions.
713 # Find any overlaps between deletions and insertions.
714 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
714 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
715 # -> <del>abc</del>xxx<ins>def</ins>
715 # -> <del>abc</del>xxx<ins>def</ins>
716 # e.g: <del>xxxabc</del><ins>defxxx</ins>
716 # e.g: <del>xxxabc</del><ins>defxxx</ins>
717 # -> <ins>def</ins>xxx<del>abc</del>
717 # -> <ins>def</ins>xxx<del>abc</del>
718 # Only extract an overlap if it is as big as the edit ahead or behind it.
718 # Only extract an overlap if it is as big as the edit ahead or behind it.
719 pointer = 1
719 pointer = 1
720 while pointer < len(diffs):
720 while pointer < len(diffs):
721 if (
721 if (
722 diffs[pointer - 1][0] == self.DIFF_DELETE
722 diffs[pointer - 1][0] == self.DIFF_DELETE
723 and diffs[pointer][0] == self.DIFF_INSERT
723 and diffs[pointer][0] == self.DIFF_INSERT
724 ):
724 ):
725 deletion = diffs[pointer - 1][1]
725 deletion = diffs[pointer - 1][1]
726 insertion = diffs[pointer][1]
726 insertion = diffs[pointer][1]
727 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
727 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
728 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
728 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
729 if overlap_length1 >= overlap_length2:
729 if overlap_length1 >= overlap_length2:
730 if (
730 if (
731 overlap_length1 >= len(deletion) / 2.0
731 overlap_length1 >= len(deletion) / 2.0
732 or overlap_length1 >= len(insertion) / 2.0
732 or overlap_length1 >= len(insertion) / 2.0
733 ):
733 ):
734 # Overlap found. Insert an equality and trim the surrounding edits.
734 # Overlap found. Insert an equality and trim the surrounding edits.
735 diffs.insert(
735 diffs.insert(
736 pointer, (self.DIFF_EQUAL, insertion[:overlap_length1])
736 pointer, (self.DIFF_EQUAL, insertion[:overlap_length1])
737 )
737 )
738 diffs[pointer - 1] = (
738 diffs[pointer - 1] = (
739 self.DIFF_DELETE,
739 self.DIFF_DELETE,
740 deletion[: len(deletion) - overlap_length1],
740 deletion[: len(deletion) - overlap_length1],
741 )
741 )
742 diffs[pointer + 1] = (
742 diffs[pointer + 1] = (
743 self.DIFF_INSERT,
743 self.DIFF_INSERT,
744 insertion[overlap_length1:],
744 insertion[overlap_length1:],
745 )
745 )
746 pointer += 1
746 pointer += 1
747 else:
747 else:
748 if (
748 if (
749 overlap_length2 >= len(deletion) / 2.0
749 overlap_length2 >= len(deletion) / 2.0
750 or overlap_length2 >= len(insertion) / 2.0
750 or overlap_length2 >= len(insertion) / 2.0
751 ):
751 ):
752 # Reverse overlap found.
752 # Reverse overlap found.
753 # Insert an equality and swap and trim the surrounding edits.
753 # Insert an equality and swap and trim the surrounding edits.
754 diffs.insert(
754 diffs.insert(
755 pointer, (self.DIFF_EQUAL, deletion[:overlap_length2])
755 pointer, (self.DIFF_EQUAL, deletion[:overlap_length2])
756 )
756 )
757 diffs[pointer - 1] = (
757 diffs[pointer - 1] = (
758 self.DIFF_INSERT,
758 self.DIFF_INSERT,
759 insertion[: len(insertion) - overlap_length2],
759 insertion[: len(insertion) - overlap_length2],
760 )
760 )
761 diffs[pointer + 1] = (
761 diffs[pointer + 1] = (
762 self.DIFF_DELETE,
762 self.DIFF_DELETE,
763 deletion[overlap_length2:],
763 deletion[overlap_length2:],
764 )
764 )
765 pointer += 1
765 pointer += 1
766 pointer += 1
766 pointer += 1
767 pointer += 1
767 pointer += 1
768
768
769 def diff_cleanupSemanticLossless(self, diffs):
769 def diff_cleanupSemanticLossless(self, diffs):
770 """Look for single edits surrounded on both sides by equalities
770 """Look for single edits surrounded on both sides by equalities
771 which can be shifted sideways to align the edit to a word boundary.
771 which can be shifted sideways to align the edit to a word boundary.
772 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
772 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
773
773
774 Args:
774 Args:
775 diffs: Array of diff tuples.
775 diffs: Array of diff tuples.
776 """
776 """
777
777
778 def diff_cleanupSemanticScore(one, two):
778 def diff_cleanupSemanticScore(one, two):
779 """Given two strings, compute a score representing whether the
779 """Given two strings, compute a score representing whether the
780 internal boundary falls on logical boundaries.
780 internal boundary falls on logical boundaries.
781 Scores range from 6 (best) to 0 (worst).
781 Scores range from 6 (best) to 0 (worst).
782 Closure, but does not reference any external variables.
782 Closure, but does not reference any external variables.
783
783
784 Args:
784 Args:
785 one: First string.
785 one: First string.
786 two: Second string.
786 two: Second string.
787
787
788 Returns:
788 Returns:
789 The score.
789 The score.
790 """
790 """
791 if not one or not two:
791 if not one or not two:
792 # Edges are the best.
792 # Edges are the best.
793 return 6
793 return 6
794
794
795 # Each port of this function behaves slightly differently due to
795 # Each port of this function behaves slightly differently due to
796 # subtle differences in each language's definition of things like
796 # subtle differences in each language's definition of things like
797 # 'whitespace'. Since this function's purpose is largely cosmetic,
797 # 'whitespace'. Since this function's purpose is largely cosmetic,
798 # the choice has been made to use each language's native features
798 # the choice has been made to use each language's native features
799 # rather than force total conformity.
799 # rather than force total conformity.
800 char1 = one[-1]
800 char1 = one[-1]
801 char2 = two[0]
801 char2 = two[0]
802 nonAlphaNumeric1 = not char1.isalnum()
802 nonAlphaNumeric1 = not char1.isalnum()
803 nonAlphaNumeric2 = not char2.isalnum()
803 nonAlphaNumeric2 = not char2.isalnum()
804 whitespace1 = nonAlphaNumeric1 and char1.isspace()
804 whitespace1 = nonAlphaNumeric1 and char1.isspace()
805 whitespace2 = nonAlphaNumeric2 and char2.isspace()
805 whitespace2 = nonAlphaNumeric2 and char2.isspace()
806 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
806 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
807 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
807 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
808 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
808 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
809 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
809 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
810
810
811 if blankLine1 or blankLine2:
811 if blankLine1 or blankLine2:
812 # Five points for blank lines.
812 # Five points for blank lines.
813 return 5
813 return 5
814 elif lineBreak1 or lineBreak2:
814 elif lineBreak1 or lineBreak2:
815 # Four points for line breaks.
815 # Four points for line breaks.
816 return 4
816 return 4
817 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
817 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
818 # Three points for end of sentences.
818 # Three points for end of sentences.
819 return 3
819 return 3
820 elif whitespace1 or whitespace2:
820 elif whitespace1 or whitespace2:
821 # Two points for whitespace.
821 # Two points for whitespace.
822 return 2
822 return 2
823 elif nonAlphaNumeric1 or nonAlphaNumeric2:
823 elif nonAlphaNumeric1 or nonAlphaNumeric2:
824 # One point for non-alphanumeric.
824 # One point for non-alphanumeric.
825 return 1
825 return 1
826 return 0
826 return 0
827
827
828 pointer = 1
828 pointer = 1
829 # Intentionally ignore the first and last element (don't need checking).
829 # Intentionally ignore the first and last element (don't need checking).
830 while pointer < len(diffs) - 1:
830 while pointer < len(diffs) - 1:
831 if (
831 if (
832 diffs[pointer - 1][0] == self.DIFF_EQUAL
832 diffs[pointer - 1][0] == self.DIFF_EQUAL
833 and diffs[pointer + 1][0] == self.DIFF_EQUAL
833 and diffs[pointer + 1][0] == self.DIFF_EQUAL
834 ):
834 ):
835 # This is a single edit surrounded by equalities.
835 # This is a single edit surrounded by equalities.
836 equality1 = diffs[pointer - 1][1]
836 equality1 = diffs[pointer - 1][1]
837 edit = diffs[pointer][1]
837 edit = diffs[pointer][1]
838 equality2 = diffs[pointer + 1][1]
838 equality2 = diffs[pointer + 1][1]
839
839
840 # First, shift the edit as far left as possible.
840 # First, shift the edit as far left as possible.
841 commonOffset = self.diff_commonSuffix(equality1, edit)
841 commonOffset = self.diff_commonSuffix(equality1, edit)
842 if commonOffset:
842 if commonOffset:
843 commonString = edit[-commonOffset:]
843 commonString = edit[-commonOffset:]
844 equality1 = equality1[:-commonOffset]
844 equality1 = equality1[:-commonOffset]
845 edit = commonString + edit[:-commonOffset]
845 edit = commonString + edit[:-commonOffset]
846 equality2 = commonString + equality2
846 equality2 = commonString + equality2
847
847
848 # Second, step character by character right, looking for the best fit.
848 # Second, step character by character right, looking for the best fit.
849 bestEquality1 = equality1
849 bestEquality1 = equality1
850 bestEdit = edit
850 bestEdit = edit
851 bestEquality2 = equality2
851 bestEquality2 = equality2
852 bestScore = diff_cleanupSemanticScore(
852 bestScore = diff_cleanupSemanticScore(
853 equality1, edit
853 equality1, edit
854 ) + diff_cleanupSemanticScore(edit, equality2)
854 ) + diff_cleanupSemanticScore(edit, equality2)
855 while edit and equality2 and edit[0] == equality2[0]:
855 while edit and equality2 and edit[0] == equality2[0]:
856 equality1 += edit[0]
856 equality1 += edit[0]
857 edit = edit[1:] + equality2[0]
857 edit = edit[1:] + equality2[0]
858 equality2 = equality2[1:]
858 equality2 = equality2[1:]
859 score = diff_cleanupSemanticScore(
859 score = diff_cleanupSemanticScore(
860 equality1, edit
860 equality1, edit
861 ) + diff_cleanupSemanticScore(edit, equality2)
861 ) + diff_cleanupSemanticScore(edit, equality2)
862 # The >= encourages trailing rather than leading whitespace on edits.
862 # The >= encourages trailing rather than leading whitespace on edits.
863 if score >= bestScore:
863 if score >= bestScore:
864 bestScore = score
864 bestScore = score
865 bestEquality1 = equality1
865 bestEquality1 = equality1
866 bestEdit = edit
866 bestEdit = edit
867 bestEquality2 = equality2
867 bestEquality2 = equality2
868
868
869 if diffs[pointer - 1][1] != bestEquality1:
869 if diffs[pointer - 1][1] != bestEquality1:
870 # We have an improvement, save it back to the diff.
870 # We have an improvement, save it back to the diff.
871 if bestEquality1:
871 if bestEquality1:
872 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
872 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
873 else:
873 else:
874 del diffs[pointer - 1]
874 del diffs[pointer - 1]
875 pointer -= 1
875 pointer -= 1
876 diffs[pointer] = (diffs[pointer][0], bestEdit)
876 diffs[pointer] = (diffs[pointer][0], bestEdit)
877 if bestEquality2:
877 if bestEquality2:
878 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
878 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
879 else:
879 else:
880 del diffs[pointer + 1]
880 del diffs[pointer + 1]
881 pointer -= 1
881 pointer -= 1
882 pointer += 1
882 pointer += 1
883
883
884 # Define some regex patterns for matching boundaries.
884 # Define some regex patterns for matching boundaries.
885 BLANKLINEEND = re.compile(r"\n\r?\n$")
885 BLANKLINEEND = re.compile(r"\n\r?\n$")
886 BLANKLINESTART = re.compile(r"^\r?\n\r?\n")
886 BLANKLINESTART = re.compile(r"^\r?\n\r?\n")
887
887
888 def diff_cleanupEfficiency(self, diffs):
888 def diff_cleanupEfficiency(self, diffs):
889 """Reduce the number of edits by eliminating operationally trivial
889 """Reduce the number of edits by eliminating operationally trivial
890 equalities.
890 equalities.
891
891
892 Args:
892 Args:
893 diffs: Array of diff tuples.
893 diffs: Array of diff tuples.
894 """
894 """
895 changes = False
895 changes = False
896 equalities = [] # Stack of indices where equalities are found.
896 equalities = [] # Stack of indices where equalities are found.
897 lastequality = None # Always equal to diffs[equalities[-1]][1]
897 lastequality = None # Always equal to diffs[equalities[-1]][1]
898 pointer = 0 # Index of current position.
898 pointer = 0 # Index of current position.
899 pre_ins = False # Is there an insertion operation before the last equality.
899 pre_ins = False # Is there an insertion operation before the last equality.
900 pre_del = False # Is there a deletion operation before the last equality.
900 pre_del = False # Is there a deletion operation before the last equality.
901 post_ins = False # Is there an insertion operation after the last equality.
901 post_ins = False # Is there an insertion operation after the last equality.
902 post_del = False # Is there a deletion operation after the last equality.
902 post_del = False # Is there a deletion operation after the last equality.
903 while pointer < len(diffs):
903 while pointer < len(diffs):
904 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
904 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
905 if len(diffs[pointer][1]) < self.Diff_EditCost and (
905 if len(diffs[pointer][1]) < self.Diff_EditCost and (
906 post_ins or post_del
906 post_ins or post_del
907 ):
907 ):
908 # Candidate found.
908 # Candidate found.
909 equalities.append(pointer)
909 equalities.append(pointer)
910 pre_ins = post_ins
910 pre_ins = post_ins
911 pre_del = post_del
911 pre_del = post_del
912 lastequality = diffs[pointer][1]
912 lastequality = diffs[pointer][1]
913 else:
913 else:
914 # Not a candidate, and can never become one.
914 # Not a candidate, and can never become one.
915 equalities = []
915 equalities = []
916 lastequality = None
916 lastequality = None
917
917
918 post_ins = post_del = False
918 post_ins = post_del = False
919 else: # An insertion or deletion.
919 else: # An insertion or deletion.
920 if diffs[pointer][0] == self.DIFF_DELETE:
920 if diffs[pointer][0] == self.DIFF_DELETE:
921 post_del = True
921 post_del = True
922 else:
922 else:
923 post_ins = True
923 post_ins = True
924
924
925 # Five types to be split:
925 # Five types to be split:
926 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
926 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
927 # <ins>A</ins>X<ins>C</ins><del>D</del>
927 # <ins>A</ins>X<ins>C</ins><del>D</del>
928 # <ins>A</ins><del>B</del>X<ins>C</ins>
928 # <ins>A</ins><del>B</del>X<ins>C</ins>
929 # <ins>A</del>X<ins>C</ins><del>D</del>
929 # <ins>A</del>X<ins>C</ins><del>D</del>
930 # <ins>A</ins><del>B</del>X<del>C</del>
930 # <ins>A</ins><del>B</del>X<del>C</del>
931
931
932 if lastequality and (
932 if lastequality and (
933 (pre_ins and pre_del and post_ins and post_del)
933 (pre_ins and pre_del and post_ins and post_del)
934 or (
934 or (
935 (len(lastequality) < self.Diff_EditCost / 2)
935 (len(lastequality) < self.Diff_EditCost / 2)
936 and (pre_ins + pre_del + post_ins + post_del) == 3
936 and (pre_ins + pre_del + post_ins + post_del) == 3
937 )
937 )
938 ):
938 ):
939 # Duplicate record.
939 # Duplicate record.
940 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
940 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
941 # Change second copy to insert.
941 # Change second copy to insert.
942 diffs[equalities[-1] + 1] = (
942 diffs[equalities[-1] + 1] = (
943 self.DIFF_INSERT,
943 self.DIFF_INSERT,
944 diffs[equalities[-1] + 1][1],
944 diffs[equalities[-1] + 1][1],
945 )
945 )
946 equalities.pop() # Throw away the equality we just deleted.
946 equalities.pop() # Throw away the equality we just deleted.
947 lastequality = None
947 lastequality = None
948 if pre_ins and pre_del:
948 if pre_ins and pre_del:
949 # No changes made which could affect previous entry, keep going.
949 # No changes made which could affect previous entry, keep going.
950 post_ins = post_del = True
950 post_ins = post_del = True
951 equalities = []
951 equalities = []
952 else:
952 else:
953 if len(equalities):
953 if len(equalities):
954 equalities.pop() # Throw away the previous equality.
954 equalities.pop() # Throw away the previous equality.
955 if len(equalities):
955 if len(equalities):
956 pointer = equalities[-1]
956 pointer = equalities[-1]
957 else:
957 else:
958 pointer = -1
958 pointer = -1
959 post_ins = post_del = False
959 post_ins = post_del = False
960 changes = True
960 changes = True
961 pointer += 1
961 pointer += 1
962
962
963 if changes:
963 if changes:
964 self.diff_cleanupMerge(diffs)
964 self.diff_cleanupMerge(diffs)
965
965
966 def diff_cleanupMerge(self, diffs):
966 def diff_cleanupMerge(self, diffs):
967 """Reorder and merge like edit sections. Merge equalities.
967 """Reorder and merge like edit sections. Merge equalities.
968 Any edit section can move as long as it doesn't cross an equality.
968 Any edit section can move as long as it doesn't cross an equality.
969
969
970 Args:
970 Args:
971 diffs: Array of diff tuples.
971 diffs: Array of diff tuples.
972 """
972 """
973 diffs.append((self.DIFF_EQUAL, "")) # Add a dummy entry at the end.
973 diffs.append((self.DIFF_EQUAL, "")) # Add a dummy entry at the end.
974 pointer = 0
974 pointer = 0
975 count_delete = 0
975 count_delete = 0
976 count_insert = 0
976 count_insert = 0
977 text_delete = ""
977 text_delete = ""
978 text_insert = ""
978 text_insert = ""
979 while pointer < len(diffs):
979 while pointer < len(diffs):
980 if diffs[pointer][0] == self.DIFF_INSERT:
980 if diffs[pointer][0] == self.DIFF_INSERT:
981 count_insert += 1
981 count_insert += 1
982 text_insert += diffs[pointer][1]
982 text_insert += diffs[pointer][1]
983 pointer += 1
983 pointer += 1
984 elif diffs[pointer][0] == self.DIFF_DELETE:
984 elif diffs[pointer][0] == self.DIFF_DELETE:
985 count_delete += 1
985 count_delete += 1
986 text_delete += diffs[pointer][1]
986 text_delete += diffs[pointer][1]
987 pointer += 1
987 pointer += 1
988 elif diffs[pointer][0] == self.DIFF_EQUAL:
988 elif diffs[pointer][0] == self.DIFF_EQUAL:
989 # Upon reaching an equality, check for prior redundancies.
989 # Upon reaching an equality, check for prior redundancies.
990 if count_delete + count_insert > 1:
990 if count_delete + count_insert > 1:
991 if count_delete != 0 and count_insert != 0:
991 if count_delete != 0 and count_insert != 0:
992 # Factor out any common prefixies.
992 # Factor out any common prefixies.
993 commonlength = self.diff_commonPrefix(text_insert, text_delete)
993 commonlength = self.diff_commonPrefix(text_insert, text_delete)
994 if commonlength != 0:
994 if commonlength != 0:
995 x = pointer - count_delete - count_insert - 1
995 x = pointer - count_delete - count_insert - 1
996 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
996 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
997 diffs[x] = (
997 diffs[x] = (
998 diffs[x][0],
998 diffs[x][0],
999 diffs[x][1] + text_insert[:commonlength],
999 diffs[x][1] + text_insert[:commonlength],
1000 )
1000 )
1001 else:
1001 else:
1002 diffs.insert(
1002 diffs.insert(
1003 0, (self.DIFF_EQUAL, text_insert[:commonlength])
1003 0, (self.DIFF_EQUAL, text_insert[:commonlength])
1004 )
1004 )
1005 pointer += 1
1005 pointer += 1
1006 text_insert = text_insert[commonlength:]
1006 text_insert = text_insert[commonlength:]
1007 text_delete = text_delete[commonlength:]
1007 text_delete = text_delete[commonlength:]
1008 # Factor out any common suffixies.
1008 # Factor out any common suffixies.
1009 commonlength = self.diff_commonSuffix(text_insert, text_delete)
1009 commonlength = self.diff_commonSuffix(text_insert, text_delete)
1010 if commonlength != 0:
1010 if commonlength != 0:
1011 diffs[pointer] = (
1011 diffs[pointer] = (
1012 diffs[pointer][0],
1012 diffs[pointer][0],
1013 text_insert[-commonlength:] + diffs[pointer][1],
1013 text_insert[-commonlength:] + diffs[pointer][1],
1014 )
1014 )
1015 text_insert = text_insert[:-commonlength]
1015 text_insert = text_insert[:-commonlength]
1016 text_delete = text_delete[:-commonlength]
1016 text_delete = text_delete[:-commonlength]
1017 # Delete the offending records and add the merged ones.
1017 # Delete the offending records and add the merged ones.
1018 if count_delete == 0:
1018 if count_delete == 0:
1019 diffs[pointer - count_insert : pointer] = [
1019 diffs[pointer - count_insert : pointer] = [
1020 (self.DIFF_INSERT, text_insert)
1020 (self.DIFF_INSERT, text_insert)
1021 ]
1021 ]
1022 elif count_insert == 0:
1022 elif count_insert == 0:
1023 diffs[pointer - count_delete : pointer] = [
1023 diffs[pointer - count_delete : pointer] = [
1024 (self.DIFF_DELETE, text_delete)
1024 (self.DIFF_DELETE, text_delete)
1025 ]
1025 ]
1026 else:
1026 else:
1027 diffs[pointer - count_delete - count_insert : pointer] = [
1027 diffs[pointer - count_delete - count_insert : pointer] = [
1028 (self.DIFF_DELETE, text_delete),
1028 (self.DIFF_DELETE, text_delete),
1029 (self.DIFF_INSERT, text_insert),
1029 (self.DIFF_INSERT, text_insert),
1030 ]
1030 ]
1031 pointer = pointer - count_delete - count_insert + 1
1031 pointer = pointer - count_delete - count_insert + 1
1032 if count_delete != 0:
1032 if count_delete != 0:
1033 pointer += 1
1033 pointer += 1
1034 if count_insert != 0:
1034 if count_insert != 0:
1035 pointer += 1
1035 pointer += 1
1036 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
1036 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
1037 # Merge this equality with the previous one.
1037 # Merge this equality with the previous one.
1038 diffs[pointer - 1] = (
1038 diffs[pointer - 1] = (
1039 diffs[pointer - 1][0],
1039 diffs[pointer - 1][0],
1040 diffs[pointer - 1][1] + diffs[pointer][1],
1040 diffs[pointer - 1][1] + diffs[pointer][1],
1041 )
1041 )
1042 del diffs[pointer]
1042 del diffs[pointer]
1043 else:
1043 else:
1044 pointer += 1
1044 pointer += 1
1045
1045
1046 count_insert = 0
1046 count_insert = 0
1047 count_delete = 0
1047 count_delete = 0
1048 text_delete = ""
1048 text_delete = ""
1049 text_insert = ""
1049 text_insert = ""
1050
1050
1051 if diffs[-1][1] == "":
1051 if diffs[-1][1] == "":
1052 diffs.pop() # Remove the dummy entry at the end.
1052 diffs.pop() # Remove the dummy entry at the end.
1053
1053
1054 # Second pass: look for single edits surrounded on both sides by equalities
1054 # Second pass: look for single edits surrounded on both sides by equalities
1055 # which can be shifted sideways to eliminate an equality.
1055 # which can be shifted sideways to eliminate an equality.
1056 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
1056 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
1057 changes = False
1057 changes = False
1058 pointer = 1
1058 pointer = 1
1059 # Intentionally ignore the first and last element (don't need checking).
1059 # Intentionally ignore the first and last element (don't need checking).
1060 while pointer < len(diffs) - 1:
1060 while pointer < len(diffs) - 1:
1061 if (
1061 if (
1062 diffs[pointer - 1][0] == self.DIFF_EQUAL
1062 diffs[pointer - 1][0] == self.DIFF_EQUAL
1063 and diffs[pointer + 1][0] == self.DIFF_EQUAL
1063 and diffs[pointer + 1][0] == self.DIFF_EQUAL
1064 ):
1064 ):
1065 # This is a single edit surrounded by equalities.
1065 # This is a single edit surrounded by equalities.
1066 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1066 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1067 # Shift the edit over the previous equality.
1067 # Shift the edit over the previous equality.
1068 diffs[pointer] = (
1068 diffs[pointer] = (
1069 diffs[pointer][0],
1069 diffs[pointer][0],
1070 diffs[pointer - 1][1]
1070 diffs[pointer - 1][1]
1071 + diffs[pointer][1][: -len(diffs[pointer - 1][1])],
1071 + diffs[pointer][1][: -len(diffs[pointer - 1][1])],
1072 )
1072 )
1073 diffs[pointer + 1] = (
1073 diffs[pointer + 1] = (
1074 diffs[pointer + 1][0],
1074 diffs[pointer + 1][0],
1075 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1075 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1076 )
1076 )
1077 del diffs[pointer - 1]
1077 del diffs[pointer - 1]
1078 changes = True
1078 changes = True
1079 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1079 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1080 # Shift the edit over the next equality.
1080 # Shift the edit over the next equality.
1081 diffs[pointer - 1] = (
1081 diffs[pointer - 1] = (
1082 diffs[pointer - 1][0],
1082 diffs[pointer - 1][0],
1083 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1083 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1084 )
1084 )
1085 diffs[pointer] = (
1085 diffs[pointer] = (
1086 diffs[pointer][0],
1086 diffs[pointer][0],
1087 diffs[pointer][1][len(diffs[pointer + 1][1]) :]
1087 diffs[pointer][1][len(diffs[pointer + 1][1]) :]
1088 + diffs[pointer + 1][1],
1088 + diffs[pointer + 1][1],
1089 )
1089 )
1090 del diffs[pointer + 1]
1090 del diffs[pointer + 1]
1091 changes = True
1091 changes = True
1092 pointer += 1
1092 pointer += 1
1093
1093
1094 # If shifts were made, the diff needs reordering and another shift sweep.
1094 # If shifts were made, the diff needs reordering and another shift sweep.
1095 if changes:
1095 if changes:
1096 self.diff_cleanupMerge(diffs)
1096 self.diff_cleanupMerge(diffs)
1097
1097
1098 def diff_xIndex(self, diffs, loc):
1098 def diff_xIndex(self, diffs, loc):
1099 """loc is a location in text1, compute and return the equivalent location
1099 """loc is a location in text1, compute and return the equivalent location
1100 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1100 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1101
1101
1102 Args:
1102 Args:
1103 diffs: Array of diff tuples.
1103 diffs: Array of diff tuples.
1104 loc: Location within text1.
1104 loc: Location within text1.
1105
1105
1106 Returns:
1106 Returns:
1107 Location within text2.
1107 Location within text2.
1108 """
1108 """
1109 chars1 = 0
1109 chars1 = 0
1110 chars2 = 0
1110 chars2 = 0
1111 last_chars1 = 0
1111 last_chars1 = 0
1112 last_chars2 = 0
1112 last_chars2 = 0
1113 for x in range(len(diffs)):
1113 for x in range(len(diffs)):
1114 (op, text) = diffs[x]
1114 (op, text) = diffs[x]
1115 if op != self.DIFF_INSERT: # Equality or deletion.
1115 if op != self.DIFF_INSERT: # Equality or deletion.
1116 chars1 += len(text)
1116 chars1 += len(text)
1117 if op != self.DIFF_DELETE: # Equality or insertion.
1117 if op != self.DIFF_DELETE: # Equality or insertion.
1118 chars2 += len(text)
1118 chars2 += len(text)
1119 if chars1 > loc: # Overshot the location.
1119 if chars1 > loc: # Overshot the location.
1120 break
1120 break
1121 last_chars1 = chars1
1121 last_chars1 = chars1
1122 last_chars2 = chars2
1122 last_chars2 = chars2
1123
1123
1124 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1124 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1125 # The location was deleted.
1125 # The location was deleted.
1126 return last_chars2
1126 return last_chars2
1127 # Add the remaining len(character).
1127 # Add the remaining len(character).
1128 return last_chars2 + (loc - last_chars1)
1128 return last_chars2 + (loc - last_chars1)
1129
1129
1130 def diff_prettyHtml(self, diffs):
1130 def diff_prettyHtml(self, diffs):
1131 """Convert a diff array into a pretty HTML report.
1131 """Convert a diff array into a pretty HTML report.
1132
1132
1133 Args:
1133 Args:
1134 diffs: Array of diff tuples.
1134 diffs: Array of diff tuples.
1135
1135
1136 Returns:
1136 Returns:
1137 HTML representation.
1137 HTML representation.
1138 """
1138 """
1139 html = []
1139 html = []
1140 for op, data in diffs:
1140 for op, data in diffs:
1141 text = (
1141 text = (
1142 data.replace("&", "&amp;")
1142 data.replace("&", "&amp;")
1143 .replace("<", "&lt;")
1143 .replace("<", "&lt;")
1144 .replace(">", "&gt;")
1144 .replace(">", "&gt;")
1145 .replace("\n", "&para;<br>")
1145 .replace("\n", "&para;<br>")
1146 )
1146 )
1147 if op == self.DIFF_INSERT:
1147 if op == self.DIFF_INSERT:
1148 html.append('<ins style="background:#e6ffe6;">%s</ins>' % text)
1148 html.append('<ins style="background:#e6ffe6;">%s</ins>' % text)
1149 elif op == self.DIFF_DELETE:
1149 elif op == self.DIFF_DELETE:
1150 html.append('<del style="background:#ffe6e6;">%s</del>' % text)
1150 html.append('<del style="background:#ffe6e6;">%s</del>' % text)
1151 elif op == self.DIFF_EQUAL:
1151 elif op == self.DIFF_EQUAL:
1152 html.append("<span>%s</span>" % text)
1152 html.append("<span>%s</span>" % text)
1153 return "".join(html)
1153 return "".join(html)
1154
1154
1155 def diff_text1(self, diffs):
1155 def diff_text1(self, diffs):
1156 """Compute and return the source text (all equalities and deletions).
1156 """Compute and return the source text (all equalities and deletions).
1157
1157
1158 Args:
1158 Args:
1159 diffs: Array of diff tuples.
1159 diffs: Array of diff tuples.
1160
1160
1161 Returns:
1161 Returns:
1162 Source text.
1162 Source text.
1163 """
1163 """
1164 text = []
1164 text = []
1165 for op, data in diffs:
1165 for op, data in diffs:
1166 if op != self.DIFF_INSERT:
1166 if op != self.DIFF_INSERT:
1167 text.append(data)
1167 text.append(data)
1168 return "".join(text)
1168 return "".join(text)
1169
1169
1170 def diff_text2(self, diffs):
1170 def diff_text2(self, diffs):
1171 """Compute and return the destination text (all equalities and insertions).
1171 """Compute and return the destination text (all equalities and insertions).
1172
1172
1173 Args:
1173 Args:
1174 diffs: Array of diff tuples.
1174 diffs: Array of diff tuples.
1175
1175
1176 Returns:
1176 Returns:
1177 Destination text.
1177 Destination text.
1178 """
1178 """
1179 text = []
1179 text = []
1180 for op, data in diffs:
1180 for op, data in diffs:
1181 if op != self.DIFF_DELETE:
1181 if op != self.DIFF_DELETE:
1182 text.append(data)
1182 text.append(data)
1183 return "".join(text)
1183 return "".join(text)
1184
1184
1185 def diff_levenshtein(self, diffs):
1185 def diff_levenshtein(self, diffs):
1186 """Compute the Levenshtein distance; the number of inserted, deleted or
1186 """Compute the Levenshtein distance; the number of inserted, deleted or
1187 substituted characters.
1187 substituted characters.
1188
1188
1189 Args:
1189 Args:
1190 diffs: Array of diff tuples.
1190 diffs: Array of diff tuples.
1191
1191
1192 Returns:
1192 Returns:
1193 Number of changes.
1193 Number of changes.
1194 """
1194 """
1195 levenshtein = 0
1195 levenshtein = 0
1196 insertions = 0
1196 insertions = 0
1197 deletions = 0
1197 deletions = 0
1198 for op, data in diffs:
1198 for op, data in diffs:
1199 if op == self.DIFF_INSERT:
1199 if op == self.DIFF_INSERT:
1200 insertions += len(data)
1200 insertions += len(data)
1201 elif op == self.DIFF_DELETE:
1201 elif op == self.DIFF_DELETE:
1202 deletions += len(data)
1202 deletions += len(data)
1203 elif op == self.DIFF_EQUAL:
1203 elif op == self.DIFF_EQUAL:
1204 # A deletion and an insertion is one substitution.
1204 # A deletion and an insertion is one substitution.
1205 levenshtein += max(insertions, deletions)
1205 levenshtein += max(insertions, deletions)
1206 insertions = 0
1206 insertions = 0
1207 deletions = 0
1207 deletions = 0
1208 levenshtein += max(insertions, deletions)
1208 levenshtein += max(insertions, deletions)
1209 return levenshtein
1209 return levenshtein
1210
1210
1211 def diff_toDelta(self, diffs):
1211 def diff_toDelta(self, diffs):
1212 """Crush the diff into an encoded string which describes the operations
1212 """Crush the diff into an encoded string which describes the operations
1213 required to transform text1 into text2.
1213 required to transform text1 into text2.
1214 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1214 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1215 Operations are tab-separated. Inserted text is escaped using %xx notation.
1215 Operations are tab-separated. Inserted text is escaped using %xx notation.
1216
1216
1217 Args:
1217 Args:
1218 diffs: Array of diff tuples.
1218 diffs: Array of diff tuples.
1219
1219
1220 Returns:
1220 Returns:
1221 Delta text.
1221 Delta text.
1222 """
1222 """
1223 text = []
1223 text = []
1224 for op, data in diffs:
1224 for op, data in diffs:
1225 if op == self.DIFF_INSERT:
1225 if op == self.DIFF_INSERT:
1226 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1226 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1227 data = data.encode("utf-8")
1227 data = data.encode("utf-8")
1228 text.append("+" + urllib.parse.quote(data, "!~*'();/?:@&=+$,# "))
1228 text.append("+" + urllib.parse.quote(data, "!~*'();/?:@&=+$,# "))
1229 elif op == self.DIFF_DELETE:
1229 elif op == self.DIFF_DELETE:
1230 text.append("-%d" % len(data))
1230 text.append("-%d" % len(data))
1231 elif op == self.DIFF_EQUAL:
1231 elif op == self.DIFF_EQUAL:
1232 text.append("=%d" % len(data))
1232 text.append("=%d" % len(data))
1233 return "\t".join(text)
1233 return "\t".join(text)
1234
1234
1235 def diff_fromDelta(self, text1, delta):
1235 def diff_fromDelta(self, text1, delta):
1236 """Given the original text1, and an encoded string which describes the
1236 """Given the original text1, and an encoded string which describes the
1237 operations required to transform text1 into text2, compute the full diff.
1237 operations required to transform text1 into text2, compute the full diff.
1238
1238
1239 Args:
1239 Args:
1240 text1: Source string for the diff.
1240 text1: Source string for the diff.
1241 delta: Delta text.
1241 delta: Delta text.
1242
1242
1243 Returns:
1243 Returns:
1244 Array of diff tuples.
1244 Array of diff tuples.
1245
1245
1246 Raises:
1246 Raises:
1247 ValueError: If invalid input.
1247 ValueError: If invalid input.
1248 """
1248 """
1249 if type(delta) == str:
1249 if type(delta) == str:
1250 # Deltas should be composed of a subset of ascii chars, Unicode not
1250 # Deltas should be composed of a subset of ascii chars, Unicode not
1251 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1251 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1252 delta = delta.encode("ascii")
1252 delta = delta.encode("ascii")
1253 diffs = []
1253 diffs = []
1254 pointer = 0 # Cursor in text1
1254 pointer = 0 # Cursor in text1
1255 tokens = delta.split("\t")
1255 tokens = delta.split("\t")
1256 for token in tokens:
1256 for token in tokens:
1257 if token == "":
1257 if token == "":
1258 # Blank tokens are ok (from a trailing \t).
1258 # Blank tokens are ok (from a trailing \t).
1259 continue
1259 continue
1260 # Each token begins with a one character parameter which specifies the
1260 # Each token begins with a one character parameter which specifies the
1261 # operation of this token (delete, insert, equality).
1261 # operation of this token (delete, insert, equality).
1262 param = token[1:]
1262 param = token[1:]
1263 if token[0] == "+":
1263 if token[0] == "+":
1264 param = urllib.parse.unquote(param)
1264 param = urllib.parse.unquote(param)
1265 diffs.append((self.DIFF_INSERT, param))
1265 diffs.append((self.DIFF_INSERT, param))
1266 elif token[0] == "-" or token[0] == "=":
1266 elif token[0] == "-" or token[0] == "=":
1267 try:
1267 try:
1268 n = int(param)
1268 n = int(param)
1269 except ValueError:
1269 except ValueError:
1270 raise ValueError("Invalid number in diff_fromDelta: " + param)
1270 raise ValueError("Invalid number in diff_fromDelta: " + param)
1271 if n < 0:
1271 if n < 0:
1272 raise ValueError("Negative number in diff_fromDelta: " + param)
1272 raise ValueError("Negative number in diff_fromDelta: " + param)
1273 text = text1[pointer : pointer + n]
1273 text = text1[pointer : pointer + n]
1274 pointer += n
1274 pointer += n
1275 if token[0] == "=":
1275 if token[0] == "=":
1276 diffs.append((self.DIFF_EQUAL, text))
1276 diffs.append((self.DIFF_EQUAL, text))
1277 else:
1277 else:
1278 diffs.append((self.DIFF_DELETE, text))
1278 diffs.append((self.DIFF_DELETE, text))
1279 else:
1279 else:
1280 # Anything else is an error.
1280 # Anything else is an error.
1281 raise ValueError(
1281 raise ValueError(
1282 "Invalid diff operation in diff_fromDelta: " + token[0]
1282 "Invalid diff operation in diff_fromDelta: " + token[0]
1283 )
1283 )
1284 if pointer != len(text1):
1284 if pointer != len(text1):
1285 raise ValueError(
1285 raise ValueError(
1286 "Delta length (%d) does not equal source text length (%d)."
1286 "Delta length (%d) does not equal source text length (%d)."
1287 % (pointer, len(text1))
1287 % (pointer, len(text1))
1288 )
1288 )
1289 return diffs
1289 return diffs
1290
1290
1291 # MATCH FUNCTIONS
1291 # MATCH FUNCTIONS
1292
1292
1293 def match_main(self, text, pattern, loc):
1293 def match_main(self, text, pattern, loc):
1294 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1294 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1295
1295
1296 Args:
1296 Args:
1297 text: The text to search.
1297 text: The text to search.
1298 pattern: The pattern to search for.
1298 pattern: The pattern to search for.
1299 loc: The location to search around.
1299 loc: The location to search around.
1300
1300
1301 Returns:
1301 Returns:
1302 Best match index or -1.
1302 Best match index or -1.
1303 """
1303 """
1304 # Check for null inputs.
1304 # Check for null inputs.
1305 if text is None or pattern is None:
1305 if text is None or pattern is None:
1306 raise ValueError("Null inputs. (match_main)")
1306 raise ValueError("Null inputs. (match_main)")
1307
1307
1308 loc = max(0, min(loc, len(text)))
1308 loc = max(0, min(loc, len(text)))
1309 if text == pattern:
1309 if text == pattern:
1310 # Shortcut (potentially not guaranteed by the algorithm)
1310 # Shortcut (potentially not guaranteed by the algorithm)
1311 return 0
1311 return 0
1312 elif not text:
1312 elif not text:
1313 # Nothing to match.
1313 # Nothing to match.
1314 return -1
1314 return -1
1315 elif text[loc : loc + len(pattern)] == pattern:
1315 elif text[loc : loc + len(pattern)] == pattern:
1316 # Perfect match at the perfect spot! (Includes case of null pattern)
1316 # Perfect match at the perfect spot! (Includes case of null pattern)
1317 return loc
1317 return loc
1318 else:
1318 else:
1319 # Do a fuzzy compare.
1319 # Do a fuzzy compare.
1320 match = self.match_bitap(text, pattern, loc)
1320 match = self.match_bitap(text, pattern, loc)
1321 return match
1321 return match
1322
1322
1323 def match_bitap(self, text, pattern, loc):
1323 def match_bitap(self, text, pattern, loc):
1324 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1324 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1325 Bitap algorithm.
1325 Bitap algorithm.
1326
1326
1327 Args:
1327 Args:
1328 text: The text to search.
1328 text: The text to search.
1329 pattern: The pattern to search for.
1329 pattern: The pattern to search for.
1330 loc: The location to search around.
1330 loc: The location to search around.
1331
1331
1332 Returns:
1332 Returns:
1333 Best match index or -1.
1333 Best match index or -1.
1334 """
1334 """
1335 # Python doesn't have a maxint limit, so ignore this check.
1335 # Python doesn't have a maxint limit, so ignore this check.
1336 # if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1336 # if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1337 # raise ValueError("Pattern too long for this application.")
1337 # raise ValueError("Pattern too long for this application.")
1338
1338
1339 # Initialise the alphabet.
1339 # Initialise the alphabet.
1340 s = self.match_alphabet(pattern)
1340 s = self.match_alphabet(pattern)
1341
1341
1342 def match_bitapScore(e, x):
1342 def match_bitapScore(e, x):
1343 """Compute and return the score for a match with e errors and x location.
1343 """Compute and return the score for a match with e errors and x location.
1344 Accesses loc and pattern through being a closure.
1344 Accesses loc and pattern through being a closure.
1345
1345
1346 Args:
1346 Args:
1347 e: Number of errors in match.
1347 e: Number of errors in match.
1348 x: Location of match.
1348 x: Location of match.
1349
1349
1350 Returns:
1350 Returns:
1351 Overall score for match (0.0 = good, 1.0 = bad).
1351 Overall score for match (0.0 = good, 1.0 = bad).
1352 """
1352 """
1353 accuracy = float(e) / len(pattern)
1353 accuracy = float(e) / len(pattern)
1354 proximity = abs(loc - x)
1354 proximity = abs(loc - x)
1355 if not self.Match_Distance:
1355 if not self.Match_Distance:
1356 # Dodge divide by zero error.
1356 # Dodge divide by zero error.
1357 return proximity and 1.0 or accuracy
1357 return proximity and 1.0 or accuracy
1358 return accuracy + (proximity / float(self.Match_Distance))
1358 return accuracy + (proximity / float(self.Match_Distance))
1359
1359
1360 # Highest score beyond which we give up.
1360 # Highest score beyond which we give up.
1361 score_threshold = self.Match_Threshold
1361 score_threshold = self.Match_Threshold
1362 # Is there a nearby exact match? (speedup)
1362 # Is there a nearby exact match? (speedup)
1363 best_loc = text.find(pattern, loc)
1363 best_loc = text.find(pattern, loc)
1364 if best_loc != -1:
1364 if best_loc != -1:
1365 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1365 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1366 # What about in the other direction? (speedup)
1366 # What about in the other direction? (speedup)
1367 best_loc = text.rfind(pattern, loc + len(pattern))
1367 best_loc = text.rfind(pattern, loc + len(pattern))
1368 if best_loc != -1:
1368 if best_loc != -1:
1369 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1369 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1370
1370
1371 # Initialise the bit arrays.
1371 # Initialise the bit arrays.
1372 matchmask = 1 << (len(pattern) - 1)
1372 matchmask = 1 << (len(pattern) - 1)
1373 best_loc = -1
1373 best_loc = -1
1374
1374
1375 bin_max = len(pattern) + len(text)
1375 bin_max = len(pattern) + len(text)
1376 # Empty initialization added to appease pychecker.
1376 # Empty initialization added to appease pychecker.
1377 last_rd = None
1377 last_rd = None
1378 for d in range(len(pattern)):
1378 for d in range(len(pattern)):
1379 # Scan for the best match each iteration allows for one more error.
1379 # Scan for the best match each iteration allows for one more error.
1380 # Run a binary search to determine how far from 'loc' we can stray at
1380 # Run a binary search to determine how far from 'loc' we can stray at
1381 # this error level.
1381 # this error level.
1382 bin_min = 0
1382 bin_min = 0
1383 bin_mid = bin_max
1383 bin_mid = bin_max
1384 while bin_min < bin_mid:
1384 while bin_min < bin_mid:
1385 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1385 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1386 bin_min = bin_mid
1386 bin_min = bin_mid
1387 else:
1387 else:
1388 bin_max = bin_mid
1388 bin_max = bin_mid
1389 bin_mid = (bin_max - bin_min) // 2 + bin_min
1389 bin_mid = (bin_max - bin_min) // 2 + bin_min
1390
1390
1391 # Use the result from this iteration as the maximum for the next.
1391 # Use the result from this iteration as the maximum for the next.
1392 bin_max = bin_mid
1392 bin_max = bin_mid
1393 start = max(1, loc - bin_mid + 1)
1393 start = max(1, loc - bin_mid + 1)
1394 finish = min(loc + bin_mid, len(text)) + len(pattern)
1394 finish = min(loc + bin_mid, len(text)) + len(pattern)
1395
1395
1396 rd = [0] * (finish + 2)
1396 rd = [0] * (finish + 2)
1397 rd[finish + 1] = (1 << d) - 1
1397 rd[finish + 1] = (1 << d) - 1
1398 for j in range(finish, start - 1, -1):
1398 for j in range(finish, start - 1, -1):
1399 if len(text) <= j - 1:
1399 if len(text) <= j - 1:
1400 # Out of range.
1400 # Out of range.
1401 charMatch = 0
1401 charMatch = 0
1402 else:
1402 else:
1403 charMatch = s.get(text[j - 1], 0)
1403 charMatch = s.get(text[j - 1], 0)
1404 if d == 0: # First pass: exact match.
1404 if d == 0: # First pass: exact match.
1405 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1405 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1406 else: # Subsequent passes: fuzzy match.
1406 else: # Subsequent passes: fuzzy match.
1407 rd[j] = (
1407 rd[j] = (
1408 (((rd[j + 1] << 1) | 1) & charMatch)
1408 (((rd[j + 1] << 1) | 1) & charMatch)
1409 | (((last_rd[j + 1] | last_rd[j]) << 1) | 1)
1409 | (((last_rd[j + 1] | last_rd[j]) << 1) | 1)
1410 | last_rd[j + 1]
1410 | last_rd[j + 1]
1411 )
1411 )
1412 if rd[j] & matchmask:
1412 if rd[j] & matchmask:
1413 score = match_bitapScore(d, j - 1)
1413 score = match_bitapScore(d, j - 1)
1414 # This match will almost certainly be better than any existing match.
1414 # This match will almost certainly be better than any existing match.
1415 # But check anyway.
1415 # But check anyway.
1416 if score <= score_threshold:
1416 if score <= score_threshold:
1417 # Told you so.
1417 # Told you so.
1418 score_threshold = score
1418 score_threshold = score
1419 best_loc = j - 1
1419 best_loc = j - 1
1420 if best_loc > loc:
1420 if best_loc > loc:
1421 # When passing loc, don't exceed our current distance from loc.
1421 # When passing loc, don't exceed our current distance from loc.
1422 start = max(1, 2 * loc - best_loc)
1422 start = max(1, 2 * loc - best_loc)
1423 else:
1423 else:
1424 # Already passed loc, downhill from here on in.
1424 # Already passed loc, downhill from here on in.
1425 break
1425 break
1426 # No hope for a (better) match at greater error levels.
1426 # No hope for a (better) match at greater error levels.
1427 if match_bitapScore(d + 1, loc) > score_threshold:
1427 if match_bitapScore(d + 1, loc) > score_threshold:
1428 break
1428 break
1429 last_rd = rd
1429 last_rd = rd
1430 return best_loc
1430 return best_loc
1431
1431
1432 def match_alphabet(self, pattern):
1432 def match_alphabet(self, pattern):
1433 """Initialise the alphabet for the Bitap algorithm.
1433 """Initialise the alphabet for the Bitap algorithm.
1434
1434
1435 Args:
1435 Args:
1436 pattern: The text to encode.
1436 pattern: The text to encode.
1437
1437
1438 Returns:
1438 Returns:
1439 Hash of character locations.
1439 Hash of character locations.
1440 """
1440 """
1441 s = {}
1441 s = {}
1442 for char in pattern:
1442 for char in pattern:
1443 s[char] = 0
1443 s[char] = 0
1444 for i in range(len(pattern)):
1444 for i in range(len(pattern)):
1445 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1445 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1446 return s
1446 return s
1447
1447
1448 # PATCH FUNCTIONS
1448 # PATCH FUNCTIONS
1449
1449
1450 def patch_addContext(self, patch, text):
1450 def patch_addContext(self, patch, text):
1451 """Increase the context until it is unique,
1451 """Increase the context until it is unique,
1452 but don't let the pattern expand beyond Match_MaxBits.
1452 but don't let the pattern expand beyond Match_MaxBits.
1453
1453
1454 Args:
1454 Args:
1455 patch: The patch to grow.
1455 patch: The patch to grow.
1456 text: Source text.
1456 text: Source text.
1457 """
1457 """
1458 if len(text) == 0:
1458 if len(text) == 0:
1459 return
1459 return
1460 pattern = text[patch.start2 : patch.start2 + patch.length1]
1460 pattern = text[patch.start2 : patch.start2 + patch.length1]
1461 padding = 0
1461 padding = 0
1462
1462
1463 # Look for the first and last matches of pattern in text. If two different
1463 # Look for the first and last matches of pattern in text. If two different
1464 # matches are found, increase the pattern length.
1464 # matches are found, increase the pattern length.
1465 while text.find(pattern) != text.rfind(pattern) and (
1465 while text.find(pattern) != text.rfind(pattern) and (
1466 self.Match_MaxBits == 0
1466 self.Match_MaxBits == 0
1467 or len(pattern) < self.Match_MaxBits - self.Patch_Margin - self.Patch_Margin
1467 or len(pattern) < self.Match_MaxBits - self.Patch_Margin - self.Patch_Margin
1468 ):
1468 ):
1469 padding += self.Patch_Margin
1469 padding += self.Patch_Margin
1470 pattern = text[
1470 pattern = text[
1471 max(0, patch.start2 - padding) : patch.start2 + patch.length1 + padding
1471 max(0, patch.start2 - padding) : patch.start2 + patch.length1 + padding
1472 ]
1472 ]
1473 # Add one chunk for good luck.
1473 # Add one chunk for good luck.
1474 padding += self.Patch_Margin
1474 padding += self.Patch_Margin
1475
1475
1476 # Add the prefix.
1476 # Add the prefix.
1477 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1477 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1478 if prefix:
1478 if prefix:
1479 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1479 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1480 # Add the suffix.
1480 # Add the suffix.
1481 suffix = text[
1481 suffix = text[
1482 patch.start2 + patch.length1 : patch.start2 + patch.length1 + padding
1482 patch.start2 + patch.length1 : patch.start2 + patch.length1 + padding
1483 ]
1483 ]
1484 if suffix:
1484 if suffix:
1485 patch.diffs.append((self.DIFF_EQUAL, suffix))
1485 patch.diffs.append((self.DIFF_EQUAL, suffix))
1486
1486
1487 # Roll back the start points.
1487 # Roll back the start points.
1488 patch.start1 -= len(prefix)
1488 patch.start1 -= len(prefix)
1489 patch.start2 -= len(prefix)
1489 patch.start2 -= len(prefix)
1490 # Extend lengths.
1490 # Extend lengths.
1491 patch.length1 += len(prefix) + len(suffix)
1491 patch.length1 += len(prefix) + len(suffix)
1492 patch.length2 += len(prefix) + len(suffix)
1492 patch.length2 += len(prefix) + len(suffix)
1493
1493
1494 def patch_make(self, a, b=None, c=None):
1494 def patch_make(self, a, b=None, c=None):
1495 """Compute a list of patches to turn text1 into text2.
1495 """Compute a list of patches to turn text1 into text2.
1496 Use diffs if provided, otherwise compute it ourselves.
1496 Use diffs if provided, otherwise compute it ourselves.
1497 There are four ways to call this function, depending on what data is
1497 There are four ways to call this function, depending on what data is
1498 available to the caller:
1498 available to the caller:
1499 Method 1:
1499 Method 1:
1500 a = text1, b = text2
1500 a = text1, b = text2
1501 Method 2:
1501 Method 2:
1502 a = diffs
1502 a = diffs
1503 Method 3 (optimal):
1503 Method 3 (optimal):
1504 a = text1, b = diffs
1504 a = text1, b = diffs
1505 Method 4 (deprecated, use method 3):
1505 Method 4 (deprecated, use method 3):
1506 a = text1, b = text2, c = diffs
1506 a = text1, b = text2, c = diffs
1507
1507
1508 Args:
1508 Args:
1509 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1509 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1510 text2 (method 2).
1510 text2 (method 2).
1511 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1511 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1512 text2 (method 3) or undefined (method 2).
1512 text2 (method 3) or undefined (method 2).
1513 c: Array of diff tuples for text1 to text2 (method 4) or
1513 c: Array of diff tuples for text1 to text2 (method 4) or
1514 undefined (methods 1,2,3).
1514 undefined (methods 1,2,3).
1515
1515
1516 Returns:
1516 Returns:
1517 Array of Patch objects.
1517 Array of Patch objects.
1518 """
1518 """
1519 text1 = None
1519 text1 = None
1520 diffs = None
1520 diffs = None
1521 # Note that texts may arrive as 'str' or 'unicode'.
1521 # Note that texts may arrive as 'str' or 'unicode'.
1522 if isinstance(a, str) and isinstance(b, str) and c is None:
1522 if isinstance(a, str) and isinstance(b, str) and c is None:
1523 # Method 1: text1, text2
1523 # Method 1: text1, text2
1524 # Compute diffs from text1 and text2.
1524 # Compute diffs from text1 and text2.
1525 text1 = a
1525 text1 = a
1526 diffs = self.diff_main(text1, b, True)
1526 diffs = self.diff_main(text1, b, True)
1527 if len(diffs) > 2:
1527 if len(diffs) > 2:
1528 self.diff_cleanupSemantic(diffs)
1528 self.diff_cleanupSemantic(diffs)
1529 self.diff_cleanupEfficiency(diffs)
1529 self.diff_cleanupEfficiency(diffs)
1530 elif isinstance(a, list) and b is None and c is None:
1530 elif isinstance(a, list) and b is None and c is None:
1531 # Method 2: diffs
1531 # Method 2: diffs
1532 # Compute text1 from diffs.
1532 # Compute text1 from diffs.
1533 diffs = a
1533 diffs = a
1534 text1 = self.diff_text1(diffs)
1534 text1 = self.diff_text1(diffs)
1535 elif isinstance(a, str) and isinstance(b, list) and c is None:
1535 elif isinstance(a, str) and isinstance(b, list) and c is None:
1536 # Method 3: text1, diffs
1536 # Method 3: text1, diffs
1537 text1 = a
1537 text1 = a
1538 diffs = b
1538 diffs = b
1539 elif isinstance(a, str) and isinstance(b, str) and isinstance(c, list):
1539 elif isinstance(a, str) and isinstance(b, str) and isinstance(c, list):
1540 # Method 4: text1, text2, diffs
1540 # Method 4: text1, text2, diffs
1541 # text2 is not used.
1541 # text2 is not used.
1542 text1 = a
1542 text1 = a
1543 diffs = c
1543 diffs = c
1544 else:
1544 else:
1545 raise ValueError("Unknown call format to patch_make.")
1545 raise ValueError("Unknown call format to patch_make.")
1546
1546
1547 if not diffs:
1547 if not diffs:
1548 return [] # Get rid of the None case.
1548 return [] # Get rid of the None case.
1549 patches = []
1549 patches = []
1550 patch = patch_obj()
1550 patch = patch_obj()
1551 char_count1 = 0 # Number of characters into the text1 string.
1551 char_count1 = 0 # Number of characters into the text1 string.
1552 char_count2 = 0 # Number of characters into the text2 string.
1552 char_count2 = 0 # Number of characters into the text2 string.
1553 prepatch_text = text1 # Recreate the patches to determine context info.
1553 prepatch_text = text1 # Recreate the patches to determine context info.
1554 postpatch_text = text1
1554 postpatch_text = text1
1555 for x in range(len(diffs)):
1555 for x in range(len(diffs)):
1556 (diff_type, diff_text) = diffs[x]
1556 (diff_type, diff_text) = diffs[x]
1557 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1557 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1558 # A new patch starts here.
1558 # A new patch starts here.
1559 patch.start1 = char_count1
1559 patch.start1 = char_count1
1560 patch.start2 = char_count2
1560 patch.start2 = char_count2
1561 if diff_type == self.DIFF_INSERT:
1561 if diff_type == self.DIFF_INSERT:
1562 # Insertion
1562 # Insertion
1563 patch.diffs.append(diffs[x])
1563 patch.diffs.append(diffs[x])
1564 patch.length2 += len(diff_text)
1564 patch.length2 += len(diff_text)
1565 postpatch_text = (
1565 postpatch_text = (
1566 postpatch_text[:char_count2]
1566 postpatch_text[:char_count2]
1567 + diff_text
1567 + diff_text
1568 + postpatch_text[char_count2:]
1568 + postpatch_text[char_count2:]
1569 )
1569 )
1570 elif diff_type == self.DIFF_DELETE:
1570 elif diff_type == self.DIFF_DELETE:
1571 # Deletion.
1571 # Deletion.
1572 patch.length1 += len(diff_text)
1572 patch.length1 += len(diff_text)
1573 patch.diffs.append(diffs[x])
1573 patch.diffs.append(diffs[x])
1574 postpatch_text = (
1574 postpatch_text = (
1575 postpatch_text[:char_count2]
1575 postpatch_text[:char_count2]
1576 + postpatch_text[char_count2 + len(diff_text) :]
1576 + postpatch_text[char_count2 + len(diff_text) :]
1577 )
1577 )
1578 elif (
1578 elif (
1579 diff_type == self.DIFF_EQUAL
1579 diff_type == self.DIFF_EQUAL
1580 and len(diff_text) <= 2 * self.Patch_Margin
1580 and len(diff_text) <= 2 * self.Patch_Margin
1581 and len(patch.diffs) != 0
1581 and len(patch.diffs) != 0
1582 and len(diffs) != x + 1
1582 and len(diffs) != x + 1
1583 ):
1583 ):
1584 # Small equality inside a patch.
1584 # Small equality inside a patch.
1585 patch.diffs.append(diffs[x])
1585 patch.diffs.append(diffs[x])
1586 patch.length1 += len(diff_text)
1586 patch.length1 += len(diff_text)
1587 patch.length2 += len(diff_text)
1587 patch.length2 += len(diff_text)
1588
1588
1589 if diff_type == self.DIFF_EQUAL and len(diff_text) >= 2 * self.Patch_Margin:
1589 if diff_type == self.DIFF_EQUAL and len(diff_text) >= 2 * self.Patch_Margin:
1590 # Time for a new patch.
1590 # Time for a new patch.
1591 if len(patch.diffs) != 0:
1591 if len(patch.diffs) != 0:
1592 self.patch_addContext(patch, prepatch_text)
1592 self.patch_addContext(patch, prepatch_text)
1593 patches.append(patch)
1593 patches.append(patch)
1594 patch = patch_obj()
1594 patch = patch_obj()
1595 # Unlike Unidiff, our patch lists have a rolling context.
1595 # Unlike Unidiff, our patch lists have a rolling context.
1596 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1596 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1597 # Update prepatch text & pos to reflect the application of the
1597 # Update prepatch text & pos to reflect the application of the
1598 # just completed patch.
1598 # just completed patch.
1599 prepatch_text = postpatch_text
1599 prepatch_text = postpatch_text
1600 char_count1 = char_count2
1600 char_count1 = char_count2
1601
1601
1602 # Update the current character count.
1602 # Update the current character count.
1603 if diff_type != self.DIFF_INSERT:
1603 if diff_type != self.DIFF_INSERT:
1604 char_count1 += len(diff_text)
1604 char_count1 += len(diff_text)
1605 if diff_type != self.DIFF_DELETE:
1605 if diff_type != self.DIFF_DELETE:
1606 char_count2 += len(diff_text)
1606 char_count2 += len(diff_text)
1607
1607
1608 # Pick up the leftover patch if not empty.
1608 # Pick up the leftover patch if not empty.
1609 if len(patch.diffs) != 0:
1609 if len(patch.diffs) != 0:
1610 self.patch_addContext(patch, prepatch_text)
1610 self.patch_addContext(patch, prepatch_text)
1611 patches.append(patch)
1611 patches.append(patch)
1612 return patches
1612 return patches
1613
1613
1614 def patch_deepCopy(self, patches):
1614 def patch_deepCopy(self, patches):
1615 """Given an array of patches, return another array that is identical.
1615 """Given an array of patches, return another array that is identical.
1616
1616
1617 Args:
1617 Args:
1618 patches: Array of Patch objects.
1618 patches: Array of Patch objects.
1619
1619
1620 Returns:
1620 Returns:
1621 Array of Patch objects.
1621 Array of Patch objects.
1622 """
1622 """
1623 patchesCopy = []
1623 patchesCopy = []
1624 for patch in patches:
1624 for patch in patches:
1625 patchCopy = patch_obj()
1625 patchCopy = patch_obj()
1626 # No need to deep copy the tuples since they are immutable.
1626 # No need to deep copy the tuples since they are immutable.
1627 patchCopy.diffs = patch.diffs[:]
1627 patchCopy.diffs = patch.diffs[:]
1628 patchCopy.start1 = patch.start1
1628 patchCopy.start1 = patch.start1
1629 patchCopy.start2 = patch.start2
1629 patchCopy.start2 = patch.start2
1630 patchCopy.length1 = patch.length1
1630 patchCopy.length1 = patch.length1
1631 patchCopy.length2 = patch.length2
1631 patchCopy.length2 = patch.length2
1632 patchesCopy.append(patchCopy)
1632 patchesCopy.append(patchCopy)
1633 return patchesCopy
1633 return patchesCopy
1634
1634
1635 def patch_apply(self, patches, text):
1635 def patch_apply(self, patches, text):
1636 """Merge a set of patches onto the text. Return a patched text, as well
1636 """Merge a set of patches onto the text. Return a patched text, as well
1637 as a list of true/false values indicating which patches were applied.
1637 as a list of true/false values indicating which patches were applied.
1638
1638
1639 Args:
1639 Args:
1640 patches: Array of Patch objects.
1640 patches: Array of Patch objects.
1641 text: Old text.
1641 text: Old text.
1642
1642
1643 Returns:
1643 Returns:
1644 Two element Array, containing the new text and an array of boolean values.
1644 Two element Array, containing the new text and an array of boolean values.
1645 """
1645 """
1646 if not patches:
1646 if not patches:
1647 return (text, [])
1647 return (text, [])
1648
1648
1649 # Deep copy the patches so that no changes are made to originals.
1649 # Deep copy the patches so that no changes are made to originals.
1650 patches = self.patch_deepCopy(patches)
1650 patches = self.patch_deepCopy(patches)
1651
1651
1652 nullPadding = self.patch_addPadding(patches)
1652 nullPadding = self.patch_addPadding(patches)
1653 text = nullPadding + text + nullPadding
1653 text = nullPadding + text + nullPadding
1654 self.patch_splitMax(patches)
1654 self.patch_splitMax(patches)
1655
1655
1656 # delta keeps track of the offset between the expected and actual location
1656 # delta keeps track of the offset between the expected and actual location
1657 # of the previous patch. If there are patches expected at positions 10 and
1657 # of the previous patch. If there are patches expected at positions 10 and
1658 # 20, but the first patch was found at 12, delta is 2 and the second patch
1658 # 20, but the first patch was found at 12, delta is 2 and the second patch
1659 # has an effective expected position of 22.
1659 # has an effective expected position of 22.
1660 delta = 0
1660 delta = 0
1661 results = []
1661 results = []
1662 for patch in patches:
1662 for patch in patches:
1663 expected_loc = patch.start2 + delta
1663 expected_loc = patch.start2 + delta
1664 text1 = self.diff_text1(patch.diffs)
1664 text1 = self.diff_text1(patch.diffs)
1665 end_loc = -1
1665 end_loc = -1
1666 if len(text1) > self.Match_MaxBits:
1666 if len(text1) > self.Match_MaxBits:
1667 # patch_splitMax will only provide an oversized pattern in the case of
1667 # patch_splitMax will only provide an oversized pattern in the case of
1668 # a monster delete.
1668 # a monster delete.
1669 start_loc = self.match_main(
1669 start_loc = self.match_main(
1670 text, text1[: self.Match_MaxBits], expected_loc
1670 text, text1[: self.Match_MaxBits], expected_loc
1671 )
1671 )
1672 if start_loc != -1:
1672 if start_loc != -1:
1673 end_loc = self.match_main(
1673 end_loc = self.match_main(
1674 text,
1674 text,
1675 text1[-self.Match_MaxBits :],
1675 text1[-self.Match_MaxBits :],
1676 expected_loc + len(text1) - self.Match_MaxBits,
1676 expected_loc + len(text1) - self.Match_MaxBits,
1677 )
1677 )
1678 if end_loc == -1 or start_loc >= end_loc:
1678 if end_loc == -1 or start_loc >= end_loc:
1679 # Can't find valid trailing context. Drop this patch.
1679 # Can't find valid trailing context. Drop this patch.
1680 start_loc = -1
1680 start_loc = -1
1681 else:
1681 else:
1682 start_loc = self.match_main(text, text1, expected_loc)
1682 start_loc = self.match_main(text, text1, expected_loc)
1683 if start_loc == -1:
1683 if start_loc == -1:
1684 # No match found. :(
1684 # No match found. :(
1685 results.append(False)
1685 results.append(False)
1686 # Subtract the delta for this failed patch from subsequent patches.
1686 # Subtract the delta for this failed patch from subsequent patches.
1687 delta -= patch.length2 - patch.length1
1687 delta -= patch.length2 - patch.length1
1688 else:
1688 else:
1689 # Found a match. :)
1689 # Found a match. :)
1690 results.append(True)
1690 results.append(True)
1691 delta = start_loc - expected_loc
1691 delta = start_loc - expected_loc
1692 if end_loc == -1:
1692 if end_loc == -1:
1693 text2 = text[start_loc : start_loc + len(text1)]
1693 text2 = text[start_loc : start_loc + len(text1)]
1694 else:
1694 else:
1695 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1695 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1696 if text1 == text2:
1696 if text1 == text2:
1697 # Perfect match, just shove the replacement text in.
1697 # Perfect match, just shove the replacement text in.
1698 text = (
1698 text = (
1699 text[:start_loc]
1699 text[:start_loc]
1700 + self.diff_text2(patch.diffs)
1700 + self.diff_text2(patch.diffs)
1701 + text[start_loc + len(text1) :]
1701 + text[start_loc + len(text1) :]
1702 )
1702 )
1703 else:
1703 else:
1704 # Imperfect match.
1704 # Imperfect match.
1705 # Run a diff to get a framework of equivalent indices.
1705 # Run a diff to get a framework of equivalent indices.
1706 diffs = self.diff_main(text1, text2, False)
1706 diffs = self.diff_main(text1, text2, False)
1707 if (
1707 if (
1708 len(text1) > self.Match_MaxBits
1708 len(text1) > self.Match_MaxBits
1709 and self.diff_levenshtein(diffs) / float(len(text1))
1709 and self.diff_levenshtein(diffs) / float(len(text1))
1710 > self.Patch_DeleteThreshold
1710 > self.Patch_DeleteThreshold
1711 ):
1711 ):
1712 # The end points match, but the content is unacceptably bad.
1712 # The end points match, but the content is unacceptably bad.
1713 results[-1] = False
1713 results[-1] = False
1714 else:
1714 else:
1715 self.diff_cleanupSemanticLossless(diffs)
1715 self.diff_cleanupSemanticLossless(diffs)
1716 index1 = 0
1716 index1 = 0
1717 for op, data in patch.diffs:
1717 for op, data in patch.diffs:
1718 if op != self.DIFF_EQUAL:
1718 if op != self.DIFF_EQUAL:
1719 index2 = self.diff_xIndex(diffs, index1)
1719 index2 = self.diff_xIndex(diffs, index1)
1720 if op == self.DIFF_INSERT: # Insertion
1720 if op == self.DIFF_INSERT: # Insertion
1721 text = (
1721 text = (
1722 text[: start_loc + index2]
1722 text[: start_loc + index2]
1723 + data
1723 + data
1724 + text[start_loc + index2 :]
1724 + text[start_loc + index2 :]
1725 )
1725 )
1726 elif op == self.DIFF_DELETE: # Deletion
1726 elif op == self.DIFF_DELETE: # Deletion
1727 text = (
1727 text = (
1728 text[: start_loc + index2]
1728 text[: start_loc + index2]
1729 + text[
1729 + text[
1730 start_loc
1730 start_loc
1731 + self.diff_xIndex(diffs, index1 + len(data)) :
1731 + self.diff_xIndex(diffs, index1 + len(data)) :
1732 ]
1732 ]
1733 )
1733 )
1734 if op != self.DIFF_DELETE:
1734 if op != self.DIFF_DELETE:
1735 index1 += len(data)
1735 index1 += len(data)
1736 # Strip the padding off.
1736 # Strip the padding off.
1737 text = text[len(nullPadding) : -len(nullPadding)]
1737 text = text[len(nullPadding) : -len(nullPadding)]
1738 return (text, results)
1738 return (text, results)
1739
1739
1740 def patch_addPadding(self, patches):
1740 def patch_addPadding(self, patches):
1741 """Add some padding on text start and end so that edges can match
1741 """Add some padding on text start and end so that edges can match
1742 something. Intended to be called only from within patch_apply.
1742 something. Intended to be called only from within patch_apply.
1743
1743
1744 Args:
1744 Args:
1745 patches: Array of Patch objects.
1745 patches: Array of Patch objects.
1746
1746
1747 Returns:
1747 Returns:
1748 The padding string added to each side.
1748 The padding string added to each side.
1749 """
1749 """
1750 paddingLength = self.Patch_Margin
1750 paddingLength = self.Patch_Margin
1751 nullPadding = ""
1751 nullPadding = ""
1752 for x in range(1, paddingLength + 1):
1752 for x in range(1, paddingLength + 1):
1753 nullPadding += chr(x)
1753 nullPadding += chr(x)
1754
1754
1755 # Bump all the patches forward.
1755 # Bump all the patches forward.
1756 for patch in patches:
1756 for patch in patches:
1757 patch.start1 += paddingLength
1757 patch.start1 += paddingLength
1758 patch.start2 += paddingLength
1758 patch.start2 += paddingLength
1759
1759
1760 # Add some padding on start of first diff.
1760 # Add some padding on start of first diff.
1761 patch = patches[0]
1761 patch = patches[0]
1762 diffs = patch.diffs
1762 diffs = patch.diffs
1763 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1763 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1764 # Add nullPadding equality.
1764 # Add nullPadding equality.
1765 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1765 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1766 patch.start1 -= paddingLength # Should be 0.
1766 patch.start1 -= paddingLength # Should be 0.
1767 patch.start2 -= paddingLength # Should be 0.
1767 patch.start2 -= paddingLength # Should be 0.
1768 patch.length1 += paddingLength
1768 patch.length1 += paddingLength
1769 patch.length2 += paddingLength
1769 patch.length2 += paddingLength
1770 elif paddingLength > len(diffs[0][1]):
1770 elif paddingLength > len(diffs[0][1]):
1771 # Grow first equality.
1771 # Grow first equality.
1772 extraLength = paddingLength - len(diffs[0][1])
1772 extraLength = paddingLength - len(diffs[0][1])
1773 newText = nullPadding[len(diffs[0][1]) :] + diffs[0][1]
1773 newText = nullPadding[len(diffs[0][1]) :] + diffs[0][1]
1774 diffs[0] = (diffs[0][0], newText)
1774 diffs[0] = (diffs[0][0], newText)
1775 patch.start1 -= extraLength
1775 patch.start1 -= extraLength
1776 patch.start2 -= extraLength
1776 patch.start2 -= extraLength
1777 patch.length1 += extraLength
1777 patch.length1 += extraLength
1778 patch.length2 += extraLength
1778 patch.length2 += extraLength
1779
1779
1780 # Add some padding on end of last diff.
1780 # Add some padding on end of last diff.
1781 patch = patches[-1]
1781 patch = patches[-1]
1782 diffs = patch.diffs
1782 diffs = patch.diffs
1783 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1783 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1784 # Add nullPadding equality.
1784 # Add nullPadding equality.
1785 diffs.append((self.DIFF_EQUAL, nullPadding))
1785 diffs.append((self.DIFF_EQUAL, nullPadding))
1786 patch.length1 += paddingLength
1786 patch.length1 += paddingLength
1787 patch.length2 += paddingLength
1787 patch.length2 += paddingLength
1788 elif paddingLength > len(diffs[-1][1]):
1788 elif paddingLength > len(diffs[-1][1]):
1789 # Grow last equality.
1789 # Grow last equality.
1790 extraLength = paddingLength - len(diffs[-1][1])
1790 extraLength = paddingLength - len(diffs[-1][1])
1791 newText = diffs[-1][1] + nullPadding[:extraLength]
1791 newText = diffs[-1][1] + nullPadding[:extraLength]
1792 diffs[-1] = (diffs[-1][0], newText)
1792 diffs[-1] = (diffs[-1][0], newText)
1793 patch.length1 += extraLength
1793 patch.length1 += extraLength
1794 patch.length2 += extraLength
1794 patch.length2 += extraLength
1795
1795
1796 return nullPadding
1796 return nullPadding
1797
1797
1798 def patch_splitMax(self, patches):
1798 def patch_splitMax(self, patches):
1799 """Look through the patches and break up any which are longer than the
1799 """Look through the patches and break up any which are longer than the
1800 maximum limit of the match algorithm.
1800 maximum limit of the match algorithm.
1801 Intended to be called only from within patch_apply.
1801 Intended to be called only from within patch_apply.
1802
1802
1803 Args:
1803 Args:
1804 patches: Array of Patch objects.
1804 patches: Array of Patch objects.
1805 """
1805 """
1806 patch_size = self.Match_MaxBits
1806 patch_size = self.Match_MaxBits
1807 if patch_size == 0:
1807 if patch_size == 0:
1808 # Python has the option of not splitting strings due to its ability
1808 # Python has the option of not splitting strings due to its ability
1809 # to handle integers of arbitrary precision.
1809 # to handle integers of arbitrary precision.
1810 return
1810 return
1811 for x in range(len(patches)):
1811 for x in range(len(patches)):
1812 if patches[x].length1 <= patch_size:
1812 if patches[x].length1 <= patch_size:
1813 continue
1813 continue
1814 bigpatch = patches[x]
1814 bigpatch = patches[x]
1815 # Remove the big old patch.
1815 # Remove the big old patch.
1816 del patches[x]
1816 del patches[x]
1817 x -= 1
1817 x -= 1
1818 start1 = bigpatch.start1
1818 start1 = bigpatch.start1
1819 start2 = bigpatch.start2
1819 start2 = bigpatch.start2
1820 precontext = ""
1820 precontext = ""
1821 while len(bigpatch.diffs) != 0:
1821 while len(bigpatch.diffs) != 0:
1822 # Create one of several smaller patches.
1822 # Create one of several smaller patches.
1823 patch = patch_obj()
1823 patch = patch_obj()
1824 empty = True
1824 empty = True
1825 patch.start1 = start1 - len(precontext)
1825 patch.start1 = start1 - len(precontext)
1826 patch.start2 = start2 - len(precontext)
1826 patch.start2 = start2 - len(precontext)
1827 if precontext:
1827 if precontext:
1828 patch.length1 = patch.length2 = len(precontext)
1828 patch.length1 = patch.length2 = len(precontext)
1829 patch.diffs.append((self.DIFF_EQUAL, precontext))
1829 patch.diffs.append((self.DIFF_EQUAL, precontext))
1830
1830
1831 while (
1831 while (
1832 len(bigpatch.diffs) != 0
1832 len(bigpatch.diffs) != 0
1833 and patch.length1 < patch_size - self.Patch_Margin
1833 and patch.length1 < patch_size - self.Patch_Margin
1834 ):
1834 ):
1835 (diff_type, diff_text) = bigpatch.diffs[0]
1835 (diff_type, diff_text) = bigpatch.diffs[0]
1836 if diff_type == self.DIFF_INSERT:
1836 if diff_type == self.DIFF_INSERT:
1837 # Insertions are harmless.
1837 # Insertions are harmless.
1838 patch.length2 += len(diff_text)
1838 patch.length2 += len(diff_text)
1839 start2 += len(diff_text)
1839 start2 += len(diff_text)
1840 patch.diffs.append(bigpatch.diffs.pop(0))
1840 patch.diffs.append(bigpatch.diffs.pop(0))
1841 empty = False
1841 empty = False
1842 elif (
1842 elif (
1843 diff_type == self.DIFF_DELETE
1843 diff_type == self.DIFF_DELETE
1844 and len(patch.diffs) == 1
1844 and len(patch.diffs) == 1
1845 and patch.diffs[0][0] == self.DIFF_EQUAL
1845 and patch.diffs[0][0] == self.DIFF_EQUAL
1846 and len(diff_text) > 2 * patch_size
1846 and len(diff_text) > 2 * patch_size
1847 ):
1847 ):
1848 # This is a large deletion. Let it pass in one chunk.
1848 # This is a large deletion. Let it pass in one chunk.
1849 patch.length1 += len(diff_text)
1849 patch.length1 += len(diff_text)
1850 start1 += len(diff_text)
1850 start1 += len(diff_text)
1851 empty = False
1851 empty = False
1852 patch.diffs.append((diff_type, diff_text))
1852 patch.diffs.append((diff_type, diff_text))
1853 del bigpatch.diffs[0]
1853 del bigpatch.diffs[0]
1854 else:
1854 else:
1855 # Deletion or equality. Only take as much as we can stomach.
1855 # Deletion or equality. Only take as much as we can stomach.
1856 diff_text = diff_text[
1856 diff_text = diff_text[
1857 : patch_size - patch.length1 - self.Patch_Margin
1857 : patch_size - patch.length1 - self.Patch_Margin
1858 ]
1858 ]
1859 patch.length1 += len(diff_text)
1859 patch.length1 += len(diff_text)
1860 start1 += len(diff_text)
1860 start1 += len(diff_text)
1861 if diff_type == self.DIFF_EQUAL:
1861 if diff_type == self.DIFF_EQUAL:
1862 patch.length2 += len(diff_text)
1862 patch.length2 += len(diff_text)
1863 start2 += len(diff_text)
1863 start2 += len(diff_text)
1864 else:
1864 else:
1865 empty = False
1865 empty = False
1866
1866
1867 patch.diffs.append((diff_type, diff_text))
1867 patch.diffs.append((diff_type, diff_text))
1868 if diff_text == bigpatch.diffs[0][1]:
1868 if diff_text == bigpatch.diffs[0][1]:
1869 del bigpatch.diffs[0]
1869 del bigpatch.diffs[0]
1870 else:
1870 else:
1871 bigpatch.diffs[0] = (
1871 bigpatch.diffs[0] = (
1872 bigpatch.diffs[0][0],
1872 bigpatch.diffs[0][0],
1873 bigpatch.diffs[0][1][len(diff_text) :],
1873 bigpatch.diffs[0][1][len(diff_text) :],
1874 )
1874 )
1875
1875
1876 # Compute the head context for the next patch.
1876 # Compute the head context for the next patch.
1877 precontext = self.diff_text2(patch.diffs)
1877 precontext = self.diff_text2(patch.diffs)
1878 precontext = precontext[-self.Patch_Margin :]
1878 precontext = precontext[-self.Patch_Margin :]
1879 # Append the end context for this patch.
1879 # Append the end context for this patch.
1880 postcontext = self.diff_text1(bigpatch.diffs)[: self.Patch_Margin]
1880 postcontext = self.diff_text1(bigpatch.diffs)[: self.Patch_Margin]
1881 if postcontext:
1881 if postcontext:
1882 patch.length1 += len(postcontext)
1882 patch.length1 += len(postcontext)
1883 patch.length2 += len(postcontext)
1883 patch.length2 += len(postcontext)
1884 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1884 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1885 patch.diffs[-1] = (
1885 patch.diffs[-1] = (
1886 self.DIFF_EQUAL,
1886 self.DIFF_EQUAL,
1887 patch.diffs[-1][1] + postcontext,
1887 patch.diffs[-1][1] + postcontext,
1888 )
1888 )
1889 else:
1889 else:
1890 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1890 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1891
1891
1892 if not empty:
1892 if not empty:
1893 x += 1
1893 x += 1
1894 patches.insert(x, patch)
1894 patches.insert(x, patch)
1895
1895
1896 def patch_toText(self, patches):
1896 def patch_toText(self, patches):
1897 """Take a list of patches and return a textual representation.
1897 """Take a list of patches and return a textual representation.
1898
1898
1899 Args:
1899 Args:
1900 patches: Array of Patch objects.
1900 patches: Array of Patch objects.
1901
1901
1902 Returns:
1902 Returns:
1903 Text representation of patches.
1903 Text representation of patches.
1904 """
1904 """
1905 text = []
1905 text = []
1906 for patch in patches:
1906 for patch in patches:
1907 text.append(str(patch))
1907 text.append(str(patch))
1908 return "".join(text)
1908 return "".join(text)
1909
1909
1910 def patch_fromText(self, textline):
1910 def patch_fromText(self, textline):
1911 """Parse a textual representation of patches and return a list of patch
1911 """Parse a textual representation of patches and return a list of patch
1912 objects.
1912 objects.
1913
1913
1914 Args:
1914 Args:
1915 textline: Text representation of patches.
1915 textline: Text representation of patches.
1916
1916
1917 Returns:
1917 Returns:
1918 Array of Patch objects.
1918 Array of Patch objects.
1919
1919
1920 Raises:
1920 Raises:
1921 ValueError: If invalid input.
1921 ValueError: If invalid input.
1922 """
1922 """
1923 if type(textline) == unicode:
1923 if type(textline) == str:
1924 # Patches should be composed of a subset of ascii chars, Unicode not
1924 # Patches should be composed of a subset of ascii chars, Unicode not
1925 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1925 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1926 textline = textline.encode("ascii")
1926 textline = textline.encode("ascii")
1927 patches = []
1927 patches = []
1928 if not textline:
1928 if not textline:
1929 return patches
1929 return patches
1930 text = textline.split("\n")
1930 text = textline.split("\n")
1931 while len(text) != 0:
1931 while len(text) != 0:
1932 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1932 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1933 if not m:
1933 if not m:
1934 raise ValueError("Invalid patch string: " + text[0])
1934 raise ValueError("Invalid patch string: " + text[0])
1935 patch = patch_obj()
1935 patch = patch_obj()
1936 patches.append(patch)
1936 patches.append(patch)
1937 patch.start1 = int(m.group(1))
1937 patch.start1 = int(m.group(1))
1938 if m.group(2) == "":
1938 if m.group(2) == "":
1939 patch.start1 -= 1
1939 patch.start1 -= 1
1940 patch.length1 = 1
1940 patch.length1 = 1
1941 elif m.group(2) == "0":
1941 elif m.group(2) == "0":
1942 patch.length1 = 0
1942 patch.length1 = 0
1943 else:
1943 else:
1944 patch.start1 -= 1
1944 patch.start1 -= 1
1945 patch.length1 = int(m.group(2))
1945 patch.length1 = int(m.group(2))
1946
1946
1947 patch.start2 = int(m.group(3))
1947 patch.start2 = int(m.group(3))
1948 if m.group(4) == "":
1948 if m.group(4) == "":
1949 patch.start2 -= 1
1949 patch.start2 -= 1
1950 patch.length2 = 1
1950 patch.length2 = 1
1951 elif m.group(4) == "0":
1951 elif m.group(4) == "0":
1952 patch.length2 = 0
1952 patch.length2 = 0
1953 else:
1953 else:
1954 patch.start2 -= 1
1954 patch.start2 -= 1
1955 patch.length2 = int(m.group(4))
1955 patch.length2 = int(m.group(4))
1956
1956
1957 del text[0]
1957 del text[0]
1958
1958
1959 while len(text) != 0:
1959 while len(text) != 0:
1960 if text[0]:
1960 if text[0]:
1961 sign = text[0][0]
1961 sign = text[0][0]
1962 else:
1962 else:
1963 sign = ""
1963 sign = ""
1964 line = urllib.parse.unquote(text[0][1:])
1964 line = urllib.parse.unquote(text[0][1:])
1965 line = line.decode("utf-8")
1965 line = line.decode("utf-8")
1966 if sign == "+":
1966 if sign == "+":
1967 # Insertion.
1967 # Insertion.
1968 patch.diffs.append((self.DIFF_INSERT, line))
1968 patch.diffs.append((self.DIFF_INSERT, line))
1969 elif sign == "-":
1969 elif sign == "-":
1970 # Deletion.
1970 # Deletion.
1971 patch.diffs.append((self.DIFF_DELETE, line))
1971 patch.diffs.append((self.DIFF_DELETE, line))
1972 elif sign == " ":
1972 elif sign == " ":
1973 # Minor equality.
1973 # Minor equality.
1974 patch.diffs.append((self.DIFF_EQUAL, line))
1974 patch.diffs.append((self.DIFF_EQUAL, line))
1975 elif sign == "@":
1975 elif sign == "@":
1976 # Start of next patch.
1976 # Start of next patch.
1977 break
1977 break
1978 elif sign == "":
1978 elif sign == "":
1979 # Blank line? Whatever.
1979 # Blank line? Whatever.
1980 pass
1980 pass
1981 else:
1981 else:
1982 # WTF?
1982 # WTF?
1983 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1983 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1984 del text[0]
1984 del text[0]
1985 return patches
1985 return patches
1986
1986
1987
1987
1988 class patch_obj:
1988 class patch_obj:
1989 """Class representing one patch operation."""
1989 """Class representing one patch operation."""
1990
1990
1991 def __init__(self):
1991 def __init__(self):
1992 """Initializes with an empty list of diffs."""
1992 """Initializes with an empty list of diffs."""
1993 self.diffs = []
1993 self.diffs = []
1994 self.start1 = None
1994 self.start1 = None
1995 self.start2 = None
1995 self.start2 = None
1996 self.length1 = 0
1996 self.length1 = 0
1997 self.length2 = 0
1997 self.length2 = 0
1998
1998
1999 def __str__(self):
1999 def __str__(self):
2000 """Emmulate GNU diff's format.
2000 """Emmulate GNU diff's format.
2001 Header: @@ -382,8 +481,9 @@
2001 Header: @@ -382,8 +481,9 @@
2002 Indicies are printed as 1-based, not 0-based.
2002 Indicies are printed as 1-based, not 0-based.
2003
2003
2004 Returns:
2004 Returns:
2005 The GNU diff string.
2005 The GNU diff string.
2006 """
2006 """
2007 if self.length1 == 0:
2007 if self.length1 == 0:
2008 coords1 = str(self.start1) + ",0"
2008 coords1 = str(self.start1) + ",0"
2009 elif self.length1 == 1:
2009 elif self.length1 == 1:
2010 coords1 = str(self.start1 + 1)
2010 coords1 = str(self.start1 + 1)
2011 else:
2011 else:
2012 coords1 = str(self.start1 + 1) + "," + str(self.length1)
2012 coords1 = str(self.start1 + 1) + "," + str(self.length1)
2013 if self.length2 == 0:
2013 if self.length2 == 0:
2014 coords2 = str(self.start2) + ",0"
2014 coords2 = str(self.start2) + ",0"
2015 elif self.length2 == 1:
2015 elif self.length2 == 1:
2016 coords2 = str(self.start2 + 1)
2016 coords2 = str(self.start2 + 1)
2017 else:
2017 else:
2018 coords2 = str(self.start2 + 1) + "," + str(self.length2)
2018 coords2 = str(self.start2 + 1) + "," + str(self.length2)
2019 text = ["@@ -", coords1, " +", coords2, " @@\n"]
2019 text = ["@@ -", coords1, " +", coords2, " @@\n"]
2020 # Escape the body of the patch with %xx notation.
2020 # Escape the body of the patch with %xx notation.
2021 for op, data in self.diffs:
2021 for op, data in self.diffs:
2022 if op == diff_match_patch.DIFF_INSERT:
2022 if op == diff_match_patch.DIFF_INSERT:
2023 text.append("+")
2023 text.append("+")
2024 elif op == diff_match_patch.DIFF_DELETE:
2024 elif op == diff_match_patch.DIFF_DELETE:
2025 text.append("-")
2025 text.append("-")
2026 elif op == diff_match_patch.DIFF_EQUAL:
2026 elif op == diff_match_patch.DIFF_EQUAL:
2027 text.append(" ")
2027 text.append(" ")
2028 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
2028 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
2029 data = data.encode("utf-8")
2029 data = data.encode("utf-8")
2030 text.append(urllib.parse.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
2030 text.append(urllib.parse.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
2031 return "".join(text)
2031 return "".join(text)
@@ -1,1272 +1,1271 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Set of diffing helpers, previously part of vcs
23 Set of diffing helpers, previously part of vcs
24 """
24 """
25
25
26 import os
26 import os
27 import re
27 import re
28 import bz2
28 import bz2
29 import gzip
29 import gzip
30 import time
30 import time
31
31
32 import collections
32 import collections
33 import difflib
33 import difflib
34 import logging
34 import logging
35 import pickle
35 import pickle
36 from itertools import tee
36 from itertools import tee
37
37
38 from rhodecode.lib.vcs.exceptions import VCSError
38 from rhodecode.lib.vcs.exceptions import VCSError
39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
40 from rhodecode.lib.utils2 import safe_unicode, safe_str
40 from rhodecode.lib.utils2 import safe_unicode, safe_str
41
41
42 log = logging.getLogger(__name__)
42 log = logging.getLogger(__name__)
43
43
44 # define max context, a file with more than this numbers of lines is unusable
44 # define max context, a file with more than this numbers of lines is unusable
45 # in browser anyway
45 # in browser anyway
46 MAX_CONTEXT = 20 * 1024
46 MAX_CONTEXT = 20 * 1024
47 DEFAULT_CONTEXT = 3
47 DEFAULT_CONTEXT = 3
48
48
49
49
50 def get_diff_context(request):
50 def get_diff_context(request):
51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
52
52
53
53
54 def get_diff_whitespace_flag(request):
54 def get_diff_whitespace_flag(request):
55 return request.GET.get('ignorews', '') == '1'
55 return request.GET.get('ignorews', '') == '1'
56
56
57
57
58 class OPS(object):
58 class OPS(object):
59 ADD = 'A'
59 ADD = 'A'
60 MOD = 'M'
60 MOD = 'M'
61 DEL = 'D'
61 DEL = 'D'
62
62
63
63
64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
65 """
65 """
66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
67
67
68 :param ignore_whitespace: ignore whitespaces in diff
68 :param ignore_whitespace: ignore whitespaces in diff
69 """
69 """
70 # make sure we pass in default context
70 # make sure we pass in default context
71 context = context or 3
71 context = context or 3
72 # protect against IntOverflow when passing HUGE context
72 # protect against IntOverflow when passing HUGE context
73 if context > MAX_CONTEXT:
73 if context > MAX_CONTEXT:
74 context = MAX_CONTEXT
74 context = MAX_CONTEXT
75
75
76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
76 submodules = [o for o in [filenode_new, filenode_old] if isinstance(o, SubModuleNode)]
77 [filenode_new, filenode_old])
78 if submodules:
77 if submodules:
79 return ''
78 return ''
80
79
81 for filenode in (filenode_old, filenode_new):
80 for filenode in (filenode_old, filenode_new):
82 if not isinstance(filenode, FileNode):
81 if not isinstance(filenode, FileNode):
83 raise VCSError(
82 raise VCSError(
84 "Given object should be FileNode object, not %s"
83 "Given object should be FileNode object, not %s"
85 % filenode.__class__)
84 % filenode.__class__)
86
85
87 repo = filenode_new.commit.repository
86 repo = filenode_new.commit.repository
88 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
87 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
89 new_commit = filenode_new.commit
88 new_commit = filenode_new.commit
90
89
91 vcs_gitdiff = repo.get_diff(
90 vcs_gitdiff = repo.get_diff(
92 old_commit, new_commit, filenode_new.path,
91 old_commit, new_commit, filenode_new.path,
93 ignore_whitespace, context, path1=filenode_old.path)
92 ignore_whitespace, context, path1=filenode_old.path)
94 return vcs_gitdiff
93 return vcs_gitdiff
95
94
96 NEW_FILENODE = 1
95 NEW_FILENODE = 1
97 DEL_FILENODE = 2
96 DEL_FILENODE = 2
98 MOD_FILENODE = 3
97 MOD_FILENODE = 3
99 RENAMED_FILENODE = 4
98 RENAMED_FILENODE = 4
100 COPIED_FILENODE = 5
99 COPIED_FILENODE = 5
101 CHMOD_FILENODE = 6
100 CHMOD_FILENODE = 6
102 BIN_FILENODE = 7
101 BIN_FILENODE = 7
103
102
104
103
105 class LimitedDiffContainer(object):
104 class LimitedDiffContainer(object):
106
105
107 def __init__(self, diff_limit, cur_diff_size, diff):
106 def __init__(self, diff_limit, cur_diff_size, diff):
108 self.diff = diff
107 self.diff = diff
109 self.diff_limit = diff_limit
108 self.diff_limit = diff_limit
110 self.cur_diff_size = cur_diff_size
109 self.cur_diff_size = cur_diff_size
111
110
112 def __getitem__(self, key):
111 def __getitem__(self, key):
113 return self.diff.__getitem__(key)
112 return self.diff.__getitem__(key)
114
113
115 def __iter__(self):
114 def __iter__(self):
116 for l in self.diff:
115 for l in self.diff:
117 yield l
116 yield l
118
117
119
118
120 class Action(object):
119 class Action(object):
121 """
120 """
122 Contains constants for the action value of the lines in a parsed diff.
121 Contains constants for the action value of the lines in a parsed diff.
123 """
122 """
124
123
125 ADD = 'add'
124 ADD = 'add'
126 DELETE = 'del'
125 DELETE = 'del'
127 UNMODIFIED = 'unmod'
126 UNMODIFIED = 'unmod'
128
127
129 CONTEXT = 'context'
128 CONTEXT = 'context'
130 OLD_NO_NL = 'old-no-nl'
129 OLD_NO_NL = 'old-no-nl'
131 NEW_NO_NL = 'new-no-nl'
130 NEW_NO_NL = 'new-no-nl'
132
131
133
132
134 class DiffProcessor(object):
133 class DiffProcessor(object):
135 """
134 """
136 Give it a unified or git diff and it returns a list of the files that were
135 Give it a unified or git diff and it returns a list of the files that were
137 mentioned in the diff together with a dict of meta information that
136 mentioned in the diff together with a dict of meta information that
138 can be used to render it in a HTML template.
137 can be used to render it in a HTML template.
139
138
140 .. note:: Unicode handling
139 .. note:: Unicode handling
141
140
142 The original diffs are a byte sequence and can contain filenames
141 The original diffs are a byte sequence and can contain filenames
143 in mixed encodings. This class generally returns `unicode` objects
142 in mixed encodings. This class generally returns `unicode` objects
144 since the result is intended for presentation to the user.
143 since the result is intended for presentation to the user.
145
144
146 """
145 """
147 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
146 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
148 _newline_marker = re.compile(r'^\\ No newline at end of file')
147 _newline_marker = re.compile(r'^\\ No newline at end of file')
149
148
150 # used for inline highlighter word split
149 # used for inline highlighter word split
151 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
150 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
152
151
153 # collapse ranges of commits over given number
152 # collapse ranges of commits over given number
154 _collapse_commits_over = 5
153 _collapse_commits_over = 5
155
154
156 def __init__(self, diff, format='gitdiff', diff_limit=None,
155 def __init__(self, diff, format='gitdiff', diff_limit=None,
157 file_limit=None, show_full_diff=True):
156 file_limit=None, show_full_diff=True):
158 """
157 """
159 :param diff: A `Diff` object representing a diff from a vcs backend
158 :param diff: A `Diff` object representing a diff from a vcs backend
160 :param format: format of diff passed, `udiff` or `gitdiff`
159 :param format: format of diff passed, `udiff` or `gitdiff`
161 :param diff_limit: define the size of diff that is considered "big"
160 :param diff_limit: define the size of diff that is considered "big"
162 based on that parameter cut off will be triggered, set to None
161 based on that parameter cut off will be triggered, set to None
163 to show full diff
162 to show full diff
164 """
163 """
165 self._diff = diff
164 self._diff = diff
166 self._format = format
165 self._format = format
167 self.adds = 0
166 self.adds = 0
168 self.removes = 0
167 self.removes = 0
169 # calculate diff size
168 # calculate diff size
170 self.diff_limit = diff_limit
169 self.diff_limit = diff_limit
171 self.file_limit = file_limit
170 self.file_limit = file_limit
172 self.show_full_diff = show_full_diff
171 self.show_full_diff = show_full_diff
173 self.cur_diff_size = 0
172 self.cur_diff_size = 0
174 self.parsed = False
173 self.parsed = False
175 self.parsed_diff = []
174 self.parsed_diff = []
176
175
177 log.debug('Initialized DiffProcessor with %s mode', format)
176 log.debug('Initialized DiffProcessor with %s mode', format)
178 if format == 'gitdiff':
177 if format == 'gitdiff':
179 self.differ = self._highlight_line_difflib
178 self.differ = self._highlight_line_difflib
180 self._parser = self._parse_gitdiff
179 self._parser = self._parse_gitdiff
181 else:
180 else:
182 self.differ = self._highlight_line_udiff
181 self.differ = self._highlight_line_udiff
183 self._parser = self._new_parse_gitdiff
182 self._parser = self._new_parse_gitdiff
184
183
185 def _copy_iterator(self):
184 def _copy_iterator(self):
186 """
185 """
187 make a fresh copy of generator, we should not iterate thru
186 make a fresh copy of generator, we should not iterate thru
188 an original as it's needed for repeating operations on
187 an original as it's needed for repeating operations on
189 this instance of DiffProcessor
188 this instance of DiffProcessor
190 """
189 """
191 self.__udiff, iterator_copy = tee(self.__udiff)
190 self.__udiff, iterator_copy = tee(self.__udiff)
192 return iterator_copy
191 return iterator_copy
193
192
194 def _escaper(self, string):
193 def _escaper(self, string):
195 """
194 """
196 Escaper for diff escapes special chars and checks the diff limit
195 Escaper for diff escapes special chars and checks the diff limit
197
196
198 :param string:
197 :param string:
199 """
198 """
200 self.cur_diff_size += len(string)
199 self.cur_diff_size += len(string)
201
200
202 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
201 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
203 raise DiffLimitExceeded('Diff Limit Exceeded')
202 raise DiffLimitExceeded('Diff Limit Exceeded')
204
203
205 return string \
204 return string \
206 .replace('&', '&amp;')\
205 .replace('&', '&amp;')\
207 .replace('<', '&lt;')\
206 .replace('<', '&lt;')\
208 .replace('>', '&gt;')
207 .replace('>', '&gt;')
209
208
210 def _line_counter(self, l):
209 def _line_counter(self, l):
211 """
210 """
212 Checks each line and bumps total adds/removes for this diff
211 Checks each line and bumps total adds/removes for this diff
213
212
214 :param l:
213 :param l:
215 """
214 """
216 if l.startswith('+') and not l.startswith('+++'):
215 if l.startswith('+') and not l.startswith('+++'):
217 self.adds += 1
216 self.adds += 1
218 elif l.startswith('-') and not l.startswith('---'):
217 elif l.startswith('-') and not l.startswith('---'):
219 self.removes += 1
218 self.removes += 1
220 return safe_unicode(l)
219 return safe_unicode(l)
221
220
222 def _highlight_line_difflib(self, line, next_):
221 def _highlight_line_difflib(self, line, next_):
223 """
222 """
224 Highlight inline changes in both lines.
223 Highlight inline changes in both lines.
225 """
224 """
226
225
227 if line['action'] == Action.DELETE:
226 if line['action'] == Action.DELETE:
228 old, new = line, next_
227 old, new = line, next_
229 else:
228 else:
230 old, new = next_, line
229 old, new = next_, line
231
230
232 oldwords = self._token_re.split(old['line'])
231 oldwords = self._token_re.split(old['line'])
233 newwords = self._token_re.split(new['line'])
232 newwords = self._token_re.split(new['line'])
234 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
233 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
235
234
236 oldfragments, newfragments = [], []
235 oldfragments, newfragments = [], []
237 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
236 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
238 oldfrag = ''.join(oldwords[i1:i2])
237 oldfrag = ''.join(oldwords[i1:i2])
239 newfrag = ''.join(newwords[j1:j2])
238 newfrag = ''.join(newwords[j1:j2])
240 if tag != 'equal':
239 if tag != 'equal':
241 if oldfrag:
240 if oldfrag:
242 oldfrag = '<del>%s</del>' % oldfrag
241 oldfrag = '<del>%s</del>' % oldfrag
243 if newfrag:
242 if newfrag:
244 newfrag = '<ins>%s</ins>' % newfrag
243 newfrag = '<ins>%s</ins>' % newfrag
245 oldfragments.append(oldfrag)
244 oldfragments.append(oldfrag)
246 newfragments.append(newfrag)
245 newfragments.append(newfrag)
247
246
248 old['line'] = "".join(oldfragments)
247 old['line'] = "".join(oldfragments)
249 new['line'] = "".join(newfragments)
248 new['line'] = "".join(newfragments)
250
249
251 def _highlight_line_udiff(self, line, next_):
250 def _highlight_line_udiff(self, line, next_):
252 """
251 """
253 Highlight inline changes in both lines.
252 Highlight inline changes in both lines.
254 """
253 """
255 start = 0
254 start = 0
256 limit = min(len(line['line']), len(next_['line']))
255 limit = min(len(line['line']), len(next_['line']))
257 while start < limit and line['line'][start] == next_['line'][start]:
256 while start < limit and line['line'][start] == next_['line'][start]:
258 start += 1
257 start += 1
259 end = -1
258 end = -1
260 limit -= start
259 limit -= start
261 while -end <= limit and line['line'][end] == next_['line'][end]:
260 while -end <= limit and line['line'][end] == next_['line'][end]:
262 end -= 1
261 end -= 1
263 end += 1
262 end += 1
264 if start or end:
263 if start or end:
265 def do(l):
264 def do(l):
266 last = end + len(l['line'])
265 last = end + len(l['line'])
267 if l['action'] == Action.ADD:
266 if l['action'] == Action.ADD:
268 tag = 'ins'
267 tag = 'ins'
269 else:
268 else:
270 tag = 'del'
269 tag = 'del'
271 l['line'] = '%s<%s>%s</%s>%s' % (
270 l['line'] = '%s<%s>%s</%s>%s' % (
272 l['line'][:start],
271 l['line'][:start],
273 tag,
272 tag,
274 l['line'][start:last],
273 l['line'][start:last],
275 tag,
274 tag,
276 l['line'][last:]
275 l['line'][last:]
277 )
276 )
278 do(line)
277 do(line)
279 do(next_)
278 do(next_)
280
279
281 def _clean_line(self, line, command):
280 def _clean_line(self, line, command):
282 if command in ['+', '-', ' ']:
281 if command in ['+', '-', ' ']:
283 # only modify the line if it's actually a diff thing
282 # only modify the line if it's actually a diff thing
284 line = line[1:]
283 line = line[1:]
285 return line
284 return line
286
285
287 def _parse_gitdiff(self, inline_diff=True):
286 def _parse_gitdiff(self, inline_diff=True):
288 _files = []
287 _files = []
289 diff_container = lambda arg: arg
288 diff_container = lambda arg: arg
290
289
291 for chunk in self._diff.chunks():
290 for chunk in self._diff.chunks():
292 head = chunk.header
291 head = chunk.header
293
292
294 diff = map(self._escaper, self.diff_splitter(chunk.diff))
293 diff = map(self._escaper, self.diff_splitter(chunk.diff))
295 raw_diff = chunk.raw
294 raw_diff = chunk.raw
296 limited_diff = False
295 limited_diff = False
297 exceeds_limit = False
296 exceeds_limit = False
298
297
299 op = None
298 op = None
300 stats = {
299 stats = {
301 'added': 0,
300 'added': 0,
302 'deleted': 0,
301 'deleted': 0,
303 'binary': False,
302 'binary': False,
304 'ops': {},
303 'ops': {},
305 }
304 }
306
305
307 if head['deleted_file_mode']:
306 if head['deleted_file_mode']:
308 op = OPS.DEL
307 op = OPS.DEL
309 stats['binary'] = True
308 stats['binary'] = True
310 stats['ops'][DEL_FILENODE] = 'deleted file'
309 stats['ops'][DEL_FILENODE] = 'deleted file'
311
310
312 elif head['new_file_mode']:
311 elif head['new_file_mode']:
313 op = OPS.ADD
312 op = OPS.ADD
314 stats['binary'] = True
313 stats['binary'] = True
315 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
314 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
316 else: # modify operation, can be copy, rename or chmod
315 else: # modify operation, can be copy, rename or chmod
317
316
318 # CHMOD
317 # CHMOD
319 if head['new_mode'] and head['old_mode']:
318 if head['new_mode'] and head['old_mode']:
320 op = OPS.MOD
319 op = OPS.MOD
321 stats['binary'] = True
320 stats['binary'] = True
322 stats['ops'][CHMOD_FILENODE] = (
321 stats['ops'][CHMOD_FILENODE] = (
323 'modified file chmod %s => %s' % (
322 'modified file chmod %s => %s' % (
324 head['old_mode'], head['new_mode']))
323 head['old_mode'], head['new_mode']))
325 # RENAME
324 # RENAME
326 if head['rename_from'] != head['rename_to']:
325 if head['rename_from'] != head['rename_to']:
327 op = OPS.MOD
326 op = OPS.MOD
328 stats['binary'] = True
327 stats['binary'] = True
329 stats['ops'][RENAMED_FILENODE] = (
328 stats['ops'][RENAMED_FILENODE] = (
330 'file renamed from %s to %s' % (
329 'file renamed from %s to %s' % (
331 head['rename_from'], head['rename_to']))
330 head['rename_from'], head['rename_to']))
332 # COPY
331 # COPY
333 if head.get('copy_from') and head.get('copy_to'):
332 if head.get('copy_from') and head.get('copy_to'):
334 op = OPS.MOD
333 op = OPS.MOD
335 stats['binary'] = True
334 stats['binary'] = True
336 stats['ops'][COPIED_FILENODE] = (
335 stats['ops'][COPIED_FILENODE] = (
337 'file copied from %s to %s' % (
336 'file copied from %s to %s' % (
338 head['copy_from'], head['copy_to']))
337 head['copy_from'], head['copy_to']))
339
338
340 # If our new parsed headers didn't match anything fallback to
339 # If our new parsed headers didn't match anything fallback to
341 # old style detection
340 # old style detection
342 if op is None:
341 if op is None:
343 if not head['a_file'] and head['b_file']:
342 if not head['a_file'] and head['b_file']:
344 op = OPS.ADD
343 op = OPS.ADD
345 stats['binary'] = True
344 stats['binary'] = True
346 stats['ops'][NEW_FILENODE] = 'new file'
345 stats['ops'][NEW_FILENODE] = 'new file'
347
346
348 elif head['a_file'] and not head['b_file']:
347 elif head['a_file'] and not head['b_file']:
349 op = OPS.DEL
348 op = OPS.DEL
350 stats['binary'] = True
349 stats['binary'] = True
351 stats['ops'][DEL_FILENODE] = 'deleted file'
350 stats['ops'][DEL_FILENODE] = 'deleted file'
352
351
353 # it's not ADD not DELETE
352 # it's not ADD not DELETE
354 if op is None:
353 if op is None:
355 op = OPS.MOD
354 op = OPS.MOD
356 stats['binary'] = True
355 stats['binary'] = True
357 stats['ops'][MOD_FILENODE] = 'modified file'
356 stats['ops'][MOD_FILENODE] = 'modified file'
358
357
359 # a real non-binary diff
358 # a real non-binary diff
360 if head['a_file'] or head['b_file']:
359 if head['a_file'] or head['b_file']:
361 try:
360 try:
362 raw_diff, chunks, _stats = self._parse_lines(diff)
361 raw_diff, chunks, _stats = self._parse_lines(diff)
363 stats['binary'] = False
362 stats['binary'] = False
364 stats['added'] = _stats[0]
363 stats['added'] = _stats[0]
365 stats['deleted'] = _stats[1]
364 stats['deleted'] = _stats[1]
366 # explicit mark that it's a modified file
365 # explicit mark that it's a modified file
367 if op == OPS.MOD:
366 if op == OPS.MOD:
368 stats['ops'][MOD_FILENODE] = 'modified file'
367 stats['ops'][MOD_FILENODE] = 'modified file'
369 exceeds_limit = len(raw_diff) > self.file_limit
368 exceeds_limit = len(raw_diff) > self.file_limit
370
369
371 # changed from _escaper function so we validate size of
370 # changed from _escaper function so we validate size of
372 # each file instead of the whole diff
371 # each file instead of the whole diff
373 # diff will hide big files but still show small ones
372 # diff will hide big files but still show small ones
374 # from my tests, big files are fairly safe to be parsed
373 # from my tests, big files are fairly safe to be parsed
375 # but the browser is the bottleneck
374 # but the browser is the bottleneck
376 if not self.show_full_diff and exceeds_limit:
375 if not self.show_full_diff and exceeds_limit:
377 raise DiffLimitExceeded('File Limit Exceeded')
376 raise DiffLimitExceeded('File Limit Exceeded')
378
377
379 except DiffLimitExceeded:
378 except DiffLimitExceeded:
380 diff_container = lambda _diff: \
379 diff_container = lambda _diff: \
381 LimitedDiffContainer(
380 LimitedDiffContainer(
382 self.diff_limit, self.cur_diff_size, _diff)
381 self.diff_limit, self.cur_diff_size, _diff)
383
382
384 exceeds_limit = len(raw_diff) > self.file_limit
383 exceeds_limit = len(raw_diff) > self.file_limit
385 limited_diff = True
384 limited_diff = True
386 chunks = []
385 chunks = []
387
386
388 else: # GIT format binary patch, or possibly empty diff
387 else: # GIT format binary patch, or possibly empty diff
389 if head['bin_patch']:
388 if head['bin_patch']:
390 # we have operation already extracted, but we mark simply
389 # we have operation already extracted, but we mark simply
391 # it's a diff we wont show for binary files
390 # it's a diff we wont show for binary files
392 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
391 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
393 chunks = []
392 chunks = []
394
393
395 if chunks and not self.show_full_diff and op == OPS.DEL:
394 if chunks and not self.show_full_diff and op == OPS.DEL:
396 # if not full diff mode show deleted file contents
395 # if not full diff mode show deleted file contents
397 # TODO: anderson: if the view is not too big, there is no way
396 # TODO: anderson: if the view is not too big, there is no way
398 # to see the content of the file
397 # to see the content of the file
399 chunks = []
398 chunks = []
400
399
401 chunks.insert(0, [{
400 chunks.insert(0, [{
402 'old_lineno': '',
401 'old_lineno': '',
403 'new_lineno': '',
402 'new_lineno': '',
404 'action': Action.CONTEXT,
403 'action': Action.CONTEXT,
405 'line': msg,
404 'line': msg,
406 } for _op, msg in stats['ops'].items()
405 } for _op, msg in stats['ops'].items()
407 if _op not in [MOD_FILENODE]])
406 if _op not in [MOD_FILENODE]])
408
407
409 _files.append({
408 _files.append({
410 'filename': safe_unicode(head['b_path']),
409 'filename': safe_unicode(head['b_path']),
411 'old_revision': head['a_blob_id'],
410 'old_revision': head['a_blob_id'],
412 'new_revision': head['b_blob_id'],
411 'new_revision': head['b_blob_id'],
413 'chunks': chunks,
412 'chunks': chunks,
414 'raw_diff': safe_unicode(raw_diff),
413 'raw_diff': safe_unicode(raw_diff),
415 'operation': op,
414 'operation': op,
416 'stats': stats,
415 'stats': stats,
417 'exceeds_limit': exceeds_limit,
416 'exceeds_limit': exceeds_limit,
418 'is_limited_diff': limited_diff,
417 'is_limited_diff': limited_diff,
419 })
418 })
420
419
421 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
420 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
422 OPS.DEL: 2}.get(info['operation'])
421 OPS.DEL: 2}.get(info['operation'])
423
422
424 if not inline_diff:
423 if not inline_diff:
425 return diff_container(sorted(_files, key=sorter))
424 return diff_container(sorted(_files, key=sorter))
426
425
427 # highlight inline changes
426 # highlight inline changes
428 for diff_data in _files:
427 for diff_data in _files:
429 for chunk in diff_data['chunks']:
428 for chunk in diff_data['chunks']:
430 lineiter = iter(chunk)
429 lineiter = iter(chunk)
431 try:
430 try:
432 while 1:
431 while 1:
433 line = next(lineiter)
432 line = next(lineiter)
434 if line['action'] not in (
433 if line['action'] not in (
435 Action.UNMODIFIED, Action.CONTEXT):
434 Action.UNMODIFIED, Action.CONTEXT):
436 nextline = next(lineiter)
435 nextline = next(lineiter)
437 if nextline['action'] in ['unmod', 'context'] or \
436 if nextline['action'] in ['unmod', 'context'] or \
438 nextline['action'] == line['action']:
437 nextline['action'] == line['action']:
439 continue
438 continue
440 self.differ(line, nextline)
439 self.differ(line, nextline)
441 except StopIteration:
440 except StopIteration:
442 pass
441 pass
443
442
444 return diff_container(sorted(_files, key=sorter))
443 return diff_container(sorted(_files, key=sorter))
445
444
446 def _check_large_diff(self):
445 def _check_large_diff(self):
447 if self.diff_limit:
446 if self.diff_limit:
448 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
447 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
449 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
448 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
450 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
449 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
451
450
452 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
451 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
453 def _new_parse_gitdiff(self, inline_diff=True):
452 def _new_parse_gitdiff(self, inline_diff=True):
454 _files = []
453 _files = []
455
454
456 # this can be overriden later to a LimitedDiffContainer type
455 # this can be overriden later to a LimitedDiffContainer type
457 diff_container = lambda arg: arg
456 diff_container = lambda arg: arg
458
457
459 for chunk in self._diff.chunks():
458 for chunk in self._diff.chunks():
460 head = chunk.header
459 head = chunk.header
461 log.debug('parsing diff %r', head)
460 log.debug('parsing diff %r', head)
462
461
463 raw_diff = chunk.raw
462 raw_diff = chunk.raw
464 limited_diff = False
463 limited_diff = False
465 exceeds_limit = False
464 exceeds_limit = False
466
465
467 op = None
466 op = None
468 stats = {
467 stats = {
469 'added': 0,
468 'added': 0,
470 'deleted': 0,
469 'deleted': 0,
471 'binary': False,
470 'binary': False,
472 'old_mode': None,
471 'old_mode': None,
473 'new_mode': None,
472 'new_mode': None,
474 'ops': {},
473 'ops': {},
475 }
474 }
476 if head['old_mode']:
475 if head['old_mode']:
477 stats['old_mode'] = head['old_mode']
476 stats['old_mode'] = head['old_mode']
478 if head['new_mode']:
477 if head['new_mode']:
479 stats['new_mode'] = head['new_mode']
478 stats['new_mode'] = head['new_mode']
480 if head['b_mode']:
479 if head['b_mode']:
481 stats['new_mode'] = head['b_mode']
480 stats['new_mode'] = head['b_mode']
482
481
483 # delete file
482 # delete file
484 if head['deleted_file_mode']:
483 if head['deleted_file_mode']:
485 op = OPS.DEL
484 op = OPS.DEL
486 stats['binary'] = True
485 stats['binary'] = True
487 stats['ops'][DEL_FILENODE] = 'deleted file'
486 stats['ops'][DEL_FILENODE] = 'deleted file'
488
487
489 # new file
488 # new file
490 elif head['new_file_mode']:
489 elif head['new_file_mode']:
491 op = OPS.ADD
490 op = OPS.ADD
492 stats['binary'] = True
491 stats['binary'] = True
493 stats['old_mode'] = None
492 stats['old_mode'] = None
494 stats['new_mode'] = head['new_file_mode']
493 stats['new_mode'] = head['new_file_mode']
495 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
494 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
496
495
497 # modify operation, can be copy, rename or chmod
496 # modify operation, can be copy, rename or chmod
498 else:
497 else:
499 # CHMOD
498 # CHMOD
500 if head['new_mode'] and head['old_mode']:
499 if head['new_mode'] and head['old_mode']:
501 op = OPS.MOD
500 op = OPS.MOD
502 stats['binary'] = True
501 stats['binary'] = True
503 stats['ops'][CHMOD_FILENODE] = (
502 stats['ops'][CHMOD_FILENODE] = (
504 'modified file chmod %s => %s' % (
503 'modified file chmod %s => %s' % (
505 head['old_mode'], head['new_mode']))
504 head['old_mode'], head['new_mode']))
506
505
507 # RENAME
506 # RENAME
508 if head['rename_from'] != head['rename_to']:
507 if head['rename_from'] != head['rename_to']:
509 op = OPS.MOD
508 op = OPS.MOD
510 stats['binary'] = True
509 stats['binary'] = True
511 stats['renamed'] = (head['rename_from'], head['rename_to'])
510 stats['renamed'] = (head['rename_from'], head['rename_to'])
512 stats['ops'][RENAMED_FILENODE] = (
511 stats['ops'][RENAMED_FILENODE] = (
513 'file renamed from %s to %s' % (
512 'file renamed from %s to %s' % (
514 head['rename_from'], head['rename_to']))
513 head['rename_from'], head['rename_to']))
515 # COPY
514 # COPY
516 if head.get('copy_from') and head.get('copy_to'):
515 if head.get('copy_from') and head.get('copy_to'):
517 op = OPS.MOD
516 op = OPS.MOD
518 stats['binary'] = True
517 stats['binary'] = True
519 stats['copied'] = (head['copy_from'], head['copy_to'])
518 stats['copied'] = (head['copy_from'], head['copy_to'])
520 stats['ops'][COPIED_FILENODE] = (
519 stats['ops'][COPIED_FILENODE] = (
521 'file copied from %s to %s' % (
520 'file copied from %s to %s' % (
522 head['copy_from'], head['copy_to']))
521 head['copy_from'], head['copy_to']))
523
522
524 # If our new parsed headers didn't match anything fallback to
523 # If our new parsed headers didn't match anything fallback to
525 # old style detection
524 # old style detection
526 if op is None:
525 if op is None:
527 if not head['a_file'] and head['b_file']:
526 if not head['a_file'] and head['b_file']:
528 op = OPS.ADD
527 op = OPS.ADD
529 stats['binary'] = True
528 stats['binary'] = True
530 stats['new_file'] = True
529 stats['new_file'] = True
531 stats['ops'][NEW_FILENODE] = 'new file'
530 stats['ops'][NEW_FILENODE] = 'new file'
532
531
533 elif head['a_file'] and not head['b_file']:
532 elif head['a_file'] and not head['b_file']:
534 op = OPS.DEL
533 op = OPS.DEL
535 stats['binary'] = True
534 stats['binary'] = True
536 stats['ops'][DEL_FILENODE] = 'deleted file'
535 stats['ops'][DEL_FILENODE] = 'deleted file'
537
536
538 # it's not ADD not DELETE
537 # it's not ADD not DELETE
539 if op is None:
538 if op is None:
540 op = OPS.MOD
539 op = OPS.MOD
541 stats['binary'] = True
540 stats['binary'] = True
542 stats['ops'][MOD_FILENODE] = 'modified file'
541 stats['ops'][MOD_FILENODE] = 'modified file'
543
542
544 # a real non-binary diff
543 # a real non-binary diff
545 if head['a_file'] or head['b_file']:
544 if head['a_file'] or head['b_file']:
546 # simulate splitlines, so we keep the line end part
545 # simulate splitlines, so we keep the line end part
547 diff = self.diff_splitter(chunk.diff)
546 diff = self.diff_splitter(chunk.diff)
548
547
549 # append each file to the diff size
548 # append each file to the diff size
550 raw_chunk_size = len(raw_diff)
549 raw_chunk_size = len(raw_diff)
551
550
552 exceeds_limit = raw_chunk_size > self.file_limit
551 exceeds_limit = raw_chunk_size > self.file_limit
553 self.cur_diff_size += raw_chunk_size
552 self.cur_diff_size += raw_chunk_size
554
553
555 try:
554 try:
556 # Check each file instead of the whole diff.
555 # Check each file instead of the whole diff.
557 # Diff will hide big files but still show small ones.
556 # Diff will hide big files but still show small ones.
558 # From the tests big files are fairly safe to be parsed
557 # From the tests big files are fairly safe to be parsed
559 # but the browser is the bottleneck.
558 # but the browser is the bottleneck.
560 if not self.show_full_diff and exceeds_limit:
559 if not self.show_full_diff and exceeds_limit:
561 log.debug('File `%s` exceeds current file_limit of %s',
560 log.debug('File `%s` exceeds current file_limit of %s',
562 safe_unicode(head['b_path']), self.file_limit)
561 safe_unicode(head['b_path']), self.file_limit)
563 raise DiffLimitExceeded(
562 raise DiffLimitExceeded(
564 'File Limit %s Exceeded', self.file_limit)
563 'File Limit %s Exceeded', self.file_limit)
565
564
566 self._check_large_diff()
565 self._check_large_diff()
567
566
568 raw_diff, chunks, _stats = self._new_parse_lines(diff)
567 raw_diff, chunks, _stats = self._new_parse_lines(diff)
569 stats['binary'] = False
568 stats['binary'] = False
570 stats['added'] = _stats[0]
569 stats['added'] = _stats[0]
571 stats['deleted'] = _stats[1]
570 stats['deleted'] = _stats[1]
572 # explicit mark that it's a modified file
571 # explicit mark that it's a modified file
573 if op == OPS.MOD:
572 if op == OPS.MOD:
574 stats['ops'][MOD_FILENODE] = 'modified file'
573 stats['ops'][MOD_FILENODE] = 'modified file'
575
574
576 except DiffLimitExceeded:
575 except DiffLimitExceeded:
577 diff_container = lambda _diff: \
576 diff_container = lambda _diff: \
578 LimitedDiffContainer(
577 LimitedDiffContainer(
579 self.diff_limit, self.cur_diff_size, _diff)
578 self.diff_limit, self.cur_diff_size, _diff)
580
579
581 limited_diff = True
580 limited_diff = True
582 chunks = []
581 chunks = []
583
582
584 else: # GIT format binary patch, or possibly empty diff
583 else: # GIT format binary patch, or possibly empty diff
585 if head['bin_patch']:
584 if head['bin_patch']:
586 # we have operation already extracted, but we mark simply
585 # we have operation already extracted, but we mark simply
587 # it's a diff we wont show for binary files
586 # it's a diff we wont show for binary files
588 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
587 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
589 chunks = []
588 chunks = []
590
589
591 # Hide content of deleted node by setting empty chunks
590 # Hide content of deleted node by setting empty chunks
592 if chunks and not self.show_full_diff and op == OPS.DEL:
591 if chunks and not self.show_full_diff and op == OPS.DEL:
593 # if not full diff mode show deleted file contents
592 # if not full diff mode show deleted file contents
594 # TODO: anderson: if the view is not too big, there is no way
593 # TODO: anderson: if the view is not too big, there is no way
595 # to see the content of the file
594 # to see the content of the file
596 chunks = []
595 chunks = []
597
596
598 chunks.insert(
597 chunks.insert(
599 0, [{'old_lineno': '',
598 0, [{'old_lineno': '',
600 'new_lineno': '',
599 'new_lineno': '',
601 'action': Action.CONTEXT,
600 'action': Action.CONTEXT,
602 'line': msg,
601 'line': msg,
603 } for _op, msg in stats['ops'].items()
602 } for _op, msg in stats['ops'].items()
604 if _op not in [MOD_FILENODE]])
603 if _op not in [MOD_FILENODE]])
605
604
606 original_filename = safe_unicode(head['a_path'])
605 original_filename = safe_unicode(head['a_path'])
607 _files.append({
606 _files.append({
608 'original_filename': original_filename,
607 'original_filename': original_filename,
609 'filename': safe_unicode(head['b_path']),
608 'filename': safe_unicode(head['b_path']),
610 'old_revision': head['a_blob_id'],
609 'old_revision': head['a_blob_id'],
611 'new_revision': head['b_blob_id'],
610 'new_revision': head['b_blob_id'],
612 'chunks': chunks,
611 'chunks': chunks,
613 'raw_diff': safe_unicode(raw_diff),
612 'raw_diff': safe_unicode(raw_diff),
614 'operation': op,
613 'operation': op,
615 'stats': stats,
614 'stats': stats,
616 'exceeds_limit': exceeds_limit,
615 'exceeds_limit': exceeds_limit,
617 'is_limited_diff': limited_diff,
616 'is_limited_diff': limited_diff,
618 })
617 })
619
618
620 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
619 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
621 OPS.DEL: 2}.get(info['operation'])
620 OPS.DEL: 2}.get(info['operation'])
622
621
623 return diff_container(sorted(_files, key=sorter))
622 return diff_container(sorted(_files, key=sorter))
624
623
625 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
624 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
626 def _parse_lines(self, diff_iter):
625 def _parse_lines(self, diff_iter):
627 """
626 """
628 Parse the diff an return data for the template.
627 Parse the diff an return data for the template.
629 """
628 """
630
629
631 stats = [0, 0]
630 stats = [0, 0]
632 chunks = []
631 chunks = []
633 raw_diff = []
632 raw_diff = []
634
633
635 try:
634 try:
636 line = next(diff_iter)
635 line = next(diff_iter)
637
636
638 while line:
637 while line:
639 raw_diff.append(line)
638 raw_diff.append(line)
640 lines = []
639 lines = []
641 chunks.append(lines)
640 chunks.append(lines)
642
641
643 match = self._chunk_re.match(line)
642 match = self._chunk_re.match(line)
644
643
645 if not match:
644 if not match:
646 break
645 break
647
646
648 gr = match.groups()
647 gr = match.groups()
649 (old_line, old_end,
648 (old_line, old_end,
650 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
649 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
651 old_line -= 1
650 old_line -= 1
652 new_line -= 1
651 new_line -= 1
653
652
654 context = len(gr) == 5
653 context = len(gr) == 5
655 old_end += old_line
654 old_end += old_line
656 new_end += new_line
655 new_end += new_line
657
656
658 if context:
657 if context:
659 # skip context only if it's first line
658 # skip context only if it's first line
660 if int(gr[0]) > 1:
659 if int(gr[0]) > 1:
661 lines.append({
660 lines.append({
662 'old_lineno': '...',
661 'old_lineno': '...',
663 'new_lineno': '...',
662 'new_lineno': '...',
664 'action': Action.CONTEXT,
663 'action': Action.CONTEXT,
665 'line': line,
664 'line': line,
666 })
665 })
667
666
668 line = next(diff_iter)
667 line = next(diff_iter)
669
668
670 while old_line < old_end or new_line < new_end:
669 while old_line < old_end or new_line < new_end:
671 command = ' '
670 command = ' '
672 if line:
671 if line:
673 command = line[0]
672 command = line[0]
674
673
675 affects_old = affects_new = False
674 affects_old = affects_new = False
676
675
677 # ignore those if we don't expect them
676 # ignore those if we don't expect them
678 if command in '#@':
677 if command in '#@':
679 continue
678 continue
680 elif command == '+':
679 elif command == '+':
681 affects_new = True
680 affects_new = True
682 action = Action.ADD
681 action = Action.ADD
683 stats[0] += 1
682 stats[0] += 1
684 elif command == '-':
683 elif command == '-':
685 affects_old = True
684 affects_old = True
686 action = Action.DELETE
685 action = Action.DELETE
687 stats[1] += 1
686 stats[1] += 1
688 else:
687 else:
689 affects_old = affects_new = True
688 affects_old = affects_new = True
690 action = Action.UNMODIFIED
689 action = Action.UNMODIFIED
691
690
692 if not self._newline_marker.match(line):
691 if not self._newline_marker.match(line):
693 old_line += affects_old
692 old_line += affects_old
694 new_line += affects_new
693 new_line += affects_new
695 lines.append({
694 lines.append({
696 'old_lineno': affects_old and old_line or '',
695 'old_lineno': affects_old and old_line or '',
697 'new_lineno': affects_new and new_line or '',
696 'new_lineno': affects_new and new_line or '',
698 'action': action,
697 'action': action,
699 'line': self._clean_line(line, command)
698 'line': self._clean_line(line, command)
700 })
699 })
701 raw_diff.append(line)
700 raw_diff.append(line)
702
701
703 line = next(diff_iter)
702 line = next(diff_iter)
704
703
705 if self._newline_marker.match(line):
704 if self._newline_marker.match(line):
706 # we need to append to lines, since this is not
705 # we need to append to lines, since this is not
707 # counted in the line specs of diff
706 # counted in the line specs of diff
708 lines.append({
707 lines.append({
709 'old_lineno': '...',
708 'old_lineno': '...',
710 'new_lineno': '...',
709 'new_lineno': '...',
711 'action': Action.CONTEXT,
710 'action': Action.CONTEXT,
712 'line': self._clean_line(line, command)
711 'line': self._clean_line(line, command)
713 })
712 })
714
713
715 except StopIteration:
714 except StopIteration:
716 pass
715 pass
717 return ''.join(raw_diff), chunks, stats
716 return ''.join(raw_diff), chunks, stats
718
717
719 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
718 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
720 def _new_parse_lines(self, diff_iter):
719 def _new_parse_lines(self, diff_iter):
721 """
720 """
722 Parse the diff an return data for the template.
721 Parse the diff an return data for the template.
723 """
722 """
724
723
725 stats = [0, 0]
724 stats = [0, 0]
726 chunks = []
725 chunks = []
727 raw_diff = []
726 raw_diff = []
728
727
729 try:
728 try:
730 line = next(diff_iter)
729 line = next(diff_iter)
731
730
732 while line:
731 while line:
733 raw_diff.append(line)
732 raw_diff.append(line)
734 # match header e.g @@ -0,0 +1 @@\n'
733 # match header e.g @@ -0,0 +1 @@\n'
735 match = self._chunk_re.match(line)
734 match = self._chunk_re.match(line)
736
735
737 if not match:
736 if not match:
738 break
737 break
739
738
740 gr = match.groups()
739 gr = match.groups()
741 (old_line, old_end,
740 (old_line, old_end,
742 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
741 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
743
742
744 lines = []
743 lines = []
745 hunk = {
744 hunk = {
746 'section_header': gr[-1],
745 'section_header': gr[-1],
747 'source_start': old_line,
746 'source_start': old_line,
748 'source_length': old_end,
747 'source_length': old_end,
749 'target_start': new_line,
748 'target_start': new_line,
750 'target_length': new_end,
749 'target_length': new_end,
751 'lines': lines,
750 'lines': lines,
752 }
751 }
753 chunks.append(hunk)
752 chunks.append(hunk)
754
753
755 old_line -= 1
754 old_line -= 1
756 new_line -= 1
755 new_line -= 1
757
756
758 context = len(gr) == 5
757 context = len(gr) == 5
759 old_end += old_line
758 old_end += old_line
760 new_end += new_line
759 new_end += new_line
761
760
762 line = next(diff_iter)
761 line = next(diff_iter)
763
762
764 while old_line < old_end or new_line < new_end:
763 while old_line < old_end or new_line < new_end:
765 command = ' '
764 command = ' '
766 if line:
765 if line:
767 command = line[0]
766 command = line[0]
768
767
769 affects_old = affects_new = False
768 affects_old = affects_new = False
770
769
771 # ignore those if we don't expect them
770 # ignore those if we don't expect them
772 if command in '#@':
771 if command in '#@':
773 continue
772 continue
774 elif command == '+':
773 elif command == '+':
775 affects_new = True
774 affects_new = True
776 action = Action.ADD
775 action = Action.ADD
777 stats[0] += 1
776 stats[0] += 1
778 elif command == '-':
777 elif command == '-':
779 affects_old = True
778 affects_old = True
780 action = Action.DELETE
779 action = Action.DELETE
781 stats[1] += 1
780 stats[1] += 1
782 else:
781 else:
783 affects_old = affects_new = True
782 affects_old = affects_new = True
784 action = Action.UNMODIFIED
783 action = Action.UNMODIFIED
785
784
786 if not self._newline_marker.match(line):
785 if not self._newline_marker.match(line):
787 old_line += affects_old
786 old_line += affects_old
788 new_line += affects_new
787 new_line += affects_new
789 lines.append({
788 lines.append({
790 'old_lineno': affects_old and old_line or '',
789 'old_lineno': affects_old and old_line or '',
791 'new_lineno': affects_new and new_line or '',
790 'new_lineno': affects_new and new_line or '',
792 'action': action,
791 'action': action,
793 'line': self._clean_line(line, command)
792 'line': self._clean_line(line, command)
794 })
793 })
795 raw_diff.append(line)
794 raw_diff.append(line)
796
795
797 line = next(diff_iter)
796 line = next(diff_iter)
798
797
799 if self._newline_marker.match(line):
798 if self._newline_marker.match(line):
800 # we need to append to lines, since this is not
799 # we need to append to lines, since this is not
801 # counted in the line specs of diff
800 # counted in the line specs of diff
802 if affects_old:
801 if affects_old:
803 action = Action.OLD_NO_NL
802 action = Action.OLD_NO_NL
804 elif affects_new:
803 elif affects_new:
805 action = Action.NEW_NO_NL
804 action = Action.NEW_NO_NL
806 else:
805 else:
807 raise Exception('invalid context for no newline')
806 raise Exception('invalid context for no newline')
808
807
809 lines.append({
808 lines.append({
810 'old_lineno': None,
809 'old_lineno': None,
811 'new_lineno': None,
810 'new_lineno': None,
812 'action': action,
811 'action': action,
813 'line': self._clean_line(line, command)
812 'line': self._clean_line(line, command)
814 })
813 })
815
814
816 except StopIteration:
815 except StopIteration:
817 pass
816 pass
818
817
819 return ''.join(raw_diff), chunks, stats
818 return ''.join(raw_diff), chunks, stats
820
819
821 def _safe_id(self, idstring):
820 def _safe_id(self, idstring):
822 """Make a string safe for including in an id attribute.
821 """Make a string safe for including in an id attribute.
823
822
824 The HTML spec says that id attributes 'must begin with
823 The HTML spec says that id attributes 'must begin with
825 a letter ([A-Za-z]) and may be followed by any number
824 a letter ([A-Za-z]) and may be followed by any number
826 of letters, digits ([0-9]), hyphens ("-"), underscores
825 of letters, digits ([0-9]), hyphens ("-"), underscores
827 ("_"), colons (":"), and periods (".")'. These regexps
826 ("_"), colons (":"), and periods (".")'. These regexps
828 are slightly over-zealous, in that they remove colons
827 are slightly over-zealous, in that they remove colons
829 and periods unnecessarily.
828 and periods unnecessarily.
830
829
831 Whitespace is transformed into underscores, and then
830 Whitespace is transformed into underscores, and then
832 anything which is not a hyphen or a character that
831 anything which is not a hyphen or a character that
833 matches \w (alphanumerics and underscore) is removed.
832 matches \w (alphanumerics and underscore) is removed.
834
833
835 """
834 """
836 # Transform all whitespace to underscore
835 # Transform all whitespace to underscore
837 idstring = re.sub(r'\s', "_", '%s' % idstring)
836 idstring = re.sub(r'\s', "_", '%s' % idstring)
838 # Remove everything that is not a hyphen or a member of \w
837 # Remove everything that is not a hyphen or a member of \w
839 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
838 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
840 return idstring
839 return idstring
841
840
842 @classmethod
841 @classmethod
843 def diff_splitter(cls, string):
842 def diff_splitter(cls, string):
844 """
843 """
845 Diff split that emulates .splitlines() but works only on \n
844 Diff split that emulates .splitlines() but works only on \n
846 """
845 """
847 if not string:
846 if not string:
848 return
847 return
849 elif string == '\n':
848 elif string == '\n':
850 yield u'\n'
849 yield '\n'
851 else:
850 else:
852
851
853 has_newline = string.endswith('\n')
852 has_newline = string.endswith('\n')
854 elements = string.split('\n')
853 elements = string.split('\n')
855 if has_newline:
854 if has_newline:
856 # skip last element as it's empty string from newlines
855 # skip last element as it's empty string from newlines
857 elements = elements[:-1]
856 elements = elements[:-1]
858
857
859 len_elements = len(elements)
858 len_elements = len(elements)
860
859
861 for cnt, line in enumerate(elements, start=1):
860 for cnt, line in enumerate(elements, start=1):
862 last_line = cnt == len_elements
861 last_line = cnt == len_elements
863 if last_line and not has_newline:
862 if last_line and not has_newline:
864 yield safe_unicode(line)
863 yield safe_unicode(line)
865 else:
864 else:
866 yield safe_unicode(line) + '\n'
865 yield safe_unicode(line) + '\n'
867
866
868 def prepare(self, inline_diff=True):
867 def prepare(self, inline_diff=True):
869 """
868 """
870 Prepare the passed udiff for HTML rendering.
869 Prepare the passed udiff for HTML rendering.
871
870
872 :return: A list of dicts with diff information.
871 :return: A list of dicts with diff information.
873 """
872 """
874 parsed = self._parser(inline_diff=inline_diff)
873 parsed = self._parser(inline_diff=inline_diff)
875 self.parsed = True
874 self.parsed = True
876 self.parsed_diff = parsed
875 self.parsed_diff = parsed
877 return parsed
876 return parsed
878
877
879 def as_raw(self, diff_lines=None):
878 def as_raw(self, diff_lines=None):
880 """
879 """
881 Returns raw diff as a byte string
880 Returns raw diff as a byte string
882 """
881 """
883 return self._diff.raw
882 return self._diff.raw
884
883
885 def as_html(self, table_class='code-difftable', line_class='line',
884 def as_html(self, table_class='code-difftable', line_class='line',
886 old_lineno_class='lineno old', new_lineno_class='lineno new',
885 old_lineno_class='lineno old', new_lineno_class='lineno new',
887 code_class='code', enable_comments=False, parsed_lines=None):
886 code_class='code', enable_comments=False, parsed_lines=None):
888 """
887 """
889 Return given diff as html table with customized css classes
888 Return given diff as html table with customized css classes
890 """
889 """
891 # TODO(marcink): not sure how to pass in translator
890 # TODO(marcink): not sure how to pass in translator
892 # here in an efficient way, leave the _ for proper gettext extraction
891 # here in an efficient way, leave the _ for proper gettext extraction
893 _ = lambda s: s
892 _ = lambda s: s
894
893
895 def _link_to_if(condition, label, url):
894 def _link_to_if(condition, label, url):
896 """
895 """
897 Generates a link if condition is meet or just the label if not.
896 Generates a link if condition is meet or just the label if not.
898 """
897 """
899
898
900 if condition:
899 if condition:
901 return '''<a href="%(url)s" class="tooltip"
900 return '''<a href="%(url)s" class="tooltip"
902 title="%(title)s">%(label)s</a>''' % {
901 title="%(title)s">%(label)s</a>''' % {
903 'title': _('Click to select line'),
902 'title': _('Click to select line'),
904 'url': url,
903 'url': url,
905 'label': label
904 'label': label
906 }
905 }
907 else:
906 else:
908 return label
907 return label
909 if not self.parsed:
908 if not self.parsed:
910 self.prepare()
909 self.prepare()
911
910
912 diff_lines = self.parsed_diff
911 diff_lines = self.parsed_diff
913 if parsed_lines:
912 if parsed_lines:
914 diff_lines = parsed_lines
913 diff_lines = parsed_lines
915
914
916 _html_empty = True
915 _html_empty = True
917 _html = []
916 _html = []
918 _html.append('''<table class="%(table_class)s">\n''' % {
917 _html.append('''<table class="%(table_class)s">\n''' % {
919 'table_class': table_class
918 'table_class': table_class
920 })
919 })
921
920
922 for diff in diff_lines:
921 for diff in diff_lines:
923 for line in diff['chunks']:
922 for line in diff['chunks']:
924 _html_empty = False
923 _html_empty = False
925 for change in line:
924 for change in line:
926 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
925 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
927 'lc': line_class,
926 'lc': line_class,
928 'action': change['action']
927 'action': change['action']
929 })
928 })
930 anchor_old_id = ''
929 anchor_old_id = ''
931 anchor_new_id = ''
930 anchor_new_id = ''
932 anchor_old = "%(filename)s_o%(oldline_no)s" % {
931 anchor_old = "%(filename)s_o%(oldline_no)s" % {
933 'filename': self._safe_id(diff['filename']),
932 'filename': self._safe_id(diff['filename']),
934 'oldline_no': change['old_lineno']
933 'oldline_no': change['old_lineno']
935 }
934 }
936 anchor_new = "%(filename)s_n%(oldline_no)s" % {
935 anchor_new = "%(filename)s_n%(oldline_no)s" % {
937 'filename': self._safe_id(diff['filename']),
936 'filename': self._safe_id(diff['filename']),
938 'oldline_no': change['new_lineno']
937 'oldline_no': change['new_lineno']
939 }
938 }
940 cond_old = (change['old_lineno'] != '...' and
939 cond_old = (change['old_lineno'] != '...' and
941 change['old_lineno'])
940 change['old_lineno'])
942 cond_new = (change['new_lineno'] != '...' and
941 cond_new = (change['new_lineno'] != '...' and
943 change['new_lineno'])
942 change['new_lineno'])
944 if cond_old:
943 if cond_old:
945 anchor_old_id = 'id="%s"' % anchor_old
944 anchor_old_id = 'id="%s"' % anchor_old
946 if cond_new:
945 if cond_new:
947 anchor_new_id = 'id="%s"' % anchor_new
946 anchor_new_id = 'id="%s"' % anchor_new
948
947
949 if change['action'] != Action.CONTEXT:
948 if change['action'] != Action.CONTEXT:
950 anchor_link = True
949 anchor_link = True
951 else:
950 else:
952 anchor_link = False
951 anchor_link = False
953
952
954 ###########################################################
953 ###########################################################
955 # COMMENT ICONS
954 # COMMENT ICONS
956 ###########################################################
955 ###########################################################
957 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
956 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
958
957
959 if enable_comments and change['action'] != Action.CONTEXT:
958 if enable_comments and change['action'] != Action.CONTEXT:
960 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
959 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
961
960
962 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
961 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
963
962
964 ###########################################################
963 ###########################################################
965 # OLD LINE NUMBER
964 # OLD LINE NUMBER
966 ###########################################################
965 ###########################################################
967 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
966 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
968 'a_id': anchor_old_id,
967 'a_id': anchor_old_id,
969 'olc': old_lineno_class
968 'olc': old_lineno_class
970 })
969 })
971
970
972 _html.append('''%(link)s''' % {
971 _html.append('''%(link)s''' % {
973 'link': _link_to_if(anchor_link, change['old_lineno'],
972 'link': _link_to_if(anchor_link, change['old_lineno'],
974 '#%s' % anchor_old)
973 '#%s' % anchor_old)
975 })
974 })
976 _html.append('''</td>\n''')
975 _html.append('''</td>\n''')
977 ###########################################################
976 ###########################################################
978 # NEW LINE NUMBER
977 # NEW LINE NUMBER
979 ###########################################################
978 ###########################################################
980
979
981 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
980 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
982 'a_id': anchor_new_id,
981 'a_id': anchor_new_id,
983 'nlc': new_lineno_class
982 'nlc': new_lineno_class
984 })
983 })
985
984
986 _html.append('''%(link)s''' % {
985 _html.append('''%(link)s''' % {
987 'link': _link_to_if(anchor_link, change['new_lineno'],
986 'link': _link_to_if(anchor_link, change['new_lineno'],
988 '#%s' % anchor_new)
987 '#%s' % anchor_new)
989 })
988 })
990 _html.append('''</td>\n''')
989 _html.append('''</td>\n''')
991 ###########################################################
990 ###########################################################
992 # CODE
991 # CODE
993 ###########################################################
992 ###########################################################
994 code_classes = [code_class]
993 code_classes = [code_class]
995 if (not enable_comments or
994 if (not enable_comments or
996 change['action'] == Action.CONTEXT):
995 change['action'] == Action.CONTEXT):
997 code_classes.append('no-comment')
996 code_classes.append('no-comment')
998 _html.append('\t<td class="%s">' % ' '.join(code_classes))
997 _html.append('\t<td class="%s">' % ' '.join(code_classes))
999 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
998 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
1000 'code': change['line']
999 'code': change['line']
1001 })
1000 })
1002
1001
1003 _html.append('''\t</td>''')
1002 _html.append('''\t</td>''')
1004 _html.append('''\n</tr>\n''')
1003 _html.append('''\n</tr>\n''')
1005 _html.append('''</table>''')
1004 _html.append('''</table>''')
1006 if _html_empty:
1005 if _html_empty:
1007 return None
1006 return None
1008 return ''.join(_html)
1007 return ''.join(_html)
1009
1008
1010 def stat(self):
1009 def stat(self):
1011 """
1010 """
1012 Returns tuple of added, and removed lines for this instance
1011 Returns tuple of added, and removed lines for this instance
1013 """
1012 """
1014 return self.adds, self.removes
1013 return self.adds, self.removes
1015
1014
1016 def get_context_of_line(
1015 def get_context_of_line(
1017 self, path, diff_line=None, context_before=3, context_after=3):
1016 self, path, diff_line=None, context_before=3, context_after=3):
1018 """
1017 """
1019 Returns the context lines for the specified diff line.
1018 Returns the context lines for the specified diff line.
1020
1019
1021 :type diff_line: :class:`DiffLineNumber`
1020 :type diff_line: :class:`DiffLineNumber`
1022 """
1021 """
1023 assert self.parsed, "DiffProcessor is not initialized."
1022 assert self.parsed, "DiffProcessor is not initialized."
1024
1023
1025 if None not in diff_line:
1024 if None not in diff_line:
1026 raise ValueError(
1025 raise ValueError(
1027 "Cannot specify both line numbers: {}".format(diff_line))
1026 "Cannot specify both line numbers: {}".format(diff_line))
1028
1027
1029 file_diff = self._get_file_diff(path)
1028 file_diff = self._get_file_diff(path)
1030 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1029 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1031
1030
1032 first_line_to_include = max(idx - context_before, 0)
1031 first_line_to_include = max(idx - context_before, 0)
1033 first_line_after_context = idx + context_after + 1
1032 first_line_after_context = idx + context_after + 1
1034 context_lines = chunk[first_line_to_include:first_line_after_context]
1033 context_lines = chunk[first_line_to_include:first_line_after_context]
1035
1034
1036 line_contents = [
1035 line_contents = [
1037 _context_line(line) for line in context_lines
1036 _context_line(line) for line in context_lines
1038 if _is_diff_content(line)]
1037 if _is_diff_content(line)]
1039 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1038 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1040 # Once they are fixed, we can drop this line here.
1039 # Once they are fixed, we can drop this line here.
1041 if line_contents:
1040 if line_contents:
1042 line_contents[-1] = (
1041 line_contents[-1] = (
1043 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1042 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1044 return line_contents
1043 return line_contents
1045
1044
1046 def find_context(self, path, context, offset=0):
1045 def find_context(self, path, context, offset=0):
1047 """
1046 """
1048 Finds the given `context` inside of the diff.
1047 Finds the given `context` inside of the diff.
1049
1048
1050 Use the parameter `offset` to specify which offset the target line has
1049 Use the parameter `offset` to specify which offset the target line has
1051 inside of the given `context`. This way the correct diff line will be
1050 inside of the given `context`. This way the correct diff line will be
1052 returned.
1051 returned.
1053
1052
1054 :param offset: Shall be used to specify the offset of the main line
1053 :param offset: Shall be used to specify the offset of the main line
1055 within the given `context`.
1054 within the given `context`.
1056 """
1055 """
1057 if offset < 0 or offset >= len(context):
1056 if offset < 0 or offset >= len(context):
1058 raise ValueError(
1057 raise ValueError(
1059 "Only positive values up to the length of the context "
1058 "Only positive values up to the length of the context "
1060 "minus one are allowed.")
1059 "minus one are allowed.")
1061
1060
1062 matches = []
1061 matches = []
1063 file_diff = self._get_file_diff(path)
1062 file_diff = self._get_file_diff(path)
1064
1063
1065 for chunk in file_diff['chunks']:
1064 for chunk in file_diff['chunks']:
1066 context_iter = iter(context)
1065 context_iter = iter(context)
1067 for line_idx, line in enumerate(chunk):
1066 for line_idx, line in enumerate(chunk):
1068 try:
1067 try:
1069 if _context_line(line) == next(context_iter):
1068 if _context_line(line) == next(context_iter):
1070 continue
1069 continue
1071 except StopIteration:
1070 except StopIteration:
1072 matches.append((line_idx, chunk))
1071 matches.append((line_idx, chunk))
1073 context_iter = iter(context)
1072 context_iter = iter(context)
1074
1073
1075 # Increment position and triger StopIteration
1074 # Increment position and triger StopIteration
1076 # if we had a match at the end
1075 # if we had a match at the end
1077 line_idx += 1
1076 line_idx += 1
1078 try:
1077 try:
1079 next(context_iter)
1078 next(context_iter)
1080 except StopIteration:
1079 except StopIteration:
1081 matches.append((line_idx, chunk))
1080 matches.append((line_idx, chunk))
1082
1081
1083 effective_offset = len(context) - offset
1082 effective_offset = len(context) - offset
1084 found_at_diff_lines = [
1083 found_at_diff_lines = [
1085 _line_to_diff_line_number(chunk[idx - effective_offset])
1084 _line_to_diff_line_number(chunk[idx - effective_offset])
1086 for idx, chunk in matches]
1085 for idx, chunk in matches]
1087
1086
1088 return found_at_diff_lines
1087 return found_at_diff_lines
1089
1088
1090 def _get_file_diff(self, path):
1089 def _get_file_diff(self, path):
1091 for file_diff in self.parsed_diff:
1090 for file_diff in self.parsed_diff:
1092 if file_diff['filename'] == path:
1091 if file_diff['filename'] == path:
1093 break
1092 break
1094 else:
1093 else:
1095 raise FileNotInDiffException("File {} not in diff".format(path))
1094 raise FileNotInDiffException("File {} not in diff".format(path))
1096 return file_diff
1095 return file_diff
1097
1096
1098 def _find_chunk_line_index(self, file_diff, diff_line):
1097 def _find_chunk_line_index(self, file_diff, diff_line):
1099 for chunk in file_diff['chunks']:
1098 for chunk in file_diff['chunks']:
1100 for idx, line in enumerate(chunk):
1099 for idx, line in enumerate(chunk):
1101 if line['old_lineno'] == diff_line.old:
1100 if line['old_lineno'] == diff_line.old:
1102 return chunk, idx
1101 return chunk, idx
1103 if line['new_lineno'] == diff_line.new:
1102 if line['new_lineno'] == diff_line.new:
1104 return chunk, idx
1103 return chunk, idx
1105 raise LineNotInDiffException(
1104 raise LineNotInDiffException(
1106 "The line {} is not part of the diff.".format(diff_line))
1105 "The line {} is not part of the diff.".format(diff_line))
1107
1106
1108
1107
1109 def _is_diff_content(line):
1108 def _is_diff_content(line):
1110 return line['action'] in (
1109 return line['action'] in (
1111 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1110 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1112
1111
1113
1112
1114 def _context_line(line):
1113 def _context_line(line):
1115 return (line['action'], line['line'])
1114 return (line['action'], line['line'])
1116
1115
1117
1116
1118 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1117 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1119
1118
1120
1119
1121 def _line_to_diff_line_number(line):
1120 def _line_to_diff_line_number(line):
1122 new_line_no = line['new_lineno'] or None
1121 new_line_no = line['new_lineno'] or None
1123 old_line_no = line['old_lineno'] or None
1122 old_line_no = line['old_lineno'] or None
1124 return DiffLineNumber(old=old_line_no, new=new_line_no)
1123 return DiffLineNumber(old=old_line_no, new=new_line_no)
1125
1124
1126
1125
1127 class FileNotInDiffException(Exception):
1126 class FileNotInDiffException(Exception):
1128 """
1127 """
1129 Raised when the context for a missing file is requested.
1128 Raised when the context for a missing file is requested.
1130
1129
1131 If you request the context for a line in a file which is not part of the
1130 If you request the context for a line in a file which is not part of the
1132 given diff, then this exception is raised.
1131 given diff, then this exception is raised.
1133 """
1132 """
1134
1133
1135
1134
1136 class LineNotInDiffException(Exception):
1135 class LineNotInDiffException(Exception):
1137 """
1136 """
1138 Raised when the context for a missing line is requested.
1137 Raised when the context for a missing line is requested.
1139
1138
1140 If you request the context for a line in a file and this line is not
1139 If you request the context for a line in a file and this line is not
1141 part of the given diff, then this exception is raised.
1140 part of the given diff, then this exception is raised.
1142 """
1141 """
1143
1142
1144
1143
1145 class DiffLimitExceeded(Exception):
1144 class DiffLimitExceeded(Exception):
1146 pass
1145 pass
1147
1146
1148
1147
1149 # NOTE(marcink): if diffs.mako change, probably this
1148 # NOTE(marcink): if diffs.mako change, probably this
1150 # needs a bump to next version
1149 # needs a bump to next version
1151 CURRENT_DIFF_VERSION = 'v5'
1150 CURRENT_DIFF_VERSION = 'v5'
1152
1151
1153
1152
1154 def _cleanup_cache_file(cached_diff_file):
1153 def _cleanup_cache_file(cached_diff_file):
1155 # cleanup file to not store it "damaged"
1154 # cleanup file to not store it "damaged"
1156 try:
1155 try:
1157 os.remove(cached_diff_file)
1156 os.remove(cached_diff_file)
1158 except Exception:
1157 except Exception:
1159 log.exception('Failed to cleanup path %s', cached_diff_file)
1158 log.exception('Failed to cleanup path %s', cached_diff_file)
1160
1159
1161
1160
1162 def _get_compression_mode(cached_diff_file):
1161 def _get_compression_mode(cached_diff_file):
1163 mode = 'bz2'
1162 mode = 'bz2'
1164 if 'mode:plain' in cached_diff_file:
1163 if 'mode:plain' in cached_diff_file:
1165 mode = 'plain'
1164 mode = 'plain'
1166 elif 'mode:gzip' in cached_diff_file:
1165 elif 'mode:gzip' in cached_diff_file:
1167 mode = 'gzip'
1166 mode = 'gzip'
1168 return mode
1167 return mode
1169
1168
1170
1169
1171 def cache_diff(cached_diff_file, diff, commits):
1170 def cache_diff(cached_diff_file, diff, commits):
1172 compression_mode = _get_compression_mode(cached_diff_file)
1171 compression_mode = _get_compression_mode(cached_diff_file)
1173
1172
1174 struct = {
1173 struct = {
1175 'version': CURRENT_DIFF_VERSION,
1174 'version': CURRENT_DIFF_VERSION,
1176 'diff': diff,
1175 'diff': diff,
1177 'commits': commits
1176 'commits': commits
1178 }
1177 }
1179
1178
1180 start = time.time()
1179 start = time.time()
1181 try:
1180 try:
1182 if compression_mode == 'plain':
1181 if compression_mode == 'plain':
1183 with open(cached_diff_file, 'wb') as f:
1182 with open(cached_diff_file, 'wb') as f:
1184 pickle.dump(struct, f)
1183 pickle.dump(struct, f)
1185 elif compression_mode == 'gzip':
1184 elif compression_mode == 'gzip':
1186 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1185 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1187 pickle.dump(struct, f)
1186 pickle.dump(struct, f)
1188 else:
1187 else:
1189 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1188 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1190 pickle.dump(struct, f)
1189 pickle.dump(struct, f)
1191 except Exception:
1190 except Exception:
1192 log.warn('Failed to save cache', exc_info=True)
1191 log.warn('Failed to save cache', exc_info=True)
1193 _cleanup_cache_file(cached_diff_file)
1192 _cleanup_cache_file(cached_diff_file)
1194
1193
1195 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1194 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1196
1195
1197
1196
1198 def load_cached_diff(cached_diff_file):
1197 def load_cached_diff(cached_diff_file):
1199 compression_mode = _get_compression_mode(cached_diff_file)
1198 compression_mode = _get_compression_mode(cached_diff_file)
1200
1199
1201 default_struct = {
1200 default_struct = {
1202 'version': CURRENT_DIFF_VERSION,
1201 'version': CURRENT_DIFF_VERSION,
1203 'diff': None,
1202 'diff': None,
1204 'commits': None
1203 'commits': None
1205 }
1204 }
1206
1205
1207 has_cache = os.path.isfile(cached_diff_file)
1206 has_cache = os.path.isfile(cached_diff_file)
1208 if not has_cache:
1207 if not has_cache:
1209 log.debug('Reading diff cache file failed %s', cached_diff_file)
1208 log.debug('Reading diff cache file failed %s', cached_diff_file)
1210 return default_struct
1209 return default_struct
1211
1210
1212 data = None
1211 data = None
1213
1212
1214 start = time.time()
1213 start = time.time()
1215 try:
1214 try:
1216 if compression_mode == 'plain':
1215 if compression_mode == 'plain':
1217 with open(cached_diff_file, 'rb') as f:
1216 with open(cached_diff_file, 'rb') as f:
1218 data = pickle.load(f)
1217 data = pickle.load(f)
1219 elif compression_mode == 'gzip':
1218 elif compression_mode == 'gzip':
1220 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1219 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1221 data = pickle.load(f)
1220 data = pickle.load(f)
1222 else:
1221 else:
1223 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1222 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1224 data = pickle.load(f)
1223 data = pickle.load(f)
1225 except Exception:
1224 except Exception:
1226 log.warn('Failed to read diff cache file', exc_info=True)
1225 log.warn('Failed to read diff cache file', exc_info=True)
1227
1226
1228 if not data:
1227 if not data:
1229 data = default_struct
1228 data = default_struct
1230
1229
1231 if not isinstance(data, dict):
1230 if not isinstance(data, dict):
1232 # old version of data ?
1231 # old version of data ?
1233 data = default_struct
1232 data = default_struct
1234
1233
1235 # check version
1234 # check version
1236 if data.get('version') != CURRENT_DIFF_VERSION:
1235 if data.get('version') != CURRENT_DIFF_VERSION:
1237 # purge cache
1236 # purge cache
1238 _cleanup_cache_file(cached_diff_file)
1237 _cleanup_cache_file(cached_diff_file)
1239 return default_struct
1238 return default_struct
1240
1239
1241 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1240 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1242
1241
1243 return data
1242 return data
1244
1243
1245
1244
1246 def generate_diff_cache_key(*args):
1245 def generate_diff_cache_key(*args):
1247 """
1246 """
1248 Helper to generate a cache key using arguments
1247 Helper to generate a cache key using arguments
1249 """
1248 """
1250 def arg_mapper(input_param):
1249 def arg_mapper(input_param):
1251 input_param = safe_str(input_param)
1250 input_param = safe_str(input_param)
1252 # we cannot allow '/' in arguments since it would allow
1251 # we cannot allow '/' in arguments since it would allow
1253 # subdirectory usage
1252 # subdirectory usage
1254 input_param.replace('/', '_')
1253 input_param.replace('/', '_')
1255 return input_param or None # prevent empty string arguments
1254 return input_param or None # prevent empty string arguments
1256
1255
1257 return '_'.join([
1256 return '_'.join([
1258 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1257 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1259
1258
1260
1259
1261 def diff_cache_exist(cache_storage, *args):
1260 def diff_cache_exist(cache_storage, *args):
1262 """
1261 """
1263 Based on all generated arguments check and return a cache path
1262 Based on all generated arguments check and return a cache path
1264 """
1263 """
1265 args = list(args) + ['mode:gzip']
1264 args = list(args) + ['mode:gzip']
1266 cache_key = generate_diff_cache_key(*args)
1265 cache_key = generate_diff_cache_key(*args)
1267 cache_file_path = os.path.join(cache_storage, cache_key)
1266 cache_file_path = os.path.join(cache_storage, cache_key)
1268 # prevent path traversal attacks using some param that have e.g '../../'
1267 # prevent path traversal attacks using some param that have e.g '../../'
1269 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1268 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1270 raise ValueError('Final path must be within {}'.format(cache_storage))
1269 raise ValueError('Final path must be within {}'.format(cache_storage))
1271
1270
1272 return cache_file_path
1271 return cache_file_path
@@ -1,444 +1,444 b''
1 # Copyright (c) Django Software Foundation and individual contributors.
1 # Copyright (c) Django Software Foundation and individual contributors.
2 # All rights reserved.
2 # All rights reserved.
3 #
3 #
4 # Redistribution and use in source and binary forms, with or without modification,
4 # Redistribution and use in source and binary forms, with or without modification,
5 # are permitted provided that the following conditions are met:
5 # are permitted provided that the following conditions are met:
6 #
6 #
7 # 1. Redistributions of source code must retain the above copyright notice,
7 # 1. Redistributions of source code must retain the above copyright notice,
8 # this list of conditions and the following disclaimer.
8 # this list of conditions and the following disclaimer.
9 #
9 #
10 # 2. Redistributions in binary form must reproduce the above copyright
10 # 2. Redistributions in binary form must reproduce the above copyright
11 # notice, this list of conditions and the following disclaimer in the
11 # notice, this list of conditions and the following disclaimer in the
12 # documentation and/or other materials provided with the distribution.
12 # documentation and/or other materials provided with the distribution.
13 #
13 #
14 # 3. Neither the name of Django nor the names of its contributors may be used
14 # 3. Neither the name of Django nor the names of its contributors may be used
15 # to endorse or promote products derived from this software without
15 # to endorse or promote products derived from this software without
16 # specific prior written permission.
16 # specific prior written permission.
17 #
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
29 """
29 """
30 For definitions of the different versions of RSS, see:
30 For definitions of the different versions of RSS, see:
31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
32 """
32 """
33
33
34
34
35 import datetime
35 import datetime
36 from io import StringIO
36 import io
37
37
38 import pytz
38 import pytz
39 from six.moves.urllib import parse as urlparse
39 from six.moves.urllib import parse as urlparse
40
40
41 from rhodecode.lib.feedgenerator import datetime_safe
41 from rhodecode.lib.feedgenerator import datetime_safe
42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
43
43
44
44
45 #### The following code comes from ``django.utils.feedgenerator`` ####
45 #### The following code comes from ``django.utils.feedgenerator`` ####
46
46
47
47
48 def rfc2822_date(date):
48 def rfc2822_date(date):
49 # We can't use strftime() because it produces locale-dependent results, so
49 # We can't use strftime() because it produces locale-dependent results, so
50 # we have to map english month and day names manually
50 # we have to map english month and day names manually
51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
53 # Support datetime objects older than 1900
53 # Support datetime objects older than 1900
54 date = datetime_safe.new_datetime(date)
54 date = datetime_safe.new_datetime(date)
55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
56 dow = days[date.weekday()]
56 dow = days[date.weekday()]
57 month = months[date.month - 1]
57 month = months[date.month - 1]
58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
59
59
60 offset = date.utcoffset()
60 offset = date.utcoffset()
61 # Historically, this function assumes that naive datetimes are in UTC.
61 # Historically, this function assumes that naive datetimes are in UTC.
62 if offset is None:
62 if offset is None:
63 return time_str + '-0000'
63 return time_str + '-0000'
64 else:
64 else:
65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
66 hour, minute = divmod(timezone, 60)
66 hour, minute = divmod(timezone, 60)
67 return time_str + '%+03d%02d' % (hour, minute)
67 return time_str + '%+03d%02d' % (hour, minute)
68
68
69
69
70 def rfc3339_date(date):
70 def rfc3339_date(date):
71 # Support datetime objects older than 1900
71 # Support datetime objects older than 1900
72 date = datetime_safe.new_datetime(date)
72 date = datetime_safe.new_datetime(date)
73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
74
74
75 offset = date.utcoffset()
75 offset = date.utcoffset()
76 # Historically, this function assumes that naive datetimes are in UTC.
76 # Historically, this function assumes that naive datetimes are in UTC.
77 if offset is None:
77 if offset is None:
78 return time_str + 'Z'
78 return time_str + 'Z'
79 else:
79 else:
80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
81 hour, minute = divmod(timezone, 60)
81 hour, minute = divmod(timezone, 60)
82 return time_str + '%+03d:%02d' % (hour, minute)
82 return time_str + '%+03d:%02d' % (hour, minute)
83
83
84
84
85 def get_tag_uri(url, date):
85 def get_tag_uri(url, date):
86 """
86 """
87 Creates a TagURI.
87 Creates a TagURI.
88
88
89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
90 """
90 """
91 bits = urlparse(url)
91 bits = urlparse(url)
92 d = ''
92 d = ''
93 if date is not None:
93 if date is not None:
94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
96
96
97
97
98 class SyndicationFeed(object):
98 class SyndicationFeed(object):
99 """Base class for all syndication feeds. Subclasses should provide write()"""
99 """Base class for all syndication feeds. Subclasses should provide write()"""
100
100
101 def __init__(self, title, link, description, language=None, author_email=None,
101 def __init__(self, title, link, description, language=None, author_email=None,
102 author_name=None, author_link=None, subtitle=None, categories=None,
102 author_name=None, author_link=None, subtitle=None, categories=None,
103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
104 def to_unicode(s):
104 def to_unicode(s):
105 return force_text(s, strings_only=True)
105 return force_text(s, strings_only=True)
106 if categories:
106 if categories:
107 categories = [force_text(c) for c in categories]
107 categories = [force_text(c) for c in categories]
108 if ttl is not None:
108 if ttl is not None:
109 # Force ints to unicode
109 # Force ints to unicode
110 ttl = force_text(ttl)
110 ttl = force_text(ttl)
111 self.feed = {
111 self.feed = {
112 'title': to_unicode(title),
112 'title': to_unicode(title),
113 'link': iri_to_uri(link),
113 'link': iri_to_uri(link),
114 'description': to_unicode(description),
114 'description': to_unicode(description),
115 'language': to_unicode(language),
115 'language': to_unicode(language),
116 'author_email': to_unicode(author_email),
116 'author_email': to_unicode(author_email),
117 'author_name': to_unicode(author_name),
117 'author_name': to_unicode(author_name),
118 'author_link': iri_to_uri(author_link),
118 'author_link': iri_to_uri(author_link),
119 'subtitle': to_unicode(subtitle),
119 'subtitle': to_unicode(subtitle),
120 'categories': categories or (),
120 'categories': categories or (),
121 'feed_url': iri_to_uri(feed_url),
121 'feed_url': iri_to_uri(feed_url),
122 'feed_copyright': to_unicode(feed_copyright),
122 'feed_copyright': to_unicode(feed_copyright),
123 'id': feed_guid or link,
123 'id': feed_guid or link,
124 'ttl': ttl,
124 'ttl': ttl,
125 }
125 }
126 self.feed.update(kwargs)
126 self.feed.update(kwargs)
127 self.items = []
127 self.items = []
128
128
129 def add_item(self, title, link, description, author_email=None,
129 def add_item(self, title, link, description, author_email=None,
130 author_name=None, author_link=None, pubdate=None, comments=None,
130 author_name=None, author_link=None, pubdate=None, comments=None,
131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
132 categories=(), item_copyright=None, ttl=None, updateddate=None,
132 categories=(), item_copyright=None, ttl=None, updateddate=None,
133 enclosures=None, **kwargs):
133 enclosures=None, **kwargs):
134 """
134 """
135 Adds an item to the feed. All args are expected to be Python Unicode
135 Adds an item to the feed. All args are expected to be Python Unicode
136 objects except pubdate and updateddate, which are datetime.datetime
136 objects except pubdate and updateddate, which are datetime.datetime
137 objects, and enclosures, which is an iterable of instances of the
137 objects, and enclosures, which is an iterable of instances of the
138 Enclosure class.
138 Enclosure class.
139 """
139 """
140 def to_unicode(s):
140 def to_unicode(s):
141 return force_text(s, strings_only=True)
141 return force_text(s, strings_only=True)
142 if categories:
142 if categories:
143 categories = [to_unicode(c) for c in categories]
143 categories = [to_unicode(c) for c in categories]
144 if ttl is not None:
144 if ttl is not None:
145 # Force ints to unicode
145 # Force ints to unicode
146 ttl = force_text(ttl)
146 ttl = force_text(ttl)
147 if enclosure is None:
147 if enclosure is None:
148 enclosures = [] if enclosures is None else enclosures
148 enclosures = [] if enclosures is None else enclosures
149
149
150 item = {
150 item = {
151 'title': to_unicode(title),
151 'title': to_unicode(title),
152 'link': iri_to_uri(link),
152 'link': iri_to_uri(link),
153 'description': to_unicode(description),
153 'description': to_unicode(description),
154 'author_email': to_unicode(author_email),
154 'author_email': to_unicode(author_email),
155 'author_name': to_unicode(author_name),
155 'author_name': to_unicode(author_name),
156 'author_link': iri_to_uri(author_link),
156 'author_link': iri_to_uri(author_link),
157 'pubdate': pubdate,
157 'pubdate': pubdate,
158 'updateddate': updateddate,
158 'updateddate': updateddate,
159 'comments': to_unicode(comments),
159 'comments': to_unicode(comments),
160 'unique_id': to_unicode(unique_id),
160 'unique_id': to_unicode(unique_id),
161 'unique_id_is_permalink': unique_id_is_permalink,
161 'unique_id_is_permalink': unique_id_is_permalink,
162 'enclosures': enclosures,
162 'enclosures': enclosures,
163 'categories': categories or (),
163 'categories': categories or (),
164 'item_copyright': to_unicode(item_copyright),
164 'item_copyright': to_unicode(item_copyright),
165 'ttl': ttl,
165 'ttl': ttl,
166 }
166 }
167 item.update(kwargs)
167 item.update(kwargs)
168 self.items.append(item)
168 self.items.append(item)
169
169
170 def num_items(self):
170 def num_items(self):
171 return len(self.items)
171 return len(self.items)
172
172
173 def root_attributes(self):
173 def root_attributes(self):
174 """
174 """
175 Return extra attributes to place on the root (i.e. feed/channel) element.
175 Return extra attributes to place on the root (i.e. feed/channel) element.
176 Called from write().
176 Called from write().
177 """
177 """
178 return {}
178 return {}
179
179
180 def add_root_elements(self, handler):
180 def add_root_elements(self, handler):
181 """
181 """
182 Add elements in the root (i.e. feed/channel) element. Called
182 Add elements in the root (i.e. feed/channel) element. Called
183 from write().
183 from write().
184 """
184 """
185 pass
185 pass
186
186
187 def item_attributes(self, item):
187 def item_attributes(self, item):
188 """
188 """
189 Return extra attributes to place on each item (i.e. item/entry) element.
189 Return extra attributes to place on each item (i.e. item/entry) element.
190 """
190 """
191 return {}
191 return {}
192
192
193 def add_item_elements(self, handler, item):
193 def add_item_elements(self, handler, item):
194 """
194 """
195 Add elements on each item (i.e. item/entry) element.
195 Add elements on each item (i.e. item/entry) element.
196 """
196 """
197 pass
197 pass
198
198
199 def write(self, outfile, encoding):
199 def write(self, outfile, encoding):
200 """
200 """
201 Outputs the feed in the given encoding to outfile, which is a file-like
201 Outputs the feed in the given encoding to outfile, which is a file-like
202 object. Subclasses should override this.
202 object. Subclasses should override this.
203 """
203 """
204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
205
205
206 def writeString(self, encoding):
206 def writeString(self, encoding):
207 """
207 """
208 Returns the feed in the given encoding as a string.
208 Returns the feed in the given encoding as a string.
209 """
209 """
210 s = StringIO()
210 s = io.StringIO()
211 self.write(s, encoding)
211 self.write(s, encoding)
212 return s.getvalue()
212 return s.getvalue()
213
213
214 def latest_post_date(self):
214 def latest_post_date(self):
215 """
215 """
216 Returns the latest item's pubdate or updateddate. If no items
216 Returns the latest item's pubdate or updateddate. If no items
217 have either of these attributes this returns the current UTC date/time.
217 have either of these attributes this returns the current UTC date/time.
218 """
218 """
219 latest_date = None
219 latest_date = None
220 date_keys = ('updateddate', 'pubdate')
220 date_keys = ('updateddate', 'pubdate')
221
221
222 for item in self.items:
222 for item in self.items:
223 for date_key in date_keys:
223 for date_key in date_keys:
224 item_date = item.get(date_key)
224 item_date = item.get(date_key)
225 if item_date:
225 if item_date:
226 if latest_date is None or item_date > latest_date:
226 if latest_date is None or item_date > latest_date:
227 latest_date = item_date
227 latest_date = item_date
228
228
229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
231
231
232
232
233 class Enclosure(object):
233 class Enclosure(object):
234 """Represents an RSS enclosure"""
234 """Represents an RSS enclosure"""
235 def __init__(self, url, length, mime_type):
235 def __init__(self, url, length, mime_type):
236 """All args are expected to be Python Unicode objects"""
236 """All args are expected to be Python Unicode objects"""
237 self.length, self.mime_type = length, mime_type
237 self.length, self.mime_type = length, mime_type
238 self.url = iri_to_uri(url)
238 self.url = iri_to_uri(url)
239
239
240
240
241 class RssFeed(SyndicationFeed):
241 class RssFeed(SyndicationFeed):
242 content_type = 'application/rss+xml; charset=utf-8'
242 content_type = 'application/rss+xml; charset=utf-8'
243
243
244 def write(self, outfile, encoding):
244 def write(self, outfile, encoding):
245 handler = SimplerXMLGenerator(outfile, encoding)
245 handler = SimplerXMLGenerator(outfile, encoding)
246 handler.startDocument()
246 handler.startDocument()
247 handler.startElement("rss", self.rss_attributes())
247 handler.startElement("rss", self.rss_attributes())
248 handler.startElement("channel", self.root_attributes())
248 handler.startElement("channel", self.root_attributes())
249 self.add_root_elements(handler)
249 self.add_root_elements(handler)
250 self.write_items(handler)
250 self.write_items(handler)
251 self.endChannelElement(handler)
251 self.endChannelElement(handler)
252 handler.endElement("rss")
252 handler.endElement("rss")
253
253
254 def rss_attributes(self):
254 def rss_attributes(self):
255 return {"version": self._version,
255 return {"version": self._version,
256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
257
257
258 def write_items(self, handler):
258 def write_items(self, handler):
259 for item in self.items:
259 for item in self.items:
260 handler.startElement('item', self.item_attributes(item))
260 handler.startElement('item', self.item_attributes(item))
261 self.add_item_elements(handler, item)
261 self.add_item_elements(handler, item)
262 handler.endElement("item")
262 handler.endElement("item")
263
263
264 def add_root_elements(self, handler):
264 def add_root_elements(self, handler):
265 handler.addQuickElement("title", self.feed['title'])
265 handler.addQuickElement("title", self.feed['title'])
266 handler.addQuickElement("link", self.feed['link'])
266 handler.addQuickElement("link", self.feed['link'])
267 handler.addQuickElement("description", self.feed['description'])
267 handler.addQuickElement("description", self.feed['description'])
268 if self.feed['feed_url'] is not None:
268 if self.feed['feed_url'] is not None:
269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
270 if self.feed['language'] is not None:
270 if self.feed['language'] is not None:
271 handler.addQuickElement("language", self.feed['language'])
271 handler.addQuickElement("language", self.feed['language'])
272 for cat in self.feed['categories']:
272 for cat in self.feed['categories']:
273 handler.addQuickElement("category", cat)
273 handler.addQuickElement("category", cat)
274 if self.feed['feed_copyright'] is not None:
274 if self.feed['feed_copyright'] is not None:
275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
277 if self.feed['ttl'] is not None:
277 if self.feed['ttl'] is not None:
278 handler.addQuickElement("ttl", self.feed['ttl'])
278 handler.addQuickElement("ttl", self.feed['ttl'])
279
279
280 def endChannelElement(self, handler):
280 def endChannelElement(self, handler):
281 handler.endElement("channel")
281 handler.endElement("channel")
282
282
283
283
284 class RssUserland091Feed(RssFeed):
284 class RssUserland091Feed(RssFeed):
285 _version = "0.91"
285 _version = "0.91"
286
286
287 def add_item_elements(self, handler, item):
287 def add_item_elements(self, handler, item):
288 handler.addQuickElement("title", item['title'])
288 handler.addQuickElement("title", item['title'])
289 handler.addQuickElement("link", item['link'])
289 handler.addQuickElement("link", item['link'])
290 if item['description'] is not None:
290 if item['description'] is not None:
291 handler.addQuickElement("description", item['description'])
291 handler.addQuickElement("description", item['description'])
292
292
293
293
294 class Rss201rev2Feed(RssFeed):
294 class Rss201rev2Feed(RssFeed):
295 # Spec: http://blogs.law.harvard.edu/tech/rss
295 # Spec: http://blogs.law.harvard.edu/tech/rss
296 _version = "2.0"
296 _version = "2.0"
297
297
298 def add_item_elements(self, handler, item):
298 def add_item_elements(self, handler, item):
299 handler.addQuickElement("title", item['title'])
299 handler.addQuickElement("title", item['title'])
300 handler.addQuickElement("link", item['link'])
300 handler.addQuickElement("link", item['link'])
301 if item['description'] is not None:
301 if item['description'] is not None:
302 handler.addQuickElement("description", item['description'])
302 handler.addQuickElement("description", item['description'])
303
303
304 # Author information.
304 # Author information.
305 if item["author_name"] and item["author_email"]:
305 if item["author_name"] and item["author_email"]:
306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
307 elif item["author_email"]:
307 elif item["author_email"]:
308 handler.addQuickElement("author", item["author_email"])
308 handler.addQuickElement("author", item["author_email"])
309 elif item["author_name"]:
309 elif item["author_name"]:
310 handler.addQuickElement(
310 handler.addQuickElement(
311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
312 )
312 )
313
313
314 if item['pubdate'] is not None:
314 if item['pubdate'] is not None:
315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
316 if item['comments'] is not None:
316 if item['comments'] is not None:
317 handler.addQuickElement("comments", item['comments'])
317 handler.addQuickElement("comments", item['comments'])
318 if item['unique_id'] is not None:
318 if item['unique_id'] is not None:
319 guid_attrs = {}
319 guid_attrs = {}
320 if isinstance(item.get('unique_id_is_permalink'), bool):
320 if isinstance(item.get('unique_id_is_permalink'), bool):
321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
323 if item['ttl'] is not None:
323 if item['ttl'] is not None:
324 handler.addQuickElement("ttl", item['ttl'])
324 handler.addQuickElement("ttl", item['ttl'])
325
325
326 # Enclosure.
326 # Enclosure.
327 if item['enclosures']:
327 if item['enclosures']:
328 enclosures = list(item['enclosures'])
328 enclosures = list(item['enclosures'])
329 if len(enclosures) > 1:
329 if len(enclosures) > 1:
330 raise ValueError(
330 raise ValueError(
331 "RSS feed items may only have one enclosure, see "
331 "RSS feed items may only have one enclosure, see "
332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
333 )
333 )
334 enclosure = enclosures[0]
334 enclosure = enclosures[0]
335 handler.addQuickElement('enclosure', '', {
335 handler.addQuickElement('enclosure', '', {
336 'url': enclosure.url,
336 'url': enclosure.url,
337 'length': enclosure.length,
337 'length': enclosure.length,
338 'type': enclosure.mime_type,
338 'type': enclosure.mime_type,
339 })
339 })
340
340
341 # Categories.
341 # Categories.
342 for cat in item['categories']:
342 for cat in item['categories']:
343 handler.addQuickElement("category", cat)
343 handler.addQuickElement("category", cat)
344
344
345
345
346 class Atom1Feed(SyndicationFeed):
346 class Atom1Feed(SyndicationFeed):
347 # Spec: https://tools.ietf.org/html/rfc4287
347 # Spec: https://tools.ietf.org/html/rfc4287
348 content_type = 'application/atom+xml; charset=utf-8'
348 content_type = 'application/atom+xml; charset=utf-8'
349 ns = "http://www.w3.org/2005/Atom"
349 ns = "http://www.w3.org/2005/Atom"
350
350
351 def write(self, outfile, encoding):
351 def write(self, outfile, encoding):
352 handler = SimplerXMLGenerator(outfile, encoding)
352 handler = SimplerXMLGenerator(outfile, encoding)
353 handler.startDocument()
353 handler.startDocument()
354 handler.startElement('feed', self.root_attributes())
354 handler.startElement('feed', self.root_attributes())
355 self.add_root_elements(handler)
355 self.add_root_elements(handler)
356 self.write_items(handler)
356 self.write_items(handler)
357 handler.endElement("feed")
357 handler.endElement("feed")
358
358
359 def root_attributes(self):
359 def root_attributes(self):
360 if self.feed['language'] is not None:
360 if self.feed['language'] is not None:
361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
362 else:
362 else:
363 return {"xmlns": self.ns}
363 return {"xmlns": self.ns}
364
364
365 def add_root_elements(self, handler):
365 def add_root_elements(self, handler):
366 handler.addQuickElement("title", self.feed['title'])
366 handler.addQuickElement("title", self.feed['title'])
367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
368 if self.feed['feed_url'] is not None:
368 if self.feed['feed_url'] is not None:
369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
370 handler.addQuickElement("id", self.feed['id'])
370 handler.addQuickElement("id", self.feed['id'])
371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
372 if self.feed['author_name'] is not None:
372 if self.feed['author_name'] is not None:
373 handler.startElement("author", {})
373 handler.startElement("author", {})
374 handler.addQuickElement("name", self.feed['author_name'])
374 handler.addQuickElement("name", self.feed['author_name'])
375 if self.feed['author_email'] is not None:
375 if self.feed['author_email'] is not None:
376 handler.addQuickElement("email", self.feed['author_email'])
376 handler.addQuickElement("email", self.feed['author_email'])
377 if self.feed['author_link'] is not None:
377 if self.feed['author_link'] is not None:
378 handler.addQuickElement("uri", self.feed['author_link'])
378 handler.addQuickElement("uri", self.feed['author_link'])
379 handler.endElement("author")
379 handler.endElement("author")
380 if self.feed['subtitle'] is not None:
380 if self.feed['subtitle'] is not None:
381 handler.addQuickElement("subtitle", self.feed['subtitle'])
381 handler.addQuickElement("subtitle", self.feed['subtitle'])
382 for cat in self.feed['categories']:
382 for cat in self.feed['categories']:
383 handler.addQuickElement("category", "", {"term": cat})
383 handler.addQuickElement("category", "", {"term": cat})
384 if self.feed['feed_copyright'] is not None:
384 if self.feed['feed_copyright'] is not None:
385 handler.addQuickElement("rights", self.feed['feed_copyright'])
385 handler.addQuickElement("rights", self.feed['feed_copyright'])
386
386
387 def write_items(self, handler):
387 def write_items(self, handler):
388 for item in self.items:
388 for item in self.items:
389 handler.startElement("entry", self.item_attributes(item))
389 handler.startElement("entry", self.item_attributes(item))
390 self.add_item_elements(handler, item)
390 self.add_item_elements(handler, item)
391 handler.endElement("entry")
391 handler.endElement("entry")
392
392
393 def add_item_elements(self, handler, item):
393 def add_item_elements(self, handler, item):
394 handler.addQuickElement("title", item['title'])
394 handler.addQuickElement("title", item['title'])
395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
396
396
397 if item['pubdate'] is not None:
397 if item['pubdate'] is not None:
398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
399
399
400 if item['updateddate'] is not None:
400 if item['updateddate'] is not None:
401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
402
402
403 # Author information.
403 # Author information.
404 if item['author_name'] is not None:
404 if item['author_name'] is not None:
405 handler.startElement("author", {})
405 handler.startElement("author", {})
406 handler.addQuickElement("name", item['author_name'])
406 handler.addQuickElement("name", item['author_name'])
407 if item['author_email'] is not None:
407 if item['author_email'] is not None:
408 handler.addQuickElement("email", item['author_email'])
408 handler.addQuickElement("email", item['author_email'])
409 if item['author_link'] is not None:
409 if item['author_link'] is not None:
410 handler.addQuickElement("uri", item['author_link'])
410 handler.addQuickElement("uri", item['author_link'])
411 handler.endElement("author")
411 handler.endElement("author")
412
412
413 # Unique ID.
413 # Unique ID.
414 if item['unique_id'] is not None:
414 if item['unique_id'] is not None:
415 unique_id = item['unique_id']
415 unique_id = item['unique_id']
416 else:
416 else:
417 unique_id = get_tag_uri(item['link'], item['pubdate'])
417 unique_id = get_tag_uri(item['link'], item['pubdate'])
418 handler.addQuickElement("id", unique_id)
418 handler.addQuickElement("id", unique_id)
419
419
420 # Summary.
420 # Summary.
421 if item['description'] is not None:
421 if item['description'] is not None:
422 handler.addQuickElement("summary", item['description'], {"type": "html"})
422 handler.addQuickElement("summary", item['description'], {"type": "html"})
423
423
424 # Enclosures.
424 # Enclosures.
425 for enclosure in item['enclosures']:
425 for enclosure in item['enclosures']:
426 handler.addQuickElement('link', '', {
426 handler.addQuickElement('link', '', {
427 'rel': 'enclosure',
427 'rel': 'enclosure',
428 'href': enclosure.url,
428 'href': enclosure.url,
429 'length': enclosure.length,
429 'length': enclosure.length,
430 'type': enclosure.mime_type,
430 'type': enclosure.mime_type,
431 })
431 })
432
432
433 # Categories.
433 # Categories.
434 for cat in item['categories']:
434 for cat in item['categories']:
435 handler.addQuickElement("category", "", {"term": cat})
435 handler.addQuickElement("category", "", {"term": cat})
436
436
437 # Rights.
437 # Rights.
438 if item['item_copyright'] is not None:
438 if item['item_copyright'] is not None:
439 handler.addQuickElement("rights", item['item_copyright'])
439 handler.addQuickElement("rights", item['item_copyright'])
440
440
441
441
442 # This isolates the decision of what the system default is, so calling code can
442 # This isolates the decision of what the system default is, so calling code can
443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
444 DefaultFeed = Rss201rev2Feed No newline at end of file
444 DefaultFeed = Rss201rev2Feed
@@ -1,538 +1,538 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2013-2020 RhodeCode GmbH
3 # Copyright (C) 2013-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Set of hooks run by RhodeCode Enterprise
23 Set of hooks run by RhodeCode Enterprise
24 """
24 """
25
25
26 import os
26 import os
27 import logging
27 import logging
28
28
29 import rhodecode
29 import rhodecode
30 from rhodecode import events
30 from rhodecode import events
31 from rhodecode.lib import helpers as h
31 from rhodecode.lib import helpers as h
32 from rhodecode.lib import audit_logger
32 from rhodecode.lib import audit_logger
33 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
33 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
34 from rhodecode.lib.exceptions import (
34 from rhodecode.lib.exceptions import (
35 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
35 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
36 from rhodecode.model.db import Repository, User
36 from rhodecode.model.db import Repository, User
37 from rhodecode.lib.statsd_client import StatsdClient
37 from rhodecode.lib.statsd_client import StatsdClient
38
38
39 log = logging.getLogger(__name__)
39 log = logging.getLogger(__name__)
40
40
41
41
42 class HookResponse(object):
42 class HookResponse(object):
43 def __init__(self, status, output):
43 def __init__(self, status, output):
44 self.status = status
44 self.status = status
45 self.output = output
45 self.output = output
46
46
47 def __add__(self, other):
47 def __add__(self, other):
48 other_status = getattr(other, 'status', 0)
48 other_status = getattr(other, 'status', 0)
49 new_status = max(self.status, other_status)
49 new_status = max(self.status, other_status)
50 other_output = getattr(other, 'output', '')
50 other_output = getattr(other, 'output', '')
51 new_output = self.output + other_output
51 new_output = self.output + other_output
52
52
53 return HookResponse(new_status, new_output)
53 return HookResponse(new_status, new_output)
54
54
55 def __bool__(self):
55 def __bool__(self):
56 return self.status == 0
56 return self.status == 0
57
57
58
58
59 def is_shadow_repo(extras):
59 def is_shadow_repo(extras):
60 """
60 """
61 Returns ``True`` if this is an action executed against a shadow repository.
61 Returns ``True`` if this is an action executed against a shadow repository.
62 """
62 """
63 return extras['is_shadow_repo']
63 return extras['is_shadow_repo']
64
64
65
65
66 def _get_scm_size(alias, root_path):
66 def _get_scm_size(alias, root_path):
67
67
68 if not alias.startswith('.'):
68 if not alias.startswith('.'):
69 alias += '.'
69 alias += '.'
70
70
71 size_scm, size_root = 0, 0
71 size_scm, size_root = 0, 0
72 for path, unused_dirs, files in os.walk(safe_str(root_path)):
72 for path, unused_dirs, files in os.walk(safe_str(root_path)):
73 if path.find(alias) != -1:
73 if path.find(alias) != -1:
74 for f in files:
74 for f in files:
75 try:
75 try:
76 size_scm += os.path.getsize(os.path.join(path, f))
76 size_scm += os.path.getsize(os.path.join(path, f))
77 except OSError:
77 except OSError:
78 pass
78 pass
79 else:
79 else:
80 for f in files:
80 for f in files:
81 try:
81 try:
82 size_root += os.path.getsize(os.path.join(path, f))
82 size_root += os.path.getsize(os.path.join(path, f))
83 except OSError:
83 except OSError:
84 pass
84 pass
85
85
86 size_scm_f = h.format_byte_size_binary(size_scm)
86 size_scm_f = h.format_byte_size_binary(size_scm)
87 size_root_f = h.format_byte_size_binary(size_root)
87 size_root_f = h.format_byte_size_binary(size_root)
88 size_total_f = h.format_byte_size_binary(size_root + size_scm)
88 size_total_f = h.format_byte_size_binary(size_root + size_scm)
89
89
90 return size_scm_f, size_root_f, size_total_f
90 return size_scm_f, size_root_f, size_total_f
91
91
92
92
93 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
93 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
94 def repo_size(extras):
94 def repo_size(extras):
95 """Present size of repository after push."""
95 """Present size of repository after push."""
96 repo = Repository.get_by_repo_name(extras.repository)
96 repo = Repository.get_by_repo_name(extras.repository)
97 vcs_part = safe_str(u'.%s' % repo.repo_type)
97 vcs_part = safe_str('.%s' % repo.repo_type)
98 size_vcs, size_root, size_total = _get_scm_size(vcs_part,
98 size_vcs, size_root, size_total = _get_scm_size(vcs_part,
99 repo.repo_full_path)
99 repo.repo_full_path)
100 msg = ('Repository `%s` size summary %s:%s repo:%s total:%s\n'
100 msg = ('Repository `%s` size summary %s:%s repo:%s total:%s\n'
101 % (repo.repo_name, vcs_part, size_vcs, size_root, size_total))
101 % (repo.repo_name, vcs_part, size_vcs, size_root, size_total))
102 return HookResponse(0, msg)
102 return HookResponse(0, msg)
103
103
104
104
105 def pre_push(extras):
105 def pre_push(extras):
106 """
106 """
107 Hook executed before pushing code.
107 Hook executed before pushing code.
108
108
109 It bans pushing when the repository is locked.
109 It bans pushing when the repository is locked.
110 """
110 """
111
111
112 user = User.get_by_username(extras.username)
112 user = User.get_by_username(extras.username)
113 output = ''
113 output = ''
114 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
114 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
115 locked_by = User.get(extras.locked_by[0]).username
115 locked_by = User.get(extras.locked_by[0]).username
116 reason = extras.locked_by[2]
116 reason = extras.locked_by[2]
117 # this exception is interpreted in git/hg middlewares and based
117 # this exception is interpreted in git/hg middlewares and based
118 # on that proper return code is server to client
118 # on that proper return code is server to client
119 _http_ret = HTTPLockedRC(
119 _http_ret = HTTPLockedRC(
120 _locked_by_explanation(extras.repository, locked_by, reason))
120 _locked_by_explanation(extras.repository, locked_by, reason))
121 if str(_http_ret.code).startswith('2'):
121 if str(_http_ret.code).startswith('2'):
122 # 2xx Codes don't raise exceptions
122 # 2xx Codes don't raise exceptions
123 output = _http_ret.title
123 output = _http_ret.title
124 else:
124 else:
125 raise _http_ret
125 raise _http_ret
126
126
127 hook_response = ''
127 hook_response = ''
128 if not is_shadow_repo(extras):
128 if not is_shadow_repo(extras):
129 if extras.commit_ids and extras.check_branch_perms:
129 if extras.commit_ids and extras.check_branch_perms:
130
130
131 auth_user = user.AuthUser()
131 auth_user = user.AuthUser()
132 repo = Repository.get_by_repo_name(extras.repository)
132 repo = Repository.get_by_repo_name(extras.repository)
133 affected_branches = []
133 affected_branches = []
134 if repo.repo_type == 'hg':
134 if repo.repo_type == 'hg':
135 for entry in extras.commit_ids:
135 for entry in extras.commit_ids:
136 if entry['type'] == 'branch':
136 if entry['type'] == 'branch':
137 is_forced = bool(entry['multiple_heads'])
137 is_forced = bool(entry['multiple_heads'])
138 affected_branches.append([entry['name'], is_forced])
138 affected_branches.append([entry['name'], is_forced])
139 elif repo.repo_type == 'git':
139 elif repo.repo_type == 'git':
140 for entry in extras.commit_ids:
140 for entry in extras.commit_ids:
141 if entry['type'] == 'heads':
141 if entry['type'] == 'heads':
142 is_forced = bool(entry['pruned_sha'])
142 is_forced = bool(entry['pruned_sha'])
143 affected_branches.append([entry['name'], is_forced])
143 affected_branches.append([entry['name'], is_forced])
144
144
145 for branch_name, is_forced in affected_branches:
145 for branch_name, is_forced in affected_branches:
146
146
147 rule, branch_perm = auth_user.get_rule_and_branch_permission(
147 rule, branch_perm = auth_user.get_rule_and_branch_permission(
148 extras.repository, branch_name)
148 extras.repository, branch_name)
149 if not branch_perm:
149 if not branch_perm:
150 # no branch permission found for this branch, just keep checking
150 # no branch permission found for this branch, just keep checking
151 continue
151 continue
152
152
153 if branch_perm == 'branch.push_force':
153 if branch_perm == 'branch.push_force':
154 continue
154 continue
155 elif branch_perm == 'branch.push' and is_forced is False:
155 elif branch_perm == 'branch.push' and is_forced is False:
156 continue
156 continue
157 elif branch_perm == 'branch.push' and is_forced is True:
157 elif branch_perm == 'branch.push' and is_forced is True:
158 halt_message = 'Branch `{}` changes rejected by rule {}. ' \
158 halt_message = 'Branch `{}` changes rejected by rule {}. ' \
159 'FORCE PUSH FORBIDDEN.'.format(branch_name, rule)
159 'FORCE PUSH FORBIDDEN.'.format(branch_name, rule)
160 else:
160 else:
161 halt_message = 'Branch `{}` changes rejected by rule {}.'.format(
161 halt_message = 'Branch `{}` changes rejected by rule {}.'.format(
162 branch_name, rule)
162 branch_name, rule)
163
163
164 if halt_message:
164 if halt_message:
165 _http_ret = HTTPBranchProtected(halt_message)
165 _http_ret = HTTPBranchProtected(halt_message)
166 raise _http_ret
166 raise _http_ret
167
167
168 # Propagate to external components. This is done after checking the
168 # Propagate to external components. This is done after checking the
169 # lock, for consistent behavior.
169 # lock, for consistent behavior.
170 hook_response = pre_push_extension(
170 hook_response = pre_push_extension(
171 repo_store_path=Repository.base_path(), **extras)
171 repo_store_path=Repository.base_path(), **extras)
172 events.trigger(events.RepoPrePushEvent(
172 events.trigger(events.RepoPrePushEvent(
173 repo_name=extras.repository, extras=extras))
173 repo_name=extras.repository, extras=extras))
174
174
175 return HookResponse(0, output) + hook_response
175 return HookResponse(0, output) + hook_response
176
176
177
177
178 def pre_pull(extras):
178 def pre_pull(extras):
179 """
179 """
180 Hook executed before pulling the code.
180 Hook executed before pulling the code.
181
181
182 It bans pulling when the repository is locked.
182 It bans pulling when the repository is locked.
183 """
183 """
184
184
185 output = ''
185 output = ''
186 if extras.locked_by[0]:
186 if extras.locked_by[0]:
187 locked_by = User.get(extras.locked_by[0]).username
187 locked_by = User.get(extras.locked_by[0]).username
188 reason = extras.locked_by[2]
188 reason = extras.locked_by[2]
189 # this exception is interpreted in git/hg middlewares and based
189 # this exception is interpreted in git/hg middlewares and based
190 # on that proper return code is server to client
190 # on that proper return code is server to client
191 _http_ret = HTTPLockedRC(
191 _http_ret = HTTPLockedRC(
192 _locked_by_explanation(extras.repository, locked_by, reason))
192 _locked_by_explanation(extras.repository, locked_by, reason))
193 if str(_http_ret.code).startswith('2'):
193 if str(_http_ret.code).startswith('2'):
194 # 2xx Codes don't raise exceptions
194 # 2xx Codes don't raise exceptions
195 output = _http_ret.title
195 output = _http_ret.title
196 else:
196 else:
197 raise _http_ret
197 raise _http_ret
198
198
199 # Propagate to external components. This is done after checking the
199 # Propagate to external components. This is done after checking the
200 # lock, for consistent behavior.
200 # lock, for consistent behavior.
201 hook_response = ''
201 hook_response = ''
202 if not is_shadow_repo(extras):
202 if not is_shadow_repo(extras):
203 extras.hook_type = extras.hook_type or 'pre_pull'
203 extras.hook_type = extras.hook_type or 'pre_pull'
204 hook_response = pre_pull_extension(
204 hook_response = pre_pull_extension(
205 repo_store_path=Repository.base_path(), **extras)
205 repo_store_path=Repository.base_path(), **extras)
206 events.trigger(events.RepoPrePullEvent(
206 events.trigger(events.RepoPrePullEvent(
207 repo_name=extras.repository, extras=extras))
207 repo_name=extras.repository, extras=extras))
208
208
209 return HookResponse(0, output) + hook_response
209 return HookResponse(0, output) + hook_response
210
210
211
211
212 def post_pull(extras):
212 def post_pull(extras):
213 """Hook executed after client pulls the code."""
213 """Hook executed after client pulls the code."""
214
214
215 audit_user = audit_logger.UserWrap(
215 audit_user = audit_logger.UserWrap(
216 username=extras.username,
216 username=extras.username,
217 ip_addr=extras.ip)
217 ip_addr=extras.ip)
218 repo = audit_logger.RepoWrap(repo_name=extras.repository)
218 repo = audit_logger.RepoWrap(repo_name=extras.repository)
219 audit_logger.store(
219 audit_logger.store(
220 'user.pull', action_data={'user_agent': extras.user_agent},
220 'user.pull', action_data={'user_agent': extras.user_agent},
221 user=audit_user, repo=repo, commit=True)
221 user=audit_user, repo=repo, commit=True)
222
222
223 statsd = StatsdClient.statsd
223 statsd = StatsdClient.statsd
224 if statsd:
224 if statsd:
225 statsd.incr('rhodecode_pull_total', tags=[
225 statsd.incr('rhodecode_pull_total', tags=[
226 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
226 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
227 ])
227 ])
228 output = ''
228 output = ''
229 # make lock is a tri state False, True, None. We only make lock on True
229 # make lock is a tri state False, True, None. We only make lock on True
230 if extras.make_lock is True and not is_shadow_repo(extras):
230 if extras.make_lock is True and not is_shadow_repo(extras):
231 user = User.get_by_username(extras.username)
231 user = User.get_by_username(extras.username)
232 Repository.lock(Repository.get_by_repo_name(extras.repository),
232 Repository.lock(Repository.get_by_repo_name(extras.repository),
233 user.user_id,
233 user.user_id,
234 lock_reason=Repository.LOCK_PULL)
234 lock_reason=Repository.LOCK_PULL)
235 msg = 'Made lock on repo `%s`' % (extras.repository,)
235 msg = 'Made lock on repo `%s`' % (extras.repository,)
236 output += msg
236 output += msg
237
237
238 if extras.locked_by[0]:
238 if extras.locked_by[0]:
239 locked_by = User.get(extras.locked_by[0]).username
239 locked_by = User.get(extras.locked_by[0]).username
240 reason = extras.locked_by[2]
240 reason = extras.locked_by[2]
241 _http_ret = HTTPLockedRC(
241 _http_ret = HTTPLockedRC(
242 _locked_by_explanation(extras.repository, locked_by, reason))
242 _locked_by_explanation(extras.repository, locked_by, reason))
243 if str(_http_ret.code).startswith('2'):
243 if str(_http_ret.code).startswith('2'):
244 # 2xx Codes don't raise exceptions
244 # 2xx Codes don't raise exceptions
245 output += _http_ret.title
245 output += _http_ret.title
246
246
247 # Propagate to external components.
247 # Propagate to external components.
248 hook_response = ''
248 hook_response = ''
249 if not is_shadow_repo(extras):
249 if not is_shadow_repo(extras):
250 extras.hook_type = extras.hook_type or 'post_pull'
250 extras.hook_type = extras.hook_type or 'post_pull'
251 hook_response = post_pull_extension(
251 hook_response = post_pull_extension(
252 repo_store_path=Repository.base_path(), **extras)
252 repo_store_path=Repository.base_path(), **extras)
253 events.trigger(events.RepoPullEvent(
253 events.trigger(events.RepoPullEvent(
254 repo_name=extras.repository, extras=extras))
254 repo_name=extras.repository, extras=extras))
255
255
256 return HookResponse(0, output) + hook_response
256 return HookResponse(0, output) + hook_response
257
257
258
258
259 def post_push(extras):
259 def post_push(extras):
260 """Hook executed after user pushes to the repository."""
260 """Hook executed after user pushes to the repository."""
261 commit_ids = extras.commit_ids
261 commit_ids = extras.commit_ids
262
262
263 # log the push call
263 # log the push call
264 audit_user = audit_logger.UserWrap(
264 audit_user = audit_logger.UserWrap(
265 username=extras.username, ip_addr=extras.ip)
265 username=extras.username, ip_addr=extras.ip)
266 repo = audit_logger.RepoWrap(repo_name=extras.repository)
266 repo = audit_logger.RepoWrap(repo_name=extras.repository)
267 audit_logger.store(
267 audit_logger.store(
268 'user.push', action_data={
268 'user.push', action_data={
269 'user_agent': extras.user_agent,
269 'user_agent': extras.user_agent,
270 'commit_ids': commit_ids[:400]},
270 'commit_ids': commit_ids[:400]},
271 user=audit_user, repo=repo, commit=True)
271 user=audit_user, repo=repo, commit=True)
272
272
273 statsd = StatsdClient.statsd
273 statsd = StatsdClient.statsd
274 if statsd:
274 if statsd:
275 statsd.incr('rhodecode_push_total', tags=[
275 statsd.incr('rhodecode_push_total', tags=[
276 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
276 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
277 ])
277 ])
278
278
279 # Propagate to external components.
279 # Propagate to external components.
280 output = ''
280 output = ''
281 # make lock is a tri state False, True, None. We only release lock on False
281 # make lock is a tri state False, True, None. We only release lock on False
282 if extras.make_lock is False and not is_shadow_repo(extras):
282 if extras.make_lock is False and not is_shadow_repo(extras):
283 Repository.unlock(Repository.get_by_repo_name(extras.repository))
283 Repository.unlock(Repository.get_by_repo_name(extras.repository))
284 msg = 'Released lock on repo `{}`\n'.format(safe_str(extras.repository))
284 msg = 'Released lock on repo `{}`\n'.format(safe_str(extras.repository))
285 output += msg
285 output += msg
286
286
287 if extras.locked_by[0]:
287 if extras.locked_by[0]:
288 locked_by = User.get(extras.locked_by[0]).username
288 locked_by = User.get(extras.locked_by[0]).username
289 reason = extras.locked_by[2]
289 reason = extras.locked_by[2]
290 _http_ret = HTTPLockedRC(
290 _http_ret = HTTPLockedRC(
291 _locked_by_explanation(extras.repository, locked_by, reason))
291 _locked_by_explanation(extras.repository, locked_by, reason))
292 # TODO: johbo: if not?
292 # TODO: johbo: if not?
293 if str(_http_ret.code).startswith('2'):
293 if str(_http_ret.code).startswith('2'):
294 # 2xx Codes don't raise exceptions
294 # 2xx Codes don't raise exceptions
295 output += _http_ret.title
295 output += _http_ret.title
296
296
297 if extras.new_refs:
297 if extras.new_refs:
298 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
298 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
299 safe_str(extras.server_url), safe_str(extras.repository))
299 safe_str(extras.server_url), safe_str(extras.repository))
300
300
301 for branch_name in extras.new_refs['branches']:
301 for branch_name in extras.new_refs['branches']:
302 output += 'RhodeCode: open pull request link: {}\n'.format(
302 output += 'RhodeCode: open pull request link: {}\n'.format(
303 tmpl.format(ref_type='branch', ref_name=safe_str(branch_name)))
303 tmpl.format(ref_type='branch', ref_name=safe_str(branch_name)))
304
304
305 for book_name in extras.new_refs['bookmarks']:
305 for book_name in extras.new_refs['bookmarks']:
306 output += 'RhodeCode: open pull request link: {}\n'.format(
306 output += 'RhodeCode: open pull request link: {}\n'.format(
307 tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name)))
307 tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name)))
308
308
309 hook_response = ''
309 hook_response = ''
310 if not is_shadow_repo(extras):
310 if not is_shadow_repo(extras):
311 hook_response = post_push_extension(
311 hook_response = post_push_extension(
312 repo_store_path=Repository.base_path(),
312 repo_store_path=Repository.base_path(),
313 **extras)
313 **extras)
314 events.trigger(events.RepoPushEvent(
314 events.trigger(events.RepoPushEvent(
315 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
315 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
316
316
317 output += 'RhodeCode: push completed\n'
317 output += 'RhodeCode: push completed\n'
318 return HookResponse(0, output) + hook_response
318 return HookResponse(0, output) + hook_response
319
319
320
320
321 def _locked_by_explanation(repo_name, user_name, reason):
321 def _locked_by_explanation(repo_name, user_name, reason):
322 message = (
322 message = (
323 'Repository `%s` locked by user `%s`. Reason:`%s`'
323 'Repository `%s` locked by user `%s`. Reason:`%s`'
324 % (repo_name, user_name, reason))
324 % (repo_name, user_name, reason))
325 return message
325 return message
326
326
327
327
328 def check_allowed_create_user(user_dict, created_by, **kwargs):
328 def check_allowed_create_user(user_dict, created_by, **kwargs):
329 # pre create hooks
329 # pre create hooks
330 if pre_create_user.is_active():
330 if pre_create_user.is_active():
331 hook_result = pre_create_user(created_by=created_by, **user_dict)
331 hook_result = pre_create_user(created_by=created_by, **user_dict)
332 allowed = hook_result.status == 0
332 allowed = hook_result.status == 0
333 if not allowed:
333 if not allowed:
334 reason = hook_result.output
334 reason = hook_result.output
335 raise UserCreationError(reason)
335 raise UserCreationError(reason)
336
336
337
337
338 class ExtensionCallback(object):
338 class ExtensionCallback(object):
339 """
339 """
340 Forwards a given call to rcextensions, sanitizes keyword arguments.
340 Forwards a given call to rcextensions, sanitizes keyword arguments.
341
341
342 Does check if there is an extension active for that hook. If it is
342 Does check if there is an extension active for that hook. If it is
343 there, it will forward all `kwargs_keys` keyword arguments to the
343 there, it will forward all `kwargs_keys` keyword arguments to the
344 extension callback.
344 extension callback.
345 """
345 """
346
346
347 def __init__(self, hook_name, kwargs_keys):
347 def __init__(self, hook_name, kwargs_keys):
348 self._hook_name = hook_name
348 self._hook_name = hook_name
349 self._kwargs_keys = set(kwargs_keys)
349 self._kwargs_keys = set(kwargs_keys)
350
350
351 def __call__(self, *args, **kwargs):
351 def __call__(self, *args, **kwargs):
352 log.debug('Calling extension callback for `%s`', self._hook_name)
352 log.debug('Calling extension callback for `%s`', self._hook_name)
353 callback = self._get_callback()
353 callback = self._get_callback()
354 if not callback:
354 if not callback:
355 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
355 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
356 return
356 return
357
357
358 kwargs_to_pass = {}
358 kwargs_to_pass = {}
359 for key in self._kwargs_keys:
359 for key in self._kwargs_keys:
360 try:
360 try:
361 kwargs_to_pass[key] = kwargs[key]
361 kwargs_to_pass[key] = kwargs[key]
362 except KeyError:
362 except KeyError:
363 log.error('Failed to fetch %s key from given kwargs. '
363 log.error('Failed to fetch %s key from given kwargs. '
364 'Expected keys: %s', key, self._kwargs_keys)
364 'Expected keys: %s', key, self._kwargs_keys)
365 raise
365 raise
366
366
367 # backward compat for removed api_key for old hooks. This was it works
367 # backward compat for removed api_key for old hooks. This was it works
368 # with older rcextensions that require api_key present
368 # with older rcextensions that require api_key present
369 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
369 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
370 kwargs_to_pass['api_key'] = '_DEPRECATED_'
370 kwargs_to_pass['api_key'] = '_DEPRECATED_'
371 return callback(**kwargs_to_pass)
371 return callback(**kwargs_to_pass)
372
372
373 def is_active(self):
373 def is_active(self):
374 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
374 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
375
375
376 def _get_callback(self):
376 def _get_callback(self):
377 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
377 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
378
378
379
379
380 pre_pull_extension = ExtensionCallback(
380 pre_pull_extension = ExtensionCallback(
381 hook_name='PRE_PULL_HOOK',
381 hook_name='PRE_PULL_HOOK',
382 kwargs_keys=(
382 kwargs_keys=(
383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
385
385
386
386
387 post_pull_extension = ExtensionCallback(
387 post_pull_extension = ExtensionCallback(
388 hook_name='PULL_HOOK',
388 hook_name='PULL_HOOK',
389 kwargs_keys=(
389 kwargs_keys=(
390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
391 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
391 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
392
392
393
393
394 pre_push_extension = ExtensionCallback(
394 pre_push_extension = ExtensionCallback(
395 hook_name='PRE_PUSH_HOOK',
395 hook_name='PRE_PUSH_HOOK',
396 kwargs_keys=(
396 kwargs_keys=(
397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
399
399
400
400
401 post_push_extension = ExtensionCallback(
401 post_push_extension = ExtensionCallback(
402 hook_name='PUSH_HOOK',
402 hook_name='PUSH_HOOK',
403 kwargs_keys=(
403 kwargs_keys=(
404 'server_url', 'config', 'scm', 'username', 'ip', 'action',
404 'server_url', 'config', 'scm', 'username', 'ip', 'action',
405 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
405 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
406
406
407
407
408 pre_create_user = ExtensionCallback(
408 pre_create_user = ExtensionCallback(
409 hook_name='PRE_CREATE_USER_HOOK',
409 hook_name='PRE_CREATE_USER_HOOK',
410 kwargs_keys=(
410 kwargs_keys=(
411 'username', 'password', 'email', 'firstname', 'lastname', 'active',
411 'username', 'password', 'email', 'firstname', 'lastname', 'active',
412 'admin', 'created_by'))
412 'admin', 'created_by'))
413
413
414
414
415 create_pull_request = ExtensionCallback(
415 create_pull_request = ExtensionCallback(
416 hook_name='CREATE_PULL_REQUEST',
416 hook_name='CREATE_PULL_REQUEST',
417 kwargs_keys=(
417 kwargs_keys=(
418 'server_url', 'config', 'scm', 'username', 'ip', 'action',
418 'server_url', 'config', 'scm', 'username', 'ip', 'action',
419 'repository', 'pull_request_id', 'url', 'title', 'description',
419 'repository', 'pull_request_id', 'url', 'title', 'description',
420 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
420 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
421 'mergeable', 'source', 'target', 'author', 'reviewers'))
421 'mergeable', 'source', 'target', 'author', 'reviewers'))
422
422
423
423
424 merge_pull_request = ExtensionCallback(
424 merge_pull_request = ExtensionCallback(
425 hook_name='MERGE_PULL_REQUEST',
425 hook_name='MERGE_PULL_REQUEST',
426 kwargs_keys=(
426 kwargs_keys=(
427 'server_url', 'config', 'scm', 'username', 'ip', 'action',
427 'server_url', 'config', 'scm', 'username', 'ip', 'action',
428 'repository', 'pull_request_id', 'url', 'title', 'description',
428 'repository', 'pull_request_id', 'url', 'title', 'description',
429 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
429 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
430 'mergeable', 'source', 'target', 'author', 'reviewers'))
430 'mergeable', 'source', 'target', 'author', 'reviewers'))
431
431
432
432
433 close_pull_request = ExtensionCallback(
433 close_pull_request = ExtensionCallback(
434 hook_name='CLOSE_PULL_REQUEST',
434 hook_name='CLOSE_PULL_REQUEST',
435 kwargs_keys=(
435 kwargs_keys=(
436 'server_url', 'config', 'scm', 'username', 'ip', 'action',
436 'server_url', 'config', 'scm', 'username', 'ip', 'action',
437 'repository', 'pull_request_id', 'url', 'title', 'description',
437 'repository', 'pull_request_id', 'url', 'title', 'description',
438 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
438 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
439 'mergeable', 'source', 'target', 'author', 'reviewers'))
439 'mergeable', 'source', 'target', 'author', 'reviewers'))
440
440
441
441
442 review_pull_request = ExtensionCallback(
442 review_pull_request = ExtensionCallback(
443 hook_name='REVIEW_PULL_REQUEST',
443 hook_name='REVIEW_PULL_REQUEST',
444 kwargs_keys=(
444 kwargs_keys=(
445 'server_url', 'config', 'scm', 'username', 'ip', 'action',
445 'server_url', 'config', 'scm', 'username', 'ip', 'action',
446 'repository', 'pull_request_id', 'url', 'title', 'description',
446 'repository', 'pull_request_id', 'url', 'title', 'description',
447 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
447 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
448 'mergeable', 'source', 'target', 'author', 'reviewers'))
448 'mergeable', 'source', 'target', 'author', 'reviewers'))
449
449
450
450
451 comment_pull_request = ExtensionCallback(
451 comment_pull_request = ExtensionCallback(
452 hook_name='COMMENT_PULL_REQUEST',
452 hook_name='COMMENT_PULL_REQUEST',
453 kwargs_keys=(
453 kwargs_keys=(
454 'server_url', 'config', 'scm', 'username', 'ip', 'action',
454 'server_url', 'config', 'scm', 'username', 'ip', 'action',
455 'repository', 'pull_request_id', 'url', 'title', 'description',
455 'repository', 'pull_request_id', 'url', 'title', 'description',
456 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
456 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
457 'mergeable', 'source', 'target', 'author', 'reviewers'))
457 'mergeable', 'source', 'target', 'author', 'reviewers'))
458
458
459
459
460 comment_edit_pull_request = ExtensionCallback(
460 comment_edit_pull_request = ExtensionCallback(
461 hook_name='COMMENT_EDIT_PULL_REQUEST',
461 hook_name='COMMENT_EDIT_PULL_REQUEST',
462 kwargs_keys=(
462 kwargs_keys=(
463 'server_url', 'config', 'scm', 'username', 'ip', 'action',
463 'server_url', 'config', 'scm', 'username', 'ip', 'action',
464 'repository', 'pull_request_id', 'url', 'title', 'description',
464 'repository', 'pull_request_id', 'url', 'title', 'description',
465 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
465 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
466 'mergeable', 'source', 'target', 'author', 'reviewers'))
466 'mergeable', 'source', 'target', 'author', 'reviewers'))
467
467
468
468
469 update_pull_request = ExtensionCallback(
469 update_pull_request = ExtensionCallback(
470 hook_name='UPDATE_PULL_REQUEST',
470 hook_name='UPDATE_PULL_REQUEST',
471 kwargs_keys=(
471 kwargs_keys=(
472 'server_url', 'config', 'scm', 'username', 'ip', 'action',
472 'server_url', 'config', 'scm', 'username', 'ip', 'action',
473 'repository', 'pull_request_id', 'url', 'title', 'description',
473 'repository', 'pull_request_id', 'url', 'title', 'description',
474 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
474 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
475 'mergeable', 'source', 'target', 'author', 'reviewers'))
475 'mergeable', 'source', 'target', 'author', 'reviewers'))
476
476
477
477
478 create_user = ExtensionCallback(
478 create_user = ExtensionCallback(
479 hook_name='CREATE_USER_HOOK',
479 hook_name='CREATE_USER_HOOK',
480 kwargs_keys=(
480 kwargs_keys=(
481 'username', 'full_name_or_username', 'full_contact', 'user_id',
481 'username', 'full_name_or_username', 'full_contact', 'user_id',
482 'name', 'firstname', 'short_contact', 'admin', 'lastname',
482 'name', 'firstname', 'short_contact', 'admin', 'lastname',
483 'ip_addresses', 'extern_type', 'extern_name',
483 'ip_addresses', 'extern_type', 'extern_name',
484 'email', 'api_keys', 'last_login',
484 'email', 'api_keys', 'last_login',
485 'full_name', 'active', 'password', 'emails',
485 'full_name', 'active', 'password', 'emails',
486 'inherit_default_permissions', 'created_by', 'created_on'))
486 'inherit_default_permissions', 'created_by', 'created_on'))
487
487
488
488
489 delete_user = ExtensionCallback(
489 delete_user = ExtensionCallback(
490 hook_name='DELETE_USER_HOOK',
490 hook_name='DELETE_USER_HOOK',
491 kwargs_keys=(
491 kwargs_keys=(
492 'username', 'full_name_or_username', 'full_contact', 'user_id',
492 'username', 'full_name_or_username', 'full_contact', 'user_id',
493 'name', 'firstname', 'short_contact', 'admin', 'lastname',
493 'name', 'firstname', 'short_contact', 'admin', 'lastname',
494 'ip_addresses',
494 'ip_addresses',
495 'email', 'last_login',
495 'email', 'last_login',
496 'full_name', 'active', 'password', 'emails',
496 'full_name', 'active', 'password', 'emails',
497 'inherit_default_permissions', 'deleted_by'))
497 'inherit_default_permissions', 'deleted_by'))
498
498
499
499
500 create_repository = ExtensionCallback(
500 create_repository = ExtensionCallback(
501 hook_name='CREATE_REPO_HOOK',
501 hook_name='CREATE_REPO_HOOK',
502 kwargs_keys=(
502 kwargs_keys=(
503 'repo_name', 'repo_type', 'description', 'private', 'created_on',
503 'repo_name', 'repo_type', 'description', 'private', 'created_on',
504 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
504 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
505 'clone_uri', 'fork_id', 'group_id', 'created_by'))
505 'clone_uri', 'fork_id', 'group_id', 'created_by'))
506
506
507
507
508 delete_repository = ExtensionCallback(
508 delete_repository = ExtensionCallback(
509 hook_name='DELETE_REPO_HOOK',
509 hook_name='DELETE_REPO_HOOK',
510 kwargs_keys=(
510 kwargs_keys=(
511 'repo_name', 'repo_type', 'description', 'private', 'created_on',
511 'repo_name', 'repo_type', 'description', 'private', 'created_on',
512 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
512 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
513 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
513 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
514
514
515
515
516 comment_commit_repository = ExtensionCallback(
516 comment_commit_repository = ExtensionCallback(
517 hook_name='COMMENT_COMMIT_REPO_HOOK',
517 hook_name='COMMENT_COMMIT_REPO_HOOK',
518 kwargs_keys=(
518 kwargs_keys=(
519 'repo_name', 'repo_type', 'description', 'private', 'created_on',
519 'repo_name', 'repo_type', 'description', 'private', 'created_on',
520 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
520 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
521 'clone_uri', 'fork_id', 'group_id',
521 'clone_uri', 'fork_id', 'group_id',
522 'repository', 'created_by', 'comment', 'commit'))
522 'repository', 'created_by', 'comment', 'commit'))
523
523
524 comment_edit_commit_repository = ExtensionCallback(
524 comment_edit_commit_repository = ExtensionCallback(
525 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
525 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
526 kwargs_keys=(
526 kwargs_keys=(
527 'repo_name', 'repo_type', 'description', 'private', 'created_on',
527 'repo_name', 'repo_type', 'description', 'private', 'created_on',
528 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
528 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
529 'clone_uri', 'fork_id', 'group_id',
529 'clone_uri', 'fork_id', 'group_id',
530 'repository', 'created_by', 'comment', 'commit'))
530 'repository', 'created_by', 'comment', 'commit'))
531
531
532
532
533 create_repository_group = ExtensionCallback(
533 create_repository_group = ExtensionCallback(
534 hook_name='CREATE_REPO_GROUP_HOOK',
534 hook_name='CREATE_REPO_GROUP_HOOK',
535 kwargs_keys=(
535 kwargs_keys=(
536 'group_name', 'group_parent_id', 'group_description',
536 'group_name', 'group_parent_id', 'group_description',
537 'group_id', 'user_id', 'created_by', 'created_on',
537 'group_id', 'user_id', 'created_by', 'created_on',
538 'enable_locking'))
538 'enable_locking'))
@@ -1,187 +1,187 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import sys
21 import sys
22 import logging
22 import logging
23
23
24
24
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(30, 38)
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = list(range(30, 38))
26
26
27 # Sequences
27 # Sequences
28 RESET_SEQ = "\033[0m"
28 RESET_SEQ = "\033[0m"
29 COLOR_SEQ = "\033[0;%dm"
29 COLOR_SEQ = "\033[0;%dm"
30 BOLD_SEQ = "\033[1m"
30 BOLD_SEQ = "\033[1m"
31
31
32 COLORS = {
32 COLORS = {
33 'CRITICAL': MAGENTA,
33 'CRITICAL': MAGENTA,
34 'ERROR': RED,
34 'ERROR': RED,
35 'WARNING': CYAN,
35 'WARNING': CYAN,
36 'INFO': GREEN,
36 'INFO': GREEN,
37 'DEBUG': BLUE,
37 'DEBUG': BLUE,
38 'SQL': YELLOW
38 'SQL': YELLOW
39 }
39 }
40
40
41
41
42 def _inject_req_id(record, with_prefix=True):
42 def _inject_req_id(record, with_prefix=True):
43 from pyramid.threadlocal import get_current_request
43 from pyramid.threadlocal import get_current_request
44 dummy = '00000000-0000-0000-0000-000000000000'
44 dummy = '00000000-0000-0000-0000-000000000000'
45 req_id = None
45 req_id = None
46
46
47 req = get_current_request()
47 req = get_current_request()
48 if req:
48 if req:
49 req_id = getattr(req, 'req_id', None)
49 req_id = getattr(req, 'req_id', None)
50 if with_prefix:
50 if with_prefix:
51 req_id = 'req_id:%-36s' % (req_id or dummy)
51 req_id = 'req_id:%-36s' % (req_id or dummy)
52 else:
52 else:
53 req_id = (req_id or dummy)
53 req_id = (req_id or dummy)
54 record.req_id = req_id
54 record.req_id = req_id
55
55
56
56
57 def _add_log_to_debug_bucket(formatted_record):
57 def _add_log_to_debug_bucket(formatted_record):
58 from pyramid.threadlocal import get_current_request
58 from pyramid.threadlocal import get_current_request
59 req = get_current_request()
59 req = get_current_request()
60 if req:
60 if req:
61 req.req_id_bucket.append(formatted_record)
61 req.req_id_bucket.append(formatted_record)
62
62
63
63
64 def one_space_trim(s):
64 def one_space_trim(s):
65 if s.find(" ") == -1:
65 if s.find(" ") == -1:
66 return s
66 return s
67 else:
67 else:
68 s = s.replace(' ', ' ')
68 s = s.replace(' ', ' ')
69 return one_space_trim(s)
69 return one_space_trim(s)
70
70
71
71
72 def format_sql(sql):
72 def format_sql(sql):
73 sql = sql.replace('\n', '')
73 sql = sql.replace('\n', '')
74 sql = one_space_trim(sql)
74 sql = one_space_trim(sql)
75 sql = sql\
75 sql = sql\
76 .replace(',', ',\n\t')\
76 .replace(',', ',\n\t')\
77 .replace('SELECT', '\n\tSELECT \n\t')\
77 .replace('SELECT', '\n\tSELECT \n\t')\
78 .replace('UPDATE', '\n\tUPDATE \n\t')\
78 .replace('UPDATE', '\n\tUPDATE \n\t')\
79 .replace('DELETE', '\n\tDELETE \n\t')\
79 .replace('DELETE', '\n\tDELETE \n\t')\
80 .replace('FROM', '\n\tFROM')\
80 .replace('FROM', '\n\tFROM')\
81 .replace('ORDER BY', '\n\tORDER BY')\
81 .replace('ORDER BY', '\n\tORDER BY')\
82 .replace('LIMIT', '\n\tLIMIT')\
82 .replace('LIMIT', '\n\tLIMIT')\
83 .replace('WHERE', '\n\tWHERE')\
83 .replace('WHERE', '\n\tWHERE')\
84 .replace('AND', '\n\tAND')\
84 .replace('AND', '\n\tAND')\
85 .replace('LEFT', '\n\tLEFT')\
85 .replace('LEFT', '\n\tLEFT')\
86 .replace('INNER', '\n\tINNER')\
86 .replace('INNER', '\n\tINNER')\
87 .replace('INSERT', '\n\tINSERT')\
87 .replace('INSERT', '\n\tINSERT')\
88 .replace('DELETE', '\n\tDELETE')
88 .replace('DELETE', '\n\tDELETE')
89 return sql
89 return sql
90
90
91
91
92 class ExceptionAwareFormatter(logging.Formatter):
92 class ExceptionAwareFormatter(logging.Formatter):
93 """
93 """
94 Extended logging formatter which prints out remote tracebacks.
94 Extended logging formatter which prints out remote tracebacks.
95 """
95 """
96
96
97 def formatException(self, ei):
97 def formatException(self, ei):
98 ex_type, ex_value, ex_tb = ei
98 ex_type, ex_value, ex_tb = ei
99
99
100 local_tb = logging.Formatter.formatException(self, ei)
100 local_tb = logging.Formatter.formatException(self, ei)
101 if hasattr(ex_value, '_vcs_server_traceback'):
101 if hasattr(ex_value, '_vcs_server_traceback'):
102
102
103 def formatRemoteTraceback(remote_tb_lines):
103 def formatRemoteTraceback(remote_tb_lines):
104 result = ["\n +--- This exception occured remotely on VCSServer - Remote traceback:\n\n"]
104 result = ["\n +--- This exception occured remotely on VCSServer - Remote traceback:\n\n"]
105 result.append(remote_tb_lines)
105 result.append(remote_tb_lines)
106 result.append("\n +--- End of remote traceback\n")
106 result.append("\n +--- End of remote traceback\n")
107 return result
107 return result
108
108
109 try:
109 try:
110 if ex_type is not None and ex_value is None and ex_tb is None:
110 if ex_type is not None and ex_value is None and ex_tb is None:
111 # possible old (3.x) call syntax where caller is only
111 # possible old (3.x) call syntax where caller is only
112 # providing exception object
112 # providing exception object
113 if type(ex_type) is not type:
113 if type(ex_type) is not type:
114 raise TypeError(
114 raise TypeError(
115 "invalid argument: ex_type should be an exception "
115 "invalid argument: ex_type should be an exception "
116 "type, or just supply no arguments at all")
116 "type, or just supply no arguments at all")
117 if ex_type is None and ex_tb is None:
117 if ex_type is None and ex_tb is None:
118 ex_type, ex_value, ex_tb = sys.exc_info()
118 ex_type, ex_value, ex_tb = sys.exc_info()
119
119
120 remote_tb = getattr(ex_value, "_vcs_server_traceback", None)
120 remote_tb = getattr(ex_value, "_vcs_server_traceback", None)
121
121
122 if remote_tb:
122 if remote_tb:
123 remote_tb = formatRemoteTraceback(remote_tb)
123 remote_tb = formatRemoteTraceback(remote_tb)
124 return local_tb + ''.join(remote_tb)
124 return local_tb + ''.join(remote_tb)
125 finally:
125 finally:
126 # clean up cycle to traceback, to allow proper GC
126 # clean up cycle to traceback, to allow proper GC
127 del ex_type, ex_value, ex_tb
127 del ex_type, ex_value, ex_tb
128
128
129 return local_tb
129 return local_tb
130
130
131
131
132 class RequestTrackingFormatter(ExceptionAwareFormatter):
132 class RequestTrackingFormatter(ExceptionAwareFormatter):
133 def format(self, record):
133 def format(self, record):
134 _inject_req_id(record)
134 _inject_req_id(record)
135 def_record = logging.Formatter.format(self, record)
135 def_record = logging.Formatter.format(self, record)
136 _add_log_to_debug_bucket(def_record)
136 _add_log_to_debug_bucket(def_record)
137 return def_record
137 return def_record
138
138
139
139
140 class ColorFormatter(ExceptionAwareFormatter):
140 class ColorFormatter(ExceptionAwareFormatter):
141
141
142 def format(self, record):
142 def format(self, record):
143 """
143 """
144 Changes record's levelname to use with COLORS enum
144 Changes record's levelname to use with COLORS enum
145 """
145 """
146 def_record = super(ColorFormatter, self).format(record)
146 def_record = super(ColorFormatter, self).format(record)
147
147
148 levelname = record.levelname
148 levelname = record.levelname
149 start = COLOR_SEQ % (COLORS[levelname])
149 start = COLOR_SEQ % (COLORS[levelname])
150 end = RESET_SEQ
150 end = RESET_SEQ
151
151
152 colored_record = ''.join([start, def_record, end])
152 colored_record = ''.join([start, def_record, end])
153 return colored_record
153 return colored_record
154
154
155
155
156 class ColorRequestTrackingFormatter(RequestTrackingFormatter):
156 class ColorRequestTrackingFormatter(RequestTrackingFormatter):
157
157
158 def format(self, record):
158 def format(self, record):
159 """
159 """
160 Changes record's levelname to use with COLORS enum
160 Changes record's levelname to use with COLORS enum
161 """
161 """
162 def_record = super(ColorRequestTrackingFormatter, self).format(record)
162 def_record = super(ColorRequestTrackingFormatter, self).format(record)
163
163
164 levelname = record.levelname
164 levelname = record.levelname
165 start = COLOR_SEQ % (COLORS[levelname])
165 start = COLOR_SEQ % (COLORS[levelname])
166 end = RESET_SEQ
166 end = RESET_SEQ
167
167
168 colored_record = ''.join([start, def_record, end])
168 colored_record = ''.join([start, def_record, end])
169 return colored_record
169 return colored_record
170
170
171
171
172 class ColorFormatterSql(logging.Formatter):
172 class ColorFormatterSql(logging.Formatter):
173
173
174 def format(self, record):
174 def format(self, record):
175 """
175 """
176 Changes record's levelname to use with COLORS enum
176 Changes record's levelname to use with COLORS enum
177 """
177 """
178
178
179 start = COLOR_SEQ % (COLORS['SQL'])
179 start = COLOR_SEQ % (COLORS['SQL'])
180 def_record = format_sql(logging.Formatter.format(self, record))
180 def_record = format_sql(logging.Formatter.format(self, record))
181 end = RESET_SEQ
181 end = RESET_SEQ
182
182
183 colored_record = ''.join([start, def_record, end])
183 colored_record = ''.join([start, def_record, end])
184 return colored_record
184 return colored_record
185
185
186 # marcink: needs to stay with this name for backward .ini compatability
186 # marcink: needs to stay with this name for backward .ini compatability
187 Pyro4AwareFormatter = ExceptionAwareFormatter
187 Pyro4AwareFormatter = ExceptionAwareFormatter
@@ -1,580 +1,580 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Renderer for markup languages with ability to parse using rst or markdown
23 Renderer for markup languages with ability to parse using rst or markdown
24 """
24 """
25
25
26 import re
26 import re
27 import os
27 import os
28 import lxml
28 import lxml
29 import logging
29 import logging
30 import urllib.parse
30 import urllib.parse
31 import bleach
31 import bleach
32
32
33 from mako.lookup import TemplateLookup
33 from mako.lookup import TemplateLookup
34 from mako.template import Template as MakoTemplate
34 from mako.template import Template as MakoTemplate
35
35
36 from docutils.core import publish_parts
36 from docutils.core import publish_parts
37 from docutils.parsers.rst import directives
37 from docutils.parsers.rst import directives
38 from docutils import writers
38 from docutils import writers
39 from docutils.writers import html4css1
39 from docutils.writers import html4css1
40 import markdown
40 import markdown
41
41
42 from rhodecode.lib.markdown_ext import GithubFlavoredMarkdownExtension
42 from rhodecode.lib.markdown_ext import GithubFlavoredMarkdownExtension
43 from rhodecode.lib.utils2 import (safe_unicode, md5_safe, MENTIONS_REGEX)
43 from rhodecode.lib.utils2 import (safe_unicode, md5_safe, MENTIONS_REGEX)
44
44
45 log = logging.getLogger(__name__)
45 log = logging.getLogger(__name__)
46
46
47 # default renderer used to generate automated comments
47 # default renderer used to generate automated comments
48 DEFAULT_COMMENTS_RENDERER = 'rst'
48 DEFAULT_COMMENTS_RENDERER = 'rst'
49
49
50 try:
50 try:
51 from lxml.html import fromstring
51 from lxml.html import fromstring
52 from lxml.html import tostring
52 from lxml.html import tostring
53 except ImportError:
53 except ImportError:
54 log.exception('Failed to import lxml')
54 log.exception('Failed to import lxml')
55 fromstring = None
55 fromstring = None
56 tostring = None
56 tostring = None
57
57
58
58
59 class CustomHTMLTranslator(writers.html4css1.HTMLTranslator):
59 class CustomHTMLTranslator(writers.html4css1.HTMLTranslator):
60 """
60 """
61 Custom HTML Translator used for sandboxing potential
61 Custom HTML Translator used for sandboxing potential
62 JS injections in ref links
62 JS injections in ref links
63 """
63 """
64 def visit_literal_block(self, node):
64 def visit_literal_block(self, node):
65 self.body.append(self.starttag(node, 'pre', CLASS='codehilite literal-block'))
65 self.body.append(self.starttag(node, 'pre', CLASS='codehilite literal-block'))
66
66
67 def visit_reference(self, node):
67 def visit_reference(self, node):
68 if 'refuri' in node.attributes:
68 if 'refuri' in node.attributes:
69 refuri = node['refuri']
69 refuri = node['refuri']
70 if ':' in refuri:
70 if ':' in refuri:
71 prefix, link = refuri.lstrip().split(':', 1)
71 prefix, link = refuri.lstrip().split(':', 1)
72 prefix = prefix or ''
72 prefix = prefix or ''
73
73
74 if prefix.lower() == 'javascript':
74 if prefix.lower() == 'javascript':
75 # we don't allow javascript type of refs...
75 # we don't allow javascript type of refs...
76 node['refuri'] = 'javascript:alert("SandBoxedJavascript")'
76 node['refuri'] = 'javascript:alert("SandBoxedJavascript")'
77
77
78 # old style class requires this...
78 # old style class requires this...
79 return html4css1.HTMLTranslator.visit_reference(self, node)
79 return html4css1.HTMLTranslator.visit_reference(self, node)
80
80
81
81
82 class RhodeCodeWriter(writers.html4css1.Writer):
82 class RhodeCodeWriter(writers.html4css1.Writer):
83 def __init__(self):
83 def __init__(self):
84 writers.Writer.__init__(self)
84 writers.Writer.__init__(self)
85 self.translator_class = CustomHTMLTranslator
85 self.translator_class = CustomHTMLTranslator
86
86
87
87
88 def relative_links(html_source, server_paths):
88 def relative_links(html_source, server_paths):
89 if not html_source:
89 if not html_source:
90 return html_source
90 return html_source
91
91
92 if not fromstring and tostring:
92 if not fromstring and tostring:
93 return html_source
93 return html_source
94
94
95 try:
95 try:
96 doc = lxml.html.fromstring(html_source)
96 doc = lxml.html.fromstring(html_source)
97 except Exception:
97 except Exception:
98 return html_source
98 return html_source
99
99
100 for el in doc.cssselect('img, video'):
100 for el in doc.cssselect('img, video'):
101 src = el.attrib.get('src')
101 src = el.attrib.get('src')
102 if src:
102 if src:
103 el.attrib['src'] = relative_path(src, server_paths['raw'])
103 el.attrib['src'] = relative_path(src, server_paths['raw'])
104
104
105 for el in doc.cssselect('a:not(.gfm)'):
105 for el in doc.cssselect('a:not(.gfm)'):
106 src = el.attrib.get('href')
106 src = el.attrib.get('href')
107 if src:
107 if src:
108 raw_mode = el.attrib['href'].endswith('?raw=1')
108 raw_mode = el.attrib['href'].endswith('?raw=1')
109 if raw_mode:
109 if raw_mode:
110 el.attrib['href'] = relative_path(src, server_paths['raw'])
110 el.attrib['href'] = relative_path(src, server_paths['raw'])
111 else:
111 else:
112 el.attrib['href'] = relative_path(src, server_paths['standard'])
112 el.attrib['href'] = relative_path(src, server_paths['standard'])
113
113
114 return lxml.html.tostring(doc)
114 return lxml.html.tostring(doc)
115
115
116
116
117 def relative_path(path, request_path, is_repo_file=None):
117 def relative_path(path, request_path, is_repo_file=None):
118 """
118 """
119 relative link support, path is a rel path, and request_path is current
119 relative link support, path is a rel path, and request_path is current
120 server path (not absolute)
120 server path (not absolute)
121
121
122 e.g.
122 e.g.
123
123
124 path = '../logo.png'
124 path = '../logo.png'
125 request_path= '/repo/files/path/file.md'
125 request_path= '/repo/files/path/file.md'
126 produces: '/repo/files/logo.png'
126 produces: '/repo/files/logo.png'
127 """
127 """
128 # TODO(marcink): unicode/str support ?
128 # TODO(marcink): unicode/str support ?
129 # maybe=> safe_unicode(urllib.quote(safe_str(final_path), '/:'))
129 # maybe=> safe_unicode(urllib.quote(safe_str(final_path), '/:'))
130
130
131 def dummy_check(p):
131 def dummy_check(p):
132 return True # assume default is a valid file path
132 return True # assume default is a valid file path
133
133
134 is_repo_file = is_repo_file or dummy_check
134 is_repo_file = is_repo_file or dummy_check
135 if not path:
135 if not path:
136 return request_path
136 return request_path
137
137
138 path = safe_unicode(path)
138 path = safe_unicode(path)
139 request_path = safe_unicode(request_path)
139 request_path = safe_unicode(request_path)
140
140
141 if path.startswith((u'data:', u'javascript:', u'#', u':')):
141 if path.startswith(('data:', 'javascript:', '#', ':')):
142 # skip data, anchor, invalid links
142 # skip data, anchor, invalid links
143 return path
143 return path
144
144
145 is_absolute = bool(urllib.parse.urlparse(path).netloc)
145 is_absolute = bool(urllib.parse.urlparse(path).netloc)
146 if is_absolute:
146 if is_absolute:
147 return path
147 return path
148
148
149 if not request_path:
149 if not request_path:
150 return path
150 return path
151
151
152 if path.startswith(u'/'):
152 if path.startswith('/'):
153 path = path[1:]
153 path = path[1:]
154
154
155 if path.startswith(u'./'):
155 if path.startswith('./'):
156 path = path[2:]
156 path = path[2:]
157
157
158 parts = request_path.split('/')
158 parts = request_path.split('/')
159 # compute how deep we need to traverse the request_path
159 # compute how deep we need to traverse the request_path
160 depth = 0
160 depth = 0
161
161
162 if is_repo_file(request_path):
162 if is_repo_file(request_path):
163 # if request path is a VALID file, we use a relative path with
163 # if request path is a VALID file, we use a relative path with
164 # one level up
164 # one level up
165 depth += 1
165 depth += 1
166
166
167 while path.startswith(u'../'):
167 while path.startswith('../'):
168 depth += 1
168 depth += 1
169 path = path[3:]
169 path = path[3:]
170
170
171 if depth > 0:
171 if depth > 0:
172 parts = parts[:-depth]
172 parts = parts[:-depth]
173
173
174 parts.append(path)
174 parts.append(path)
175 final_path = u'/'.join(parts).lstrip(u'/')
175 final_path = '/'.join(parts).lstrip('/')
176
176
177 return u'/' + final_path
177 return '/' + final_path
178
178
179
179
180 _cached_markdown_renderer = None
180 _cached_markdown_renderer = None
181
181
182
182
183 def get_markdown_renderer(extensions, output_format):
183 def get_markdown_renderer(extensions, output_format):
184 global _cached_markdown_renderer
184 global _cached_markdown_renderer
185
185
186 if _cached_markdown_renderer is None:
186 if _cached_markdown_renderer is None:
187 _cached_markdown_renderer = markdown.Markdown(
187 _cached_markdown_renderer = markdown.Markdown(
188 extensions=extensions,
188 extensions=extensions,
189 enable_attributes=False, output_format=output_format)
189 enable_attributes=False, output_format=output_format)
190 return _cached_markdown_renderer
190 return _cached_markdown_renderer
191
191
192
192
193 _cached_markdown_renderer_flavored = None
193 _cached_markdown_renderer_flavored = None
194
194
195
195
196 def get_markdown_renderer_flavored(extensions, output_format):
196 def get_markdown_renderer_flavored(extensions, output_format):
197 global _cached_markdown_renderer_flavored
197 global _cached_markdown_renderer_flavored
198
198
199 if _cached_markdown_renderer_flavored is None:
199 if _cached_markdown_renderer_flavored is None:
200 _cached_markdown_renderer_flavored = markdown.Markdown(
200 _cached_markdown_renderer_flavored = markdown.Markdown(
201 extensions=extensions + [GithubFlavoredMarkdownExtension()],
201 extensions=extensions + [GithubFlavoredMarkdownExtension()],
202 enable_attributes=False, output_format=output_format)
202 enable_attributes=False, output_format=output_format)
203 return _cached_markdown_renderer_flavored
203 return _cached_markdown_renderer_flavored
204
204
205
205
206 class MarkupRenderer(object):
206 class MarkupRenderer(object):
207 RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES = ['include', 'meta', 'raw']
207 RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES = ['include', 'meta', 'raw']
208
208
209 MARKDOWN_PAT = re.compile(r'\.(md|mkdn?|mdown|markdown)$', re.IGNORECASE)
209 MARKDOWN_PAT = re.compile(r'\.(md|mkdn?|mdown|markdown)$', re.IGNORECASE)
210 RST_PAT = re.compile(r'\.re?st$', re.IGNORECASE)
210 RST_PAT = re.compile(r'\.re?st$', re.IGNORECASE)
211 JUPYTER_PAT = re.compile(r'\.(ipynb)$', re.IGNORECASE)
211 JUPYTER_PAT = re.compile(r'\.(ipynb)$', re.IGNORECASE)
212 PLAIN_PAT = re.compile(r'^readme$', re.IGNORECASE)
212 PLAIN_PAT = re.compile(r'^readme$', re.IGNORECASE)
213
213
214 URL_PAT = re.compile(r'(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]'
214 URL_PAT = re.compile(r'(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]'
215 r'|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)')
215 r'|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)')
216
216
217 MENTION_PAT = re.compile(MENTIONS_REGEX)
217 MENTION_PAT = re.compile(MENTIONS_REGEX)
218
218
219 extensions = ['markdown.extensions.codehilite', 'markdown.extensions.extra',
219 extensions = ['markdown.extensions.codehilite', 'markdown.extensions.extra',
220 'markdown.extensions.def_list', 'markdown.extensions.sane_lists']
220 'markdown.extensions.def_list', 'markdown.extensions.sane_lists']
221
221
222 output_format = 'html4'
222 output_format = 'html4'
223
223
224 # extension together with weights. Lower is first means we control how
224 # extension together with weights. Lower is first means we control how
225 # extensions are attached to readme names with those.
225 # extensions are attached to readme names with those.
226 PLAIN_EXTS = [
226 PLAIN_EXTS = [
227 # prefer no extension
227 # prefer no extension
228 ('', 0), # special case that renders READMES names without extension
228 ('', 0), # special case that renders READMES names without extension
229 ('.text', 2), ('.TEXT', 2),
229 ('.text', 2), ('.TEXT', 2),
230 ('.txt', 3), ('.TXT', 3)
230 ('.txt', 3), ('.TXT', 3)
231 ]
231 ]
232
232
233 RST_EXTS = [
233 RST_EXTS = [
234 ('.rst', 1), ('.rest', 1),
234 ('.rst', 1), ('.rest', 1),
235 ('.RST', 2), ('.REST', 2)
235 ('.RST', 2), ('.REST', 2)
236 ]
236 ]
237
237
238 MARKDOWN_EXTS = [
238 MARKDOWN_EXTS = [
239 ('.md', 1), ('.MD', 1),
239 ('.md', 1), ('.MD', 1),
240 ('.mkdn', 2), ('.MKDN', 2),
240 ('.mkdn', 2), ('.MKDN', 2),
241 ('.mdown', 3), ('.MDOWN', 3),
241 ('.mdown', 3), ('.MDOWN', 3),
242 ('.markdown', 4), ('.MARKDOWN', 4)
242 ('.markdown', 4), ('.MARKDOWN', 4)
243 ]
243 ]
244
244
245 def _detect_renderer(self, source, filename=None):
245 def _detect_renderer(self, source, filename=None):
246 """
246 """
247 runs detection of what renderer should be used for generating html
247 runs detection of what renderer should be used for generating html
248 from a markup language
248 from a markup language
249
249
250 filename can be also explicitly a renderer name
250 filename can be also explicitly a renderer name
251
251
252 :param source:
252 :param source:
253 :param filename:
253 :param filename:
254 """
254 """
255
255
256 if MarkupRenderer.MARKDOWN_PAT.findall(filename):
256 if MarkupRenderer.MARKDOWN_PAT.findall(filename):
257 detected_renderer = 'markdown'
257 detected_renderer = 'markdown'
258 elif MarkupRenderer.RST_PAT.findall(filename):
258 elif MarkupRenderer.RST_PAT.findall(filename):
259 detected_renderer = 'rst'
259 detected_renderer = 'rst'
260 elif MarkupRenderer.JUPYTER_PAT.findall(filename):
260 elif MarkupRenderer.JUPYTER_PAT.findall(filename):
261 detected_renderer = 'jupyter'
261 detected_renderer = 'jupyter'
262 elif MarkupRenderer.PLAIN_PAT.findall(filename):
262 elif MarkupRenderer.PLAIN_PAT.findall(filename):
263 detected_renderer = 'plain'
263 detected_renderer = 'plain'
264 else:
264 else:
265 detected_renderer = 'plain'
265 detected_renderer = 'plain'
266
266
267 return getattr(MarkupRenderer, detected_renderer)
267 return getattr(MarkupRenderer, detected_renderer)
268
268
269 @classmethod
269 @classmethod
270 def bleach_clean(cls, text):
270 def bleach_clean(cls, text):
271 from .bleach_whitelist import markdown_attrs, markdown_tags
271 from .bleach_whitelist import markdown_attrs, markdown_tags
272 allowed_tags = markdown_tags
272 allowed_tags = markdown_tags
273 allowed_attrs = markdown_attrs
273 allowed_attrs = markdown_attrs
274
274
275 try:
275 try:
276 return bleach.clean(text, tags=allowed_tags, attributes=allowed_attrs)
276 return bleach.clean(text, tags=allowed_tags, attributes=allowed_attrs)
277 except Exception:
277 except Exception:
278 return 'UNPARSEABLE TEXT'
278 return 'UNPARSEABLE TEXT'
279
279
280 @classmethod
280 @classmethod
281 def renderer_from_filename(cls, filename, exclude):
281 def renderer_from_filename(cls, filename, exclude):
282 """
282 """
283 Detect renderer markdown/rst from filename and optionally use exclude
283 Detect renderer markdown/rst from filename and optionally use exclude
284 list to remove some options. This is mostly used in helpers.
284 list to remove some options. This is mostly used in helpers.
285 Returns None when no renderer can be detected.
285 Returns None when no renderer can be detected.
286 """
286 """
287 def _filter(elements):
287 def _filter(elements):
288 if isinstance(exclude, (list, tuple)):
288 if isinstance(exclude, (list, tuple)):
289 return [x for x in elements if x not in exclude]
289 return [x for x in elements if x not in exclude]
290 return elements
290 return elements
291
291
292 if filename.endswith(
292 if filename.endswith(
293 tuple(_filter([x[0] for x in cls.MARKDOWN_EXTS if x[0]]))):
293 tuple(_filter([x[0] for x in cls.MARKDOWN_EXTS if x[0]]))):
294 return 'markdown'
294 return 'markdown'
295 if filename.endswith(tuple(_filter([x[0] for x in cls.RST_EXTS if x[0]]))):
295 if filename.endswith(tuple(_filter([x[0] for x in cls.RST_EXTS if x[0]]))):
296 return 'rst'
296 return 'rst'
297
297
298 return None
298 return None
299
299
300 def render(self, source, filename=None):
300 def render(self, source, filename=None):
301 """
301 """
302 Renders a given filename using detected renderer
302 Renders a given filename using detected renderer
303 it detects renderers based on file extension or mimetype.
303 it detects renderers based on file extension or mimetype.
304 At last it will just do a simple html replacing new lines with <br/>
304 At last it will just do a simple html replacing new lines with <br/>
305
305
306 :param file_name:
306 :param file_name:
307 :param source:
307 :param source:
308 """
308 """
309
309
310 renderer = self._detect_renderer(source, filename)
310 renderer = self._detect_renderer(source, filename)
311 readme_data = renderer(source)
311 readme_data = renderer(source)
312 return readme_data
312 return readme_data
313
313
314 @classmethod
314 @classmethod
315 def _flavored_markdown(cls, text):
315 def _flavored_markdown(cls, text):
316 """
316 """
317 Github style flavored markdown
317 Github style flavored markdown
318
318
319 :param text:
319 :param text:
320 """
320 """
321
321
322 # Extract pre blocks.
322 # Extract pre blocks.
323 extractions = {}
323 extractions = {}
324
324
325 def pre_extraction_callback(matchobj):
325 def pre_extraction_callback(matchobj):
326 digest = md5_safe(matchobj.group(0))
326 digest = md5_safe(matchobj.group(0))
327 extractions[digest] = matchobj.group(0)
327 extractions[digest] = matchobj.group(0)
328 return "{gfm-extraction-%s}" % digest
328 return "{gfm-extraction-%s}" % digest
329 pattern = re.compile(r'<pre>.*?</pre>', re.MULTILINE | re.DOTALL)
329 pattern = re.compile(r'<pre>.*?</pre>', re.MULTILINE | re.DOTALL)
330 text = re.sub(pattern, pre_extraction_callback, text)
330 text = re.sub(pattern, pre_extraction_callback, text)
331
331
332 # Prevent foo_bar_baz from ending up with an italic word in the middle.
332 # Prevent foo_bar_baz from ending up with an italic word in the middle.
333 def italic_callback(matchobj):
333 def italic_callback(matchobj):
334 s = matchobj.group(0)
334 s = matchobj.group(0)
335 if list(s).count('_') >= 2:
335 if list(s).count('_') >= 2:
336 return s.replace('_', r'\_')
336 return s.replace('_', r'\_')
337 return s
337 return s
338 text = re.sub(r'^(?! {4}|\t)\w+_\w+_\w[\w_]*', italic_callback, text)
338 text = re.sub(r'^(?! {4}|\t)\w+_\w+_\w[\w_]*', italic_callback, text)
339
339
340 # Insert pre block extractions.
340 # Insert pre block extractions.
341 def pre_insert_callback(matchobj):
341 def pre_insert_callback(matchobj):
342 return '\n\n' + extractions[matchobj.group(1)]
342 return '\n\n' + extractions[matchobj.group(1)]
343 text = re.sub(r'\{gfm-extraction-([0-9a-f]{32})\}',
343 text = re.sub(r'\{gfm-extraction-([0-9a-f]{32})\}',
344 pre_insert_callback, text)
344 pre_insert_callback, text)
345
345
346 return text
346 return text
347
347
348 @classmethod
348 @classmethod
349 def urlify_text(cls, text):
349 def urlify_text(cls, text):
350 def url_func(match_obj):
350 def url_func(match_obj):
351 url_full = match_obj.groups()[0]
351 url_full = match_obj.groups()[0]
352 return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
352 return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
353
353
354 return cls.URL_PAT.sub(url_func, text)
354 return cls.URL_PAT.sub(url_func, text)
355
355
356 @classmethod
356 @classmethod
357 def convert_mentions(cls, text, mode):
357 def convert_mentions(cls, text, mode):
358 mention_pat = cls.MENTION_PAT
358 mention_pat = cls.MENTION_PAT
359
359
360 def wrapp(match_obj):
360 def wrapp(match_obj):
361 uname = match_obj.groups()[0]
361 uname = match_obj.groups()[0]
362 hovercard_url = "pyroutes.url('hovercard_username', {'username': '%s'});" % uname
362 hovercard_url = "pyroutes.url('hovercard_username', {'username': '%s'});" % uname
363
363
364 if mode == 'markdown':
364 if mode == 'markdown':
365 tmpl = '<strong class="tooltip-hovercard" data-hovercard-alt="{uname}" data-hovercard-url="{hovercard_url}">@{uname}</strong>'
365 tmpl = '<strong class="tooltip-hovercard" data-hovercard-alt="{uname}" data-hovercard-url="{hovercard_url}">@{uname}</strong>'
366 elif mode == 'rst':
366 elif mode == 'rst':
367 tmpl = ' **@{uname}** '
367 tmpl = ' **@{uname}** '
368 else:
368 else:
369 raise ValueError('mode must be rst or markdown')
369 raise ValueError('mode must be rst or markdown')
370
370
371 return tmpl.format(**{'uname': uname,
371 return tmpl.format(**{'uname': uname,
372 'hovercard_url': hovercard_url})
372 'hovercard_url': hovercard_url})
373
373
374 return mention_pat.sub(wrapp, text).strip()
374 return mention_pat.sub(wrapp, text).strip()
375
375
376 @classmethod
376 @classmethod
377 def plain(cls, source, universal_newline=True, leading_newline=True):
377 def plain(cls, source, universal_newline=True, leading_newline=True):
378 source = safe_unicode(source)
378 source = safe_unicode(source)
379 if universal_newline:
379 if universal_newline:
380 newline = '\n'
380 newline = '\n'
381 source = newline.join(source.splitlines())
381 source = newline.join(source.splitlines())
382
382
383 rendered_source = cls.urlify_text(source)
383 rendered_source = cls.urlify_text(source)
384 source = ''
384 source = ''
385 if leading_newline:
385 if leading_newline:
386 source += '<br />'
386 source += '<br />'
387 source += rendered_source.replace("\n", '<br />')
387 source += rendered_source.replace("\n", '<br />')
388
388
389 rendered = cls.bleach_clean(source)
389 rendered = cls.bleach_clean(source)
390 return rendered
390 return rendered
391
391
392 @classmethod
392 @classmethod
393 def markdown(cls, source, safe=True, flavored=True, mentions=False,
393 def markdown(cls, source, safe=True, flavored=True, mentions=False,
394 clean_html=True):
394 clean_html=True):
395 """
395 """
396 returns markdown rendered code cleaned by the bleach library
396 returns markdown rendered code cleaned by the bleach library
397 """
397 """
398
398
399 if flavored:
399 if flavored:
400 markdown_renderer = get_markdown_renderer_flavored(
400 markdown_renderer = get_markdown_renderer_flavored(
401 cls.extensions, cls.output_format)
401 cls.extensions, cls.output_format)
402 else:
402 else:
403 markdown_renderer = get_markdown_renderer(
403 markdown_renderer = get_markdown_renderer(
404 cls.extensions, cls.output_format)
404 cls.extensions, cls.output_format)
405
405
406 if mentions:
406 if mentions:
407 mention_hl = cls.convert_mentions(source, mode='markdown')
407 mention_hl = cls.convert_mentions(source, mode='markdown')
408 # we extracted mentions render with this using Mentions false
408 # we extracted mentions render with this using Mentions false
409 return cls.markdown(mention_hl, safe=safe, flavored=flavored,
409 return cls.markdown(mention_hl, safe=safe, flavored=flavored,
410 mentions=False)
410 mentions=False)
411
411
412 source = safe_unicode(source)
412 source = safe_unicode(source)
413
413
414 try:
414 try:
415 if flavored:
415 if flavored:
416 source = cls._flavored_markdown(source)
416 source = cls._flavored_markdown(source)
417 rendered = markdown_renderer.convert(source)
417 rendered = markdown_renderer.convert(source)
418 except Exception:
418 except Exception:
419 log.exception('Error when rendering Markdown')
419 log.exception('Error when rendering Markdown')
420 if safe:
420 if safe:
421 log.debug('Fallback to render in plain mode')
421 log.debug('Fallback to render in plain mode')
422 rendered = cls.plain(source)
422 rendered = cls.plain(source)
423 else:
423 else:
424 raise
424 raise
425
425
426 if clean_html:
426 if clean_html:
427 rendered = cls.bleach_clean(rendered)
427 rendered = cls.bleach_clean(rendered)
428 return rendered
428 return rendered
429
429
430 @classmethod
430 @classmethod
431 def rst(cls, source, safe=True, mentions=False, clean_html=False):
431 def rst(cls, source, safe=True, mentions=False, clean_html=False):
432 if mentions:
432 if mentions:
433 mention_hl = cls.convert_mentions(source, mode='rst')
433 mention_hl = cls.convert_mentions(source, mode='rst')
434 # we extracted mentions render with this using Mentions false
434 # we extracted mentions render with this using Mentions false
435 return cls.rst(mention_hl, safe=safe, mentions=False)
435 return cls.rst(mention_hl, safe=safe, mentions=False)
436
436
437 source = safe_unicode(source)
437 source = safe_unicode(source)
438 try:
438 try:
439 docutils_settings = dict(
439 docutils_settings = dict(
440 [(alias, None) for alias in
440 [(alias, None) for alias in
441 cls.RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES])
441 cls.RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES])
442
442
443 docutils_settings.update({
443 docutils_settings.update({
444 'input_encoding': 'unicode',
444 'input_encoding': 'unicode',
445 'report_level': 4,
445 'report_level': 4,
446 'syntax_highlight': 'short',
446 'syntax_highlight': 'short',
447 })
447 })
448
448
449 for k, v in docutils_settings.items():
449 for k, v in docutils_settings.items():
450 directives.register_directive(k, v)
450 directives.register_directive(k, v)
451
451
452 parts = publish_parts(source=source,
452 parts = publish_parts(source=source,
453 writer=RhodeCodeWriter(),
453 writer=RhodeCodeWriter(),
454 settings_overrides=docutils_settings)
454 settings_overrides=docutils_settings)
455 rendered = parts["fragment"]
455 rendered = parts["fragment"]
456 if clean_html:
456 if clean_html:
457 rendered = cls.bleach_clean(rendered)
457 rendered = cls.bleach_clean(rendered)
458 return parts['html_title'] + rendered
458 return parts['html_title'] + rendered
459 except Exception:
459 except Exception:
460 log.exception('Error when rendering RST')
460 log.exception('Error when rendering RST')
461 if safe:
461 if safe:
462 log.debug('Fallback to render in plain mode')
462 log.debug('Fallback to render in plain mode')
463 return cls.plain(source)
463 return cls.plain(source)
464 else:
464 else:
465 raise
465 raise
466
466
467 @classmethod
467 @classmethod
468 def jupyter(cls, source, safe=True):
468 def jupyter(cls, source, safe=True):
469 from rhodecode.lib import helpers
469 from rhodecode.lib import helpers
470
470
471 from traitlets.config import Config
471 from traitlets.config import Config
472 import nbformat
472 import nbformat
473 from nbconvert import HTMLExporter
473 from nbconvert import HTMLExporter
474 from nbconvert.preprocessors import Preprocessor
474 from nbconvert.preprocessors import Preprocessor
475
475
476 class CustomHTMLExporter(HTMLExporter):
476 class CustomHTMLExporter(HTMLExporter):
477 def _template_file_default(self):
477 def _template_file_default(self):
478 return 'basic'
478 return 'basic'
479
479
480 class Sandbox(Preprocessor):
480 class Sandbox(Preprocessor):
481
481
482 def preprocess(self, nb, resources):
482 def preprocess(self, nb, resources):
483 sandbox_text = 'SandBoxed(IPython.core.display.Javascript object)'
483 sandbox_text = 'SandBoxed(IPython.core.display.Javascript object)'
484 for cell in nb['cells']:
484 for cell in nb['cells']:
485 if not safe:
485 if not safe:
486 continue
486 continue
487
487
488 if 'outputs' in cell:
488 if 'outputs' in cell:
489 for cell_output in cell['outputs']:
489 for cell_output in cell['outputs']:
490 if 'data' in cell_output:
490 if 'data' in cell_output:
491 if 'application/javascript' in cell_output['data']:
491 if 'application/javascript' in cell_output['data']:
492 cell_output['data']['text/plain'] = sandbox_text
492 cell_output['data']['text/plain'] = sandbox_text
493 cell_output['data'].pop('application/javascript', None)
493 cell_output['data'].pop('application/javascript', None)
494
494
495 if 'source' in cell and cell['cell_type'] == 'markdown':
495 if 'source' in cell and cell['cell_type'] == 'markdown':
496 # sanitize similar like in markdown
496 # sanitize similar like in markdown
497 cell['source'] = cls.bleach_clean(cell['source'])
497 cell['source'] = cls.bleach_clean(cell['source'])
498
498
499 return nb, resources
499 return nb, resources
500
500
501 def _sanitize_resources(input_resources):
501 def _sanitize_resources(input_resources):
502 """
502 """
503 Skip/sanitize some of the CSS generated and included in jupyter
503 Skip/sanitize some of the CSS generated and included in jupyter
504 so it doesn't messes up UI so much
504 so it doesn't messes up UI so much
505 """
505 """
506
506
507 # TODO(marcink): probably we should replace this with whole custom
507 # TODO(marcink): probably we should replace this with whole custom
508 # CSS set that doesn't screw up, but jupyter generated html has some
508 # CSS set that doesn't screw up, but jupyter generated html has some
509 # special markers, so it requires Custom HTML exporter template with
509 # special markers, so it requires Custom HTML exporter template with
510 # _default_template_path_default, to achieve that
510 # _default_template_path_default, to achieve that
511
511
512 # strip the reset CSS
512 # strip the reset CSS
513 input_resources[0] = input_resources[0][input_resources[0].find('/*! Source'):]
513 input_resources[0] = input_resources[0][input_resources[0].find('/*! Source'):]
514 return input_resources
514 return input_resources
515
515
516 def as_html(notebook):
516 def as_html(notebook):
517 conf = Config()
517 conf = Config()
518 conf.CustomHTMLExporter.preprocessors = [Sandbox]
518 conf.CustomHTMLExporter.preprocessors = [Sandbox]
519 html_exporter = CustomHTMLExporter(config=conf)
519 html_exporter = CustomHTMLExporter(config=conf)
520
520
521 (body, resources) = html_exporter.from_notebook_node(notebook)
521 (body, resources) = html_exporter.from_notebook_node(notebook)
522 header = '<!-- ## IPYTHON NOTEBOOK RENDERING ## -->'
522 header = '<!-- ## IPYTHON NOTEBOOK RENDERING ## -->'
523 js = MakoTemplate(r'''
523 js = MakoTemplate(r'''
524 <!-- MathJax configuration -->
524 <!-- MathJax configuration -->
525 <script type="text/x-mathjax-config">
525 <script type="text/x-mathjax-config">
526 MathJax.Hub.Config({
526 MathJax.Hub.Config({
527 jax: ["input/TeX","output/HTML-CSS", "output/PreviewHTML"],
527 jax: ["input/TeX","output/HTML-CSS", "output/PreviewHTML"],
528 extensions: ["tex2jax.js","MathMenu.js","MathZoom.js", "fast-preview.js", "AssistiveMML.js", "[Contrib]/a11y/accessibility-menu.js"],
528 extensions: ["tex2jax.js","MathMenu.js","MathZoom.js", "fast-preview.js", "AssistiveMML.js", "[Contrib]/a11y/accessibility-menu.js"],
529 TeX: {
529 TeX: {
530 extensions: ["AMSmath.js","AMSsymbols.js","noErrors.js","noUndefined.js"]
530 extensions: ["AMSmath.js","AMSsymbols.js","noErrors.js","noUndefined.js"]
531 },
531 },
532 tex2jax: {
532 tex2jax: {
533 inlineMath: [ ['$','$'], ["\\(","\\)"] ],
533 inlineMath: [ ['$','$'], ["\\(","\\)"] ],
534 displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
534 displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
535 processEscapes: true,
535 processEscapes: true,
536 processEnvironments: true
536 processEnvironments: true
537 },
537 },
538 // Center justify equations in code and markdown cells. Elsewhere
538 // Center justify equations in code and markdown cells. Elsewhere
539 // we use CSS to left justify single line equations in code cells.
539 // we use CSS to left justify single line equations in code cells.
540 displayAlign: 'center',
540 displayAlign: 'center',
541 "HTML-CSS": {
541 "HTML-CSS": {
542 styles: {'.MathJax_Display': {"margin": 0}},
542 styles: {'.MathJax_Display': {"margin": 0}},
543 linebreaks: { automatic: true },
543 linebreaks: { automatic: true },
544 availableFonts: ["STIX", "TeX"]
544 availableFonts: ["STIX", "TeX"]
545 },
545 },
546 showMathMenu: false
546 showMathMenu: false
547 });
547 });
548 </script>
548 </script>
549 <!-- End of MathJax configuration -->
549 <!-- End of MathJax configuration -->
550 <script src="${h.asset('js/src/math_jax/MathJax.js')}"></script>
550 <script src="${h.asset('js/src/math_jax/MathJax.js')}"></script>
551 ''').render(h=helpers)
551 ''').render(h=helpers)
552
552
553 css = MakoTemplate(r'''
553 css = MakoTemplate(r'''
554 <link rel="stylesheet" type="text/css" href="${h.asset('css/style-ipython.css', ver=ver)}" media="screen"/>
554 <link rel="stylesheet" type="text/css" href="${h.asset('css/style-ipython.css', ver=ver)}" media="screen"/>
555 ''').render(h=helpers, ver='ver1')
555 ''').render(h=helpers, ver='ver1')
556
556
557 body = '\n'.join([header, css, js, body])
557 body = '\n'.join([header, css, js, body])
558 return body, resources
558 return body, resources
559
559
560 notebook = nbformat.reads(source, as_version=4)
560 notebook = nbformat.reads(source, as_version=4)
561 (body, resources) = as_html(notebook)
561 (body, resources) = as_html(notebook)
562 return body
562 return body
563
563
564
564
565 class RstTemplateRenderer(object):
565 class RstTemplateRenderer(object):
566
566
567 def __init__(self):
567 def __init__(self):
568 base = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
568 base = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
569 rst_template_dirs = [os.path.join(base, 'templates', 'rst_templates')]
569 rst_template_dirs = [os.path.join(base, 'templates', 'rst_templates')]
570 self.template_store = TemplateLookup(
570 self.template_store = TemplateLookup(
571 directories=rst_template_dirs,
571 directories=rst_template_dirs,
572 input_encoding='utf-8',
572 input_encoding='utf-8',
573 imports=['from rhodecode.lib import helpers as h'])
573 imports=['from rhodecode.lib import helpers as h'])
574
574
575 def _get_template(self, templatename):
575 def _get_template(self, templatename):
576 return self.template_store.get_template(templatename)
576 return self.template_store.get_template(templatename)
577
577
578 def render(self, template_name, **kwargs):
578 def render(self, template_name, **kwargs):
579 template = self._get_template(template_name)
579 template = self._get_template(template_name)
580 return template.render(**kwargs)
580 return template.render(**kwargs)
@@ -1,156 +1,156 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleGit middleware for handling git protocol request (push/clone etc.)
22 SimpleGit middleware for handling git protocol request (push/clone etc.)
23 It's implemented with basic auth function
23 It's implemented with basic auth function
24 """
24 """
25 import os
25 import os
26 import re
26 import re
27 import logging
27 import logging
28 import urllib.parse
28 import urllib.parse
29
29
30 import rhodecode
30 import rhodecode
31 from rhodecode.lib import utils
31 from rhodecode.lib import utils
32 from rhodecode.lib import utils2
32 from rhodecode.lib import utils2
33 from rhodecode.lib.middleware import simplevcs
33 from rhodecode.lib.middleware import simplevcs
34
34
35 log = logging.getLogger(__name__)
35 log = logging.getLogger(__name__)
36
36
37
37
38 GIT_PROTO_PAT = re.compile(
38 GIT_PROTO_PAT = re.compile(
39 r'^/(.+)/(info/refs|info/lfs/(.+)|git-upload-pack|git-receive-pack)')
39 r'^/(.+)/(info/refs|info/lfs/(.+)|git-upload-pack|git-receive-pack)')
40 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
40 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
41
41
42
42
43 def default_lfs_store():
43 def default_lfs_store():
44 """
44 """
45 Default lfs store location, it's consistent with Mercurials large file
45 Default lfs store location, it's consistent with Mercurials large file
46 store which is in .cache/largefiles
46 store which is in .cache/largefiles
47 """
47 """
48 from rhodecode.lib.vcs.backends.git import lfs_store
48 from rhodecode.lib.vcs.backends.git import lfs_store
49 user_home = os.path.expanduser("~")
49 user_home = os.path.expanduser("~")
50 return lfs_store(user_home)
50 return lfs_store(user_home)
51
51
52
52
53 class SimpleGit(simplevcs.SimpleVCS):
53 class SimpleGit(simplevcs.SimpleVCS):
54
54
55 SCM = 'git'
55 SCM = 'git'
56
56
57 def _get_repository_name(self, environ):
57 def _get_repository_name(self, environ):
58 """
58 """
59 Gets repository name out of PATH_INFO header
59 Gets repository name out of PATH_INFO header
60
60
61 :param environ: environ where PATH_INFO is stored
61 :param environ: environ where PATH_INFO is stored
62 """
62 """
63 repo_name = GIT_PROTO_PAT.match(environ['PATH_INFO']).group(1)
63 repo_name = GIT_PROTO_PAT.match(environ['PATH_INFO']).group(1)
64 # for GIT LFS, and bare format strip .git suffix from names
64 # for GIT LFS, and bare format strip .git suffix from names
65 if repo_name.endswith('.git'):
65 if repo_name.endswith('.git'):
66 repo_name = repo_name[:-4]
66 repo_name = repo_name[:-4]
67 return repo_name
67 return repo_name
68
68
69 def _get_lfs_action(self, path, request_method):
69 def _get_lfs_action(self, path, request_method):
70 """
70 """
71 return an action based on LFS requests type.
71 return an action based on LFS requests type.
72 Those routes are handled inside vcsserver app.
72 Those routes are handled inside vcsserver app.
73
73
74 batch -> POST to /info/lfs/objects/batch => PUSH/PULL
74 batch -> POST to /info/lfs/objects/batch => PUSH/PULL
75 batch is based on the `operation.
75 batch is based on the `operation.
76 that could be download or upload, but those are only
76 that could be download or upload, but those are only
77 instructions to fetch so we return pull always
77 instructions to fetch so we return pull always
78
78
79 download -> GET to /info/lfs/{oid} => PULL
79 download -> GET to /info/lfs/{oid} => PULL
80 upload -> PUT to /info/lfs/{oid} => PUSH
80 upload -> PUT to /info/lfs/{oid} => PUSH
81
81
82 verification -> POST to /info/lfs/verify => PULL
82 verification -> POST to /info/lfs/verify => PULL
83
83
84 """
84 """
85
85
86 match_obj = GIT_LFS_PROTO_PAT.match(path)
86 match_obj = GIT_LFS_PROTO_PAT.match(path)
87 _parts = match_obj.groups()
87 _parts = match_obj.groups()
88 repo_name, path, operation = _parts
88 repo_name, path, operation = _parts
89 log.debug(
89 log.debug(
90 'LFS: detecting operation based on following '
90 'LFS: detecting operation based on following '
91 'data: %s, req_method:%s', _parts, request_method)
91 'data: %s, req_method:%s', _parts, request_method)
92
92
93 if operation == 'verify':
93 if operation == 'verify':
94 return 'pull'
94 return 'pull'
95 elif operation == 'objects/batch':
95 elif operation == 'objects/batch':
96 # batch sends back instructions for API to dl/upl we report it
96 # batch sends back instructions for API to dl/upl we report it
97 # as pull
97 # as pull
98 if request_method == 'POST':
98 if request_method == 'POST':
99 return 'pull'
99 return 'pull'
100
100
101 elif operation:
101 elif operation:
102 # probably a OID, upload is PUT, download a GET
102 # probably a OID, upload is PUT, download a GET
103 if request_method == 'GET':
103 if request_method == 'GET':
104 return 'pull'
104 return 'pull'
105 else:
105 else:
106 return 'push'
106 return 'push'
107
107
108 # if default not found require push, as action
108 # if default not found require push, as action
109 return 'push'
109 return 'push'
110
110
111 _ACTION_MAPPING = {
111 _ACTION_MAPPING = {
112 'git-receive-pack': 'push',
112 'git-receive-pack': 'push',
113 'git-upload-pack': 'pull',
113 'git-upload-pack': 'pull',
114 }
114 }
115
115
116 def _get_action(self, environ):
116 def _get_action(self, environ):
117 """
117 """
118 Maps git request commands into a pull or push command.
118 Maps git request commands into a pull or push command.
119 In case of unknown/unexpected data, it returns 'pull' to be safe.
119 In case of unknown/unexpected data, it returns 'pull' to be safe.
120
120
121 :param environ:
121 :param environ:
122 """
122 """
123 path = environ['PATH_INFO']
123 path = environ['PATH_INFO']
124
124
125 if path.endswith('/info/refs'):
125 if path.endswith('/info/refs'):
126 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'])
126 query = urllib.parse.parse_qs(environ['QUERY_STRING'])
127 service_cmd = query.get('service', [''])[0]
127 service_cmd = query.get('service', [''])[0]
128 return self._ACTION_MAPPING.get(service_cmd, 'pull')
128 return self._ACTION_MAPPING.get(service_cmd, 'pull')
129
129
130 elif GIT_LFS_PROTO_PAT.match(environ['PATH_INFO']):
130 elif GIT_LFS_PROTO_PAT.match(environ['PATH_INFO']):
131 return self._get_lfs_action(
131 return self._get_lfs_action(
132 environ['PATH_INFO'], environ['REQUEST_METHOD'])
132 environ['PATH_INFO'], environ['REQUEST_METHOD'])
133
133
134 elif path.endswith('/git-receive-pack'):
134 elif path.endswith('/git-receive-pack'):
135 return 'push'
135 return 'push'
136 elif path.endswith('/git-upload-pack'):
136 elif path.endswith('/git-upload-pack'):
137 return 'pull'
137 return 'pull'
138
138
139 return 'pull'
139 return 'pull'
140
140
141 def _create_wsgi_app(self, repo_path, repo_name, config):
141 def _create_wsgi_app(self, repo_path, repo_name, config):
142 return self.scm_app.create_git_wsgi_app(
142 return self.scm_app.create_git_wsgi_app(
143 repo_path, repo_name, config)
143 repo_path, repo_name, config)
144
144
145 def _create_config(self, extras, repo_name, scheme='http'):
145 def _create_config(self, extras, repo_name, scheme='http'):
146 extras['git_update_server_info'] = utils2.str2bool(
146 extras['git_update_server_info'] = utils2.str2bool(
147 rhodecode.CONFIG.get('git_update_server_info'))
147 rhodecode.CONFIG.get('git_update_server_info'))
148
148
149 config = utils.make_db_config(repo=repo_name)
149 config = utils.make_db_config(repo=repo_name)
150 custom_store = config.get('vcs_git_lfs', 'store_location')
150 custom_store = config.get('vcs_git_lfs', 'store_location')
151
151
152 extras['git_lfs_enabled'] = utils2.str2bool(
152 extras['git_lfs_enabled'] = utils2.str2bool(
153 config.get('vcs_git_lfs', 'enabled'))
153 config.get('vcs_git_lfs', 'enabled'))
154 extras['git_lfs_store_path'] = custom_store or default_lfs_store()
154 extras['git_lfs_store_path'] = custom_store or default_lfs_store()
155 extras['git_lfs_http_scheme'] = scheme
155 extras['git_lfs_http_scheme'] = scheme
156 return extras
156 return extras
@@ -1,160 +1,159 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleHG middleware for handling mercurial protocol request
22 SimpleHG middleware for handling mercurial protocol request
23 (push/clone etc.). It's implemented with basic auth function
23 (push/clone etc.). It's implemented with basic auth function
24 """
24 """
25
25
26 import logging
26 import logging
27 import urllib.parse
27 import urllib.parse
28 import urllib.request, urllib.parse, urllib.error
28 import urllib.request, urllib.parse, urllib.error
29
29
30 from rhodecode.lib import utils
30 from rhodecode.lib import utils
31 from rhodecode.lib.ext_json import json
31 from rhodecode.lib.ext_json import json
32 from rhodecode.lib.middleware import simplevcs
32 from rhodecode.lib.middleware import simplevcs
33
33
34 log = logging.getLogger(__name__)
34 log = logging.getLogger(__name__)
35
35
36
36
37 class SimpleHg(simplevcs.SimpleVCS):
37 class SimpleHg(simplevcs.SimpleVCS):
38
38
39 SCM = 'hg'
39 SCM = 'hg'
40
40
41 def _get_repository_name(self, environ):
41 def _get_repository_name(self, environ):
42 """
42 """
43 Gets repository name out of PATH_INFO header
43 Gets repository name out of PATH_INFO header
44
44
45 :param environ: environ where PATH_INFO is stored
45 :param environ: environ where PATH_INFO is stored
46 """
46 """
47 repo_name = environ['PATH_INFO']
47 repo_name = environ['PATH_INFO']
48 if repo_name and repo_name.startswith('/'):
48 if repo_name and repo_name.startswith('/'):
49 # remove only the first leading /
49 # remove only the first leading /
50 repo_name = repo_name[1:]
50 repo_name = repo_name[1:]
51 return repo_name.rstrip('/')
51 return repo_name.rstrip('/')
52
52
53 _ACTION_MAPPING = {
53 _ACTION_MAPPING = {
54 'changegroup': 'pull',
54 'changegroup': 'pull',
55 'changegroupsubset': 'pull',
55 'changegroupsubset': 'pull',
56 'getbundle': 'pull',
56 'getbundle': 'pull',
57 'stream_out': 'pull',
57 'stream_out': 'pull',
58 'listkeys': 'pull',
58 'listkeys': 'pull',
59 'between': 'pull',
59 'between': 'pull',
60 'branchmap': 'pull',
60 'branchmap': 'pull',
61 'branches': 'pull',
61 'branches': 'pull',
62 'clonebundles': 'pull',
62 'clonebundles': 'pull',
63 'capabilities': 'pull',
63 'capabilities': 'pull',
64 'debugwireargs': 'pull',
64 'debugwireargs': 'pull',
65 'heads': 'pull',
65 'heads': 'pull',
66 'lookup': 'pull',
66 'lookup': 'pull',
67 'hello': 'pull',
67 'hello': 'pull',
68 'known': 'pull',
68 'known': 'pull',
69
69
70 # largefiles
70 # largefiles
71 'putlfile': 'push',
71 'putlfile': 'push',
72 'getlfile': 'pull',
72 'getlfile': 'pull',
73 'statlfile': 'pull',
73 'statlfile': 'pull',
74 'lheads': 'pull',
74 'lheads': 'pull',
75
75
76 # evolve
76 # evolve
77 'evoext_obshashrange_v1': 'pull',
77 'evoext_obshashrange_v1': 'pull',
78 'evoext_obshash': 'pull',
78 'evoext_obshash': 'pull',
79 'evoext_obshash1': 'pull',
79 'evoext_obshash1': 'pull',
80
80
81 'unbundle': 'push',
81 'unbundle': 'push',
82 'pushkey': 'push',
82 'pushkey': 'push',
83 }
83 }
84
84
85 @classmethod
85 @classmethod
86 def _get_xarg_headers(cls, environ):
86 def _get_xarg_headers(cls, environ):
87 i = 1
87 i = 1
88 chunks = [] # gather chunks stored in multiple 'hgarg_N'
88 chunks = [] # gather chunks stored in multiple 'hgarg_N'
89 while True:
89 while True:
90 head = environ.get('HTTP_X_HGARG_{}'.format(i))
90 head = environ.get('HTTP_X_HGARG_{}'.format(i))
91 if not head:
91 if not head:
92 break
92 break
93 i += 1
93 i += 1
94 chunks.append(urllib.parse.unquote_plus(head))
94 chunks.append(urllib.parse.unquote_plus(head))
95 full_arg = ''.join(chunks)
95 full_arg = ''.join(chunks)
96 pref = 'cmds='
96 pref = 'cmds='
97 if full_arg.startswith(pref):
97 if full_arg.startswith(pref):
98 # strip the cmds= header defining our batch commands
98 # strip the cmds= header defining our batch commands
99 full_arg = full_arg[len(pref):]
99 full_arg = full_arg[len(pref):]
100 cmds = full_arg.split(';')
100 cmds = full_arg.split(';')
101 return cmds
101 return cmds
102
102
103 @classmethod
103 @classmethod
104 def _get_batch_cmd(cls, environ):
104 def _get_batch_cmd(cls, environ):
105 """
105 """
106 Handle batch command send commands. Those are ';' separated commands
106 Handle batch command send commands. Those are ';' separated commands
107 sent by batch command that server needs to execute. We need to extract
107 sent by batch command that server needs to execute. We need to extract
108 those, and map them to our ACTION_MAPPING to get all push/pull commands
108 those, and map them to our ACTION_MAPPING to get all push/pull commands
109 specified in the batch
109 specified in the batch
110 """
110 """
111 default = 'push'
111 default = 'push'
112 batch_cmds = []
112 batch_cmds = []
113 try:
113 try:
114 cmds = cls._get_xarg_headers(environ)
114 cmds = cls._get_xarg_headers(environ)
115 for pair in cmds:
115 for pair in cmds:
116 parts = pair.split(' ', 1)
116 parts = pair.split(' ', 1)
117 if len(parts) != 2:
117 if len(parts) != 2:
118 continue
118 continue
119 # entry should be in a format `key ARGS`
119 # entry should be in a format `key ARGS`
120 cmd, args = parts
120 cmd, args = parts
121 action = cls._ACTION_MAPPING.get(cmd, default)
121 action = cls._ACTION_MAPPING.get(cmd, default)
122 batch_cmds.append(action)
122 batch_cmds.append(action)
123 except Exception:
123 except Exception:
124 log.exception('Failed to extract batch commands operations')
124 log.exception('Failed to extract batch commands operations')
125
125
126 # in case we failed, (e.g malformed data) assume it's PUSH sub-command
126 # in case we failed, (e.g malformed data) assume it's PUSH sub-command
127 # for safety
127 # for safety
128 return batch_cmds or [default]
128 return batch_cmds or [default]
129
129
130 def _get_action(self, environ):
130 def _get_action(self, environ):
131 """
131 """
132 Maps mercurial request commands into a pull or push command.
132 Maps mercurial request commands into a pull or push command.
133 In case of unknown/unexpected data, it returns 'push' to be safe.
133 In case of unknown/unexpected data, it returns 'push' to be safe.
134
134
135 :param environ:
135 :param environ:
136 """
136 """
137 default = 'push'
137 default = 'push'
138 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'],
138 query = urllib.parse.parse_qs(environ['QUERY_STRING'], keep_blank_values=True)
139 keep_blank_values=True)
140
139
141 if 'cmd' in query:
140 if 'cmd' in query:
142 cmd = query['cmd'][0]
141 cmd = query['cmd'][0]
143 if cmd == 'batch':
142 if cmd == 'batch':
144 cmds = self._get_batch_cmd(environ)
143 cmds = self._get_batch_cmd(environ)
145 if 'push' in cmds:
144 if 'push' in cmds:
146 return 'push'
145 return 'push'
147 else:
146 else:
148 return 'pull'
147 return 'pull'
149 return self._ACTION_MAPPING.get(cmd, default)
148 return self._ACTION_MAPPING.get(cmd, default)
150
149
151 return default
150 return default
152
151
153 def _create_wsgi_app(self, repo_path, repo_name, config):
152 def _create_wsgi_app(self, repo_path, repo_name, config):
154 return self.scm_app.create_hg_wsgi_app(repo_path, repo_name, config)
153 return self.scm_app.create_hg_wsgi_app(repo_path, repo_name, config)
155
154
156 def _create_config(self, extras, repo_name, scheme='http'):
155 def _create_config(self, extras, repo_name, scheme='http'):
157 config = utils.make_db_config(repo=repo_name)
156 config = utils.make_db_config(repo=repo_name)
158 config.set('rhodecode', 'RC_SCM_DATA', json.dumps(extras))
157 config.set('rhodecode', 'RC_SCM_DATA', json.dumps(extras))
159
158
160 return config.serialize()
159 return config.serialize()
@@ -1,679 +1,679 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2014-2020 RhodeCode GmbH
3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleVCS middleware for handling protocol request (push/clone etc.)
22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 It's implemented with basic auth function
23 It's implemented with basic auth function
24 """
24 """
25
25
26 import os
26 import os
27 import re
27 import re
28 import io
28 import logging
29 import logging
29 import importlib
30 import importlib
30 from functools import wraps
31 from functools import wraps
31 from io import StringIO
32 from lxml import etree
32 from lxml import etree
33
33
34 import time
34 import time
35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36
36
37 from pyramid.httpexceptions import (
37 from pyramid.httpexceptions import (
38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 from zope.cachedescriptors.property import Lazy as LazyProperty
39 from zope.cachedescriptors.property import Lazy as LazyProperty
40
40
41 import rhodecode
41 import rhodecode
42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 from rhodecode.lib import rc_cache
43 from rhodecode.lib import rc_cache
44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 from rhodecode.lib.base import (
45 from rhodecode.lib.base import (
46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
49 from rhodecode.lib.middleware import appenlight
49 from rhodecode.lib.middleware import appenlight
50 from rhodecode.lib.middleware.utils import scm_app_http
50 from rhodecode.lib.middleware.utils import scm_app_http
51 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
51 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
52 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool, safe_unicode
52 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool, safe_unicode
53 from rhodecode.lib.vcs.conf import settings as vcs_settings
53 from rhodecode.lib.vcs.conf import settings as vcs_settings
54 from rhodecode.lib.vcs.backends import base
54 from rhodecode.lib.vcs.backends import base
55
55
56 from rhodecode.model import meta
56 from rhodecode.model import meta
57 from rhodecode.model.db import User, Repository, PullRequest
57 from rhodecode.model.db import User, Repository, PullRequest
58 from rhodecode.model.scm import ScmModel
58 from rhodecode.model.scm import ScmModel
59 from rhodecode.model.pull_request import PullRequestModel
59 from rhodecode.model.pull_request import PullRequestModel
60 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
60 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
61
61
62 log = logging.getLogger(__name__)
62 log = logging.getLogger(__name__)
63
63
64
64
65 def extract_svn_txn_id(acl_repo_name, data):
65 def extract_svn_txn_id(acl_repo_name, data):
66 """
66 """
67 Helper method for extraction of svn txn_id from submitted XML data during
67 Helper method for extraction of svn txn_id from submitted XML data during
68 POST operations
68 POST operations
69 """
69 """
70 try:
70 try:
71 root = etree.fromstring(data)
71 root = etree.fromstring(data)
72 pat = re.compile(r'/txn/(?P<txn_id>.*)')
72 pat = re.compile(r'/txn/(?P<txn_id>.*)')
73 for el in root:
73 for el in root:
74 if el.tag == '{DAV:}source':
74 if el.tag == '{DAV:}source':
75 for sub_el in el:
75 for sub_el in el:
76 if sub_el.tag == '{DAV:}href':
76 if sub_el.tag == '{DAV:}href':
77 match = pat.search(sub_el.text)
77 match = pat.search(sub_el.text)
78 if match:
78 if match:
79 svn_tx_id = match.groupdict()['txn_id']
79 svn_tx_id = match.groupdict()['txn_id']
80 txn_id = rc_cache.utils.compute_key_from_params(
80 txn_id = rc_cache.utils.compute_key_from_params(
81 acl_repo_name, svn_tx_id)
81 acl_repo_name, svn_tx_id)
82 return txn_id
82 return txn_id
83 except Exception:
83 except Exception:
84 log.exception('Failed to extract txn_id')
84 log.exception('Failed to extract txn_id')
85
85
86
86
87 def initialize_generator(factory):
87 def initialize_generator(factory):
88 """
88 """
89 Initializes the returned generator by draining its first element.
89 Initializes the returned generator by draining its first element.
90
90
91 This can be used to give a generator an initializer, which is the code
91 This can be used to give a generator an initializer, which is the code
92 up to the first yield statement. This decorator enforces that the first
92 up to the first yield statement. This decorator enforces that the first
93 produced element has the value ``"__init__"`` to make its special
93 produced element has the value ``"__init__"`` to make its special
94 purpose very explicit in the using code.
94 purpose very explicit in the using code.
95 """
95 """
96
96
97 @wraps(factory)
97 @wraps(factory)
98 def wrapper(*args, **kwargs):
98 def wrapper(*args, **kwargs):
99 gen = factory(*args, **kwargs)
99 gen = factory(*args, **kwargs)
100 try:
100 try:
101 init = next(gen)
101 init = next(gen)
102 except StopIteration:
102 except StopIteration:
103 raise ValueError('Generator must yield at least one element.')
103 raise ValueError('Generator must yield at least one element.')
104 if init != "__init__":
104 if init != "__init__":
105 raise ValueError('First yielded element must be "__init__".')
105 raise ValueError('First yielded element must be "__init__".')
106 return gen
106 return gen
107 return wrapper
107 return wrapper
108
108
109
109
110 class SimpleVCS(object):
110 class SimpleVCS(object):
111 """Common functionality for SCM HTTP handlers."""
111 """Common functionality for SCM HTTP handlers."""
112
112
113 SCM = 'unknown'
113 SCM = 'unknown'
114
114
115 acl_repo_name = None
115 acl_repo_name = None
116 url_repo_name = None
116 url_repo_name = None
117 vcs_repo_name = None
117 vcs_repo_name = None
118 rc_extras = {}
118 rc_extras = {}
119
119
120 # We have to handle requests to shadow repositories different than requests
120 # We have to handle requests to shadow repositories different than requests
121 # to normal repositories. Therefore we have to distinguish them. To do this
121 # to normal repositories. Therefore we have to distinguish them. To do this
122 # we use this regex which will match only on URLs pointing to shadow
122 # we use this regex which will match only on URLs pointing to shadow
123 # repositories.
123 # repositories.
124 shadow_repo_re = re.compile(
124 shadow_repo_re = re.compile(
125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
126 '(?P<target>{slug_pat})/' # target repo
126 '(?P<target>{slug_pat})/' # target repo
127 'pull-request/(?P<pr_id>\d+)/' # pull request
127 'pull-request/(?P<pr_id>\\d+)/' # pull request
128 'repository$' # shadow repo
128 'repository$' # shadow repo
129 .format(slug_pat=SLUG_RE.pattern))
129 .format(slug_pat=SLUG_RE.pattern))
130
130
131 def __init__(self, config, registry):
131 def __init__(self, config, registry):
132 self.registry = registry
132 self.registry = registry
133 self.config = config
133 self.config = config
134 # re-populated by specialized middleware
134 # re-populated by specialized middleware
135 self.repo_vcs_config = base.Config()
135 self.repo_vcs_config = base.Config()
136
136
137 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
137 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
138 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
138 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
139
139
140 # authenticate this VCS request using authfunc
140 # authenticate this VCS request using authfunc
141 auth_ret_code_detection = \
141 auth_ret_code_detection = \
142 str2bool(self.config.get('auth_ret_code_detection', False))
142 str2bool(self.config.get('auth_ret_code_detection', False))
143 self.authenticate = BasicAuth(
143 self.authenticate = BasicAuth(
144 '', authenticate, registry, config.get('auth_ret_code'),
144 '', authenticate, registry, config.get('auth_ret_code'),
145 auth_ret_code_detection, rc_realm=realm)
145 auth_ret_code_detection, rc_realm=realm)
146 self.ip_addr = '0.0.0.0'
146 self.ip_addr = '0.0.0.0'
147
147
148 @LazyProperty
148 @LazyProperty
149 def global_vcs_config(self):
149 def global_vcs_config(self):
150 try:
150 try:
151 return VcsSettingsModel().get_ui_settings_as_config_obj()
151 return VcsSettingsModel().get_ui_settings_as_config_obj()
152 except Exception:
152 except Exception:
153 return base.Config()
153 return base.Config()
154
154
155 @property
155 @property
156 def base_path(self):
156 def base_path(self):
157 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
157 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
158
158
159 if not settings_path:
159 if not settings_path:
160 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
160 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
161
161
162 if not settings_path:
162 if not settings_path:
163 # try, maybe we passed in explicitly as config option
163 # try, maybe we passed in explicitly as config option
164 settings_path = self.config.get('base_path')
164 settings_path = self.config.get('base_path')
165
165
166 if not settings_path:
166 if not settings_path:
167 raise ValueError('FATAL: base_path is empty')
167 raise ValueError('FATAL: base_path is empty')
168 return settings_path
168 return settings_path
169
169
170 def set_repo_names(self, environ):
170 def set_repo_names(self, environ):
171 """
171 """
172 This will populate the attributes acl_repo_name, url_repo_name,
172 This will populate the attributes acl_repo_name, url_repo_name,
173 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
173 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
174 shadow) repositories all names are equal. In case of requests to a
174 shadow) repositories all names are equal. In case of requests to a
175 shadow repository the acl-name points to the target repo of the pull
175 shadow repository the acl-name points to the target repo of the pull
176 request and the vcs-name points to the shadow repo file system path.
176 request and the vcs-name points to the shadow repo file system path.
177 The url-name is always the URL used by the vcs client program.
177 The url-name is always the URL used by the vcs client program.
178
178
179 Example in case of a shadow repo:
179 Example in case of a shadow repo:
180 acl_repo_name = RepoGroup/MyRepo
180 acl_repo_name = RepoGroup/MyRepo
181 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
181 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
182 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
182 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
183 """
183 """
184 # First we set the repo name from URL for all attributes. This is the
184 # First we set the repo name from URL for all attributes. This is the
185 # default if handling normal (non shadow) repo requests.
185 # default if handling normal (non shadow) repo requests.
186 self.url_repo_name = self._get_repository_name(environ)
186 self.url_repo_name = self._get_repository_name(environ)
187 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
187 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
188 self.is_shadow_repo = False
188 self.is_shadow_repo = False
189
189
190 # Check if this is a request to a shadow repository.
190 # Check if this is a request to a shadow repository.
191 match = self.shadow_repo_re.match(self.url_repo_name)
191 match = self.shadow_repo_re.match(self.url_repo_name)
192 if match:
192 if match:
193 match_dict = match.groupdict()
193 match_dict = match.groupdict()
194
194
195 # Build acl repo name from regex match.
195 # Build acl repo name from regex match.
196 acl_repo_name = safe_unicode('{groups}{target}'.format(
196 acl_repo_name = safe_unicode('{groups}{target}'.format(
197 groups=match_dict['groups'] or '',
197 groups=match_dict['groups'] or '',
198 target=match_dict['target']))
198 target=match_dict['target']))
199
199
200 # Retrieve pull request instance by ID from regex match.
200 # Retrieve pull request instance by ID from regex match.
201 pull_request = PullRequest.get(match_dict['pr_id'])
201 pull_request = PullRequest.get(match_dict['pr_id'])
202
202
203 # Only proceed if we got a pull request and if acl repo name from
203 # Only proceed if we got a pull request and if acl repo name from
204 # URL equals the target repo name of the pull request.
204 # URL equals the target repo name of the pull request.
205 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
205 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
206
206
207 # Get file system path to shadow repository.
207 # Get file system path to shadow repository.
208 workspace_id = PullRequestModel()._workspace_id(pull_request)
208 workspace_id = PullRequestModel()._workspace_id(pull_request)
209 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
209 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
210
210
211 # Store names for later usage.
211 # Store names for later usage.
212 self.vcs_repo_name = vcs_repo_name
212 self.vcs_repo_name = vcs_repo_name
213 self.acl_repo_name = acl_repo_name
213 self.acl_repo_name = acl_repo_name
214 self.is_shadow_repo = True
214 self.is_shadow_repo = True
215
215
216 log.debug('Setting all VCS repository names: %s', {
216 log.debug('Setting all VCS repository names: %s', {
217 'acl_repo_name': self.acl_repo_name,
217 'acl_repo_name': self.acl_repo_name,
218 'url_repo_name': self.url_repo_name,
218 'url_repo_name': self.url_repo_name,
219 'vcs_repo_name': self.vcs_repo_name,
219 'vcs_repo_name': self.vcs_repo_name,
220 })
220 })
221
221
222 @property
222 @property
223 def scm_app(self):
223 def scm_app(self):
224 custom_implementation = self.config['vcs.scm_app_implementation']
224 custom_implementation = self.config['vcs.scm_app_implementation']
225 if custom_implementation == 'http':
225 if custom_implementation == 'http':
226 log.debug('Using HTTP implementation of scm app.')
226 log.debug('Using HTTP implementation of scm app.')
227 scm_app_impl = scm_app_http
227 scm_app_impl = scm_app_http
228 else:
228 else:
229 log.debug('Using custom implementation of scm_app: "{}"'.format(
229 log.debug('Using custom implementation of scm_app: "{}"'.format(
230 custom_implementation))
230 custom_implementation))
231 scm_app_impl = importlib.import_module(custom_implementation)
231 scm_app_impl = importlib.import_module(custom_implementation)
232 return scm_app_impl
232 return scm_app_impl
233
233
234 def _get_by_id(self, repo_name):
234 def _get_by_id(self, repo_name):
235 """
235 """
236 Gets a special pattern _<ID> from clone url and tries to replace it
236 Gets a special pattern _<ID> from clone url and tries to replace it
237 with a repository_name for support of _<ID> non changeable urls
237 with a repository_name for support of _<ID> non changeable urls
238 """
238 """
239
239
240 data = repo_name.split('/')
240 data = repo_name.split('/')
241 if len(data) >= 2:
241 if len(data) >= 2:
242 from rhodecode.model.repo import RepoModel
242 from rhodecode.model.repo import RepoModel
243 by_id_match = RepoModel().get_repo_by_id(repo_name)
243 by_id_match = RepoModel().get_repo_by_id(repo_name)
244 if by_id_match:
244 if by_id_match:
245 data[1] = by_id_match.repo_name
245 data[1] = by_id_match.repo_name
246
246
247 return safe_str('/'.join(data))
247 return safe_str('/'.join(data))
248
248
249 def _invalidate_cache(self, repo_name):
249 def _invalidate_cache(self, repo_name):
250 """
250 """
251 Set's cache for this repository for invalidation on next access
251 Set's cache for this repository for invalidation on next access
252
252
253 :param repo_name: full repo name, also a cache key
253 :param repo_name: full repo name, also a cache key
254 """
254 """
255 ScmModel().mark_for_invalidation(repo_name)
255 ScmModel().mark_for_invalidation(repo_name)
256
256
257 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
257 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
258 db_repo = Repository.get_by_repo_name(repo_name)
258 db_repo = Repository.get_by_repo_name(repo_name)
259 if not db_repo:
259 if not db_repo:
260 log.debug('Repository `%s` not found inside the database.',
260 log.debug('Repository `%s` not found inside the database.',
261 repo_name)
261 repo_name)
262 return False
262 return False
263
263
264 if db_repo.repo_type != scm_type:
264 if db_repo.repo_type != scm_type:
265 log.warning(
265 log.warning(
266 'Repository `%s` have incorrect scm_type, expected %s got %s',
266 'Repository `%s` have incorrect scm_type, expected %s got %s',
267 repo_name, db_repo.repo_type, scm_type)
267 repo_name, db_repo.repo_type, scm_type)
268 return False
268 return False
269
269
270 config = db_repo._config
270 config = db_repo._config
271 config.set('extensions', 'largefiles', '')
271 config.set('extensions', 'largefiles', '')
272 return is_valid_repo(
272 return is_valid_repo(
273 repo_name, base_path,
273 repo_name, base_path,
274 explicit_scm=scm_type, expect_scm=scm_type, config=config)
274 explicit_scm=scm_type, expect_scm=scm_type, config=config)
275
275
276 def valid_and_active_user(self, user):
276 def valid_and_active_user(self, user):
277 """
277 """
278 Checks if that user is not empty, and if it's actually object it checks
278 Checks if that user is not empty, and if it's actually object it checks
279 if he's active.
279 if he's active.
280
280
281 :param user: user object or None
281 :param user: user object or None
282 :return: boolean
282 :return: boolean
283 """
283 """
284 if user is None:
284 if user is None:
285 return False
285 return False
286
286
287 elif user.active:
287 elif user.active:
288 return True
288 return True
289
289
290 return False
290 return False
291
291
292 @property
292 @property
293 def is_shadow_repo_dir(self):
293 def is_shadow_repo_dir(self):
294 return os.path.isdir(self.vcs_repo_name)
294 return os.path.isdir(self.vcs_repo_name)
295
295
296 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
296 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
297 plugin_id='', plugin_cache_active=False, cache_ttl=0):
297 plugin_id='', plugin_cache_active=False, cache_ttl=0):
298 """
298 """
299 Checks permissions using action (push/pull) user and repository
299 Checks permissions using action (push/pull) user and repository
300 name. If plugin_cache and ttl is set it will use the plugin which
300 name. If plugin_cache and ttl is set it will use the plugin which
301 authenticated the user to store the cached permissions result for N
301 authenticated the user to store the cached permissions result for N
302 amount of seconds as in cache_ttl
302 amount of seconds as in cache_ttl
303
303
304 :param action: push or pull action
304 :param action: push or pull action
305 :param user: user instance
305 :param user: user instance
306 :param repo_name: repository name
306 :param repo_name: repository name
307 """
307 """
308
308
309 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
309 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
310 plugin_id, plugin_cache_active, cache_ttl)
310 plugin_id, plugin_cache_active, cache_ttl)
311
311
312 user_id = user.user_id
312 user_id = user.user_id
313 cache_namespace_uid = 'cache_user_auth.{}'.format(user_id)
313 cache_namespace_uid = 'cache_user_auth.{}'.format(user_id)
314 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
314 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
315
315
316 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
316 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
317 expiration_time=cache_ttl,
317 expiration_time=cache_ttl,
318 condition=plugin_cache_active)
318 condition=plugin_cache_active)
319 def compute_perm_vcs(
319 def compute_perm_vcs(
320 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
320 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
321
321
322 log.debug('auth: calculating permission access now...')
322 log.debug('auth: calculating permission access now...')
323 # check IP
323 # check IP
324 inherit = user.inherit_default_permissions
324 inherit = user.inherit_default_permissions
325 ip_allowed = AuthUser.check_ip_allowed(
325 ip_allowed = AuthUser.check_ip_allowed(
326 user_id, ip_addr, inherit_from_default=inherit)
326 user_id, ip_addr, inherit_from_default=inherit)
327 if ip_allowed:
327 if ip_allowed:
328 log.info('Access for IP:%s allowed', ip_addr)
328 log.info('Access for IP:%s allowed', ip_addr)
329 else:
329 else:
330 return False
330 return False
331
331
332 if action == 'push':
332 if action == 'push':
333 perms = ('repository.write', 'repository.admin')
333 perms = ('repository.write', 'repository.admin')
334 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
334 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
335 return False
335 return False
336
336
337 else:
337 else:
338 # any other action need at least read permission
338 # any other action need at least read permission
339 perms = (
339 perms = (
340 'repository.read', 'repository.write', 'repository.admin')
340 'repository.read', 'repository.write', 'repository.admin')
341 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
341 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
342 return False
342 return False
343
343
344 return True
344 return True
345
345
346 start = time.time()
346 start = time.time()
347 log.debug('Running plugin `%s` permissions check', plugin_id)
347 log.debug('Running plugin `%s` permissions check', plugin_id)
348
348
349 # for environ based auth, password can be empty, but then the validation is
349 # for environ based auth, password can be empty, but then the validation is
350 # on the server that fills in the env data needed for authentication
350 # on the server that fills in the env data needed for authentication
351 perm_result = compute_perm_vcs(
351 perm_result = compute_perm_vcs(
352 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
352 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
353
353
354 auth_time = time.time() - start
354 auth_time = time.time() - start
355 log.debug('Permissions for plugin `%s` completed in %.4fs, '
355 log.debug('Permissions for plugin `%s` completed in %.4fs, '
356 'expiration time of fetched cache %.1fs.',
356 'expiration time of fetched cache %.1fs.',
357 plugin_id, auth_time, cache_ttl)
357 plugin_id, auth_time, cache_ttl)
358
358
359 return perm_result
359 return perm_result
360
360
361 def _get_http_scheme(self, environ):
361 def _get_http_scheme(self, environ):
362 try:
362 try:
363 return environ['wsgi.url_scheme']
363 return environ['wsgi.url_scheme']
364 except Exception:
364 except Exception:
365 log.exception('Failed to read http scheme')
365 log.exception('Failed to read http scheme')
366 return 'http'
366 return 'http'
367
367
368 def _check_ssl(self, environ, start_response):
368 def _check_ssl(self, environ, start_response):
369 """
369 """
370 Checks the SSL check flag and returns False if SSL is not present
370 Checks the SSL check flag and returns False if SSL is not present
371 and required True otherwise
371 and required True otherwise
372 """
372 """
373 org_proto = environ['wsgi._org_proto']
373 org_proto = environ['wsgi._org_proto']
374 # check if we have SSL required ! if not it's a bad request !
374 # check if we have SSL required ! if not it's a bad request !
375 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
375 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
376 if require_ssl and org_proto == 'http':
376 if require_ssl and org_proto == 'http':
377 log.debug(
377 log.debug(
378 'Bad request: detected protocol is `%s` and '
378 'Bad request: detected protocol is `%s` and '
379 'SSL/HTTPS is required.', org_proto)
379 'SSL/HTTPS is required.', org_proto)
380 return False
380 return False
381 return True
381 return True
382
382
383 def _get_default_cache_ttl(self):
383 def _get_default_cache_ttl(self):
384 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
384 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
385 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
385 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
386 plugin_settings = plugin.get_settings()
386 plugin_settings = plugin.get_settings()
387 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
387 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
388 plugin_settings) or (False, 0)
388 plugin_settings) or (False, 0)
389 return plugin_cache_active, cache_ttl
389 return plugin_cache_active, cache_ttl
390
390
391 def __call__(self, environ, start_response):
391 def __call__(self, environ, start_response):
392 try:
392 try:
393 return self._handle_request(environ, start_response)
393 return self._handle_request(environ, start_response)
394 except Exception:
394 except Exception:
395 log.exception("Exception while handling request")
395 log.exception("Exception while handling request")
396 appenlight.track_exception(environ)
396 appenlight.track_exception(environ)
397 return HTTPInternalServerError()(environ, start_response)
397 return HTTPInternalServerError()(environ, start_response)
398 finally:
398 finally:
399 meta.Session.remove()
399 meta.Session.remove()
400
400
401 def _handle_request(self, environ, start_response):
401 def _handle_request(self, environ, start_response):
402 if not self._check_ssl(environ, start_response):
402 if not self._check_ssl(environ, start_response):
403 reason = ('SSL required, while RhodeCode was unable '
403 reason = ('SSL required, while RhodeCode was unable '
404 'to detect this as SSL request')
404 'to detect this as SSL request')
405 log.debug('User not allowed to proceed, %s', reason)
405 log.debug('User not allowed to proceed, %s', reason)
406 return HTTPNotAcceptable(reason)(environ, start_response)
406 return HTTPNotAcceptable(reason)(environ, start_response)
407
407
408 if not self.url_repo_name:
408 if not self.url_repo_name:
409 log.warning('Repository name is empty: %s', self.url_repo_name)
409 log.warning('Repository name is empty: %s', self.url_repo_name)
410 # failed to get repo name, we fail now
410 # failed to get repo name, we fail now
411 return HTTPNotFound()(environ, start_response)
411 return HTTPNotFound()(environ, start_response)
412 log.debug('Extracted repo name is %s', self.url_repo_name)
412 log.debug('Extracted repo name is %s', self.url_repo_name)
413
413
414 ip_addr = get_ip_addr(environ)
414 ip_addr = get_ip_addr(environ)
415 user_agent = get_user_agent(environ)
415 user_agent = get_user_agent(environ)
416 username = None
416 username = None
417
417
418 # skip passing error to error controller
418 # skip passing error to error controller
419 environ['pylons.status_code_redirect'] = True
419 environ['pylons.status_code_redirect'] = True
420
420
421 # ======================================================================
421 # ======================================================================
422 # GET ACTION PULL or PUSH
422 # GET ACTION PULL or PUSH
423 # ======================================================================
423 # ======================================================================
424 action = self._get_action(environ)
424 action = self._get_action(environ)
425
425
426 # ======================================================================
426 # ======================================================================
427 # Check if this is a request to a shadow repository of a pull request.
427 # Check if this is a request to a shadow repository of a pull request.
428 # In this case only pull action is allowed.
428 # In this case only pull action is allowed.
429 # ======================================================================
429 # ======================================================================
430 if self.is_shadow_repo and action != 'pull':
430 if self.is_shadow_repo and action != 'pull':
431 reason = 'Only pull action is allowed for shadow repositories.'
431 reason = 'Only pull action is allowed for shadow repositories.'
432 log.debug('User not allowed to proceed, %s', reason)
432 log.debug('User not allowed to proceed, %s', reason)
433 return HTTPNotAcceptable(reason)(environ, start_response)
433 return HTTPNotAcceptable(reason)(environ, start_response)
434
434
435 # Check if the shadow repo actually exists, in case someone refers
435 # Check if the shadow repo actually exists, in case someone refers
436 # to it, and it has been deleted because of successful merge.
436 # to it, and it has been deleted because of successful merge.
437 if self.is_shadow_repo and not self.is_shadow_repo_dir:
437 if self.is_shadow_repo and not self.is_shadow_repo_dir:
438 log.debug(
438 log.debug(
439 'Shadow repo detected, and shadow repo dir `%s` is missing',
439 'Shadow repo detected, and shadow repo dir `%s` is missing',
440 self.is_shadow_repo_dir)
440 self.is_shadow_repo_dir)
441 return HTTPNotFound()(environ, start_response)
441 return HTTPNotFound()(environ, start_response)
442
442
443 # ======================================================================
443 # ======================================================================
444 # CHECK ANONYMOUS PERMISSION
444 # CHECK ANONYMOUS PERMISSION
445 # ======================================================================
445 # ======================================================================
446 detect_force_push = False
446 detect_force_push = False
447 check_branch_perms = False
447 check_branch_perms = False
448 if action in ['pull', 'push']:
448 if action in ['pull', 'push']:
449 user_obj = anonymous_user = User.get_default_user()
449 user_obj = anonymous_user = User.get_default_user()
450 auth_user = user_obj.AuthUser()
450 auth_user = user_obj.AuthUser()
451 username = anonymous_user.username
451 username = anonymous_user.username
452 if anonymous_user.active:
452 if anonymous_user.active:
453 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
453 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
454 # ONLY check permissions if the user is activated
454 # ONLY check permissions if the user is activated
455 anonymous_perm = self._check_permission(
455 anonymous_perm = self._check_permission(
456 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
456 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
457 plugin_id='anonymous_access',
457 plugin_id='anonymous_access',
458 plugin_cache_active=plugin_cache_active,
458 plugin_cache_active=plugin_cache_active,
459 cache_ttl=cache_ttl,
459 cache_ttl=cache_ttl,
460 )
460 )
461 else:
461 else:
462 anonymous_perm = False
462 anonymous_perm = False
463
463
464 if not anonymous_user.active or not anonymous_perm:
464 if not anonymous_user.active or not anonymous_perm:
465 if not anonymous_user.active:
465 if not anonymous_user.active:
466 log.debug('Anonymous access is disabled, running '
466 log.debug('Anonymous access is disabled, running '
467 'authentication')
467 'authentication')
468
468
469 if not anonymous_perm:
469 if not anonymous_perm:
470 log.debug('Not enough credentials to access this '
470 log.debug('Not enough credentials to access this '
471 'repository as anonymous user')
471 'repository as anonymous user')
472
472
473 username = None
473 username = None
474 # ==============================================================
474 # ==============================================================
475 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
475 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
476 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
476 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
477 # ==============================================================
477 # ==============================================================
478
478
479 # try to auth based on environ, container auth methods
479 # try to auth based on environ, container auth methods
480 log.debug('Running PRE-AUTH for container based authentication')
480 log.debug('Running PRE-AUTH for container based authentication')
481 pre_auth = authenticate(
481 pre_auth = authenticate(
482 '', '', environ, VCS_TYPE, registry=self.registry,
482 '', '', environ, VCS_TYPE, registry=self.registry,
483 acl_repo_name=self.acl_repo_name)
483 acl_repo_name=self.acl_repo_name)
484 if pre_auth and pre_auth.get('username'):
484 if pre_auth and pre_auth.get('username'):
485 username = pre_auth['username']
485 username = pre_auth['username']
486 log.debug('PRE-AUTH got %s as username', username)
486 log.debug('PRE-AUTH got %s as username', username)
487 if pre_auth:
487 if pre_auth:
488 log.debug('PRE-AUTH successful from %s',
488 log.debug('PRE-AUTH successful from %s',
489 pre_auth.get('auth_data', {}).get('_plugin'))
489 pre_auth.get('auth_data', {}).get('_plugin'))
490
490
491 # If not authenticated by the container, running basic auth
491 # If not authenticated by the container, running basic auth
492 # before inject the calling repo_name for special scope checks
492 # before inject the calling repo_name for special scope checks
493 self.authenticate.acl_repo_name = self.acl_repo_name
493 self.authenticate.acl_repo_name = self.acl_repo_name
494
494
495 plugin_cache_active, cache_ttl = False, 0
495 plugin_cache_active, cache_ttl = False, 0
496 plugin = None
496 plugin = None
497 if not username:
497 if not username:
498 self.authenticate.realm = self.authenticate.get_rc_realm()
498 self.authenticate.realm = self.authenticate.get_rc_realm()
499
499
500 try:
500 try:
501 auth_result = self.authenticate(environ)
501 auth_result = self.authenticate(environ)
502 except (UserCreationError, NotAllowedToCreateUserError) as e:
502 except (UserCreationError, NotAllowedToCreateUserError) as e:
503 log.error(e)
503 log.error(e)
504 reason = safe_str(e)
504 reason = safe_str(e)
505 return HTTPNotAcceptable(reason)(environ, start_response)
505 return HTTPNotAcceptable(reason)(environ, start_response)
506
506
507 if isinstance(auth_result, dict):
507 if isinstance(auth_result, dict):
508 AUTH_TYPE.update(environ, 'basic')
508 AUTH_TYPE.update(environ, 'basic')
509 REMOTE_USER.update(environ, auth_result['username'])
509 REMOTE_USER.update(environ, auth_result['username'])
510 username = auth_result['username']
510 username = auth_result['username']
511 plugin = auth_result.get('auth_data', {}).get('_plugin')
511 plugin = auth_result.get('auth_data', {}).get('_plugin')
512 log.info(
512 log.info(
513 'MAIN-AUTH successful for user `%s` from %s plugin',
513 'MAIN-AUTH successful for user `%s` from %s plugin',
514 username, plugin)
514 username, plugin)
515
515
516 plugin_cache_active, cache_ttl = auth_result.get(
516 plugin_cache_active, cache_ttl = auth_result.get(
517 'auth_data', {}).get('_ttl_cache') or (False, 0)
517 'auth_data', {}).get('_ttl_cache') or (False, 0)
518 else:
518 else:
519 return auth_result.wsgi_application(environ, start_response)
519 return auth_result.wsgi_application(environ, start_response)
520
520
521 # ==============================================================
521 # ==============================================================
522 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
522 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
523 # ==============================================================
523 # ==============================================================
524 user = User.get_by_username(username)
524 user = User.get_by_username(username)
525 if not self.valid_and_active_user(user):
525 if not self.valid_and_active_user(user):
526 return HTTPForbidden()(environ, start_response)
526 return HTTPForbidden()(environ, start_response)
527 username = user.username
527 username = user.username
528 user_id = user.user_id
528 user_id = user.user_id
529
529
530 # check user attributes for password change flag
530 # check user attributes for password change flag
531 user_obj = user
531 user_obj = user
532 auth_user = user_obj.AuthUser()
532 auth_user = user_obj.AuthUser()
533 if user_obj and user_obj.username != User.DEFAULT_USER and \
533 if user_obj and user_obj.username != User.DEFAULT_USER and \
534 user_obj.user_data.get('force_password_change'):
534 user_obj.user_data.get('force_password_change'):
535 reason = 'password change required'
535 reason = 'password change required'
536 log.debug('User not allowed to authenticate, %s', reason)
536 log.debug('User not allowed to authenticate, %s', reason)
537 return HTTPNotAcceptable(reason)(environ, start_response)
537 return HTTPNotAcceptable(reason)(environ, start_response)
538
538
539 # check permissions for this repository
539 # check permissions for this repository
540 perm = self._check_permission(
540 perm = self._check_permission(
541 action, user, auth_user, self.acl_repo_name, ip_addr,
541 action, user, auth_user, self.acl_repo_name, ip_addr,
542 plugin, plugin_cache_active, cache_ttl)
542 plugin, plugin_cache_active, cache_ttl)
543 if not perm:
543 if not perm:
544 return HTTPForbidden()(environ, start_response)
544 return HTTPForbidden()(environ, start_response)
545 environ['rc_auth_user_id'] = user_id
545 environ['rc_auth_user_id'] = user_id
546
546
547 if action == 'push':
547 if action == 'push':
548 perms = auth_user.get_branch_permissions(self.acl_repo_name)
548 perms = auth_user.get_branch_permissions(self.acl_repo_name)
549 if perms:
549 if perms:
550 check_branch_perms = True
550 check_branch_perms = True
551 detect_force_push = True
551 detect_force_push = True
552
552
553 # extras are injected into UI object and later available
553 # extras are injected into UI object and later available
554 # in hooks executed by RhodeCode
554 # in hooks executed by RhodeCode
555 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
555 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
556
556
557 extras = vcs_operation_context(
557 extras = vcs_operation_context(
558 environ, repo_name=self.acl_repo_name, username=username,
558 environ, repo_name=self.acl_repo_name, username=username,
559 action=action, scm=self.SCM, check_locking=check_locking,
559 action=action, scm=self.SCM, check_locking=check_locking,
560 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
560 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
561 detect_force_push=detect_force_push
561 detect_force_push=detect_force_push
562 )
562 )
563
563
564 # ======================================================================
564 # ======================================================================
565 # REQUEST HANDLING
565 # REQUEST HANDLING
566 # ======================================================================
566 # ======================================================================
567 repo_path = os.path.join(
567 repo_path = os.path.join(
568 safe_str(self.base_path), safe_str(self.vcs_repo_name))
568 safe_str(self.base_path), safe_str(self.vcs_repo_name))
569 log.debug('Repository path is %s', repo_path)
569 log.debug('Repository path is %s', repo_path)
570
570
571 fix_PATH()
571 fix_PATH()
572
572
573 log.info(
573 log.info(
574 '%s action on %s repo "%s" by "%s" from %s %s',
574 '%s action on %s repo "%s" by "%s" from %s %s',
575 action, self.SCM, safe_str(self.url_repo_name),
575 action, self.SCM, safe_str(self.url_repo_name),
576 safe_str(username), ip_addr, user_agent)
576 safe_str(username), ip_addr, user_agent)
577
577
578 return self._generate_vcs_response(
578 return self._generate_vcs_response(
579 environ, start_response, repo_path, extras, action)
579 environ, start_response, repo_path, extras, action)
580
580
581 @initialize_generator
581 @initialize_generator
582 def _generate_vcs_response(
582 def _generate_vcs_response(
583 self, environ, start_response, repo_path, extras, action):
583 self, environ, start_response, repo_path, extras, action):
584 """
584 """
585 Returns a generator for the response content.
585 Returns a generator for the response content.
586
586
587 This method is implemented as a generator, so that it can trigger
587 This method is implemented as a generator, so that it can trigger
588 the cache validation after all content sent back to the client. It
588 the cache validation after all content sent back to the client. It
589 also handles the locking exceptions which will be triggered when
589 also handles the locking exceptions which will be triggered when
590 the first chunk is produced by the underlying WSGI application.
590 the first chunk is produced by the underlying WSGI application.
591 """
591 """
592 txn_id = ''
592 txn_id = ''
593 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
593 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
594 # case for SVN, we want to re-use the callback daemon port
594 # case for SVN, we want to re-use the callback daemon port
595 # so we use the txn_id, for this we peek the body, and still save
595 # so we use the txn_id, for this we peek the body, and still save
596 # it as wsgi.input
596 # it as wsgi.input
597 data = environ['wsgi.input'].read()
597 data = environ['wsgi.input'].read()
598 environ['wsgi.input'] = StringIO(data)
598 environ['wsgi.input'] = io.StringIO(data)
599 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
599 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
600
600
601 callback_daemon, extras = self._prepare_callback_daemon(
601 callback_daemon, extras = self._prepare_callback_daemon(
602 extras, environ, action, txn_id=txn_id)
602 extras, environ, action, txn_id=txn_id)
603 log.debug('HOOKS extras is %s', extras)
603 log.debug('HOOKS extras is %s', extras)
604
604
605 http_scheme = self._get_http_scheme(environ)
605 http_scheme = self._get_http_scheme(environ)
606
606
607 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
607 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
608 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
608 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
609 with callback_daemon:
609 with callback_daemon:
610 app.rc_extras = extras
610 app.rc_extras = extras
611
611
612 try:
612 try:
613 response = app(environ, start_response)
613 response = app(environ, start_response)
614 finally:
614 finally:
615 # This statement works together with the decorator
615 # This statement works together with the decorator
616 # "initialize_generator" above. The decorator ensures that
616 # "initialize_generator" above. The decorator ensures that
617 # we hit the first yield statement before the generator is
617 # we hit the first yield statement before the generator is
618 # returned back to the WSGI server. This is needed to
618 # returned back to the WSGI server. This is needed to
619 # ensure that the call to "app" above triggers the
619 # ensure that the call to "app" above triggers the
620 # needed callback to "start_response" before the
620 # needed callback to "start_response" before the
621 # generator is actually used.
621 # generator is actually used.
622 yield "__init__"
622 yield "__init__"
623
623
624 # iter content
624 # iter content
625 for chunk in response:
625 for chunk in response:
626 yield chunk
626 yield chunk
627
627
628 try:
628 try:
629 # invalidate cache on push
629 # invalidate cache on push
630 if action == 'push':
630 if action == 'push':
631 self._invalidate_cache(self.url_repo_name)
631 self._invalidate_cache(self.url_repo_name)
632 finally:
632 finally:
633 meta.Session.remove()
633 meta.Session.remove()
634
634
635 def _get_repository_name(self, environ):
635 def _get_repository_name(self, environ):
636 """Get repository name out of the environmnent
636 """Get repository name out of the environmnent
637
637
638 :param environ: WSGI environment
638 :param environ: WSGI environment
639 """
639 """
640 raise NotImplementedError()
640 raise NotImplementedError()
641
641
642 def _get_action(self, environ):
642 def _get_action(self, environ):
643 """Map request commands into a pull or push command.
643 """Map request commands into a pull or push command.
644
644
645 :param environ: WSGI environment
645 :param environ: WSGI environment
646 """
646 """
647 raise NotImplementedError()
647 raise NotImplementedError()
648
648
649 def _create_wsgi_app(self, repo_path, repo_name, config):
649 def _create_wsgi_app(self, repo_path, repo_name, config):
650 """Return the WSGI app that will finally handle the request."""
650 """Return the WSGI app that will finally handle the request."""
651 raise NotImplementedError()
651 raise NotImplementedError()
652
652
653 def _create_config(self, extras, repo_name, scheme='http'):
653 def _create_config(self, extras, repo_name, scheme='http'):
654 """Create a safe config representation."""
654 """Create a safe config representation."""
655 raise NotImplementedError()
655 raise NotImplementedError()
656
656
657 def _should_use_callback_daemon(self, extras, environ, action):
657 def _should_use_callback_daemon(self, extras, environ, action):
658 if extras.get('is_shadow_repo'):
658 if extras.get('is_shadow_repo'):
659 # we don't want to execute hooks, and callback daemon for shadow repos
659 # we don't want to execute hooks, and callback daemon for shadow repos
660 return False
660 return False
661 return True
661 return True
662
662
663 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
663 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
664 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
664 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
665 if not self._should_use_callback_daemon(extras, environ, action):
665 if not self._should_use_callback_daemon(extras, environ, action):
666 # disable callback daemon for actions that don't require it
666 # disable callback daemon for actions that don't require it
667 direct_calls = True
667 direct_calls = True
668
668
669 return prepare_callback_daemon(
669 return prepare_callback_daemon(
670 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
670 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
671 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
671 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
672
672
673
673
674 def _should_check_locking(query_string):
674 def _should_check_locking(query_string):
675 # this is kind of hacky, but due to how mercurial handles client-server
675 # this is kind of hacky, but due to how mercurial handles client-server
676 # server see all operation on commit; bookmarks, phases and
676 # server see all operation on commit; bookmarks, phases and
677 # obsolescence marker in different transaction, we don't want to check
677 # obsolescence marker in different transaction, we don't want to check
678 # locking on those
678 # locking on those
679 return query_string not in ['cmd=listkeys']
679 return query_string not in ['cmd=listkeys']
@@ -1,284 +1,284 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import gzip
21 import gzip
22 import shutil
22 import shutil
23 import logging
23 import logging
24 import tempfile
24 import tempfile
25 import urllib.parse
25 import urllib.parse
26
26
27 from webob.exc import HTTPNotFound
27 from webob.exc import HTTPNotFound
28
28
29 import rhodecode
29 import rhodecode
30 from rhodecode.lib.middleware.appenlight import wrap_in_appenlight_if_enabled
30 from rhodecode.lib.middleware.appenlight import wrap_in_appenlight_if_enabled
31 from rhodecode.lib.middleware.simplegit import SimpleGit, GIT_PROTO_PAT
31 from rhodecode.lib.middleware.simplegit import SimpleGit, GIT_PROTO_PAT
32 from rhodecode.lib.middleware.simplehg import SimpleHg
32 from rhodecode.lib.middleware.simplehg import SimpleHg
33 from rhodecode.lib.middleware.simplesvn import SimpleSvn
33 from rhodecode.lib.middleware.simplesvn import SimpleSvn
34 from rhodecode.model.settings import VcsSettingsModel
34 from rhodecode.model.settings import VcsSettingsModel
35
35
36 log = logging.getLogger(__name__)
36 log = logging.getLogger(__name__)
37
37
38 VCS_TYPE_KEY = '_rc_vcs_type'
38 VCS_TYPE_KEY = '_rc_vcs_type'
39 VCS_TYPE_SKIP = '_rc_vcs_skip'
39 VCS_TYPE_SKIP = '_rc_vcs_skip'
40
40
41
41
42 def is_git(environ):
42 def is_git(environ):
43 """
43 """
44 Returns True if requests should be handled by GIT wsgi middleware
44 Returns True if requests should be handled by GIT wsgi middleware
45 """
45 """
46 is_git_path = GIT_PROTO_PAT.match(environ['PATH_INFO'])
46 is_git_path = GIT_PROTO_PAT.match(environ['PATH_INFO'])
47 log.debug(
47 log.debug(
48 'request path: `%s` detected as GIT PROTOCOL %s', environ['PATH_INFO'],
48 'request path: `%s` detected as GIT PROTOCOL %s', environ['PATH_INFO'],
49 is_git_path is not None)
49 is_git_path is not None)
50
50
51 return is_git_path
51 return is_git_path
52
52
53
53
54 def is_hg(environ):
54 def is_hg(environ):
55 """
55 """
56 Returns True if requests target is mercurial server - header
56 Returns True if requests target is mercurial server - header
57 ``HTTP_ACCEPT`` of such request would start with ``application/mercurial``.
57 ``HTTP_ACCEPT`` of such request would start with ``application/mercurial``.
58 """
58 """
59 is_hg_path = False
59 is_hg_path = False
60
60
61 http_accept = environ.get('HTTP_ACCEPT')
61 http_accept = environ.get('HTTP_ACCEPT')
62
62
63 if http_accept and http_accept.startswith('application/mercurial'):
63 if http_accept and http_accept.startswith('application/mercurial'):
64 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'])
64 query = urllib.parse.parse_qs(environ['QUERY_STRING'])
65 if 'cmd' in query:
65 if 'cmd' in query:
66 is_hg_path = True
66 is_hg_path = True
67
67
68 log.debug(
68 log.debug(
69 'request path: `%s` detected as HG PROTOCOL %s', environ['PATH_INFO'],
69 'request path: `%s` detected as HG PROTOCOL %s', environ['PATH_INFO'],
70 is_hg_path)
70 is_hg_path)
71
71
72 return is_hg_path
72 return is_hg_path
73
73
74
74
75 def is_svn(environ):
75 def is_svn(environ):
76 """
76 """
77 Returns True if requests target is Subversion server
77 Returns True if requests target is Subversion server
78 """
78 """
79
79
80 http_dav = environ.get('HTTP_DAV', '')
80 http_dav = environ.get('HTTP_DAV', '')
81 magic_path_segment = rhodecode.CONFIG.get(
81 magic_path_segment = rhodecode.CONFIG.get(
82 'rhodecode_subversion_magic_path', '/!svn')
82 'rhodecode_subversion_magic_path', '/!svn')
83 is_svn_path = (
83 is_svn_path = (
84 'subversion' in http_dav or
84 'subversion' in http_dav or
85 magic_path_segment in environ['PATH_INFO']
85 magic_path_segment in environ['PATH_INFO']
86 or environ['REQUEST_METHOD'] in ['PROPFIND', 'PROPPATCH']
86 or environ['REQUEST_METHOD'] in ['PROPFIND', 'PROPPATCH']
87 )
87 )
88 log.debug(
88 log.debug(
89 'request path: `%s` detected as SVN PROTOCOL %s', environ['PATH_INFO'],
89 'request path: `%s` detected as SVN PROTOCOL %s', environ['PATH_INFO'],
90 is_svn_path)
90 is_svn_path)
91
91
92 return is_svn_path
92 return is_svn_path
93
93
94
94
95 class GunzipMiddleware(object):
95 class GunzipMiddleware(object):
96 """
96 """
97 WSGI middleware that unzips gzip-encoded requests before
97 WSGI middleware that unzips gzip-encoded requests before
98 passing on to the underlying application.
98 passing on to the underlying application.
99 """
99 """
100
100
101 def __init__(self, application):
101 def __init__(self, application):
102 self.app = application
102 self.app = application
103
103
104 def __call__(self, environ, start_response):
104 def __call__(self, environ, start_response):
105 accepts_encoding_header = environ.get('HTTP_CONTENT_ENCODING', b'')
105 accepts_encoding_header = environ.get('HTTP_CONTENT_ENCODING', b'')
106
106
107 if b'gzip' in accepts_encoding_header:
107 if b'gzip' in accepts_encoding_header:
108 log.debug('gzip detected, now running gunzip wrapper')
108 log.debug('gzip detected, now running gunzip wrapper')
109 wsgi_input = environ['wsgi.input']
109 wsgi_input = environ['wsgi.input']
110
110
111 if not hasattr(environ['wsgi.input'], 'seek'):
111 if not hasattr(environ['wsgi.input'], 'seek'):
112 # The gzip implementation in the standard library of Python 2.x
112 # The gzip implementation in the standard library of Python 2.x
113 # requires the '.seek()' and '.tell()' methods to be available
113 # requires the '.seek()' and '.tell()' methods to be available
114 # on the input stream. Read the data into a temporary file to
114 # on the input stream. Read the data into a temporary file to
115 # work around this limitation.
115 # work around this limitation.
116
116
117 wsgi_input = tempfile.SpooledTemporaryFile(64 * 1024 * 1024)
117 wsgi_input = tempfile.SpooledTemporaryFile(64 * 1024 * 1024)
118 shutil.copyfileobj(environ['wsgi.input'], wsgi_input)
118 shutil.copyfileobj(environ['wsgi.input'], wsgi_input)
119 wsgi_input.seek(0)
119 wsgi_input.seek(0)
120
120
121 environ['wsgi.input'] = gzip.GzipFile(fileobj=wsgi_input, mode='r')
121 environ['wsgi.input'] = gzip.GzipFile(fileobj=wsgi_input, mode='r')
122 # since we "Ungzipped" the content we say now it's no longer gzip
122 # since we "Ungzipped" the content we say now it's no longer gzip
123 # content encoding
123 # content encoding
124 del environ['HTTP_CONTENT_ENCODING']
124 del environ['HTTP_CONTENT_ENCODING']
125
125
126 # content length has changes ? or i'm not sure
126 # content length has changes ? or i'm not sure
127 if 'CONTENT_LENGTH' in environ:
127 if 'CONTENT_LENGTH' in environ:
128 del environ['CONTENT_LENGTH']
128 del environ['CONTENT_LENGTH']
129 else:
129 else:
130 log.debug('content not gzipped, gzipMiddleware passing '
130 log.debug('content not gzipped, gzipMiddleware passing '
131 'request further')
131 'request further')
132 return self.app(environ, start_response)
132 return self.app(environ, start_response)
133
133
134
134
135 def is_vcs_call(environ):
135 def is_vcs_call(environ):
136 if VCS_TYPE_KEY in environ:
136 if VCS_TYPE_KEY in environ:
137 raw_type = environ[VCS_TYPE_KEY]
137 raw_type = environ[VCS_TYPE_KEY]
138 return raw_type and raw_type != VCS_TYPE_SKIP
138 return raw_type and raw_type != VCS_TYPE_SKIP
139 return False
139 return False
140
140
141
141
142 def get_path_elem(route_path):
142 def get_path_elem(route_path):
143 if not route_path:
143 if not route_path:
144 return None
144 return None
145
145
146 cleaned_route_path = route_path.lstrip('/')
146 cleaned_route_path = route_path.lstrip('/')
147 if cleaned_route_path:
147 if cleaned_route_path:
148 cleaned_route_path_elems = cleaned_route_path.split('/')
148 cleaned_route_path_elems = cleaned_route_path.split('/')
149 if cleaned_route_path_elems:
149 if cleaned_route_path_elems:
150 return cleaned_route_path_elems[0]
150 return cleaned_route_path_elems[0]
151 return None
151 return None
152
152
153
153
154 def detect_vcs_request(environ, backends):
154 def detect_vcs_request(environ, backends):
155 checks = {
155 checks = {
156 'hg': (is_hg, SimpleHg),
156 'hg': (is_hg, SimpleHg),
157 'git': (is_git, SimpleGit),
157 'git': (is_git, SimpleGit),
158 'svn': (is_svn, SimpleSvn),
158 'svn': (is_svn, SimpleSvn),
159 }
159 }
160 handler = None
160 handler = None
161 # List of path views first chunk we don't do any checks
161 # List of path views first chunk we don't do any checks
162 white_list = [
162 white_list = [
163 # e.g /_file_store/download
163 # e.g /_file_store/download
164 '_file_store',
164 '_file_store',
165
165
166 # static files no detection
166 # static files no detection
167 '_static',
167 '_static',
168
168
169 # skip ops ping, status
169 # skip ops ping, status
170 '_admin/ops/ping',
170 '_admin/ops/ping',
171 '_admin/ops/status',
171 '_admin/ops/status',
172
172
173 # full channelstream connect should be VCS skipped
173 # full channelstream connect should be VCS skipped
174 '_admin/channelstream/connect',
174 '_admin/channelstream/connect',
175 ]
175 ]
176
176
177 path_info = environ['PATH_INFO']
177 path_info = environ['PATH_INFO']
178
178
179 path_elem = get_path_elem(path_info)
179 path_elem = get_path_elem(path_info)
180
180
181 if path_elem in white_list:
181 if path_elem in white_list:
182 log.debug('path `%s` in whitelist, skipping...', path_info)
182 log.debug('path `%s` in whitelist, skipping...', path_info)
183 return handler
183 return handler
184
184
185 path_url = path_info.lstrip('/')
185 path_url = path_info.lstrip('/')
186 if path_url in white_list:
186 if path_url in white_list:
187 log.debug('full url path `%s` in whitelist, skipping...', path_url)
187 log.debug('full url path `%s` in whitelist, skipping...', path_url)
188 return handler
188 return handler
189
189
190 if VCS_TYPE_KEY in environ:
190 if VCS_TYPE_KEY in environ:
191 raw_type = environ[VCS_TYPE_KEY]
191 raw_type = environ[VCS_TYPE_KEY]
192 if raw_type == VCS_TYPE_SKIP:
192 if raw_type == VCS_TYPE_SKIP:
193 log.debug('got `skip` marker for vcs detection, skipping...')
193 log.debug('got `skip` marker for vcs detection, skipping...')
194 return handler
194 return handler
195
195
196 _check, handler = checks.get(raw_type) or [None, None]
196 _check, handler = checks.get(raw_type) or [None, None]
197 if handler:
197 if handler:
198 log.debug('got handler:%s from environ', handler)
198 log.debug('got handler:%s from environ', handler)
199
199
200 if not handler:
200 if not handler:
201 log.debug('request start: checking if request for `%s` is of VCS type in order: %s', path_elem, backends)
201 log.debug('request start: checking if request for `%s` is of VCS type in order: %s', path_elem, backends)
202 for vcs_type in backends:
202 for vcs_type in backends:
203 vcs_check, _handler = checks[vcs_type]
203 vcs_check, _handler = checks[vcs_type]
204 if vcs_check(environ):
204 if vcs_check(environ):
205 log.debug('vcs handler found %s', _handler)
205 log.debug('vcs handler found %s', _handler)
206 handler = _handler
206 handler = _handler
207 break
207 break
208
208
209 return handler
209 return handler
210
210
211
211
212 class VCSMiddleware(object):
212 class VCSMiddleware(object):
213
213
214 def __init__(self, app, registry, config, appenlight_client):
214 def __init__(self, app, registry, config, appenlight_client):
215 self.application = app
215 self.application = app
216 self.registry = registry
216 self.registry = registry
217 self.config = config
217 self.config = config
218 self.appenlight_client = appenlight_client
218 self.appenlight_client = appenlight_client
219 self.use_gzip = True
219 self.use_gzip = True
220 # order in which we check the middlewares, based on vcs.backends config
220 # order in which we check the middlewares, based on vcs.backends config
221 self.check_middlewares = config['vcs.backends']
221 self.check_middlewares = config['vcs.backends']
222
222
223 def vcs_config(self, repo_name=None):
223 def vcs_config(self, repo_name=None):
224 """
224 """
225 returns serialized VcsSettings
225 returns serialized VcsSettings
226 """
226 """
227 try:
227 try:
228 return VcsSettingsModel(
228 return VcsSettingsModel(
229 repo=repo_name).get_ui_settings_as_config_obj()
229 repo=repo_name).get_ui_settings_as_config_obj()
230 except Exception:
230 except Exception:
231 pass
231 pass
232
232
233 def wrap_in_gzip_if_enabled(self, app, config):
233 def wrap_in_gzip_if_enabled(self, app, config):
234 if self.use_gzip:
234 if self.use_gzip:
235 app = GunzipMiddleware(app)
235 app = GunzipMiddleware(app)
236 return app
236 return app
237
237
238 def _get_handler_app(self, environ):
238 def _get_handler_app(self, environ):
239 app = None
239 app = None
240 log.debug('VCSMiddleware: detecting vcs type.')
240 log.debug('VCSMiddleware: detecting vcs type.')
241 handler = detect_vcs_request(environ, self.check_middlewares)
241 handler = detect_vcs_request(environ, self.check_middlewares)
242 if handler:
242 if handler:
243 app = handler(self.config, self.registry)
243 app = handler(self.config, self.registry)
244
244
245 return app
245 return app
246
246
247 def __call__(self, environ, start_response):
247 def __call__(self, environ, start_response):
248 # check if we handle one of interesting protocols, optionally extract
248 # check if we handle one of interesting protocols, optionally extract
249 # specific vcsSettings and allow changes of how things are wrapped
249 # specific vcsSettings and allow changes of how things are wrapped
250 vcs_handler = self._get_handler_app(environ)
250 vcs_handler = self._get_handler_app(environ)
251 if vcs_handler:
251 if vcs_handler:
252 # translate the _REPO_ID into real repo NAME for usage
252 # translate the _REPO_ID into real repo NAME for usage
253 # in middleware
253 # in middleware
254 environ['PATH_INFO'] = vcs_handler._get_by_id(environ['PATH_INFO'])
254 environ['PATH_INFO'] = vcs_handler._get_by_id(environ['PATH_INFO'])
255
255
256 # Set acl, url and vcs repo names.
256 # Set acl, url and vcs repo names.
257 vcs_handler.set_repo_names(environ)
257 vcs_handler.set_repo_names(environ)
258
258
259 # register repo config back to the handler
259 # register repo config back to the handler
260 vcs_conf = self.vcs_config(vcs_handler.acl_repo_name)
260 vcs_conf = self.vcs_config(vcs_handler.acl_repo_name)
261 # maybe damaged/non existent settings. We still want to
261 # maybe damaged/non existent settings. We still want to
262 # pass that point to validate on is_valid_and_existing_repo
262 # pass that point to validate on is_valid_and_existing_repo
263 # and return proper HTTP Code back to client
263 # and return proper HTTP Code back to client
264 if vcs_conf:
264 if vcs_conf:
265 vcs_handler.repo_vcs_config = vcs_conf
265 vcs_handler.repo_vcs_config = vcs_conf
266
266
267 # check for type, presence in database and on filesystem
267 # check for type, presence in database and on filesystem
268 if not vcs_handler.is_valid_and_existing_repo(
268 if not vcs_handler.is_valid_and_existing_repo(
269 vcs_handler.acl_repo_name,
269 vcs_handler.acl_repo_name,
270 vcs_handler.base_path,
270 vcs_handler.base_path,
271 vcs_handler.SCM):
271 vcs_handler.SCM):
272 return HTTPNotFound()(environ, start_response)
272 return HTTPNotFound()(environ, start_response)
273
273
274 environ['REPO_NAME'] = vcs_handler.url_repo_name
274 environ['REPO_NAME'] = vcs_handler.url_repo_name
275
275
276 # Wrap handler in middlewares if they are enabled.
276 # Wrap handler in middlewares if they are enabled.
277 vcs_handler = self.wrap_in_gzip_if_enabled(
277 vcs_handler = self.wrap_in_gzip_if_enabled(
278 vcs_handler, self.config)
278 vcs_handler, self.config)
279 vcs_handler, _ = wrap_in_appenlight_if_enabled(
279 vcs_handler, _ = wrap_in_appenlight_if_enabled(
280 vcs_handler, self.config, self.appenlight_client)
280 vcs_handler, self.config, self.appenlight_client)
281
281
282 return vcs_handler(environ, start_response)
282 return vcs_handler(environ, start_response)
283
283
284 return self.application(environ, start_response)
284 return self.application(environ, start_response)
@@ -1,1061 +1,1061 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (c) 2007-2012 Christoph Haas <email@christoph-haas.de>
3 # Copyright (c) 2007-2012 Christoph Haas <email@christoph-haas.de>
4 # NOTE: MIT license based code, backported and edited by RhodeCode GmbH
4 # NOTE: MIT license based code, backported and edited by RhodeCode GmbH
5
5
6 """
6 """
7 paginate: helps split up large collections into individual pages
7 paginate: helps split up large collections into individual pages
8 ================================================================
8 ================================================================
9
9
10 What is pagination?
10 What is pagination?
11 ---------------------
11 ---------------------
12
12
13 This module helps split large lists of items into pages. The user is shown one page at a time and
13 This module helps split large lists of items into pages. The user is shown one page at a time and
14 can navigate to other pages. Imagine you are offering a company phonebook and let the user search
14 can navigate to other pages. Imagine you are offering a company phonebook and let the user search
15 the entries. The entire search result may contains 23 entries but you want to display no more than
15 the entries. The entire search result may contains 23 entries but you want to display no more than
16 10 entries at once. The first page contains entries 1-10, the second 11-20 and the third 21-23.
16 10 entries at once. The first page contains entries 1-10, the second 11-20 and the third 21-23.
17 Each "Page" instance represents the items of one of these three pages.
17 Each "Page" instance represents the items of one of these three pages.
18
18
19 See the documentation of the "Page" class for more information.
19 See the documentation of the "Page" class for more information.
20
20
21 How do I use it?
21 How do I use it?
22 ------------------
22 ------------------
23
23
24 A page of items is represented by the *Page* object. A *Page* gets initialized with these arguments:
24 A page of items is represented by the *Page* object. A *Page* gets initialized with these arguments:
25
25
26 - The collection of items to pick a range from. Usually just a list.
26 - The collection of items to pick a range from. Usually just a list.
27 - The page number you want to display. Default is 1: the first page.
27 - The page number you want to display. Default is 1: the first page.
28
28
29 Now we can make up a collection and create a Page instance of it::
29 Now we can make up a collection and create a Page instance of it::
30
30
31 # Create a sample collection of 1000 items
31 # Create a sample collection of 1000 items
32 >> my_collection = range(1000)
32 >> my_collection = range(1000)
33
33
34 # Create a Page object for the 3rd page (20 items per page is the default)
34 # Create a Page object for the 3rd page (20 items per page is the default)
35 >> my_page = Page(my_collection, page=3)
35 >> my_page = Page(my_collection, page=3)
36
36
37 # The page object can be printed as a string to get its details
37 # The page object can be printed as a string to get its details
38 >> str(my_page)
38 >> str(my_page)
39 Page:
39 Page:
40 Collection type: <type 'range'>
40 Collection type: <type 'range'>
41 Current page: 3
41 Current page: 3
42 First item: 41
42 First item: 41
43 Last item: 60
43 Last item: 60
44 First page: 1
44 First page: 1
45 Last page: 50
45 Last page: 50
46 Previous page: 2
46 Previous page: 2
47 Next page: 4
47 Next page: 4
48 Items per page: 20
48 Items per page: 20
49 Number of items: 1000
49 Number of items: 1000
50 Number of pages: 50
50 Number of pages: 50
51
51
52 # Print a list of items on the current page
52 # Print a list of items on the current page
53 >> my_page.items
53 >> my_page.items
54 [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
54 [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
55
55
56 # The *Page* object can be used as an iterator:
56 # The *Page* object can be used as an iterator:
57 >> for my_item in my_page: print(my_item)
57 >> for my_item in my_page: print(my_item)
58 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
58 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
59
59
60 # The .pager() method returns an HTML fragment with links to surrounding pages.
60 # The .pager() method returns an HTML fragment with links to surrounding pages.
61 >> my_page.pager(url="http://example.org/foo/page=$page")
61 >> my_page.pager(url="http://example.org/foo/page=$page")
62
62
63 <a href="http://example.org/foo/page=1">1</a>
63 <a href="http://example.org/foo/page=1">1</a>
64 <a href="http://example.org/foo/page=2">2</a>
64 <a href="http://example.org/foo/page=2">2</a>
65 3
65 3
66 <a href="http://example.org/foo/page=4">4</a>
66 <a href="http://example.org/foo/page=4">4</a>
67 <a href="http://example.org/foo/page=5">5</a>
67 <a href="http://example.org/foo/page=5">5</a>
68 ..
68 ..
69 <a href="http://example.org/foo/page=50">50</a>'
69 <a href="http://example.org/foo/page=50">50</a>'
70
70
71 # Without the HTML it would just look like:
71 # Without the HTML it would just look like:
72 # 1 2 [3] 4 5 .. 50
72 # 1 2 [3] 4 5 .. 50
73
73
74 # The pager can be customized:
74 # The pager can be customized:
75 >> my_page.pager('$link_previous ~3~ $link_next (Page $page of $page_count)',
75 >> my_page.pager('$link_previous ~3~ $link_next (Page $page of $page_count)',
76 url="http://example.org/foo/page=$page")
76 url="http://example.org/foo/page=$page")
77
77
78 <a href="http://example.org/foo/page=2">&lt;</a>
78 <a href="http://example.org/foo/page=2">&lt;</a>
79 <a href="http://example.org/foo/page=1">1</a>
79 <a href="http://example.org/foo/page=1">1</a>
80 <a href="http://example.org/foo/page=2">2</a>
80 <a href="http://example.org/foo/page=2">2</a>
81 3
81 3
82 <a href="http://example.org/foo/page=4">4</a>
82 <a href="http://example.org/foo/page=4">4</a>
83 <a href="http://example.org/foo/page=5">5</a>
83 <a href="http://example.org/foo/page=5">5</a>
84 <a href="http://example.org/foo/page=6">6</a>
84 <a href="http://example.org/foo/page=6">6</a>
85 ..
85 ..
86 <a href="http://example.org/foo/page=50">50</a>
86 <a href="http://example.org/foo/page=50">50</a>
87 <a href="http://example.org/foo/page=4">&gt;</a>
87 <a href="http://example.org/foo/page=4">&gt;</a>
88 (Page 3 of 50)
88 (Page 3 of 50)
89
89
90 # Without the HTML it would just look like:
90 # Without the HTML it would just look like:
91 # 1 2 [3] 4 5 6 .. 50 > (Page 3 of 50)
91 # 1 2 [3] 4 5 6 .. 50 > (Page 3 of 50)
92
92
93 # The url argument to the pager method can be omitted when an url_maker is
93 # The url argument to the pager method can be omitted when an url_maker is
94 # given during instantiation:
94 # given during instantiation:
95 >> my_page = Page(my_collection, page=3,
95 >> my_page = Page(my_collection, page=3,
96 url_maker=lambda p: "http://example.org/%s" % p)
96 url_maker=lambda p: "http://example.org/%s" % p)
97 >> page.pager()
97 >> page.pager()
98
98
99 There are some interesting parameters that customize the Page's behavior. See the documentation on
99 There are some interesting parameters that customize the Page's behavior. See the documentation on
100 ``Page`` and ``Page.pager()``.
100 ``Page`` and ``Page.pager()``.
101
101
102
102
103 Notes
103 Notes
104 -------
104 -------
105
105
106 Page numbers and item numbers start at 1. This concept has been used because users expect that the
106 Page numbers and item numbers start at 1. This concept has been used because users expect that the
107 first page has number 1 and the first item on a page also has number 1. So if you want to use the
107 first page has number 1 and the first item on a page also has number 1. So if you want to use the
108 page's items by their index number please note that you have to subtract 1.
108 page's items by their index number please note that you have to subtract 1.
109 """
109 """
110
110
111 import re
111 import re
112 import sys
112 import sys
113 from string import Template
113 from string import Template
114 from webhelpers2.html import literal
114 from webhelpers2.html import literal
115
115
116 # are we running at least python 3.x ?
116 # are we running at least python 3.x ?
117 PY3 = sys.version_info[0] >= 3
117 PY3 = sys.version_info[0] >= 3
118
118
119 if PY3:
119 if PY3:
120 unicode = str
120 unicode = str
121
121
122
122
123 def make_html_tag(tag, text=None, **params):
123 def make_html_tag(tag, text=None, **params):
124 """Create an HTML tag string.
124 """Create an HTML tag string.
125
125
126 tag
126 tag
127 The HTML tag to use (e.g. 'a', 'span' or 'div')
127 The HTML tag to use (e.g. 'a', 'span' or 'div')
128
128
129 text
129 text
130 The text to enclose between opening and closing tag. If no text is specified then only
130 The text to enclose between opening and closing tag. If no text is specified then only
131 the opening tag is returned.
131 the opening tag is returned.
132
132
133 Example::
133 Example::
134 make_html_tag('a', text="Hello", href="/another/page")
134 make_html_tag('a', text="Hello", href="/another/page")
135 -> <a href="/another/page">Hello</a>
135 -> <a href="/another/page">Hello</a>
136
136
137 To use reserved Python keywords like "class" as a parameter prepend it with
137 To use reserved Python keywords like "class" as a parameter prepend it with
138 an underscore. Instead of "class='green'" use "_class='green'".
138 an underscore. Instead of "class='green'" use "_class='green'".
139
139
140 Warning: Quotes and apostrophes are not escaped."""
140 Warning: Quotes and apostrophes are not escaped."""
141 params_string = ""
141 params_string = ""
142
142
143 # Parameters are passed. Turn the dict into a string like "a=1 b=2 c=3" string.
143 # Parameters are passed. Turn the dict into a string like "a=1 b=2 c=3" string.
144 for key, value in sorted(params.items()):
144 for key, value in sorted(params.items()):
145 # Strip off a leading underscore from the attribute's key to allow attributes like '_class'
145 # Strip off a leading underscore from the attribute's key to allow attributes like '_class'
146 # to be used as a CSS class specification instead of the reserved Python keyword 'class'.
146 # to be used as a CSS class specification instead of the reserved Python keyword 'class'.
147 key = key.lstrip("_")
147 key = key.lstrip("_")
148
148
149 params_string += u' {0}="{1}"'.format(key, value)
149 params_string += ' {0}="{1}"'.format(key, value)
150
150
151 # Create the tag string
151 # Create the tag string
152 tag_string = u"<{0}{1}>".format(tag, params_string)
152 tag_string = "<{0}{1}>".format(tag, params_string)
153
153
154 # Add text and closing tag if required.
154 # Add text and closing tag if required.
155 if text:
155 if text:
156 tag_string += u"{0}</{1}>".format(text, tag)
156 tag_string += "{0}</{1}>".format(text, tag)
157
157
158 return tag_string
158 return tag_string
159
159
160
160
161 # Since the items on a page are mainly a list we subclass the "list" type
161 # Since the items on a page are mainly a list we subclass the "list" type
162 class _Page(list):
162 class _Page(list):
163 """A list/iterator representing the items on one page of a larger collection.
163 """A list/iterator representing the items on one page of a larger collection.
164
164
165 An instance of the "Page" class is created from a _collection_ which is any
165 An instance of the "Page" class is created from a _collection_ which is any
166 list-like object that allows random access to its elements.
166 list-like object that allows random access to its elements.
167
167
168 The instance works as an iterator running from the first item to the last item on the given
168 The instance works as an iterator running from the first item to the last item on the given
169 page. The Page.pager() method creates a link list allowing the user to go to other pages.
169 page. The Page.pager() method creates a link list allowing the user to go to other pages.
170
170
171 A "Page" does not only carry the items on a certain page. It gives you additional information
171 A "Page" does not only carry the items on a certain page. It gives you additional information
172 about the page in these "Page" object attributes:
172 about the page in these "Page" object attributes:
173
173
174 item_count
174 item_count
175 Number of items in the collection
175 Number of items in the collection
176
176
177 **WARNING:** Unless you pass in an item_count, a count will be
177 **WARNING:** Unless you pass in an item_count, a count will be
178 performed on the collection every time a Page instance is created.
178 performed on the collection every time a Page instance is created.
179
179
180 page
180 page
181 Number of the current page
181 Number of the current page
182
182
183 items_per_page
183 items_per_page
184 Maximal number of items displayed on a page
184 Maximal number of items displayed on a page
185
185
186 first_page
186 first_page
187 Number of the first page - usually 1 :)
187 Number of the first page - usually 1 :)
188
188
189 last_page
189 last_page
190 Number of the last page
190 Number of the last page
191
191
192 previous_page
192 previous_page
193 Number of the previous page. If this is the first page it returns None.
193 Number of the previous page. If this is the first page it returns None.
194
194
195 next_page
195 next_page
196 Number of the next page. If this is the last page it returns None.
196 Number of the next page. If this is the last page it returns None.
197
197
198 page_count
198 page_count
199 Number of pages
199 Number of pages
200
200
201 items
201 items
202 Sequence/iterator of items on the current page
202 Sequence/iterator of items on the current page
203
203
204 first_item
204 first_item
205 Index of first item on the current page - starts with 1
205 Index of first item on the current page - starts with 1
206
206
207 last_item
207 last_item
208 Index of last item on the current page
208 Index of last item on the current page
209 """
209 """
210
210
211 def __init__(
211 def __init__(
212 self,
212 self,
213 collection,
213 collection,
214 page=1,
214 page=1,
215 items_per_page=20,
215 items_per_page=20,
216 item_count=None,
216 item_count=None,
217 wrapper_class=None,
217 wrapper_class=None,
218 url_maker=None,
218 url_maker=None,
219 bar_size=10,
219 bar_size=10,
220 **kwargs
220 **kwargs
221 ):
221 ):
222 """Create a "Page" instance.
222 """Create a "Page" instance.
223
223
224 Parameters:
224 Parameters:
225
225
226 collection
226 collection
227 Sequence representing the collection of items to page through.
227 Sequence representing the collection of items to page through.
228
228
229 page
229 page
230 The requested page number - starts with 1. Default: 1.
230 The requested page number - starts with 1. Default: 1.
231
231
232 items_per_page
232 items_per_page
233 The maximal number of items to be displayed per page.
233 The maximal number of items to be displayed per page.
234 Default: 20.
234 Default: 20.
235
235
236 item_count (optional)
236 item_count (optional)
237 The total number of items in the collection - if known.
237 The total number of items in the collection - if known.
238 If this parameter is not given then the paginator will count
238 If this parameter is not given then the paginator will count
239 the number of elements in the collection every time a "Page"
239 the number of elements in the collection every time a "Page"
240 is created. Giving this parameter will speed up things. In a busy
240 is created. Giving this parameter will speed up things. In a busy
241 real-life application you may want to cache the number of items.
241 real-life application you may want to cache the number of items.
242
242
243 url_maker (optional)
243 url_maker (optional)
244 Callback to generate the URL of other pages, given its numbers.
244 Callback to generate the URL of other pages, given its numbers.
245 Must accept one int parameter and return a URI string.
245 Must accept one int parameter and return a URI string.
246
246
247 bar_size
247 bar_size
248 maximum size of rendered pages numbers within radius
248 maximum size of rendered pages numbers within radius
249
249
250 """
250 """
251 if collection is not None:
251 if collection is not None:
252 if wrapper_class is None:
252 if wrapper_class is None:
253 # Default case. The collection is already a list-type object.
253 # Default case. The collection is already a list-type object.
254 self.collection = collection
254 self.collection = collection
255 else:
255 else:
256 # Special case. A custom wrapper class is used to access elements of the collection.
256 # Special case. A custom wrapper class is used to access elements of the collection.
257 self.collection = wrapper_class(collection)
257 self.collection = wrapper_class(collection)
258 else:
258 else:
259 self.collection = []
259 self.collection = []
260
260
261 self.collection_type = type(collection)
261 self.collection_type = type(collection)
262
262
263 if url_maker is not None:
263 if url_maker is not None:
264 self.url_maker = url_maker
264 self.url_maker = url_maker
265 else:
265 else:
266 self.url_maker = self._default_url_maker
266 self.url_maker = self._default_url_maker
267 self.bar_size = bar_size
267 self.bar_size = bar_size
268 # Assign kwargs to self
268 # Assign kwargs to self
269 self.kwargs = kwargs
269 self.kwargs = kwargs
270
270
271 # The self.page is the number of the current page.
271 # The self.page is the number of the current page.
272 # The first page has the number 1!
272 # The first page has the number 1!
273 try:
273 try:
274 self.page = int(page) # make it int() if we get it as a string
274 self.page = int(page) # make it int() if we get it as a string
275 except (ValueError, TypeError):
275 except (ValueError, TypeError):
276 self.page = 1
276 self.page = 1
277 # normally page should be always at least 1 but the original maintainer
277 # normally page should be always at least 1 but the original maintainer
278 # decided that for empty collection and empty page it can be...0? (based on tests)
278 # decided that for empty collection and empty page it can be...0? (based on tests)
279 # preserving behavior for BW compat
279 # preserving behavior for BW compat
280 if self.page < 1:
280 if self.page < 1:
281 self.page = 1
281 self.page = 1
282
282
283 self.items_per_page = items_per_page
283 self.items_per_page = items_per_page
284
284
285 # We subclassed "list" so we need to call its init() method
285 # We subclassed "list" so we need to call its init() method
286 # and fill the new list with the items to be displayed on the page.
286 # and fill the new list with the items to be displayed on the page.
287 # We use list() so that the items on the current page are retrieved
287 # We use list() so that the items on the current page are retrieved
288 # only once. In an SQL context that could otherwise lead to running the
288 # only once. In an SQL context that could otherwise lead to running the
289 # same SQL query every time items would be accessed.
289 # same SQL query every time items would be accessed.
290 # We do this here, prior to calling len() on the collection so that a
290 # We do this here, prior to calling len() on the collection so that a
291 # wrapper class can execute a query with the knowledge of what the
291 # wrapper class can execute a query with the knowledge of what the
292 # slice will be (for efficiency) and, in the same query, ask for the
292 # slice will be (for efficiency) and, in the same query, ask for the
293 # total number of items and only execute one query.
293 # total number of items and only execute one query.
294
294
295 try:
295 try:
296 first = (self.page - 1) * items_per_page
296 first = (self.page - 1) * items_per_page
297 last = first + items_per_page
297 last = first + items_per_page
298 self.items = list(self.collection[first:last])
298 self.items = list(self.collection[first:last])
299 except TypeError as err:
299 except TypeError as err:
300 raise TypeError(
300 raise TypeError(
301 f"Your collection of type {type(self.collection)} cannot be handled "
301 f"Your collection of type {type(self.collection)} cannot be handled "
302 f"by paginate. ERROR:{err}"
302 f"by paginate. ERROR:{err}"
303 )
303 )
304
304
305 # Unless the user tells us how many items the collections has
305 # Unless the user tells us how many items the collections has
306 # we calculate that ourselves.
306 # we calculate that ourselves.
307 if item_count is not None:
307 if item_count is not None:
308 self.item_count = item_count
308 self.item_count = item_count
309 else:
309 else:
310 self.item_count = len(self.collection)
310 self.item_count = len(self.collection)
311
311
312 # Compute the number of the first and last available page
312 # Compute the number of the first and last available page
313 if self.item_count > 0:
313 if self.item_count > 0:
314 self.first_page = 1
314 self.first_page = 1
315 self.page_count = ((self.item_count - 1) // self.items_per_page) + 1
315 self.page_count = ((self.item_count - 1) // self.items_per_page) + 1
316 self.last_page = self.first_page + self.page_count - 1
316 self.last_page = self.first_page + self.page_count - 1
317
317
318 # Make sure that the requested page number is the range of valid pages
318 # Make sure that the requested page number is the range of valid pages
319 if self.page > self.last_page:
319 if self.page > self.last_page:
320 self.page = self.last_page
320 self.page = self.last_page
321 elif self.page < self.first_page:
321 elif self.page < self.first_page:
322 self.page = self.first_page
322 self.page = self.first_page
323
323
324 # Note: the number of items on this page can be less than
324 # Note: the number of items on this page can be less than
325 # items_per_page if the last page is not full
325 # items_per_page if the last page is not full
326 self.first_item = (self.page - 1) * items_per_page + 1
326 self.first_item = (self.page - 1) * items_per_page + 1
327 self.last_item = min(self.first_item + items_per_page - 1, self.item_count)
327 self.last_item = min(self.first_item + items_per_page - 1, self.item_count)
328
328
329 # Links to previous and next page
329 # Links to previous and next page
330 if self.page > self.first_page:
330 if self.page > self.first_page:
331 self.previous_page = self.page - 1
331 self.previous_page = self.page - 1
332 else:
332 else:
333 self.previous_page = None
333 self.previous_page = None
334
334
335 if self.page < self.last_page:
335 if self.page < self.last_page:
336 self.next_page = self.page + 1
336 self.next_page = self.page + 1
337 else:
337 else:
338 self.next_page = None
338 self.next_page = None
339
339
340 # No items available
340 # No items available
341 else:
341 else:
342 self.first_page = None
342 self.first_page = None
343 self.page_count = 0
343 self.page_count = 0
344 self.last_page = None
344 self.last_page = None
345 self.first_item = None
345 self.first_item = None
346 self.last_item = None
346 self.last_item = None
347 self.previous_page = None
347 self.previous_page = None
348 self.next_page = None
348 self.next_page = None
349 self.items = []
349 self.items = []
350
350
351 # This is a subclass of the 'list' type. Initialise the list now.
351 # This is a subclass of the 'list' type. Initialise the list now.
352 list.__init__(self, self.items)
352 list.__init__(self, self.items)
353
353
354 def __str__(self):
354 def __str__(self):
355 return (
355 return (
356 "Page:\n"
356 "Page:\n"
357 "Collection type: {0.collection_type}\n"
357 "Collection type: {0.collection_type}\n"
358 "Current page: {0.page}\n"
358 "Current page: {0.page}\n"
359 "First item: {0.first_item}\n"
359 "First item: {0.first_item}\n"
360 "Last item: {0.last_item}\n"
360 "Last item: {0.last_item}\n"
361 "First page: {0.first_page}\n"
361 "First page: {0.first_page}\n"
362 "Last page: {0.last_page}\n"
362 "Last page: {0.last_page}\n"
363 "Previous page: {0.previous_page}\n"
363 "Previous page: {0.previous_page}\n"
364 "Next page: {0.next_page}\n"
364 "Next page: {0.next_page}\n"
365 "Items per page: {0.items_per_page}\n"
365 "Items per page: {0.items_per_page}\n"
366 "Total number of items: {0.item_count}\n"
366 "Total number of items: {0.item_count}\n"
367 "Number of pages: {0.page_count}\n"
367 "Number of pages: {0.page_count}\n"
368 ).format(self)
368 ).format(self)
369
369
370 def __repr__(self):
370 def __repr__(self):
371 return "<paginate.Page: Page {0}/{1}>".format(self.page, self.page_count)
371 return "<paginate.Page: Page {0}/{1}>".format(self.page, self.page_count)
372
372
373 def pager(
373 def pager(
374 self,
374 self,
375 tmpl_format="~2~",
375 tmpl_format="~2~",
376 url=None,
376 url=None,
377 show_if_single_page=False,
377 show_if_single_page=False,
378 separator=" ",
378 separator=" ",
379 symbol_first="&lt;&lt;",
379 symbol_first="&lt;&lt;",
380 symbol_last="&gt;&gt;",
380 symbol_last="&gt;&gt;",
381 symbol_previous="&lt;",
381 symbol_previous="&lt;",
382 symbol_next="&gt;",
382 symbol_next="&gt;",
383 link_attr=None,
383 link_attr=None,
384 curpage_attr=None,
384 curpage_attr=None,
385 dotdot_attr=None,
385 dotdot_attr=None,
386 link_tag=None,
386 link_tag=None,
387 ):
387 ):
388 """
388 """
389 Return string with links to other pages (e.g. '1 .. 5 6 7 [8] 9 10 11 .. 50').
389 Return string with links to other pages (e.g. '1 .. 5 6 7 [8] 9 10 11 .. 50').
390
390
391 tmpl_format:
391 tmpl_format:
392 Format string that defines how the pager is rendered. The string
392 Format string that defines how the pager is rendered. The string
393 can contain the following $-tokens that are substituted by the
393 can contain the following $-tokens that are substituted by the
394 string.Template module:
394 string.Template module:
395
395
396 - $first_page: number of first reachable page
396 - $first_page: number of first reachable page
397 - $last_page: number of last reachable page
397 - $last_page: number of last reachable page
398 - $page: number of currently selected page
398 - $page: number of currently selected page
399 - $page_count: number of reachable pages
399 - $page_count: number of reachable pages
400 - $items_per_page: maximal number of items per page
400 - $items_per_page: maximal number of items per page
401 - $first_item: index of first item on the current page
401 - $first_item: index of first item on the current page
402 - $last_item: index of last item on the current page
402 - $last_item: index of last item on the current page
403 - $item_count: total number of items
403 - $item_count: total number of items
404 - $link_first: link to first page (unless this is first page)
404 - $link_first: link to first page (unless this is first page)
405 - $link_last: link to last page (unless this is last page)
405 - $link_last: link to last page (unless this is last page)
406 - $link_previous: link to previous page (unless this is first page)
406 - $link_previous: link to previous page (unless this is first page)
407 - $link_next: link to next page (unless this is last page)
407 - $link_next: link to next page (unless this is last page)
408
408
409 To render a range of pages the token '~3~' can be used. The
409 To render a range of pages the token '~3~' can be used. The
410 number sets the radius of pages around the current page.
410 number sets the radius of pages around the current page.
411 Example for a range with radius 3:
411 Example for a range with radius 3:
412
412
413 '1 .. 5 6 7 [8] 9 10 11 .. 50'
413 '1 .. 5 6 7 [8] 9 10 11 .. 50'
414
414
415 Default: '~2~'
415 Default: '~2~'
416
416
417 url
417 url
418 The URL that page links will point to. Make sure it contains the string
418 The URL that page links will point to. Make sure it contains the string
419 $page which will be replaced by the actual page number.
419 $page which will be replaced by the actual page number.
420 Must be given unless a url_maker is specified to __init__, in which
420 Must be given unless a url_maker is specified to __init__, in which
421 case this parameter is ignored.
421 case this parameter is ignored.
422
422
423 symbol_first
423 symbol_first
424 String to be displayed as the text for the $link_first link above.
424 String to be displayed as the text for the $link_first link above.
425
425
426 Default: '&lt;&lt;' (<<)
426 Default: '&lt;&lt;' (<<)
427
427
428 symbol_last
428 symbol_last
429 String to be displayed as the text for the $link_last link above.
429 String to be displayed as the text for the $link_last link above.
430
430
431 Default: '&gt;&gt;' (>>)
431 Default: '&gt;&gt;' (>>)
432
432
433 symbol_previous
433 symbol_previous
434 String to be displayed as the text for the $link_previous link above.
434 String to be displayed as the text for the $link_previous link above.
435
435
436 Default: '&lt;' (<)
436 Default: '&lt;' (<)
437
437
438 symbol_next
438 symbol_next
439 String to be displayed as the text for the $link_next link above.
439 String to be displayed as the text for the $link_next link above.
440
440
441 Default: '&gt;' (>)
441 Default: '&gt;' (>)
442
442
443 separator:
443 separator:
444 String that is used to separate page links/numbers in the above range of pages.
444 String that is used to separate page links/numbers in the above range of pages.
445
445
446 Default: ' '
446 Default: ' '
447
447
448 show_if_single_page:
448 show_if_single_page:
449 if True the navigator will be shown even if there is only one page.
449 if True the navigator will be shown even if there is only one page.
450
450
451 Default: False
451 Default: False
452
452
453 link_attr (optional)
453 link_attr (optional)
454 A dictionary of attributes that get added to A-HREF links pointing to other pages. Can
454 A dictionary of attributes that get added to A-HREF links pointing to other pages. Can
455 be used to define a CSS style or class to customize the look of links.
455 be used to define a CSS style or class to customize the look of links.
456
456
457 Example: { 'style':'border: 1px solid green' }
457 Example: { 'style':'border: 1px solid green' }
458 Example: { 'class':'pager_link' }
458 Example: { 'class':'pager_link' }
459
459
460 curpage_attr (optional)
460 curpage_attr (optional)
461 A dictionary of attributes that get added to the current page number in the pager (which
461 A dictionary of attributes that get added to the current page number in the pager (which
462 is obviously not a link). If this dictionary is not empty then the elements will be
462 is obviously not a link). If this dictionary is not empty then the elements will be
463 wrapped in a SPAN tag with the given attributes.
463 wrapped in a SPAN tag with the given attributes.
464
464
465 Example: { 'style':'border: 3px solid blue' }
465 Example: { 'style':'border: 3px solid blue' }
466 Example: { 'class':'pager_curpage' }
466 Example: { 'class':'pager_curpage' }
467
467
468 dotdot_attr (optional)
468 dotdot_attr (optional)
469 A dictionary of attributes that get added to the '..' string in the pager (which is
469 A dictionary of attributes that get added to the '..' string in the pager (which is
470 obviously not a link). If this dictionary is not empty then the elements will be wrapped
470 obviously not a link). If this dictionary is not empty then the elements will be wrapped
471 in a SPAN tag with the given attributes.
471 in a SPAN tag with the given attributes.
472
472
473 Example: { 'style':'color: #808080' }
473 Example: { 'style':'color: #808080' }
474 Example: { 'class':'pager_dotdot' }
474 Example: { 'class':'pager_dotdot' }
475
475
476 link_tag (optional)
476 link_tag (optional)
477 A callable that accepts single argument `page` (page link information)
477 A callable that accepts single argument `page` (page link information)
478 and generates string with html that represents the link for specific page.
478 and generates string with html that represents the link for specific page.
479 Page objects are supplied from `link_map()` so the keys are the same.
479 Page objects are supplied from `link_map()` so the keys are the same.
480
480
481
481
482 """
482 """
483 link_attr = link_attr or {}
483 link_attr = link_attr or {}
484 curpage_attr = curpage_attr or {}
484 curpage_attr = curpage_attr or {}
485 dotdot_attr = dotdot_attr or {}
485 dotdot_attr = dotdot_attr or {}
486 self.curpage_attr = curpage_attr
486 self.curpage_attr = curpage_attr
487 self.separator = separator
487 self.separator = separator
488 self.link_attr = link_attr
488 self.link_attr = link_attr
489 self.dotdot_attr = dotdot_attr
489 self.dotdot_attr = dotdot_attr
490 self.url = url
490 self.url = url
491 self.link_tag = link_tag or self.default_link_tag
491 self.link_tag = link_tag or self.default_link_tag
492
492
493 # Don't show navigator if there is no more than one page
493 # Don't show navigator if there is no more than one page
494 if self.page_count == 0 or (self.page_count == 1 and not show_if_single_page):
494 if self.page_count == 0 or (self.page_count == 1 and not show_if_single_page):
495 return ""
495 return ""
496
496
497 regex_res = re.search(r"~(\d+)~", tmpl_format)
497 regex_res = re.search(r"~(\d+)~", tmpl_format)
498 if regex_res:
498 if regex_res:
499 radius = regex_res.group(1)
499 radius = regex_res.group(1)
500 else:
500 else:
501 radius = 2
501 radius = 2
502
502
503 self.radius = int(radius)
503 self.radius = int(radius)
504 link_map = self.link_map(
504 link_map = self.link_map(
505 tmpl_format=tmpl_format,
505 tmpl_format=tmpl_format,
506 url=url,
506 url=url,
507 show_if_single_page=show_if_single_page,
507 show_if_single_page=show_if_single_page,
508 separator=separator,
508 separator=separator,
509 symbol_first=symbol_first,
509 symbol_first=symbol_first,
510 symbol_last=symbol_last,
510 symbol_last=symbol_last,
511 symbol_previous=symbol_previous,
511 symbol_previous=symbol_previous,
512 symbol_next=symbol_next,
512 symbol_next=symbol_next,
513 link_attr=link_attr,
513 link_attr=link_attr,
514 curpage_attr=curpage_attr,
514 curpage_attr=curpage_attr,
515 dotdot_attr=dotdot_attr,
515 dotdot_attr=dotdot_attr,
516 link_tag=link_tag,
516 link_tag=link_tag,
517 )
517 )
518 links_markup = self._range(link_map, self.radius)
518 links_markup = self._range(link_map, self.radius)
519
519
520 # Replace ~...~ in token tmpl_format by range of pages
520 # Replace ~...~ in token tmpl_format by range of pages
521 result = re.sub(r"~(\d+)~", links_markup, tmpl_format)
521 result = re.sub(r"~(\d+)~", links_markup, tmpl_format)
522
522
523 link_first = (
523 link_first = (
524 self.page > self.first_page and self.link_tag(link_map["first_page"]) or ""
524 self.page > self.first_page and self.link_tag(link_map["first_page"]) or ""
525 )
525 )
526 link_last = (
526 link_last = (
527 self.page < self.last_page and self.link_tag(link_map["last_page"]) or ""
527 self.page < self.last_page and self.link_tag(link_map["last_page"]) or ""
528 )
528 )
529 link_previous = (
529 link_previous = (
530 self.previous_page and self.link_tag(link_map["previous_page"]) or ""
530 self.previous_page and self.link_tag(link_map["previous_page"]) or ""
531 )
531 )
532 link_next = self.next_page and self.link_tag(link_map["next_page"]) or ""
532 link_next = self.next_page and self.link_tag(link_map["next_page"]) or ""
533 # Interpolate '$' variables
533 # Interpolate '$' variables
534 result = Template(result).safe_substitute(
534 result = Template(result).safe_substitute(
535 {
535 {
536 "first_page": self.first_page,
536 "first_page": self.first_page,
537 "last_page": self.last_page,
537 "last_page": self.last_page,
538 "page": self.page,
538 "page": self.page,
539 "page_count": self.page_count,
539 "page_count": self.page_count,
540 "items_per_page": self.items_per_page,
540 "items_per_page": self.items_per_page,
541 "first_item": self.first_item,
541 "first_item": self.first_item,
542 "last_item": self.last_item,
542 "last_item": self.last_item,
543 "item_count": self.item_count,
543 "item_count": self.item_count,
544 "link_first": link_first,
544 "link_first": link_first,
545 "link_last": link_last,
545 "link_last": link_last,
546 "link_previous": link_previous,
546 "link_previous": link_previous,
547 "link_next": link_next,
547 "link_next": link_next,
548 }
548 }
549 )
549 )
550
550
551 return result
551 return result
552
552
553 def _get_edges(self, cur_page, max_page, items):
553 def _get_edges(self, cur_page, max_page, items):
554 cur_page = int(cur_page)
554 cur_page = int(cur_page)
555 edge = (items / 2) + 1
555 edge = (items // 2) + 1
556 if cur_page <= edge:
556 if cur_page <= edge:
557 radius = max(items / 2, items - cur_page)
557 radius = max(items // 2, items - cur_page)
558 elif (max_page - cur_page) < edge:
558 elif (max_page - cur_page) < edge:
559 radius = (items - 1) - (max_page - cur_page)
559 radius = (items - 1) - (max_page - cur_page)
560 else:
560 else:
561 radius = (items / 2) - 1
561 radius = (items // 2) - 1
562
562
563 left = max(1, (cur_page - radius))
563 left = max(1, (cur_page - radius))
564 right = min(max_page, cur_page + radius)
564 right = min(max_page, cur_page + radius)
565 return left, right
565 return left, right
566
566
567 def link_map(
567 def link_map(
568 self,
568 self,
569 tmpl_format="~2~",
569 tmpl_format="~2~",
570 url=None,
570 url=None,
571 show_if_single_page=False,
571 show_if_single_page=False,
572 separator=" ",
572 separator=" ",
573 symbol_first="&lt;&lt;",
573 symbol_first="&lt;&lt;",
574 symbol_last="&gt;&gt;",
574 symbol_last="&gt;&gt;",
575 symbol_previous="&lt;",
575 symbol_previous="&lt;",
576 symbol_next="&gt;",
576 symbol_next="&gt;",
577 link_attr=None,
577 link_attr=None,
578 curpage_attr=None,
578 curpage_attr=None,
579 dotdot_attr=None,
579 dotdot_attr=None,
580 link_tag=None
580 link_tag=None
581 ):
581 ):
582 """ Return map with links to other pages if default pager() function is not suitable solution.
582 """ Return map with links to other pages if default pager() function is not suitable solution.
583 tmpl_format:
583 tmpl_format:
584 Format string that defines how the pager would be normally rendered rendered. Uses same arguments as pager()
584 Format string that defines how the pager would be normally rendered rendered. Uses same arguments as pager()
585 method, but returns a simple dictionary in form of:
585 method, but returns a simple dictionary in form of:
586 {'current_page': {'attrs': {},
586 {'current_page': {'attrs': {},
587 'href': 'http://example.org/foo/page=1',
587 'href': 'http://example.org/foo/page=1',
588 'value': 1},
588 'value': 1},
589 'first_page': {'attrs': {},
589 'first_page': {'attrs': {},
590 'href': 'http://example.org/foo/page=1',
590 'href': 'http://example.org/foo/page=1',
591 'type': 'first_page',
591 'type': 'first_page',
592 'value': 1},
592 'value': 1},
593 'last_page': {'attrs': {},
593 'last_page': {'attrs': {},
594 'href': 'http://example.org/foo/page=8',
594 'href': 'http://example.org/foo/page=8',
595 'type': 'last_page',
595 'type': 'last_page',
596 'value': 8},
596 'value': 8},
597 'next_page': {'attrs': {}, 'href': 'HREF', 'type': 'next_page', 'value': 2},
597 'next_page': {'attrs': {}, 'href': 'HREF', 'type': 'next_page', 'value': 2},
598 'previous_page': None,
598 'previous_page': None,
599 'range_pages': [{'attrs': {},
599 'range_pages': [{'attrs': {},
600 'href': 'http://example.org/foo/page=1',
600 'href': 'http://example.org/foo/page=1',
601 'type': 'current_page',
601 'type': 'current_page',
602 'value': 1},
602 'value': 1},
603 ....
603 ....
604 {'attrs': {}, 'href': '', 'type': 'span', 'value': '..'}]}
604 {'attrs': {}, 'href': '', 'type': 'span', 'value': '..'}]}
605
605
606
606
607 The string can contain the following $-tokens that are substituted by the
607 The string can contain the following $-tokens that are substituted by the
608 string.Template module:
608 string.Template module:
609
609
610 - $first_page: number of first reachable page
610 - $first_page: number of first reachable page
611 - $last_page: number of last reachable page
611 - $last_page: number of last reachable page
612 - $page: number of currently selected page
612 - $page: number of currently selected page
613 - $page_count: number of reachable pages
613 - $page_count: number of reachable pages
614 - $items_per_page: maximal number of items per page
614 - $items_per_page: maximal number of items per page
615 - $first_item: index of first item on the current page
615 - $first_item: index of first item on the current page
616 - $last_item: index of last item on the current page
616 - $last_item: index of last item on the current page
617 - $item_count: total number of items
617 - $item_count: total number of items
618 - $link_first: link to first page (unless this is first page)
618 - $link_first: link to first page (unless this is first page)
619 - $link_last: link to last page (unless this is last page)
619 - $link_last: link to last page (unless this is last page)
620 - $link_previous: link to previous page (unless this is first page)
620 - $link_previous: link to previous page (unless this is first page)
621 - $link_next: link to next page (unless this is last page)
621 - $link_next: link to next page (unless this is last page)
622
622
623 To render a range of pages the token '~3~' can be used. The
623 To render a range of pages the token '~3~' can be used. The
624 number sets the radius of pages around the current page.
624 number sets the radius of pages around the current page.
625 Example for a range with radius 3:
625 Example for a range with radius 3:
626
626
627 '1 .. 5 6 7 [8] 9 10 11 .. 50'
627 '1 .. 5 6 7 [8] 9 10 11 .. 50'
628
628
629 Default: '~2~'
629 Default: '~2~'
630
630
631 url
631 url
632 The URL that page links will point to. Make sure it contains the string
632 The URL that page links will point to. Make sure it contains the string
633 $page which will be replaced by the actual page number.
633 $page which will be replaced by the actual page number.
634 Must be given unless a url_maker is specified to __init__, in which
634 Must be given unless a url_maker is specified to __init__, in which
635 case this parameter is ignored.
635 case this parameter is ignored.
636
636
637 symbol_first
637 symbol_first
638 String to be displayed as the text for the $link_first link above.
638 String to be displayed as the text for the $link_first link above.
639
639
640 Default: '&lt;&lt;' (<<)
640 Default: '&lt;&lt;' (<<)
641
641
642 symbol_last
642 symbol_last
643 String to be displayed as the text for the $link_last link above.
643 String to be displayed as the text for the $link_last link above.
644
644
645 Default: '&gt;&gt;' (>>)
645 Default: '&gt;&gt;' (>>)
646
646
647 symbol_previous
647 symbol_previous
648 String to be displayed as the text for the $link_previous link above.
648 String to be displayed as the text for the $link_previous link above.
649
649
650 Default: '&lt;' (<)
650 Default: '&lt;' (<)
651
651
652 symbol_next
652 symbol_next
653 String to be displayed as the text for the $link_next link above.
653 String to be displayed as the text for the $link_next link above.
654
654
655 Default: '&gt;' (>)
655 Default: '&gt;' (>)
656
656
657 separator:
657 separator:
658 String that is used to separate page links/numbers in the above range of pages.
658 String that is used to separate page links/numbers in the above range of pages.
659
659
660 Default: ' '
660 Default: ' '
661
661
662 show_if_single_page:
662 show_if_single_page:
663 if True the navigator will be shown even if there is only one page.
663 if True the navigator will be shown even if there is only one page.
664
664
665 Default: False
665 Default: False
666
666
667 link_attr (optional)
667 link_attr (optional)
668 A dictionary of attributes that get added to A-HREF links pointing to other pages. Can
668 A dictionary of attributes that get added to A-HREF links pointing to other pages. Can
669 be used to define a CSS style or class to customize the look of links.
669 be used to define a CSS style or class to customize the look of links.
670
670
671 Example: { 'style':'border: 1px solid green' }
671 Example: { 'style':'border: 1px solid green' }
672 Example: { 'class':'pager_link' }
672 Example: { 'class':'pager_link' }
673
673
674 curpage_attr (optional)
674 curpage_attr (optional)
675 A dictionary of attributes that get added to the current page number in the pager (which
675 A dictionary of attributes that get added to the current page number in the pager (which
676 is obviously not a link). If this dictionary is not empty then the elements will be
676 is obviously not a link). If this dictionary is not empty then the elements will be
677 wrapped in a SPAN tag with the given attributes.
677 wrapped in a SPAN tag with the given attributes.
678
678
679 Example: { 'style':'border: 3px solid blue' }
679 Example: { 'style':'border: 3px solid blue' }
680 Example: { 'class':'pager_curpage' }
680 Example: { 'class':'pager_curpage' }
681
681
682 dotdot_attr (optional)
682 dotdot_attr (optional)
683 A dictionary of attributes that get added to the '..' string in the pager (which is
683 A dictionary of attributes that get added to the '..' string in the pager (which is
684 obviously not a link). If this dictionary is not empty then the elements will be wrapped
684 obviously not a link). If this dictionary is not empty then the elements will be wrapped
685 in a SPAN tag with the given attributes.
685 in a SPAN tag with the given attributes.
686
686
687 Example: { 'style':'color: #808080' }
687 Example: { 'style':'color: #808080' }
688 Example: { 'class':'pager_dotdot' }
688 Example: { 'class':'pager_dotdot' }
689 """
689 """
690 link_attr = link_attr or {}
690 link_attr = link_attr or {}
691 curpage_attr = curpage_attr or {}
691 curpage_attr = curpage_attr or {}
692 dotdot_attr = dotdot_attr or {}
692 dotdot_attr = dotdot_attr or {}
693 self.curpage_attr = curpage_attr
693 self.curpage_attr = curpage_attr
694 self.separator = separator
694 self.separator = separator
695 self.link_attr = link_attr
695 self.link_attr = link_attr
696 self.dotdot_attr = dotdot_attr
696 self.dotdot_attr = dotdot_attr
697 self.url = url
697 self.url = url
698
698
699 regex_res = re.search(r"~(\d+)~", tmpl_format)
699 regex_res = re.search(r"~(\d+)~", tmpl_format)
700 if regex_res:
700 if regex_res:
701 radius = regex_res.group(1)
701 radius = regex_res.group(1)
702 else:
702 else:
703 radius = 2
703 radius = 2
704
704
705 self.radius = int(radius)
705 self.radius = int(radius)
706
706
707 # Compute the first and last page number within the radius
707 # Compute the first and last page number within the radius
708 # e.g. '1 .. 5 6 [7] 8 9 .. 12'
708 # e.g. '1 .. 5 6 [7] 8 9 .. 12'
709 # -> leftmost_page = 5
709 # -> leftmost_page = 5
710 # -> rightmost_page = 9
710 # -> rightmost_page = 9
711 leftmost_page, rightmost_page = self._get_edges(
711 leftmost_page, rightmost_page = self._get_edges(
712 self.page, self.last_page, (self.radius * 2) + 1)
712 self.page, self.last_page, (self.radius * 2) + 1)
713
713
714 nav_items = {
714 nav_items = {
715 "first_page": None,
715 "first_page": None,
716 "last_page": None,
716 "last_page": None,
717 "previous_page": None,
717 "previous_page": None,
718 "next_page": None,
718 "next_page": None,
719 "current_page": None,
719 "current_page": None,
720 "radius": self.radius,
720 "radius": self.radius,
721 "range_pages": [],
721 "range_pages": [],
722 }
722 }
723
723
724 if leftmost_page is None or rightmost_page is None:
724 if leftmost_page is None or rightmost_page is None:
725 return nav_items
725 return nav_items
726
726
727 nav_items["first_page"] = {
727 nav_items["first_page"] = {
728 "type": "first_page",
728 "type": "first_page",
729 "value": unicode(symbol_first),
729 "value": str(symbol_first),
730 "attrs": self.link_attr,
730 "attrs": self.link_attr,
731 "number": self.first_page,
731 "number": self.first_page,
732 "href": self.url_maker(self.first_page),
732 "href": self.url_maker(self.first_page),
733 }
733 }
734
734
735 # Insert dots if there are pages between the first page
735 # Insert dots if there are pages between the first page
736 # and the currently displayed page range
736 # and the currently displayed page range
737 if leftmost_page - self.first_page > 1:
737 if leftmost_page - self.first_page > 1:
738 # Wrap in a SPAN tag if dotdot_attr is set
738 # Wrap in a SPAN tag if dotdot_attr is set
739 nav_items["range_pages"].append(
739 nav_items["range_pages"].append(
740 {
740 {
741 "type": "span",
741 "type": "span",
742 "value": "..",
742 "value": "..",
743 "attrs": self.dotdot_attr,
743 "attrs": self.dotdot_attr,
744 "href": "",
744 "href": "",
745 "number": None,
745 "number": None,
746 }
746 }
747 )
747 )
748
748
749 for this_page in range(leftmost_page, rightmost_page + 1):
749 for this_page in range(leftmost_page, rightmost_page + 1):
750 # Highlight the current page number and do not use a link
750 # Highlight the current page number and do not use a link
751 if this_page == self.page:
751 if this_page == self.page:
752 # Wrap in a SPAN tag if curpage_attr is set
752 # Wrap in a SPAN tag if curpage_attr is set
753 nav_items["range_pages"].append(
753 nav_items["range_pages"].append(
754 {
754 {
755 "type": "current_page",
755 "type": "current_page",
756 "value": unicode(this_page),
756 "value": str(this_page),
757 "number": this_page,
757 "number": this_page,
758 "attrs": self.curpage_attr,
758 "attrs": self.curpage_attr,
759 "href": self.url_maker(this_page),
759 "href": self.url_maker(this_page),
760 }
760 }
761 )
761 )
762 nav_items["current_page"] = {
762 nav_items["current_page"] = {
763 "value": this_page,
763 "value": this_page,
764 "attrs": self.curpage_attr,
764 "attrs": self.curpage_attr,
765 "type": "current_page",
765 "type": "current_page",
766 "href": self.url_maker(this_page),
766 "href": self.url_maker(this_page),
767 }
767 }
768 # Otherwise create just a link to that page
768 # Otherwise create just a link to that page
769 else:
769 else:
770 nav_items["range_pages"].append(
770 nav_items["range_pages"].append(
771 {
771 {
772 "type": "page",
772 "type": "page",
773 "value": unicode(this_page),
773 "value": str(this_page),
774 "number": this_page,
774 "number": this_page,
775 "attrs": self.link_attr,
775 "attrs": self.link_attr,
776 "href": self.url_maker(this_page),
776 "href": self.url_maker(this_page),
777 }
777 }
778 )
778 )
779
779
780 # Insert dots if there are pages between the displayed
780 # Insert dots if there are pages between the displayed
781 # page numbers and the end of the page range
781 # page numbers and the end of the page range
782 if self.last_page - rightmost_page > 1:
782 if self.last_page - rightmost_page > 1:
783 # Wrap in a SPAN tag if dotdot_attr is set
783 # Wrap in a SPAN tag if dotdot_attr is set
784 nav_items["range_pages"].append(
784 nav_items["range_pages"].append(
785 {
785 {
786 "type": "span",
786 "type": "span",
787 "value": "..",
787 "value": "..",
788 "attrs": self.dotdot_attr,
788 "attrs": self.dotdot_attr,
789 "href": "",
789 "href": "",
790 "number": None,
790 "number": None,
791 }
791 }
792 )
792 )
793
793
794 # Create a link to the very last page (unless we are on the last
794 # Create a link to the very last page (unless we are on the last
795 # page or there would be no need to insert '..' spacers)
795 # page or there would be no need to insert '..' spacers)
796 nav_items["last_page"] = {
796 nav_items["last_page"] = {
797 "type": "last_page",
797 "type": "last_page",
798 "value": unicode(symbol_last),
798 "value": str(symbol_last),
799 "attrs": self.link_attr,
799 "attrs": self.link_attr,
800 "href": self.url_maker(self.last_page),
800 "href": self.url_maker(self.last_page),
801 "number": self.last_page,
801 "number": self.last_page,
802 }
802 }
803
803
804 nav_items["previous_page"] = {
804 nav_items["previous_page"] = {
805 "type": "previous_page",
805 "type": "previous_page",
806 "value": unicode(symbol_previous),
806 "value": str(symbol_previous),
807 "attrs": self.link_attr,
807 "attrs": self.link_attr,
808 "number": self.previous_page or self.first_page,
808 "number": self.previous_page or self.first_page,
809 "href": self.url_maker(self.previous_page or self.first_page),
809 "href": self.url_maker(self.previous_page or self.first_page),
810 }
810 }
811
811
812 nav_items["next_page"] = {
812 nav_items["next_page"] = {
813 "type": "next_page",
813 "type": "next_page",
814 "value": unicode(symbol_next),
814 "value": str(symbol_next),
815 "attrs": self.link_attr,
815 "attrs": self.link_attr,
816 "number": self.next_page or self.last_page,
816 "number": self.next_page or self.last_page,
817 "href": self.url_maker(self.next_page or self.last_page),
817 "href": self.url_maker(self.next_page or self.last_page),
818 }
818 }
819
819
820 return nav_items
820 return nav_items
821
821
822 def _range(self, link_map, radius):
822 def _range(self, link_map, radius):
823 """
823 """
824 Return range of linked pages to substitute placeholder in pattern
824 Return range of linked pages to substitute placeholder in pattern
825 """
825 """
826 # Compute the first and last page number within the radius
826 # Compute the first and last page number within the radius
827 # e.g. '1 .. 5 6 [7] 8 9 .. 12'
827 # e.g. '1 .. 5 6 [7] 8 9 .. 12'
828 # -> leftmost_page = 5
828 # -> leftmost_page = 5
829 # -> rightmost_page = 9
829 # -> rightmost_page = 9
830 leftmost_page, rightmost_page = self._get_edges(
830 leftmost_page, rightmost_page = self._get_edges(
831 self.page, self.last_page, (radius * 2) + 1)
831 self.page, self.last_page, (radius * 2) + 1)
832
832
833 nav_items = []
833 nav_items = []
834 # Create a link to the first page (unless we are on the first page
834 # Create a link to the first page (unless we are on the first page
835 # or there would be no need to insert '..' spacers)
835 # or there would be no need to insert '..' spacers)
836 if self.first_page and self.page != self.first_page and self.first_page < leftmost_page:
836 if self.first_page and self.page != self.first_page and self.first_page < leftmost_page:
837 page = link_map["first_page"].copy()
837 page = link_map["first_page"].copy()
838 page["value"] = unicode(page["number"])
838 page["value"] = str(page["number"])
839 nav_items.append(self.link_tag(page))
839 nav_items.append(self.link_tag(page))
840
840
841 for item in link_map["range_pages"]:
841 for item in link_map["range_pages"]:
842 nav_items.append(self.link_tag(item))
842 nav_items.append(self.link_tag(item))
843
843
844 # Create a link to the very last page (unless we are on the last
844 # Create a link to the very last page (unless we are on the last
845 # page or there would be no need to insert '..' spacers)
845 # page or there would be no need to insert '..' spacers)
846 if self.last_page and self.page != self.last_page and rightmost_page < self.last_page:
846 if self.last_page and self.page != self.last_page and rightmost_page < self.last_page:
847 page = link_map["last_page"].copy()
847 page = link_map["last_page"].copy()
848 page["value"] = unicode(page["number"])
848 page["value"] = str(page["number"])
849 nav_items.append(self.link_tag(page))
849 nav_items.append(self.link_tag(page))
850
850
851 return self.separator.join(nav_items)
851 return self.separator.join(nav_items)
852
852
853 def _default_url_maker(self, page_number):
853 def _default_url_maker(self, page_number):
854 if self.url is None:
854 if self.url is None:
855 raise Exception(
855 raise Exception(
856 "You need to specify a 'url' parameter containing a '$page' placeholder."
856 "You need to specify a 'url' parameter containing a '$page' placeholder."
857 )
857 )
858
858
859 if "$page" not in self.url:
859 if "$page" not in self.url:
860 raise Exception("The 'url' parameter must contain a '$page' placeholder.")
860 raise Exception("The 'url' parameter must contain a '$page' placeholder.")
861
861
862 return self.url.replace("$page", unicode(page_number))
862 return self.url.replace("$page", str(page_number))
863
863
864 @staticmethod
864 @staticmethod
865 def default_link_tag(item):
865 def default_link_tag(item):
866 """
866 """
867 Create an A-HREF tag that points to another page.
867 Create an A-HREF tag that points to another page.
868 """
868 """
869 text = item["value"]
869 text = item["value"]
870 target_url = item["href"]
870 target_url = item["href"]
871
871
872 if not item["href"] or item["type"] in ("span", "current_page"):
872 if not item["href"] or item["type"] in ("span", "current_page"):
873 if item["attrs"]:
873 if item["attrs"]:
874 text = make_html_tag("span", **item["attrs"]) + text + "</span>"
874 text = make_html_tag("span", **item["attrs"]) + text + "</span>"
875 return text
875 return text
876
876
877 return make_html_tag("a", text=text, href=target_url, **item["attrs"])
877 return make_html_tag("a", text=text, href=target_url, **item["attrs"])
878
878
879 # Below is RhodeCode custom code
879 # Below is RhodeCode custom code
880
880
881 # Copyright (C) 2010-2020 RhodeCode GmbH
881 # Copyright (C) 2010-2020 RhodeCode GmbH
882 #
882 #
883 # This program is free software: you can redistribute it and/or modify
883 # This program is free software: you can redistribute it and/or modify
884 # it under the terms of the GNU Affero General Public License, version 3
884 # it under the terms of the GNU Affero General Public License, version 3
885 # (only), as published by the Free Software Foundation.
885 # (only), as published by the Free Software Foundation.
886 #
886 #
887 # This program is distributed in the hope that it will be useful,
887 # This program is distributed in the hope that it will be useful,
888 # but WITHOUT ANY WARRANTY; without even the implied warranty of
888 # but WITHOUT ANY WARRANTY; without even the implied warranty of
889 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
889 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
890 # GNU General Public License for more details.
890 # GNU General Public License for more details.
891 #
891 #
892 # You should have received a copy of the GNU Affero General Public License
892 # You should have received a copy of the GNU Affero General Public License
893 # along with this program. If not, see <http://www.gnu.org/licenses/>.
893 # along with this program. If not, see <http://www.gnu.org/licenses/>.
894 #
894 #
895 # This program is dual-licensed. If you wish to learn more about the
895 # This program is dual-licensed. If you wish to learn more about the
896 # RhodeCode Enterprise Edition, including its added features, Support services,
896 # RhodeCode Enterprise Edition, including its added features, Support services,
897 # and proprietary license terms, please see https://rhodecode.com/licenses/
897 # and proprietary license terms, please see https://rhodecode.com/licenses/
898
898
899
899
900 PAGE_FORMAT = '$link_previous ~3~ $link_next'
900 PAGE_FORMAT = '$link_previous ~3~ $link_next'
901
901
902
902
903 class SqlalchemyOrmWrapper(object):
903 class SqlalchemyOrmWrapper(object):
904 """Wrapper class to access elements of a collection."""
904 """Wrapper class to access elements of a collection."""
905
905
906 def __init__(self, pager, collection):
906 def __init__(self, pager, collection):
907 self.pager = pager
907 self.pager = pager
908 self.collection = collection
908 self.collection = collection
909
909
910 def __getitem__(self, range):
910 def __getitem__(self, range):
911 # Return a range of objects of an sqlalchemy.orm.query.Query object
911 # Return a range of objects of an sqlalchemy.orm.query.Query object
912 return self.collection[range]
912 return self.collection[range]
913
913
914 def __len__(self):
914 def __len__(self):
915 # support empty types, without actually making a query.
915 # support empty types, without actually making a query.
916 if self.collection is None or self.collection == []:
916 if self.collection is None or self.collection == []:
917 return 0
917 return 0
918
918
919 # Count the number of objects in an sqlalchemy.orm.query.Query object
919 # Count the number of objects in an sqlalchemy.orm.query.Query object
920 return self.collection.count()
920 return self.collection.count()
921
921
922
922
923 class CustomPager(_Page):
923 class CustomPager(_Page):
924
924
925 @staticmethod
925 @staticmethod
926 def disabled_link_tag(item):
926 def disabled_link_tag(item):
927 """
927 """
928 Create an A-HREF tag that is disabled
928 Create an A-HREF tag that is disabled
929 """
929 """
930 text = item['value']
930 text = item['value']
931 attrs = item['attrs'].copy()
931 attrs = item['attrs'].copy()
932 attrs['class'] = 'disabled ' + attrs['class']
932 attrs['class'] = 'disabled ' + attrs['class']
933
933
934 return make_html_tag('a', text=text, **attrs)
934 return make_html_tag('a', text=text, **attrs)
935
935
936 def render(self):
936 def render(self):
937 # Don't show navigator if there is no more than one page
937 # Don't show navigator if there is no more than one page
938 if self.page_count == 0:
938 if self.page_count == 0:
939 return ""
939 return ""
940
940
941 self.link_tag = self.default_link_tag
941 self.link_tag = self.default_link_tag
942
942
943 link_map = self.link_map(
943 link_map = self.link_map(
944 tmpl_format=PAGE_FORMAT, url=None,
944 tmpl_format=PAGE_FORMAT, url=None,
945 show_if_single_page=False, separator=' ',
945 show_if_single_page=False, separator=' ',
946 symbol_first='<<', symbol_last='>>',
946 symbol_first='<<', symbol_last='>>',
947 symbol_previous='<', symbol_next='>',
947 symbol_previous='<', symbol_next='>',
948 link_attr={'class': 'pager_link'},
948 link_attr={'class': 'pager_link'},
949 curpage_attr={'class': 'pager_curpage'},
949 curpage_attr={'class': 'pager_curpage'},
950 dotdot_attr={'class': 'pager_dotdot'})
950 dotdot_attr={'class': 'pager_dotdot'})
951
951
952 links_markup = self._range(link_map, self.radius)
952 links_markup = self._range(link_map, self.radius)
953
953
954 link_first = (
954 link_first = (
955 self.page > self.first_page and self.link_tag(link_map['first_page']) or ''
955 self.page > self.first_page and self.link_tag(link_map['first_page']) or ''
956 )
956 )
957 link_last = (
957 link_last = (
958 self.page < self.last_page and self.link_tag(link_map['last_page']) or ''
958 self.page < self.last_page and self.link_tag(link_map['last_page']) or ''
959 )
959 )
960
960
961 link_previous = (
961 link_previous = (
962 self.previous_page and self.link_tag(link_map['previous_page'])
962 self.previous_page and self.link_tag(link_map['previous_page'])
963 or self.disabled_link_tag(link_map['previous_page'])
963 or self.disabled_link_tag(link_map['previous_page'])
964 )
964 )
965 link_next = (
965 link_next = (
966 self.next_page and self.link_tag(link_map['next_page'])
966 self.next_page and self.link_tag(link_map['next_page'])
967 or self.disabled_link_tag(link_map['next_page'])
967 or self.disabled_link_tag(link_map['next_page'])
968 )
968 )
969
969
970 # Interpolate '$' variables
970 # Interpolate '$' variables
971 # Replace ~...~ in token tmpl_format by range of pages
971 # Replace ~...~ in token tmpl_format by range of pages
972 result = re.sub(r"~(\d+)~", links_markup, PAGE_FORMAT)
972 result = re.sub(r"~(\d+)~", links_markup, PAGE_FORMAT)
973 result = Template(result).safe_substitute(
973 result = Template(result).safe_substitute(
974 {
974 {
975 "links": links_markup,
975 "links": links_markup,
976 "first_page": self.first_page,
976 "first_page": self.first_page,
977 "last_page": self.last_page,
977 "last_page": self.last_page,
978 "page": self.page,
978 "page": self.page,
979 "page_count": self.page_count,
979 "page_count": self.page_count,
980 "items_per_page": self.items_per_page,
980 "items_per_page": self.items_per_page,
981 "first_item": self.first_item,
981 "first_item": self.first_item,
982 "last_item": self.last_item,
982 "last_item": self.last_item,
983 "item_count": self.item_count,
983 "item_count": self.item_count,
984 "link_first": link_first,
984 "link_first": link_first,
985 "link_last": link_last,
985 "link_last": link_last,
986 "link_previous": link_previous,
986 "link_previous": link_previous,
987 "link_next": link_next,
987 "link_next": link_next,
988 }
988 }
989 )
989 )
990
990
991 return literal(result)
991 return literal(result)
992
992
993
993
994 class Page(CustomPager):
994 class Page(CustomPager):
995 """
995 """
996 Custom pager to match rendering style with paginator
996 Custom pager to match rendering style with paginator
997 """
997 """
998
998
999 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
999 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1000 url_maker=None, **kwargs):
1000 url_maker=None, **kwargs):
1001 """
1001 """
1002 Special type of pager. We intercept collection to wrap it in our custom
1002 Special type of pager. We intercept collection to wrap it in our custom
1003 logic instead of using wrapper_class
1003 logic instead of using wrapper_class
1004 """
1004 """
1005
1005
1006 super(Page, self).__init__(collection=collection, page=page,
1006 super(Page, self).__init__(collection=collection, page=page,
1007 items_per_page=items_per_page, item_count=item_count,
1007 items_per_page=items_per_page, item_count=item_count,
1008 wrapper_class=None, url_maker=url_maker, **kwargs)
1008 wrapper_class=None, url_maker=url_maker, **kwargs)
1009
1009
1010
1010
1011 class SqlPage(CustomPager):
1011 class SqlPage(CustomPager):
1012 """
1012 """
1013 Custom pager to match rendering style with paginator
1013 Custom pager to match rendering style with paginator
1014 """
1014 """
1015
1015
1016 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1016 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1017 url_maker=None, **kwargs):
1017 url_maker=None, **kwargs):
1018 """
1018 """
1019 Special type of pager. We intercept collection to wrap it in our custom
1019 Special type of pager. We intercept collection to wrap it in our custom
1020 logic instead of using wrapper_class
1020 logic instead of using wrapper_class
1021 """
1021 """
1022 collection = SqlalchemyOrmWrapper(self, collection)
1022 collection = SqlalchemyOrmWrapper(self, collection)
1023
1023
1024 super(SqlPage, self).__init__(collection=collection, page=page,
1024 super(SqlPage, self).__init__(collection=collection, page=page,
1025 items_per_page=items_per_page, item_count=item_count,
1025 items_per_page=items_per_page, item_count=item_count,
1026 wrapper_class=None, url_maker=url_maker, **kwargs)
1026 wrapper_class=None, url_maker=url_maker, **kwargs)
1027
1027
1028
1028
1029 class RepoCommitsWrapper(object):
1029 class RepoCommitsWrapper(object):
1030 """Wrapper class to access elements of a collection."""
1030 """Wrapper class to access elements of a collection."""
1031
1031
1032 def __init__(self, pager, collection):
1032 def __init__(self, pager, collection):
1033 self.pager = pager
1033 self.pager = pager
1034 self.collection = collection
1034 self.collection = collection
1035
1035
1036 def __getitem__(self, range):
1036 def __getitem__(self, range):
1037 cur_page = self.pager.page
1037 cur_page = self.pager.page
1038 items_per_page = self.pager.items_per_page
1038 items_per_page = self.pager.items_per_page
1039 first_item = max(0, (len(self.collection) - (cur_page * items_per_page)))
1039 first_item = max(0, (len(self.collection) - (cur_page * items_per_page)))
1040 last_item = ((len(self.collection) - 1) - items_per_page * (cur_page - 1))
1040 last_item = ((len(self.collection) - 1) - items_per_page * (cur_page - 1))
1041 return reversed(list(self.collection[first_item:last_item + 1]))
1041 return reversed(list(self.collection[first_item:last_item + 1]))
1042
1042
1043 def __len__(self):
1043 def __len__(self):
1044 return len(self.collection)
1044 return len(self.collection)
1045
1045
1046
1046
1047 class RepoPage(CustomPager):
1047 class RepoPage(CustomPager):
1048 """
1048 """
1049 Create a "RepoPage" instance. special pager for paging repository
1049 Create a "RepoPage" instance. special pager for paging repository
1050 """
1050 """
1051
1051
1052 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1052 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1053 url_maker=None, **kwargs):
1053 url_maker=None, **kwargs):
1054 """
1054 """
1055 Special type of pager. We intercept collection to wrap it in our custom
1055 Special type of pager. We intercept collection to wrap it in our custom
1056 logic instead of using wrapper_class
1056 logic instead of using wrapper_class
1057 """
1057 """
1058 collection = RepoCommitsWrapper(self, collection)
1058 collection = RepoCommitsWrapper(self, collection)
1059 super(RepoPage, self).__init__(collection=collection, page=page,
1059 super(RepoPage, self).__init__(collection=collection, page=page,
1060 items_per_page=items_per_page, item_count=item_count,
1060 items_per_page=items_per_page, item_count=item_count,
1061 wrapper_class=None, url_maker=url_maker, **kwargs)
1061 wrapper_class=None, url_maker=url_maker, **kwargs)
@@ -1,264 +1,264 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2017-2020 RhodeCode GmbH
3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import re
22 import re
23 import time
23 import time
24 import datetime
24 import datetime
25 import dateutil
25 import dateutil
26 import pickle
26 import pickle
27
27
28 from rhodecode.model.db import DbSession, Session
28 from rhodecode.model.db import DbSession, Session
29
29
30
30
31 class CleanupCommand(Exception):
31 class CleanupCommand(Exception):
32 pass
32 pass
33
33
34
34
35 class BaseAuthSessions(object):
35 class BaseAuthSessions(object):
36 SESSION_TYPE = None
36 SESSION_TYPE = None
37 NOT_AVAILABLE = 'NOT AVAILABLE'
37 NOT_AVAILABLE = 'NOT AVAILABLE'
38
38
39 def __init__(self, config):
39 def __init__(self, config):
40 session_conf = {}
40 session_conf = {}
41 for k, v in config.items():
41 for k, v in config.items():
42 if k.startswith('beaker.session'):
42 if k.startswith('beaker.session'):
43 session_conf[k] = v
43 session_conf[k] = v
44 self.config = session_conf
44 self.config = session_conf
45
45
46 def get_count(self):
46 def get_count(self):
47 raise NotImplementedError
47 raise NotImplementedError
48
48
49 def get_expired_count(self, older_than_seconds=None):
49 def get_expired_count(self, older_than_seconds=None):
50 raise NotImplementedError
50 raise NotImplementedError
51
51
52 def clean_sessions(self, older_than_seconds=None):
52 def clean_sessions(self, older_than_seconds=None):
53 raise NotImplementedError
53 raise NotImplementedError
54
54
55 def _seconds_to_date(self, seconds):
55 def _seconds_to_date(self, seconds):
56 return datetime.datetime.utcnow() - dateutil.relativedelta.relativedelta(
56 return datetime.datetime.utcnow() - dateutil.relativedelta.relativedelta(
57 seconds=seconds)
57 seconds=seconds)
58
58
59
59
60 class DbAuthSessions(BaseAuthSessions):
60 class DbAuthSessions(BaseAuthSessions):
61 SESSION_TYPE = 'ext:database'
61 SESSION_TYPE = 'ext:database'
62
62
63 def get_count(self):
63 def get_count(self):
64 return DbSession.query().count()
64 return DbSession.query().count()
65
65
66 def get_expired_count(self, older_than_seconds=None):
66 def get_expired_count(self, older_than_seconds=None):
67 expiry_date = self._seconds_to_date(older_than_seconds)
67 expiry_date = self._seconds_to_date(older_than_seconds)
68 return DbSession.query().filter(DbSession.accessed < expiry_date).count()
68 return DbSession.query().filter(DbSession.accessed < expiry_date).count()
69
69
70 def clean_sessions(self, older_than_seconds=None):
70 def clean_sessions(self, older_than_seconds=None):
71 expiry_date = self._seconds_to_date(older_than_seconds)
71 expiry_date = self._seconds_to_date(older_than_seconds)
72 to_remove = DbSession.query().filter(DbSession.accessed < expiry_date).count()
72 to_remove = DbSession.query().filter(DbSession.accessed < expiry_date).count()
73 DbSession.query().filter(DbSession.accessed < expiry_date).delete()
73 DbSession.query().filter(DbSession.accessed < expiry_date).delete()
74 Session().commit()
74 Session().commit()
75 return to_remove
75 return to_remove
76
76
77
77
78 class FileAuthSessions(BaseAuthSessions):
78 class FileAuthSessions(BaseAuthSessions):
79 SESSION_TYPE = 'file sessions'
79 SESSION_TYPE = 'file sessions'
80
80
81 def _get_sessions_dir(self):
81 def _get_sessions_dir(self):
82 data_dir = self.config.get('beaker.session.data_dir')
82 data_dir = self.config.get('beaker.session.data_dir')
83 return data_dir
83 return data_dir
84
84
85 def _count_on_filesystem(self, path, older_than=0, callback=None):
85 def _count_on_filesystem(self, path, older_than=0, callback=None):
86 value = dict(percent=0, used=0, total=0, items=0, callbacks=0,
86 value = dict(percent=0, used=0, total=0, items=0, callbacks=0,
87 path=path, text='')
87 path=path, text='')
88 items_count = 0
88 items_count = 0
89 used = 0
89 used = 0
90 callbacks = 0
90 callbacks = 0
91 cur_time = time.time()
91 cur_time = time.time()
92 for root, dirs, files in os.walk(path):
92 for root, dirs, files in os.walk(path):
93 for f in files:
93 for f in files:
94 final_path = os.path.join(root, f)
94 final_path = os.path.join(root, f)
95 try:
95 try:
96 mtime = os.stat(final_path).st_mtime
96 mtime = os.stat(final_path).st_mtime
97 if (cur_time - mtime) > older_than:
97 if (cur_time - mtime) > older_than:
98 items_count += 1
98 items_count += 1
99 if callback:
99 if callback:
100 callback_res = callback(final_path)
100 callback_res = callback(final_path)
101 callbacks += 1
101 callbacks += 1
102 else:
102 else:
103 used += os.path.getsize(final_path)
103 used += os.path.getsize(final_path)
104 except OSError:
104 except OSError:
105 pass
105 pass
106 value.update({
106 value.update({
107 'percent': 100,
107 'percent': 100,
108 'used': used,
108 'used': used,
109 'total': used,
109 'total': used,
110 'items': items_count,
110 'items': items_count,
111 'callbacks': callbacks
111 'callbacks': callbacks
112 })
112 })
113 return value
113 return value
114
114
115 def get_count(self):
115 def get_count(self):
116 try:
116 try:
117 sessions_dir = self._get_sessions_dir()
117 sessions_dir = self._get_sessions_dir()
118 items_count = self._count_on_filesystem(sessions_dir)['items']
118 items_count = self._count_on_filesystem(sessions_dir)['items']
119 except Exception:
119 except Exception:
120 items_count = self.NOT_AVAILABLE
120 items_count = self.NOT_AVAILABLE
121 return items_count
121 return items_count
122
122
123 def get_expired_count(self, older_than_seconds=0):
123 def get_expired_count(self, older_than_seconds=0):
124 try:
124 try:
125 sessions_dir = self._get_sessions_dir()
125 sessions_dir = self._get_sessions_dir()
126 items_count = self._count_on_filesystem(
126 items_count = self._count_on_filesystem(
127 sessions_dir, older_than=older_than_seconds)['items']
127 sessions_dir, older_than=older_than_seconds)['items']
128 except Exception:
128 except Exception:
129 items_count = self.NOT_AVAILABLE
129 items_count = self.NOT_AVAILABLE
130 return items_count
130 return items_count
131
131
132 def clean_sessions(self, older_than_seconds=0):
132 def clean_sessions(self, older_than_seconds=0):
133 # find . -mtime +60 -exec rm {} \;
133 # find . -mtime +60 -exec rm {} \;
134
134
135 sessions_dir = self._get_sessions_dir()
135 sessions_dir = self._get_sessions_dir()
136
136
137 def remove_item(path):
137 def remove_item(path):
138 os.remove(path)
138 os.remove(path)
139
139
140 stats = self._count_on_filesystem(
140 stats = self._count_on_filesystem(
141 sessions_dir, older_than=older_than_seconds,
141 sessions_dir, older_than=older_than_seconds,
142 callback=remove_item)
142 callback=remove_item)
143 return stats['callbacks']
143 return stats['callbacks']
144
144
145
145
146 class MemcachedAuthSessions(BaseAuthSessions):
146 class MemcachedAuthSessions(BaseAuthSessions):
147 SESSION_TYPE = 'ext:memcached'
147 SESSION_TYPE = 'ext:memcached'
148 _key_regex = re.compile(r'ITEM (.*_session) \[(.*); (.*)\]')
148 _key_regex = re.compile(r'ITEM (.*_session) \[(.*); (.*)\]')
149
149
150 def _get_client(self):
150 def _get_client(self):
151 import memcache
151 import memcache
152 client = memcache.Client([self.config.get('beaker.session.url')])
152 client = memcache.Client([self.config.get('beaker.session.url')])
153 return client
153 return client
154
154
155 def _get_telnet_client(self, host, port):
155 def _get_telnet_client(self, host, port):
156 import telnetlib
156 import telnetlib
157 client = telnetlib.Telnet(host, port, None)
157 client = telnetlib.Telnet(host, port, None)
158 return client
158 return client
159
159
160 def _run_telnet_cmd(self, client, cmd):
160 def _run_telnet_cmd(self, client, cmd):
161 client.write("%s\n" % cmd)
161 client.write("%s\n" % cmd)
162 return client.read_until('END')
162 return client.read_until('END')
163
163
164 def key_details(self, client, slab_ids, limit=100):
164 def key_details(self, client, slab_ids, limit=100):
165 """ Return a list of tuples containing keys and details """
165 """ Return a list of tuples containing keys and details """
166 cmd = 'stats cachedump %s %s'
166 cmd = 'stats cachedump %s %s'
167 for slab_id in slab_ids:
167 for slab_id in slab_ids:
168 for key in self._key_regex.finditer(
168 for key in self._key_regex.finditer(
169 self._run_telnet_cmd(client, cmd % (slab_id, limit))):
169 self._run_telnet_cmd(client, cmd % (slab_id, limit))):
170 yield key
170 yield key
171
171
172 def get_count(self):
172 def get_count(self):
173 client = self._get_client()
173 client = self._get_client()
174 count = self.NOT_AVAILABLE
174 count = self.NOT_AVAILABLE
175 try:
175 try:
176 slabs = []
176 slabs = []
177 for server, slabs_data in client.get_slabs():
177 for server, slabs_data in client.get_slabs():
178 slabs.extend(slabs_data.keys())
178 slabs.extend(list(slabs_data.keys()))
179
179
180 host, port = client.servers[0].address
180 host, port = client.servers[0].address
181 telnet_client = self._get_telnet_client(host, port)
181 telnet_client = self._get_telnet_client(host, port)
182 keys = self.key_details(telnet_client, slabs)
182 keys = self.key_details(telnet_client, slabs)
183 count = 0
183 count = 0
184 for _k in keys:
184 for _k in keys:
185 count += 1
185 count += 1
186 except Exception:
186 except Exception:
187 return count
187 return count
188
188
189 return count
189 return count
190
190
191 def get_expired_count(self, older_than_seconds=None):
191 def get_expired_count(self, older_than_seconds=None):
192 return self.NOT_AVAILABLE
192 return self.NOT_AVAILABLE
193
193
194 def clean_sessions(self, older_than_seconds=None):
194 def clean_sessions(self, older_than_seconds=None):
195 raise CleanupCommand('Cleanup for this session type not yet available')
195 raise CleanupCommand('Cleanup for this session type not yet available')
196
196
197
197
198 class RedisAuthSessions(BaseAuthSessions):
198 class RedisAuthSessions(BaseAuthSessions):
199 SESSION_TYPE = 'ext:redis'
199 SESSION_TYPE = 'ext:redis'
200
200
201 def _get_client(self):
201 def _get_client(self):
202 import redis
202 import redis
203 args = {
203 args = {
204 'socket_timeout': 60,
204 'socket_timeout': 60,
205 'url': self.config.get('beaker.session.url')
205 'url': self.config.get('beaker.session.url')
206 }
206 }
207
207
208 client = redis.StrictRedis.from_url(**args)
208 client = redis.StrictRedis.from_url(**args)
209 return client
209 return client
210
210
211 def get_count(self):
211 def get_count(self):
212 client = self._get_client()
212 client = self._get_client()
213 return len(client.keys('beaker_cache:*'))
213 return len(client.keys('beaker_cache:*'))
214
214
215 def get_expired_count(self, older_than_seconds=None):
215 def get_expired_count(self, older_than_seconds=None):
216 expiry_date = self._seconds_to_date(older_than_seconds)
216 expiry_date = self._seconds_to_date(older_than_seconds)
217 return self.NOT_AVAILABLE
217 return self.NOT_AVAILABLE
218
218
219 def clean_sessions(self, older_than_seconds=None):
219 def clean_sessions(self, older_than_seconds=None):
220 client = self._get_client()
220 client = self._get_client()
221 expiry_time = time.time() - older_than_seconds
221 expiry_time = time.time() - older_than_seconds
222 deleted_keys = 0
222 deleted_keys = 0
223 for key in client.keys('beaker_cache:*'):
223 for key in client.keys('beaker_cache:*'):
224 data = client.get(key)
224 data = client.get(key)
225 if data:
225 if data:
226 json_data = pickle.loads(data)
226 json_data = pickle.loads(data)
227 try:
227 try:
228 accessed_time = json_data['_accessed_time']
228 accessed_time = json_data['_accessed_time']
229 except KeyError:
229 except KeyError:
230 accessed_time = 0
230 accessed_time = 0
231 if accessed_time < expiry_time:
231 if accessed_time < expiry_time:
232 client.delete(key)
232 client.delete(key)
233 deleted_keys += 1
233 deleted_keys += 1
234
234
235 return deleted_keys
235 return deleted_keys
236
236
237
237
238 class MemoryAuthSessions(BaseAuthSessions):
238 class MemoryAuthSessions(BaseAuthSessions):
239 SESSION_TYPE = 'memory'
239 SESSION_TYPE = 'memory'
240
240
241 def get_count(self):
241 def get_count(self):
242 return self.NOT_AVAILABLE
242 return self.NOT_AVAILABLE
243
243
244 def get_expired_count(self, older_than_seconds=None):
244 def get_expired_count(self, older_than_seconds=None):
245 return self.NOT_AVAILABLE
245 return self.NOT_AVAILABLE
246
246
247 def clean_sessions(self, older_than_seconds=None):
247 def clean_sessions(self, older_than_seconds=None):
248 raise CleanupCommand('Cleanup for this session type not yet available')
248 raise CleanupCommand('Cleanup for this session type not yet available')
249
249
250
250
251 def get_session_handler(session_type):
251 def get_session_handler(session_type):
252 types = {
252 types = {
253 'file': FileAuthSessions,
253 'file': FileAuthSessions,
254 'ext:memcached': MemcachedAuthSessions,
254 'ext:memcached': MemcachedAuthSessions,
255 'ext:redis': RedisAuthSessions,
255 'ext:redis': RedisAuthSessions,
256 'ext:database': DbAuthSessions,
256 'ext:database': DbAuthSessions,
257 'memory': MemoryAuthSessions
257 'memory': MemoryAuthSessions
258 }
258 }
259
259
260 try:
260 try:
261 return types[session_type]
261 return types[session_type]
262 except KeyError:
262 except KeyError:
263 raise ValueError(
263 raise ValueError(
264 'This type {} is not supported'.format(session_type))
264 'This type {} is not supported'.format(session_type))
@@ -1,799 +1,799 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 Utilities library for RhodeCode
22 Utilities library for RhodeCode
23 """
23 """
24
24
25 import datetime
25 import datetime
26 import decorator
26 import decorator
27 import json
27 import json
28 import logging
28 import logging
29 import os
29 import os
30 import re
30 import re
31 import sys
31 import sys
32 import shutil
32 import shutil
33 import socket
33 import socket
34 import tempfile
34 import tempfile
35 import traceback
35 import traceback
36 import tarfile
36 import tarfile
37 import warnings
37 import warnings
38 import hashlib
38 import hashlib
39 from os.path import join as jn
39 from os.path import join as jn
40
40
41 import paste
41 import paste
42 import pkg_resources
42 import pkg_resources
43 from webhelpers2.text import collapse, remove_formatting
43 from webhelpers2.text import collapse, remove_formatting
44 from mako import exceptions
44 from mako import exceptions
45 from pyramid.threadlocal import get_current_registry
45 from pyramid.threadlocal import get_current_registry
46
46
47 from rhodecode.lib.vcs.backends.base import Config
47 from rhodecode.lib.vcs.backends.base import Config
48 from rhodecode.lib.vcs.exceptions import VCSError
48 from rhodecode.lib.vcs.exceptions import VCSError
49 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
49 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
50 from rhodecode.lib.utils2 import (
50 from rhodecode.lib.utils2 import (
51 safe_str, safe_unicode, get_current_rhodecode_user, md5, sha1)
51 safe_str, safe_unicode, get_current_rhodecode_user, md5, sha1)
52 from rhodecode.model import meta
52 from rhodecode.model import meta
53 from rhodecode.model.db import (
53 from rhodecode.model.db import (
54 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
54 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
55 from rhodecode.model.meta import Session
55 from rhodecode.model.meta import Session
56
56
57
57
58 log = logging.getLogger(__name__)
58 log = logging.getLogger(__name__)
59
59
60 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
60 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
61
61
62 # String which contains characters that are not allowed in slug names for
62 # String which contains characters that are not allowed in slug names for
63 # repositories or repository groups. It is properly escaped to use it in
63 # repositories or repository groups. It is properly escaped to use it in
64 # regular expressions.
64 # regular expressions.
65 SLUG_BAD_CHARS = re.escape('`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
65 SLUG_BAD_CHARS = re.escape('`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
66
66
67 # Regex that matches forbidden characters in repo/group slugs.
67 # Regex that matches forbidden characters in repo/group slugs.
68 SLUG_BAD_CHAR_RE = re.compile('[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
68 SLUG_BAD_CHAR_RE = re.compile('[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
69
69
70 # Regex that matches allowed characters in repo/group slugs.
70 # Regex that matches allowed characters in repo/group slugs.
71 SLUG_GOOD_CHAR_RE = re.compile('[^{}]'.format(SLUG_BAD_CHARS))
71 SLUG_GOOD_CHAR_RE = re.compile('[^{}]'.format(SLUG_BAD_CHARS))
72
72
73 # Regex that matches whole repo/group slugs.
73 # Regex that matches whole repo/group slugs.
74 SLUG_RE = re.compile('[^{}]+'.format(SLUG_BAD_CHARS))
74 SLUG_RE = re.compile('[^{}]+'.format(SLUG_BAD_CHARS))
75
75
76 _license_cache = None
76 _license_cache = None
77
77
78
78
79 def repo_name_slug(value):
79 def repo_name_slug(value):
80 """
80 """
81 Return slug of name of repository
81 Return slug of name of repository
82 This function is called on each creation/modification
82 This function is called on each creation/modification
83 of repository to prevent bad names in repo
83 of repository to prevent bad names in repo
84 """
84 """
85 replacement_char = '-'
85 replacement_char = '-'
86
86
87 slug = remove_formatting(value)
87 slug = remove_formatting(value)
88 slug = SLUG_BAD_CHAR_RE.sub('', slug)
88 slug = SLUG_BAD_CHAR_RE.sub('', slug)
89 slug = re.sub('[\s]+', '-', slug)
89 slug = re.sub('[\s]+', '-', slug)
90 slug = collapse(slug, replacement_char)
90 slug = collapse(slug, replacement_char)
91 return slug
91 return slug
92
92
93
93
94 #==============================================================================
94 #==============================================================================
95 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
95 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
96 #==============================================================================
96 #==============================================================================
97 def get_repo_slug(request):
97 def get_repo_slug(request):
98 _repo = ''
98 _repo = ''
99
99
100 if hasattr(request, 'db_repo'):
100 if hasattr(request, 'db_repo'):
101 # if our requests has set db reference use it for name, this
101 # if our requests has set db reference use it for name, this
102 # translates the example.com/_<id> into proper repo names
102 # translates the example.com/_<id> into proper repo names
103 _repo = request.db_repo.repo_name
103 _repo = request.db_repo.repo_name
104 elif getattr(request, 'matchdict', None):
104 elif getattr(request, 'matchdict', None):
105 # pyramid
105 # pyramid
106 _repo = request.matchdict.get('repo_name')
106 _repo = request.matchdict.get('repo_name')
107
107
108 if _repo:
108 if _repo:
109 _repo = _repo.rstrip('/')
109 _repo = _repo.rstrip('/')
110 return _repo
110 return _repo
111
111
112
112
113 def get_repo_group_slug(request):
113 def get_repo_group_slug(request):
114 _group = ''
114 _group = ''
115 if hasattr(request, 'db_repo_group'):
115 if hasattr(request, 'db_repo_group'):
116 # if our requests has set db reference use it for name, this
116 # if our requests has set db reference use it for name, this
117 # translates the example.com/_<id> into proper repo group names
117 # translates the example.com/_<id> into proper repo group names
118 _group = request.db_repo_group.group_name
118 _group = request.db_repo_group.group_name
119 elif getattr(request, 'matchdict', None):
119 elif getattr(request, 'matchdict', None):
120 # pyramid
120 # pyramid
121 _group = request.matchdict.get('repo_group_name')
121 _group = request.matchdict.get('repo_group_name')
122
122
123 if _group:
123 if _group:
124 _group = _group.rstrip('/')
124 _group = _group.rstrip('/')
125 return _group
125 return _group
126
126
127
127
128 def get_user_group_slug(request):
128 def get_user_group_slug(request):
129 _user_group = ''
129 _user_group = ''
130
130
131 if hasattr(request, 'db_user_group'):
131 if hasattr(request, 'db_user_group'):
132 _user_group = request.db_user_group.users_group_name
132 _user_group = request.db_user_group.users_group_name
133 elif getattr(request, 'matchdict', None):
133 elif getattr(request, 'matchdict', None):
134 # pyramid
134 # pyramid
135 _user_group = request.matchdict.get('user_group_id')
135 _user_group = request.matchdict.get('user_group_id')
136 _user_group_name = request.matchdict.get('user_group_name')
136 _user_group_name = request.matchdict.get('user_group_name')
137 try:
137 try:
138 if _user_group:
138 if _user_group:
139 _user_group = UserGroup.get(_user_group)
139 _user_group = UserGroup.get(_user_group)
140 elif _user_group_name:
140 elif _user_group_name:
141 _user_group = UserGroup.get_by_group_name(_user_group_name)
141 _user_group = UserGroup.get_by_group_name(_user_group_name)
142
142
143 if _user_group:
143 if _user_group:
144 _user_group = _user_group.users_group_name
144 _user_group = _user_group.users_group_name
145 except Exception:
145 except Exception:
146 log.exception('Failed to get user group by id and name')
146 log.exception('Failed to get user group by id and name')
147 # catch all failures here
147 # catch all failures here
148 return None
148 return None
149
149
150 return _user_group
150 return _user_group
151
151
152
152
153 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
153 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
154 """
154 """
155 Scans given path for repos and return (name,(type,path)) tuple
155 Scans given path for repos and return (name,(type,path)) tuple
156
156
157 :param path: path to scan for repositories
157 :param path: path to scan for repositories
158 :param recursive: recursive search and return names with subdirs in front
158 :param recursive: recursive search and return names with subdirs in front
159 """
159 """
160
160
161 # remove ending slash for better results
161 # remove ending slash for better results
162 path = path.rstrip(os.sep)
162 path = path.rstrip(os.sep)
163 log.debug('now scanning in %s location recursive:%s...', path, recursive)
163 log.debug('now scanning in %s location recursive:%s...', path, recursive)
164
164
165 def _get_repos(p):
165 def _get_repos(p):
166 dirpaths = _get_dirpaths(p)
166 dirpaths = _get_dirpaths(p)
167 if not _is_dir_writable(p):
167 if not _is_dir_writable(p):
168 log.warning('repo path without write access: %s', p)
168 log.warning('repo path without write access: %s', p)
169
169
170 for dirpath in dirpaths:
170 for dirpath in dirpaths:
171 if os.path.isfile(os.path.join(p, dirpath)):
171 if os.path.isfile(os.path.join(p, dirpath)):
172 continue
172 continue
173 cur_path = os.path.join(p, dirpath)
173 cur_path = os.path.join(p, dirpath)
174
174
175 # skip removed repos
175 # skip removed repos
176 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
176 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
177 continue
177 continue
178
178
179 #skip .<somethin> dirs
179 #skip .<somethin> dirs
180 if dirpath.startswith('.'):
180 if dirpath.startswith('.'):
181 continue
181 continue
182
182
183 try:
183 try:
184 scm_info = get_scm(cur_path)
184 scm_info = get_scm(cur_path)
185 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
185 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
186 except VCSError:
186 except VCSError:
187 if not recursive:
187 if not recursive:
188 continue
188 continue
189 #check if this dir containts other repos for recursive scan
189 #check if this dir containts other repos for recursive scan
190 rec_path = os.path.join(p, dirpath)
190 rec_path = os.path.join(p, dirpath)
191 if os.path.isdir(rec_path):
191 if os.path.isdir(rec_path):
192 for inner_scm in _get_repos(rec_path):
192 for inner_scm in _get_repos(rec_path):
193 yield inner_scm
193 yield inner_scm
194
194
195 return _get_repos(path)
195 return _get_repos(path)
196
196
197
197
198 def _get_dirpaths(p):
198 def _get_dirpaths(p):
199 try:
199 try:
200 # OS-independable way of checking if we have at least read-only
200 # OS-independable way of checking if we have at least read-only
201 # access or not.
201 # access or not.
202 dirpaths = os.listdir(p)
202 dirpaths = os.listdir(p)
203 except OSError:
203 except OSError:
204 log.warning('ignoring repo path without read access: %s', p)
204 log.warning('ignoring repo path without read access: %s', p)
205 return []
205 return []
206
206
207 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
207 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
208 # decode paths and suddenly returns unicode objects itself. The items it
208 # decode paths and suddenly returns unicode objects itself. The items it
209 # cannot decode are returned as strings and cause issues.
209 # cannot decode are returned as strings and cause issues.
210 #
210 #
211 # Those paths are ignored here until a solid solution for path handling has
211 # Those paths are ignored here until a solid solution for path handling has
212 # been built.
212 # been built.
213 expected_type = type(p)
213 expected_type = type(p)
214
214
215 def _has_correct_type(item):
215 def _has_correct_type(item):
216 if type(item) is not expected_type:
216 if type(item) is not expected_type:
217 log.error(
217 log.error(
218 u"Ignoring path %s since it cannot be decoded into unicode.",
218 "Ignoring path %s since it cannot be decoded into unicode.",
219 # Using "repr" to make sure that we see the byte value in case
219 # Using "repr" to make sure that we see the byte value in case
220 # of support.
220 # of support.
221 repr(item))
221 repr(item))
222 return False
222 return False
223 return True
223 return True
224
224
225 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
225 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
226
226
227 return dirpaths
227 return dirpaths
228
228
229
229
230 def _is_dir_writable(path):
230 def _is_dir_writable(path):
231 """
231 """
232 Probe if `path` is writable.
232 Probe if `path` is writable.
233
233
234 Due to trouble on Cygwin / Windows, this is actually probing if it is
234 Due to trouble on Cygwin / Windows, this is actually probing if it is
235 possible to create a file inside of `path`, stat does not produce reliable
235 possible to create a file inside of `path`, stat does not produce reliable
236 results in this case.
236 results in this case.
237 """
237 """
238 try:
238 try:
239 with tempfile.TemporaryFile(dir=path):
239 with tempfile.TemporaryFile(dir=path):
240 pass
240 pass
241 except OSError:
241 except OSError:
242 return False
242 return False
243 return True
243 return True
244
244
245
245
246 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
246 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
247 """
247 """
248 Returns True if given path is a valid repository False otherwise.
248 Returns True if given path is a valid repository False otherwise.
249 If expect_scm param is given also, compare if given scm is the same
249 If expect_scm param is given also, compare if given scm is the same
250 as expected from scm parameter. If explicit_scm is given don't try to
250 as expected from scm parameter. If explicit_scm is given don't try to
251 detect the scm, just use the given one to check if repo is valid
251 detect the scm, just use the given one to check if repo is valid
252
252
253 :param repo_name:
253 :param repo_name:
254 :param base_path:
254 :param base_path:
255 :param expect_scm:
255 :param expect_scm:
256 :param explicit_scm:
256 :param explicit_scm:
257 :param config:
257 :param config:
258
258
259 :return True: if given path is a valid repository
259 :return True: if given path is a valid repository
260 """
260 """
261 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
261 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
262 log.debug('Checking if `%s` is a valid path for repository. '
262 log.debug('Checking if `%s` is a valid path for repository. '
263 'Explicit type: %s', repo_name, explicit_scm)
263 'Explicit type: %s', repo_name, explicit_scm)
264
264
265 try:
265 try:
266 if explicit_scm:
266 if explicit_scm:
267 detected_scms = [get_scm_backend(explicit_scm)(
267 detected_scms = [get_scm_backend(explicit_scm)(
268 full_path, config=config).alias]
268 full_path, config=config).alias]
269 else:
269 else:
270 detected_scms = get_scm(full_path)
270 detected_scms = get_scm(full_path)
271
271
272 if expect_scm:
272 if expect_scm:
273 return detected_scms[0] == expect_scm
273 return detected_scms[0] == expect_scm
274 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
274 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
275 return True
275 return True
276 except VCSError:
276 except VCSError:
277 log.debug('path: %s is not a valid repo !', full_path)
277 log.debug('path: %s is not a valid repo !', full_path)
278 return False
278 return False
279
279
280
280
281 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
281 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
282 """
282 """
283 Returns True if given path is a repository group, False otherwise
283 Returns True if given path is a repository group, False otherwise
284
284
285 :param repo_name:
285 :param repo_name:
286 :param base_path:
286 :param base_path:
287 """
287 """
288 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
288 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
289 log.debug('Checking if `%s` is a valid path for repository group',
289 log.debug('Checking if `%s` is a valid path for repository group',
290 repo_group_name)
290 repo_group_name)
291
291
292 # check if it's not a repo
292 # check if it's not a repo
293 if is_valid_repo(repo_group_name, base_path):
293 if is_valid_repo(repo_group_name, base_path):
294 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
294 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
295 return False
295 return False
296
296
297 try:
297 try:
298 # we need to check bare git repos at higher level
298 # we need to check bare git repos at higher level
299 # since we might match branches/hooks/info/objects or possible
299 # since we might match branches/hooks/info/objects or possible
300 # other things inside bare git repo
300 # other things inside bare git repo
301 maybe_repo = os.path.dirname(full_path)
301 maybe_repo = os.path.dirname(full_path)
302 if maybe_repo == base_path:
302 if maybe_repo == base_path:
303 # skip root level repo check, we know root location CANNOT BE a repo group
303 # skip root level repo check, we know root location CANNOT BE a repo group
304 return False
304 return False
305
305
306 scm_ = get_scm(maybe_repo)
306 scm_ = get_scm(maybe_repo)
307 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
307 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
308 return False
308 return False
309 except VCSError:
309 except VCSError:
310 pass
310 pass
311
311
312 # check if it's a valid path
312 # check if it's a valid path
313 if skip_path_check or os.path.isdir(full_path):
313 if skip_path_check or os.path.isdir(full_path):
314 log.debug('path: %s is a valid repo group !', full_path)
314 log.debug('path: %s is a valid repo group !', full_path)
315 return True
315 return True
316
316
317 log.debug('path: %s is not a valid repo group !', full_path)
317 log.debug('path: %s is not a valid repo group !', full_path)
318 return False
318 return False
319
319
320
320
321 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
321 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
322 while True:
322 while True:
323 ok = raw_input(prompt)
323 ok = eval(input(prompt))
324 if ok.lower() in ('y', 'ye', 'yes'):
324 if ok.lower() in ('y', 'ye', 'yes'):
325 return True
325 return True
326 if ok.lower() in ('n', 'no', 'nop', 'nope'):
326 if ok.lower() in ('n', 'no', 'nop', 'nope'):
327 return False
327 return False
328 retries = retries - 1
328 retries = retries - 1
329 if retries < 0:
329 if retries < 0:
330 raise IOError
330 raise IOError
331 print(complaint)
331 print(complaint)
332
332
333 # propagated from mercurial documentation
333 # propagated from mercurial documentation
334 ui_sections = [
334 ui_sections = [
335 'alias', 'auth',
335 'alias', 'auth',
336 'decode/encode', 'defaults',
336 'decode/encode', 'defaults',
337 'diff', 'email',
337 'diff', 'email',
338 'extensions', 'format',
338 'extensions', 'format',
339 'merge-patterns', 'merge-tools',
339 'merge-patterns', 'merge-tools',
340 'hooks', 'http_proxy',
340 'hooks', 'http_proxy',
341 'smtp', 'patch',
341 'smtp', 'patch',
342 'paths', 'profiling',
342 'paths', 'profiling',
343 'server', 'trusted',
343 'server', 'trusted',
344 'ui', 'web', ]
344 'ui', 'web', ]
345
345
346
346
347 def config_data_from_db(clear_session=True, repo=None):
347 def config_data_from_db(clear_session=True, repo=None):
348 """
348 """
349 Read the configuration data from the database and return configuration
349 Read the configuration data from the database and return configuration
350 tuples.
350 tuples.
351 """
351 """
352 from rhodecode.model.settings import VcsSettingsModel
352 from rhodecode.model.settings import VcsSettingsModel
353
353
354 config = []
354 config = []
355
355
356 sa = meta.Session()
356 sa = meta.Session()
357 settings_model = VcsSettingsModel(repo=repo, sa=sa)
357 settings_model = VcsSettingsModel(repo=repo, sa=sa)
358
358
359 ui_settings = settings_model.get_ui_settings()
359 ui_settings = settings_model.get_ui_settings()
360
360
361 ui_data = []
361 ui_data = []
362 for setting in ui_settings:
362 for setting in ui_settings:
363 if setting.active:
363 if setting.active:
364 ui_data.append((setting.section, setting.key, setting.value))
364 ui_data.append((setting.section, setting.key, setting.value))
365 config.append((
365 config.append((
366 safe_str(setting.section), safe_str(setting.key),
366 safe_str(setting.section), safe_str(setting.key),
367 safe_str(setting.value)))
367 safe_str(setting.value)))
368 if setting.key == 'push_ssl':
368 if setting.key == 'push_ssl':
369 # force set push_ssl requirement to False, rhodecode
369 # force set push_ssl requirement to False, rhodecode
370 # handles that
370 # handles that
371 config.append((
371 config.append((
372 safe_str(setting.section), safe_str(setting.key), False))
372 safe_str(setting.section), safe_str(setting.key), False))
373 log.debug(
373 log.debug(
374 'settings ui from db@repo[%s]: %s',
374 'settings ui from db@repo[%s]: %s',
375 repo,
375 repo,
376 ','.join(map(lambda s: '[{}] {}={}'.format(*s), ui_data)))
376 ','.join(map(lambda s: '[{}] {}={}'.format(*s), ui_data)))
377 if clear_session:
377 if clear_session:
378 meta.Session.remove()
378 meta.Session.remove()
379
379
380 # TODO: mikhail: probably it makes no sense to re-read hooks information.
380 # TODO: mikhail: probably it makes no sense to re-read hooks information.
381 # It's already there and activated/deactivated
381 # It's already there and activated/deactivated
382 skip_entries = []
382 skip_entries = []
383 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
383 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
384 if 'pull' not in enabled_hook_classes:
384 if 'pull' not in enabled_hook_classes:
385 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
385 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
386 if 'push' not in enabled_hook_classes:
386 if 'push' not in enabled_hook_classes:
387 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
387 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
388 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
388 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
389 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
389 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
390
390
391 config = [entry for entry in config if entry[:2] not in skip_entries]
391 config = [entry for entry in config if entry[:2] not in skip_entries]
392
392
393 return config
393 return config
394
394
395
395
396 def make_db_config(clear_session=True, repo=None):
396 def make_db_config(clear_session=True, repo=None):
397 """
397 """
398 Create a :class:`Config` instance based on the values in the database.
398 Create a :class:`Config` instance based on the values in the database.
399 """
399 """
400 config = Config()
400 config = Config()
401 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
401 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
402 for section, option, value in config_data:
402 for section, option, value in config_data:
403 config.set(section, option, value)
403 config.set(section, option, value)
404 return config
404 return config
405
405
406
406
407 def get_enabled_hook_classes(ui_settings):
407 def get_enabled_hook_classes(ui_settings):
408 """
408 """
409 Return the enabled hook classes.
409 Return the enabled hook classes.
410
410
411 :param ui_settings: List of ui_settings as returned
411 :param ui_settings: List of ui_settings as returned
412 by :meth:`VcsSettingsModel.get_ui_settings`
412 by :meth:`VcsSettingsModel.get_ui_settings`
413
413
414 :return: a list with the enabled hook classes. The order is not guaranteed.
414 :return: a list with the enabled hook classes. The order is not guaranteed.
415 :rtype: list
415 :rtype: list
416 """
416 """
417 enabled_hooks = []
417 enabled_hooks = []
418 active_hook_keys = [
418 active_hook_keys = [
419 key for section, key, value, active in ui_settings
419 key for section, key, value, active in ui_settings
420 if section == 'hooks' and active]
420 if section == 'hooks' and active]
421
421
422 hook_names = {
422 hook_names = {
423 RhodeCodeUi.HOOK_PUSH: 'push',
423 RhodeCodeUi.HOOK_PUSH: 'push',
424 RhodeCodeUi.HOOK_PULL: 'pull',
424 RhodeCodeUi.HOOK_PULL: 'pull',
425 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
425 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
426 }
426 }
427
427
428 for key in active_hook_keys:
428 for key in active_hook_keys:
429 hook = hook_names.get(key)
429 hook = hook_names.get(key)
430 if hook:
430 if hook:
431 enabled_hooks.append(hook)
431 enabled_hooks.append(hook)
432
432
433 return enabled_hooks
433 return enabled_hooks
434
434
435
435
436 def set_rhodecode_config(config):
436 def set_rhodecode_config(config):
437 """
437 """
438 Updates pyramid config with new settings from database
438 Updates pyramid config with new settings from database
439
439
440 :param config:
440 :param config:
441 """
441 """
442 from rhodecode.model.settings import SettingsModel
442 from rhodecode.model.settings import SettingsModel
443 app_settings = SettingsModel().get_all_settings()
443 app_settings = SettingsModel().get_all_settings()
444
444
445 for k, v in app_settings.items():
445 for k, v in app_settings.items():
446 config[k] = v
446 config[k] = v
447
447
448
448
449 def get_rhodecode_realm():
449 def get_rhodecode_realm():
450 """
450 """
451 Return the rhodecode realm from database.
451 Return the rhodecode realm from database.
452 """
452 """
453 from rhodecode.model.settings import SettingsModel
453 from rhodecode.model.settings import SettingsModel
454 realm = SettingsModel().get_setting_by_name('realm')
454 realm = SettingsModel().get_setting_by_name('realm')
455 return safe_str(realm.app_settings_value)
455 return safe_str(realm.app_settings_value)
456
456
457
457
458 def get_rhodecode_base_path():
458 def get_rhodecode_base_path():
459 """
459 """
460 Returns the base path. The base path is the filesystem path which points
460 Returns the base path. The base path is the filesystem path which points
461 to the repository store.
461 to the repository store.
462 """
462 """
463 from rhodecode.model.settings import SettingsModel
463 from rhodecode.model.settings import SettingsModel
464 paths_ui = SettingsModel().get_ui_by_section_and_key('paths', '/')
464 paths_ui = SettingsModel().get_ui_by_section_and_key('paths', '/')
465 return safe_str(paths_ui.ui_value)
465 return safe_str(paths_ui.ui_value)
466
466
467
467
468 def map_groups(path):
468 def map_groups(path):
469 """
469 """
470 Given a full path to a repository, create all nested groups that this
470 Given a full path to a repository, create all nested groups that this
471 repo is inside. This function creates parent-child relationships between
471 repo is inside. This function creates parent-child relationships between
472 groups and creates default perms for all new groups.
472 groups and creates default perms for all new groups.
473
473
474 :param paths: full path to repository
474 :param paths: full path to repository
475 """
475 """
476 from rhodecode.model.repo_group import RepoGroupModel
476 from rhodecode.model.repo_group import RepoGroupModel
477 sa = meta.Session()
477 sa = meta.Session()
478 groups = path.split(Repository.NAME_SEP)
478 groups = path.split(Repository.NAME_SEP)
479 parent = None
479 parent = None
480 group = None
480 group = None
481
481
482 # last element is repo in nested groups structure
482 # last element is repo in nested groups structure
483 groups = groups[:-1]
483 groups = groups[:-1]
484 rgm = RepoGroupModel(sa)
484 rgm = RepoGroupModel(sa)
485 owner = User.get_first_super_admin()
485 owner = User.get_first_super_admin()
486 for lvl, group_name in enumerate(groups):
486 for lvl, group_name in enumerate(groups):
487 group_name = '/'.join(groups[:lvl] + [group_name])
487 group_name = '/'.join(groups[:lvl] + [group_name])
488 group = RepoGroup.get_by_group_name(group_name)
488 group = RepoGroup.get_by_group_name(group_name)
489 desc = '%s group' % group_name
489 desc = '%s group' % group_name
490
490
491 # skip folders that are now removed repos
491 # skip folders that are now removed repos
492 if REMOVED_REPO_PAT.match(group_name):
492 if REMOVED_REPO_PAT.match(group_name):
493 break
493 break
494
494
495 if group is None:
495 if group is None:
496 log.debug('creating group level: %s group_name: %s',
496 log.debug('creating group level: %s group_name: %s',
497 lvl, group_name)
497 lvl, group_name)
498 group = RepoGroup(group_name, parent)
498 group = RepoGroup(group_name, parent)
499 group.group_description = desc
499 group.group_description = desc
500 group.user = owner
500 group.user = owner
501 sa.add(group)
501 sa.add(group)
502 perm_obj = rgm._create_default_perms(group)
502 perm_obj = rgm._create_default_perms(group)
503 sa.add(perm_obj)
503 sa.add(perm_obj)
504 sa.flush()
504 sa.flush()
505
505
506 parent = group
506 parent = group
507 return group
507 return group
508
508
509
509
510 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
510 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
511 """
511 """
512 maps all repos given in initial_repo_list, non existing repositories
512 maps all repos given in initial_repo_list, non existing repositories
513 are created, if remove_obsolete is True it also checks for db entries
513 are created, if remove_obsolete is True it also checks for db entries
514 that are not in initial_repo_list and removes them.
514 that are not in initial_repo_list and removes them.
515
515
516 :param initial_repo_list: list of repositories found by scanning methods
516 :param initial_repo_list: list of repositories found by scanning methods
517 :param remove_obsolete: check for obsolete entries in database
517 :param remove_obsolete: check for obsolete entries in database
518 """
518 """
519 from rhodecode.model.repo import RepoModel
519 from rhodecode.model.repo import RepoModel
520 from rhodecode.model.repo_group import RepoGroupModel
520 from rhodecode.model.repo_group import RepoGroupModel
521 from rhodecode.model.settings import SettingsModel
521 from rhodecode.model.settings import SettingsModel
522
522
523 sa = meta.Session()
523 sa = meta.Session()
524 repo_model = RepoModel()
524 repo_model = RepoModel()
525 user = User.get_first_super_admin()
525 user = User.get_first_super_admin()
526 added = []
526 added = []
527
527
528 # creation defaults
528 # creation defaults
529 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
529 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
530 enable_statistics = defs.get('repo_enable_statistics')
530 enable_statistics = defs.get('repo_enable_statistics')
531 enable_locking = defs.get('repo_enable_locking')
531 enable_locking = defs.get('repo_enable_locking')
532 enable_downloads = defs.get('repo_enable_downloads')
532 enable_downloads = defs.get('repo_enable_downloads')
533 private = defs.get('repo_private')
533 private = defs.get('repo_private')
534
534
535 for name, repo in initial_repo_list.items():
535 for name, repo in initial_repo_list.items():
536 group = map_groups(name)
536 group = map_groups(name)
537 unicode_name = safe_unicode(name)
537 unicode_name = safe_unicode(name)
538 db_repo = repo_model.get_by_repo_name(unicode_name)
538 db_repo = repo_model.get_by_repo_name(unicode_name)
539 # found repo that is on filesystem not in RhodeCode database
539 # found repo that is on filesystem not in RhodeCode database
540 if not db_repo:
540 if not db_repo:
541 log.info('repository %s not found, creating now', name)
541 log.info('repository %s not found, creating now', name)
542 added.append(name)
542 added.append(name)
543 desc = (repo.description
543 desc = (repo.description
544 if repo.description != 'unknown'
544 if repo.description != 'unknown'
545 else '%s repository' % name)
545 else '%s repository' % name)
546
546
547 db_repo = repo_model._create_repo(
547 db_repo = repo_model._create_repo(
548 repo_name=name,
548 repo_name=name,
549 repo_type=repo.alias,
549 repo_type=repo.alias,
550 description=desc,
550 description=desc,
551 repo_group=getattr(group, 'group_id', None),
551 repo_group=getattr(group, 'group_id', None),
552 owner=user,
552 owner=user,
553 enable_locking=enable_locking,
553 enable_locking=enable_locking,
554 enable_downloads=enable_downloads,
554 enable_downloads=enable_downloads,
555 enable_statistics=enable_statistics,
555 enable_statistics=enable_statistics,
556 private=private,
556 private=private,
557 state=Repository.STATE_CREATED
557 state=Repository.STATE_CREATED
558 )
558 )
559 sa.commit()
559 sa.commit()
560 # we added that repo just now, and make sure we updated server info
560 # we added that repo just now, and make sure we updated server info
561 if db_repo.repo_type == 'git':
561 if db_repo.repo_type == 'git':
562 git_repo = db_repo.scm_instance()
562 git_repo = db_repo.scm_instance()
563 # update repository server-info
563 # update repository server-info
564 log.debug('Running update server info')
564 log.debug('Running update server info')
565 git_repo._update_server_info()
565 git_repo._update_server_info()
566
566
567 db_repo.update_commit_cache()
567 db_repo.update_commit_cache()
568
568
569 config = db_repo._config
569 config = db_repo._config
570 config.set('extensions', 'largefiles', '')
570 config.set('extensions', 'largefiles', '')
571 repo = db_repo.scm_instance(config=config)
571 repo = db_repo.scm_instance(config=config)
572 repo.install_hooks()
572 repo.install_hooks()
573
573
574 removed = []
574 removed = []
575 if remove_obsolete:
575 if remove_obsolete:
576 # remove from database those repositories that are not in the filesystem
576 # remove from database those repositories that are not in the filesystem
577 for repo in sa.query(Repository).all():
577 for repo in sa.query(Repository).all():
578 if repo.repo_name not in initial_repo_list.keys():
578 if repo.repo_name not in initial_repo_list.keys():
579 log.debug("Removing non-existing repository found in db `%s`",
579 log.debug("Removing non-existing repository found in db `%s`",
580 repo.repo_name)
580 repo.repo_name)
581 try:
581 try:
582 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
582 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
583 sa.commit()
583 sa.commit()
584 removed.append(repo.repo_name)
584 removed.append(repo.repo_name)
585 except Exception:
585 except Exception:
586 # don't hold further removals on error
586 # don't hold further removals on error
587 log.error(traceback.format_exc())
587 log.error(traceback.format_exc())
588 sa.rollback()
588 sa.rollback()
589
589
590 def splitter(full_repo_name):
590 def splitter(full_repo_name):
591 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
591 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
592 gr_name = None
592 gr_name = None
593 if len(_parts) == 2:
593 if len(_parts) == 2:
594 gr_name = _parts[0]
594 gr_name = _parts[0]
595 return gr_name
595 return gr_name
596
596
597 initial_repo_group_list = [splitter(x) for x in
597 initial_repo_group_list = [splitter(x) for x in
598 initial_repo_list.keys() if splitter(x)]
598 initial_repo_list.keys() if splitter(x)]
599
599
600 # remove from database those repository groups that are not in the
600 # remove from database those repository groups that are not in the
601 # filesystem due to parent child relationships we need to delete them
601 # filesystem due to parent child relationships we need to delete them
602 # in a specific order of most nested first
602 # in a specific order of most nested first
603 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
603 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
604 nested_sort = lambda gr: len(gr.split('/'))
604 nested_sort = lambda gr: len(gr.split('/'))
605 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
605 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
606 if group_name not in initial_repo_group_list:
606 if group_name not in initial_repo_group_list:
607 repo_group = RepoGroup.get_by_group_name(group_name)
607 repo_group = RepoGroup.get_by_group_name(group_name)
608 if (repo_group.children.all() or
608 if (repo_group.children.all() or
609 not RepoGroupModel().check_exist_filesystem(
609 not RepoGroupModel().check_exist_filesystem(
610 group_name=group_name, exc_on_failure=False)):
610 group_name=group_name, exc_on_failure=False)):
611 continue
611 continue
612
612
613 log.info(
613 log.info(
614 'Removing non-existing repository group found in db `%s`',
614 'Removing non-existing repository group found in db `%s`',
615 group_name)
615 group_name)
616 try:
616 try:
617 RepoGroupModel(sa).delete(group_name, fs_remove=False)
617 RepoGroupModel(sa).delete(group_name, fs_remove=False)
618 sa.commit()
618 sa.commit()
619 removed.append(group_name)
619 removed.append(group_name)
620 except Exception:
620 except Exception:
621 # don't hold further removals on error
621 # don't hold further removals on error
622 log.exception(
622 log.exception(
623 'Unable to remove repository group `%s`',
623 'Unable to remove repository group `%s`',
624 group_name)
624 group_name)
625 sa.rollback()
625 sa.rollback()
626 raise
626 raise
627
627
628 return added, removed
628 return added, removed
629
629
630
630
631 def load_rcextensions(root_path):
631 def load_rcextensions(root_path):
632 import rhodecode
632 import rhodecode
633 from rhodecode.config import conf
633 from rhodecode.config import conf
634
634
635 path = os.path.join(root_path)
635 path = os.path.join(root_path)
636 sys.path.append(path)
636 sys.path.append(path)
637
637
638 try:
638 try:
639 rcextensions = __import__('rcextensions')
639 rcextensions = __import__('rcextensions')
640 except ImportError:
640 except ImportError:
641 if os.path.isdir(os.path.join(path, 'rcextensions')):
641 if os.path.isdir(os.path.join(path, 'rcextensions')):
642 log.warn('Unable to load rcextensions from %s', path)
642 log.warn('Unable to load rcextensions from %s', path)
643 rcextensions = None
643 rcextensions = None
644
644
645 if rcextensions:
645 if rcextensions:
646 log.info('Loaded rcextensions from %s...', rcextensions)
646 log.info('Loaded rcextensions from %s...', rcextensions)
647 rhodecode.EXTENSIONS = rcextensions
647 rhodecode.EXTENSIONS = rcextensions
648
648
649 # Additional mappings that are not present in the pygments lexers
649 # Additional mappings that are not present in the pygments lexers
650 conf.LANGUAGES_EXTENSIONS_MAP.update(
650 conf.LANGUAGES_EXTENSIONS_MAP.update(
651 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
651 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
652
652
653
653
654 def get_custom_lexer(extension):
654 def get_custom_lexer(extension):
655 """
655 """
656 returns a custom lexer if it is defined in rcextensions module, or None
656 returns a custom lexer if it is defined in rcextensions module, or None
657 if there's no custom lexer defined
657 if there's no custom lexer defined
658 """
658 """
659 import rhodecode
659 import rhodecode
660 from pygments import lexers
660 from pygments import lexers
661
661
662 # custom override made by RhodeCode
662 # custom override made by RhodeCode
663 if extension in ['mako']:
663 if extension in ['mako']:
664 return lexers.get_lexer_by_name('html+mako')
664 return lexers.get_lexer_by_name('html+mako')
665
665
666 # check if we didn't define this extension as other lexer
666 # check if we didn't define this extension as other lexer
667 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
667 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
668 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
668 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
669 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
669 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
670 return lexers.get_lexer_by_name(_lexer_name)
670 return lexers.get_lexer_by_name(_lexer_name)
671
671
672
672
673 #==============================================================================
673 #==============================================================================
674 # TEST FUNCTIONS AND CREATORS
674 # TEST FUNCTIONS AND CREATORS
675 #==============================================================================
675 #==============================================================================
676 def create_test_index(repo_location, config):
676 def create_test_index(repo_location, config):
677 """
677 """
678 Makes default test index.
678 Makes default test index.
679 """
679 """
680 import rc_testdata
680 import rc_testdata
681
681
682 rc_testdata.extract_search_index(
682 rc_testdata.extract_search_index(
683 'vcs_search_index', os.path.dirname(config['search.location']))
683 'vcs_search_index', os.path.dirname(config['search.location']))
684
684
685
685
686 def create_test_directory(test_path):
686 def create_test_directory(test_path):
687 """
687 """
688 Create test directory if it doesn't exist.
688 Create test directory if it doesn't exist.
689 """
689 """
690 if not os.path.isdir(test_path):
690 if not os.path.isdir(test_path):
691 log.debug('Creating testdir %s', test_path)
691 log.debug('Creating testdir %s', test_path)
692 os.makedirs(test_path)
692 os.makedirs(test_path)
693
693
694
694
695 def create_test_database(test_path, config):
695 def create_test_database(test_path, config):
696 """
696 """
697 Makes a fresh database.
697 Makes a fresh database.
698 """
698 """
699 from rhodecode.lib.db_manage import DbManage
699 from rhodecode.lib.db_manage import DbManage
700
700
701 # PART ONE create db
701 # PART ONE create db
702 dbconf = config['sqlalchemy.db1.url']
702 dbconf = config['sqlalchemy.db1.url']
703 log.debug('making test db %s', dbconf)
703 log.debug('making test db %s', dbconf)
704
704
705 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
705 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
706 tests=True, cli_args={'force_ask': True})
706 tests=True, cli_args={'force_ask': True})
707 dbmanage.create_tables(override=True)
707 dbmanage.create_tables(override=True)
708 dbmanage.set_db_version()
708 dbmanage.set_db_version()
709 # for tests dynamically set new root paths based on generated content
709 # for tests dynamically set new root paths based on generated content
710 dbmanage.create_settings(dbmanage.config_prompt(test_path))
710 dbmanage.create_settings(dbmanage.config_prompt(test_path))
711 dbmanage.create_default_user()
711 dbmanage.create_default_user()
712 dbmanage.create_test_admin_and_users()
712 dbmanage.create_test_admin_and_users()
713 dbmanage.create_permissions()
713 dbmanage.create_permissions()
714 dbmanage.populate_default_permissions()
714 dbmanage.populate_default_permissions()
715 Session().commit()
715 Session().commit()
716
716
717
717
718 def create_test_repositories(test_path, config):
718 def create_test_repositories(test_path, config):
719 """
719 """
720 Creates test repositories in the temporary directory. Repositories are
720 Creates test repositories in the temporary directory. Repositories are
721 extracted from archives within the rc_testdata package.
721 extracted from archives within the rc_testdata package.
722 """
722 """
723 import rc_testdata
723 import rc_testdata
724 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
724 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
725
725
726 log.debug('making test vcs repositories')
726 log.debug('making test vcs repositories')
727
727
728 idx_path = config['search.location']
728 idx_path = config['search.location']
729 data_path = config['cache_dir']
729 data_path = config['cache_dir']
730
730
731 # clean index and data
731 # clean index and data
732 if idx_path and os.path.exists(idx_path):
732 if idx_path and os.path.exists(idx_path):
733 log.debug('remove %s', idx_path)
733 log.debug('remove %s', idx_path)
734 shutil.rmtree(idx_path)
734 shutil.rmtree(idx_path)
735
735
736 if data_path and os.path.exists(data_path):
736 if data_path and os.path.exists(data_path):
737 log.debug('remove %s', data_path)
737 log.debug('remove %s', data_path)
738 shutil.rmtree(data_path)
738 shutil.rmtree(data_path)
739
739
740 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
740 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
741 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
741 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
742
742
743 # Note: Subversion is in the process of being integrated with the system,
743 # Note: Subversion is in the process of being integrated with the system,
744 # until we have a properly packed version of the test svn repository, this
744 # until we have a properly packed version of the test svn repository, this
745 # tries to copy over the repo from a package "rc_testdata"
745 # tries to copy over the repo from a package "rc_testdata"
746 svn_repo_path = rc_testdata.get_svn_repo_archive()
746 svn_repo_path = rc_testdata.get_svn_repo_archive()
747 with tarfile.open(svn_repo_path) as tar:
747 with tarfile.open(svn_repo_path) as tar:
748 tar.extractall(jn(test_path, SVN_REPO))
748 tar.extractall(jn(test_path, SVN_REPO))
749
749
750
750
751 def password_changed(auth_user, session):
751 def password_changed(auth_user, session):
752 # Never report password change in case of default user or anonymous user.
752 # Never report password change in case of default user or anonymous user.
753 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
753 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
754 return False
754 return False
755
755
756 password_hash = md5(auth_user.password) if auth_user.password else None
756 password_hash = md5(auth_user.password) if auth_user.password else None
757 rhodecode_user = session.get('rhodecode_user', {})
757 rhodecode_user = session.get('rhodecode_user', {})
758 session_password_hash = rhodecode_user.get('password', '')
758 session_password_hash = rhodecode_user.get('password', '')
759 return password_hash != session_password_hash
759 return password_hash != session_password_hash
760
760
761
761
762 def read_opensource_licenses():
762 def read_opensource_licenses():
763 global _license_cache
763 global _license_cache
764
764
765 if not _license_cache:
765 if not _license_cache:
766 licenses = pkg_resources.resource_string(
766 licenses = pkg_resources.resource_string(
767 'rhodecode', 'config/licenses.json')
767 'rhodecode', 'config/licenses.json')
768 _license_cache = json.loads(licenses)
768 _license_cache = json.loads(licenses)
769
769
770 return _license_cache
770 return _license_cache
771
771
772
772
773 def generate_platform_uuid():
773 def generate_platform_uuid():
774 """
774 """
775 Generates platform UUID based on it's name
775 Generates platform UUID based on it's name
776 """
776 """
777 import platform
777 import platform
778
778
779 try:
779 try:
780 uuid_list = [platform.platform()]
780 uuid_list = [platform.platform()]
781 return hashlib.sha256(':'.join(uuid_list)).hexdigest()
781 return hashlib.sha256(':'.join(uuid_list)).hexdigest()
782 except Exception as e:
782 except Exception as e:
783 log.error('Failed to generate host uuid: %s', e)
783 log.error('Failed to generate host uuid: %s', e)
784 return 'UNDEFINED'
784 return 'UNDEFINED'
785
785
786
786
787 def send_test_email(recipients, email_body='TEST EMAIL'):
787 def send_test_email(recipients, email_body='TEST EMAIL'):
788 """
788 """
789 Simple code for generating test emails.
789 Simple code for generating test emails.
790 Usage::
790 Usage::
791
791
792 from rhodecode.lib import utils
792 from rhodecode.lib import utils
793 utils.send_test_email()
793 utils.send_test_email()
794 """
794 """
795 from rhodecode.lib.celerylib import tasks, run_task
795 from rhodecode.lib.celerylib import tasks, run_task
796
796
797 email_body = email_body_plaintext = email_body
797 email_body = email_body_plaintext = email_body
798 subject = 'SUBJECT FROM: {}'.format(socket.gethostname())
798 subject = 'SUBJECT FROM: {}'.format(socket.gethostname())
799 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
799 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
@@ -1,1045 +1,1045 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Some simple helper functions
23 Some simple helper functions
24 """
24 """
25
25
26 import collections
26 import collections
27 import datetime
27 import datetime
28 import dateutil.relativedelta
28 import dateutil.relativedelta
29 import logging
29 import logging
30 import re
30 import re
31 import sys
31 import sys
32 import time
32 import time
33 import urllib.request, urllib.parse, urllib.error
33 import urllib.request, urllib.parse, urllib.error
34 import urlobject
34 import urlobject
35 import uuid
35 import uuid
36 import getpass
36 import getpass
37 import socket
37 import socket
38 import errno
38 import errno
39 import random
39 import random
40 from functools import update_wrapper, partial, wraps
40 from functools import update_wrapper, partial, wraps
41 from contextlib import closing
41 from contextlib import closing
42
42
43 import pygments.lexers
43 import pygments.lexers
44 import sqlalchemy
44 import sqlalchemy
45 import sqlalchemy.engine.url
45 import sqlalchemy.engine.url
46 import sqlalchemy.exc
46 import sqlalchemy.exc
47 import sqlalchemy.sql
47 import sqlalchemy.sql
48 import webob
48 import webob
49 import pyramid.threadlocal
49 import pyramid.threadlocal
50 from pyramid.settings import asbool
50 from pyramid.settings import asbool
51
51
52 import rhodecode
52 import rhodecode
53 from rhodecode.translation import _, _pluralize
53 from rhodecode.translation import _, _pluralize
54 from rhodecode.lib.str_utils import safe_str, safe_int, safe_bytes
54 from rhodecode.lib.str_utils import safe_str, safe_int, safe_bytes
55 from rhodecode.lib.hash_utils import md5, md5_safe, sha1, sha1_safe
55 from rhodecode.lib.hash_utils import md5, md5_safe, sha1, sha1_safe
56 from rhodecode.lib.type_utils import aslist, str2bool
56 from rhodecode.lib.type_utils import aslist, str2bool
57 from functools import reduce
57 from functools import reduce
58
58
59 #TODO: there's no longer safe_unicode, we mock it now, but should remove it
59 #TODO: there's no longer safe_unicode, we mock it now, but should remove it
60 safe_unicode = safe_str
60 safe_unicode = safe_str
61
61
62
62
63 def __get_lem(extra_mapping=None):
63 def __get_lem(extra_mapping=None):
64 """
64 """
65 Get language extension map based on what's inside pygments lexers
65 Get language extension map based on what's inside pygments lexers
66 """
66 """
67 d = collections.defaultdict(lambda: [])
67 d = collections.defaultdict(lambda: [])
68
68
69 def __clean(s):
69 def __clean(s):
70 s = s.lstrip('*')
70 s = s.lstrip('*')
71 s = s.lstrip('.')
71 s = s.lstrip('.')
72
72
73 if s.find('[') != -1:
73 if s.find('[') != -1:
74 exts = []
74 exts = []
75 start, stop = s.find('['), s.find(']')
75 start, stop = s.find('['), s.find(']')
76
76
77 for suffix in s[start + 1:stop]:
77 for suffix in s[start + 1:stop]:
78 exts.append(s[:s.find('[')] + suffix)
78 exts.append(s[:s.find('[')] + suffix)
79 return [e.lower() for e in exts]
79 return [e.lower() for e in exts]
80 else:
80 else:
81 return [s.lower()]
81 return [s.lower()]
82
82
83 for lx, t in sorted(pygments.lexers.LEXERS.items()):
83 for lx, t in sorted(pygments.lexers.LEXERS.items()):
84 m = list(map(__clean, t[-2]))
84 m = list(map(__clean, t[-2]))
85 if m:
85 if m:
86 m = reduce(lambda x, y: x + y, m)
86 m = reduce(lambda x, y: x + y, m)
87 for ext in m:
87 for ext in m:
88 desc = lx.replace('Lexer', '')
88 desc = lx.replace('Lexer', '')
89 d[ext].append(desc)
89 d[ext].append(desc)
90
90
91 data = dict(d)
91 data = dict(d)
92
92
93 extra_mapping = extra_mapping or {}
93 extra_mapping = extra_mapping or {}
94 if extra_mapping:
94 if extra_mapping:
95 for k, v in extra_mapping.items():
95 for k, v in extra_mapping.items():
96 if k not in data:
96 if k not in data:
97 # register new mapping2lexer
97 # register new mapping2lexer
98 data[k] = [v]
98 data[k] = [v]
99
99
100 return data
100 return data
101
101
102
102
103 def convert_line_endings(line, mode):
103 def convert_line_endings(line, mode):
104 """
104 """
105 Converts a given line "line end" accordingly to given mode
105 Converts a given line "line end" accordingly to given mode
106
106
107 Available modes are::
107 Available modes are::
108 0 - Unix
108 0 - Unix
109 1 - Mac
109 1 - Mac
110 2 - DOS
110 2 - DOS
111
111
112 :param line: given line to convert
112 :param line: given line to convert
113 :param mode: mode to convert to
113 :param mode: mode to convert to
114 :rtype: str
114 :rtype: str
115 :return: converted line according to mode
115 :return: converted line according to mode
116 """
116 """
117 if mode == 0:
117 if mode == 0:
118 line = line.replace('\r\n', '\n')
118 line = line.replace('\r\n', '\n')
119 line = line.replace('\r', '\n')
119 line = line.replace('\r', '\n')
120 elif mode == 1:
120 elif mode == 1:
121 line = line.replace('\r\n', '\r')
121 line = line.replace('\r\n', '\r')
122 line = line.replace('\n', '\r')
122 line = line.replace('\n', '\r')
123 elif mode == 2:
123 elif mode == 2:
124 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
124 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
125 return line
125 return line
126
126
127
127
128 def detect_mode(line, default):
128 def detect_mode(line, default):
129 """
129 """
130 Detects line break for given line, if line break couldn't be found
130 Detects line break for given line, if line break couldn't be found
131 given default value is returned
131 given default value is returned
132
132
133 :param line: str line
133 :param line: str line
134 :param default: default
134 :param default: default
135 :rtype: int
135 :rtype: int
136 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
136 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
137 """
137 """
138 if line.endswith('\r\n'):
138 if line.endswith('\r\n'):
139 return 2
139 return 2
140 elif line.endswith('\n'):
140 elif line.endswith('\n'):
141 return 0
141 return 0
142 elif line.endswith('\r'):
142 elif line.endswith('\r'):
143 return 1
143 return 1
144 else:
144 else:
145 return default
145 return default
146
146
147
147
148 def remove_suffix(s, suffix):
148 def remove_suffix(s, suffix):
149 if s.endswith(suffix):
149 if s.endswith(suffix):
150 s = s[:-1 * len(suffix)]
150 s = s[:-1 * len(suffix)]
151 return s
151 return s
152
152
153
153
154 def remove_prefix(s, prefix):
154 def remove_prefix(s, prefix):
155 if s.startswith(prefix):
155 if s.startswith(prefix):
156 s = s[len(prefix):]
156 s = s[len(prefix):]
157 return s
157 return s
158
158
159
159
160 def find_calling_context(ignore_modules=None):
160 def find_calling_context(ignore_modules=None):
161 """
161 """
162 Look through the calling stack and return the frame which called
162 Look through the calling stack and return the frame which called
163 this function and is part of core module ( ie. rhodecode.* )
163 this function and is part of core module ( ie. rhodecode.* )
164
164
165 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
165 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
166
166
167 usage::
167 usage::
168 from rhodecode.lib.utils2 import find_calling_context
168 from rhodecode.lib.utils2 import find_calling_context
169
169
170 calling_context = find_calling_context(ignore_modules=[
170 calling_context = find_calling_context(ignore_modules=[
171 'rhodecode.lib.caching_query',
171 'rhodecode.lib.caching_query',
172 'rhodecode.model.settings',
172 'rhodecode.model.settings',
173 ])
173 ])
174
174
175 if calling_context:
175 if calling_context:
176 cc_str = 'call context %s:%s' % (
176 cc_str = 'call context %s:%s' % (
177 calling_context.f_code.co_filename,
177 calling_context.f_code.co_filename,
178 calling_context.f_lineno,
178 calling_context.f_lineno,
179 )
179 )
180 print(cc_str)
180 print(cc_str)
181 """
181 """
182
182
183 ignore_modules = ignore_modules or []
183 ignore_modules = ignore_modules or []
184
184
185 f = sys._getframe(2)
185 f = sys._getframe(2)
186 while f.f_back is not None:
186 while f.f_back is not None:
187 name = f.f_globals.get('__name__')
187 name = f.f_globals.get('__name__')
188 if name and name.startswith(__name__.split('.')[0]):
188 if name and name.startswith(__name__.split('.')[0]):
189 if name not in ignore_modules:
189 if name not in ignore_modules:
190 return f
190 return f
191 f = f.f_back
191 f = f.f_back
192 return None
192 return None
193
193
194
194
195 def ping_connection(connection, branch):
195 def ping_connection(connection, branch):
196 if branch:
196 if branch:
197 # "branch" refers to a sub-connection of a connection,
197 # "branch" refers to a sub-connection of a connection,
198 # we don't want to bother pinging on these.
198 # we don't want to bother pinging on these.
199 return
199 return
200
200
201 # turn off "close with result". This flag is only used with
201 # turn off "close with result". This flag is only used with
202 # "connectionless" execution, otherwise will be False in any case
202 # "connectionless" execution, otherwise will be False in any case
203 save_should_close_with_result = connection.should_close_with_result
203 save_should_close_with_result = connection.should_close_with_result
204 connection.should_close_with_result = False
204 connection.should_close_with_result = False
205
205
206 try:
206 try:
207 # run a SELECT 1. use a core select() so that
207 # run a SELECT 1. use a core select() so that
208 # the SELECT of a scalar value without a table is
208 # the SELECT of a scalar value without a table is
209 # appropriately formatted for the backend
209 # appropriately formatted for the backend
210 connection.scalar(sqlalchemy.sql.select([1]))
210 connection.scalar(sqlalchemy.sql.select([1]))
211 except sqlalchemy.exc.DBAPIError as err:
211 except sqlalchemy.exc.DBAPIError as err:
212 # catch SQLAlchemy's DBAPIError, which is a wrapper
212 # catch SQLAlchemy's DBAPIError, which is a wrapper
213 # for the DBAPI's exception. It includes a .connection_invalidated
213 # for the DBAPI's exception. It includes a .connection_invalidated
214 # attribute which specifies if this connection is a "disconnect"
214 # attribute which specifies if this connection is a "disconnect"
215 # condition, which is based on inspection of the original exception
215 # condition, which is based on inspection of the original exception
216 # by the dialect in use.
216 # by the dialect in use.
217 if err.connection_invalidated:
217 if err.connection_invalidated:
218 # run the same SELECT again - the connection will re-validate
218 # run the same SELECT again - the connection will re-validate
219 # itself and establish a new connection. The disconnect detection
219 # itself and establish a new connection. The disconnect detection
220 # here also causes the whole connection pool to be invalidated
220 # here also causes the whole connection pool to be invalidated
221 # so that all stale connections are discarded.
221 # so that all stale connections are discarded.
222 connection.scalar(sqlalchemy.sql.select([1]))
222 connection.scalar(sqlalchemy.sql.select([1]))
223 else:
223 else:
224 raise
224 raise
225 finally:
225 finally:
226 # restore "close with result"
226 # restore "close with result"
227 connection.should_close_with_result = save_should_close_with_result
227 connection.should_close_with_result = save_should_close_with_result
228
228
229
229
230 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
230 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
231 """Custom engine_from_config functions."""
231 """Custom engine_from_config functions."""
232 log = logging.getLogger('sqlalchemy.engine')
232 log = logging.getLogger('sqlalchemy.engine')
233 use_ping_connection = asbool(configuration.pop('sqlalchemy.db1.ping_connection', None))
233 use_ping_connection = asbool(configuration.pop('sqlalchemy.db1.ping_connection', None))
234 debug = asbool(configuration.pop('sqlalchemy.db1.debug_query', None))
234 debug = asbool(configuration.pop('sqlalchemy.db1.debug_query', None))
235
235
236 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
236 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
237
237
238 def color_sql(sql):
238 def color_sql(sql):
239 color_seq = '\033[1;33m' # This is yellow: code 33
239 color_seq = '\033[1;33m' # This is yellow: code 33
240 normal = '\x1b[0m'
240 normal = '\x1b[0m'
241 return ''.join([color_seq, sql, normal])
241 return ''.join([color_seq, sql, normal])
242
242
243 if use_ping_connection:
243 if use_ping_connection:
244 log.debug('Adding ping_connection on the engine config.')
244 log.debug('Adding ping_connection on the engine config.')
245 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
245 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
246
246
247 if debug:
247 if debug:
248 # attach events only for debug configuration
248 # attach events only for debug configuration
249 def before_cursor_execute(conn, cursor, statement,
249 def before_cursor_execute(conn, cursor, statement,
250 parameters, context, executemany):
250 parameters, context, executemany):
251 setattr(conn, 'query_start_time', time.time())
251 setattr(conn, 'query_start_time', time.time())
252 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
252 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
253 calling_context = find_calling_context(ignore_modules=[
253 calling_context = find_calling_context(ignore_modules=[
254 'rhodecode.lib.caching_query',
254 'rhodecode.lib.caching_query',
255 'rhodecode.model.settings',
255 'rhodecode.model.settings',
256 ])
256 ])
257 if calling_context:
257 if calling_context:
258 log.info(color_sql('call context %s:%s' % (
258 log.info(color_sql('call context %s:%s' % (
259 calling_context.f_code.co_filename,
259 calling_context.f_code.co_filename,
260 calling_context.f_lineno,
260 calling_context.f_lineno,
261 )))
261 )))
262
262
263 def after_cursor_execute(conn, cursor, statement,
263 def after_cursor_execute(conn, cursor, statement,
264 parameters, context, executemany):
264 parameters, context, executemany):
265 delattr(conn, 'query_start_time')
265 delattr(conn, 'query_start_time')
266
266
267 sqlalchemy.event.listen(engine, "before_cursor_execute", before_cursor_execute)
267 sqlalchemy.event.listen(engine, "before_cursor_execute", before_cursor_execute)
268 sqlalchemy.event.listen(engine, "after_cursor_execute", after_cursor_execute)
268 sqlalchemy.event.listen(engine, "after_cursor_execute", after_cursor_execute)
269
269
270 return engine
270 return engine
271
271
272
272
273 def get_encryption_key(config):
273 def get_encryption_key(config):
274 secret = config.get('rhodecode.encrypted_values.secret')
274 secret = config.get('rhodecode.encrypted_values.secret')
275 default = config['beaker.session.secret']
275 default = config['beaker.session.secret']
276 return secret or default
276 return secret or default
277
277
278
278
279 def age(prevdate, now=None, show_short_version=False, show_suffix=True, short_format=False):
279 def age(prevdate, now=None, show_short_version=False, show_suffix=True, short_format=False):
280 """
280 """
281 Turns a datetime into an age string.
281 Turns a datetime into an age string.
282 If show_short_version is True, this generates a shorter string with
282 If show_short_version is True, this generates a shorter string with
283 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
283 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
284
284
285 * IMPORTANT*
285 * IMPORTANT*
286 Code of this function is written in special way so it's easier to
286 Code of this function is written in special way so it's easier to
287 backport it to javascript. If you mean to update it, please also update
287 backport it to javascript. If you mean to update it, please also update
288 `jquery.timeago-extension.js` file
288 `jquery.timeago-extension.js` file
289
289
290 :param prevdate: datetime object
290 :param prevdate: datetime object
291 :param now: get current time, if not define we use
291 :param now: get current time, if not define we use
292 `datetime.datetime.now()`
292 `datetime.datetime.now()`
293 :param show_short_version: if it should approximate the date and
293 :param show_short_version: if it should approximate the date and
294 return a shorter string
294 return a shorter string
295 :param show_suffix:
295 :param show_suffix:
296 :param short_format: show short format, eg 2D instead of 2 days
296 :param short_format: show short format, eg 2D instead of 2 days
297 :rtype: unicode
297 :rtype: unicode
298 :returns: unicode words describing age
298 :returns: unicode words describing age
299 """
299 """
300
300
301 def _get_relative_delta(now, prevdate):
301 def _get_relative_delta(now, prevdate):
302 base = dateutil.relativedelta.relativedelta(now, prevdate)
302 base = dateutil.relativedelta.relativedelta(now, prevdate)
303 return {
303 return {
304 'year': base.years,
304 'year': base.years,
305 'month': base.months,
305 'month': base.months,
306 'day': base.days,
306 'day': base.days,
307 'hour': base.hours,
307 'hour': base.hours,
308 'minute': base.minutes,
308 'minute': base.minutes,
309 'second': base.seconds,
309 'second': base.seconds,
310 }
310 }
311
311
312 def _is_leap_year(year):
312 def _is_leap_year(year):
313 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
313 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
314
314
315 def get_month(prevdate):
315 def get_month(prevdate):
316 return prevdate.month
316 return prevdate.month
317
317
318 def get_year(prevdate):
318 def get_year(prevdate):
319 return prevdate.year
319 return prevdate.year
320
320
321 now = now or datetime.datetime.now()
321 now = now or datetime.datetime.now()
322 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
322 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
323 deltas = {}
323 deltas = {}
324 future = False
324 future = False
325
325
326 if prevdate > now:
326 if prevdate > now:
327 now_old = now
327 now_old = now
328 now = prevdate
328 now = prevdate
329 prevdate = now_old
329 prevdate = now_old
330 future = True
330 future = True
331 if future:
331 if future:
332 prevdate = prevdate.replace(microsecond=0)
332 prevdate = prevdate.replace(microsecond=0)
333 # Get date parts deltas
333 # Get date parts deltas
334 for part in order:
334 for part in order:
335 rel_delta = _get_relative_delta(now, prevdate)
335 rel_delta = _get_relative_delta(now, prevdate)
336 deltas[part] = rel_delta[part]
336 deltas[part] = rel_delta[part]
337
337
338 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
338 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
339 # not 1 hour, -59 minutes and -59 seconds)
339 # not 1 hour, -59 minutes and -59 seconds)
340 offsets = [[5, 60], [4, 60], [3, 24]]
340 offsets = [[5, 60], [4, 60], [3, 24]]
341 for element in offsets: # seconds, minutes, hours
341 for element in offsets: # seconds, minutes, hours
342 num = element[0]
342 num = element[0]
343 length = element[1]
343 length = element[1]
344
344
345 part = order[num]
345 part = order[num]
346 carry_part = order[num - 1]
346 carry_part = order[num - 1]
347
347
348 if deltas[part] < 0:
348 if deltas[part] < 0:
349 deltas[part] += length
349 deltas[part] += length
350 deltas[carry_part] -= 1
350 deltas[carry_part] -= 1
351
351
352 # Same thing for days except that the increment depends on the (variable)
352 # Same thing for days except that the increment depends on the (variable)
353 # number of days in the month
353 # number of days in the month
354 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
354 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
355 if deltas['day'] < 0:
355 if deltas['day'] < 0:
356 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
356 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
357 deltas['day'] += 29
357 deltas['day'] += 29
358 else:
358 else:
359 deltas['day'] += month_lengths[get_month(prevdate) - 1]
359 deltas['day'] += month_lengths[get_month(prevdate) - 1]
360
360
361 deltas['month'] -= 1
361 deltas['month'] -= 1
362
362
363 if deltas['month'] < 0:
363 if deltas['month'] < 0:
364 deltas['month'] += 12
364 deltas['month'] += 12
365 deltas['year'] -= 1
365 deltas['year'] -= 1
366
366
367 # Format the result
367 # Format the result
368 if short_format:
368 if short_format:
369 fmt_funcs = {
369 fmt_funcs = {
370 'year': lambda d: u'%dy' % d,
370 'year': lambda d: '%dy' % d,
371 'month': lambda d: u'%dm' % d,
371 'month': lambda d: '%dm' % d,
372 'day': lambda d: u'%dd' % d,
372 'day': lambda d: '%dd' % d,
373 'hour': lambda d: u'%dh' % d,
373 'hour': lambda d: '%dh' % d,
374 'minute': lambda d: u'%dmin' % d,
374 'minute': lambda d: '%dmin' % d,
375 'second': lambda d: u'%dsec' % d,
375 'second': lambda d: '%dsec' % d,
376 }
376 }
377 else:
377 else:
378 fmt_funcs = {
378 fmt_funcs = {
379 'year': lambda d: _pluralize(u'${num} year', u'${num} years', d, mapping={'num': d}).interpolate(),
379 'year': lambda d: _pluralize('${num} year', '${num} years', d, mapping={'num': d}).interpolate(),
380 'month': lambda d: _pluralize(u'${num} month', u'${num} months', d, mapping={'num': d}).interpolate(),
380 'month': lambda d: _pluralize('${num} month', '${num} months', d, mapping={'num': d}).interpolate(),
381 'day': lambda d: _pluralize(u'${num} day', u'${num} days', d, mapping={'num': d}).interpolate(),
381 'day': lambda d: _pluralize('${num} day', '${num} days', d, mapping={'num': d}).interpolate(),
382 'hour': lambda d: _pluralize(u'${num} hour', u'${num} hours', d, mapping={'num': d}).interpolate(),
382 'hour': lambda d: _pluralize('${num} hour', '${num} hours', d, mapping={'num': d}).interpolate(),
383 'minute': lambda d: _pluralize(u'${num} minute', u'${num} minutes', d, mapping={'num': d}).interpolate(),
383 'minute': lambda d: _pluralize('${num} minute', '${num} minutes', d, mapping={'num': d}).interpolate(),
384 'second': lambda d: _pluralize(u'${num} second', u'${num} seconds', d, mapping={'num': d}).interpolate(),
384 'second': lambda d: _pluralize('${num} second', '${num} seconds', d, mapping={'num': d}).interpolate(),
385 }
385 }
386
386
387 i = 0
387 i = 0
388 for part in order:
388 for part in order:
389 value = deltas[part]
389 value = deltas[part]
390 if value != 0:
390 if value != 0:
391
391
392 if i < 5:
392 if i < 5:
393 sub_part = order[i + 1]
393 sub_part = order[i + 1]
394 sub_value = deltas[sub_part]
394 sub_value = deltas[sub_part]
395 else:
395 else:
396 sub_value = 0
396 sub_value = 0
397
397
398 if sub_value == 0 or show_short_version:
398 if sub_value == 0 or show_short_version:
399 _val = fmt_funcs[part](value)
399 _val = fmt_funcs[part](value)
400 if future:
400 if future:
401 if show_suffix:
401 if show_suffix:
402 return _(u'in ${ago}', mapping={'ago': _val})
402 return _('in ${ago}', mapping={'ago': _val})
403 else:
403 else:
404 return _(_val)
404 return _(_val)
405
405
406 else:
406 else:
407 if show_suffix:
407 if show_suffix:
408 return _(u'${ago} ago', mapping={'ago': _val})
408 return _('${ago} ago', mapping={'ago': _val})
409 else:
409 else:
410 return _(_val)
410 return _(_val)
411
411
412 val = fmt_funcs[part](value)
412 val = fmt_funcs[part](value)
413 val_detail = fmt_funcs[sub_part](sub_value)
413 val_detail = fmt_funcs[sub_part](sub_value)
414 mapping = {'val': val, 'detail': val_detail}
414 mapping = {'val': val, 'detail': val_detail}
415
415
416 if short_format:
416 if short_format:
417 datetime_tmpl = _(u'${val}, ${detail}', mapping=mapping)
417 datetime_tmpl = _('${val}, ${detail}', mapping=mapping)
418 if show_suffix:
418 if show_suffix:
419 datetime_tmpl = _(u'${val}, ${detail} ago', mapping=mapping)
419 datetime_tmpl = _('${val}, ${detail} ago', mapping=mapping)
420 if future:
420 if future:
421 datetime_tmpl = _(u'in ${val}, ${detail}', mapping=mapping)
421 datetime_tmpl = _('in ${val}, ${detail}', mapping=mapping)
422 else:
422 else:
423 datetime_tmpl = _(u'${val} and ${detail}', mapping=mapping)
423 datetime_tmpl = _('${val} and ${detail}', mapping=mapping)
424 if show_suffix:
424 if show_suffix:
425 datetime_tmpl = _(u'${val} and ${detail} ago', mapping=mapping)
425 datetime_tmpl = _('${val} and ${detail} ago', mapping=mapping)
426 if future:
426 if future:
427 datetime_tmpl = _(u'in ${val} and ${detail}', mapping=mapping)
427 datetime_tmpl = _('in ${val} and ${detail}', mapping=mapping)
428
428
429 return datetime_tmpl
429 return datetime_tmpl
430 i += 1
430 i += 1
431 return _(u'just now')
431 return _('just now')
432
432
433
433
434 def age_from_seconds(seconds):
434 def age_from_seconds(seconds):
435 seconds = safe_int(seconds) or 0
435 seconds = safe_int(seconds) or 0
436 prevdate = time_to_datetime(time.time() + seconds)
436 prevdate = time_to_datetime(time.time() + seconds)
437 return age(prevdate, show_suffix=False, show_short_version=True)
437 return age(prevdate, show_suffix=False, show_short_version=True)
438
438
439
439
440 def cleaned_uri(uri):
440 def cleaned_uri(uri):
441 """
441 """
442 Quotes '[' and ']' from uri if there is only one of them.
442 Quotes '[' and ']' from uri if there is only one of them.
443 according to RFC3986 we cannot use such chars in uri
443 according to RFC3986 we cannot use such chars in uri
444 :param uri:
444 :param uri:
445 :return: uri without this chars
445 :return: uri without this chars
446 """
446 """
447 return urllib.parse.quote(uri, safe='@$:/')
447 return urllib.parse.quote(uri, safe='@$:/')
448
448
449
449
450 def credentials_filter(uri):
450 def credentials_filter(uri):
451 """
451 """
452 Returns a url with removed credentials
452 Returns a url with removed credentials
453
453
454 :param uri:
454 :param uri:
455 """
455 """
456 import urlobject
456 import urlobject
457 if isinstance(uri, rhodecode.lib.encrypt.InvalidDecryptedValue):
457 if isinstance(uri, rhodecode.lib.encrypt.InvalidDecryptedValue):
458 return 'InvalidDecryptionKey'
458 return 'InvalidDecryptionKey'
459
459
460 url_obj = urlobject.URLObject(cleaned_uri(uri))
460 url_obj = urlobject.URLObject(cleaned_uri(uri))
461 url_obj = url_obj.without_password().without_username()
461 url_obj = url_obj.without_password().without_username()
462
462
463 return url_obj
463 return url_obj
464
464
465
465
466 def get_host_info(request):
466 def get_host_info(request):
467 """
467 """
468 Generate host info, to obtain full url e.g https://server.com
468 Generate host info, to obtain full url e.g https://server.com
469 use this
469 use this
470 `{scheme}://{netloc}`
470 `{scheme}://{netloc}`
471 """
471 """
472 if not request:
472 if not request:
473 return {}
473 return {}
474
474
475 qualified_home_url = request.route_url('home')
475 qualified_home_url = request.route_url('home')
476 parsed_url = urlobject.URLObject(qualified_home_url)
476 parsed_url = urlobject.URLObject(qualified_home_url)
477 decoded_path = safe_unicode(urllib.parse.unquote(parsed_url.path.rstrip('/')))
477 decoded_path = safe_unicode(urllib.parse.unquote(parsed_url.path.rstrip('/')))
478
478
479 return {
479 return {
480 'scheme': parsed_url.scheme,
480 'scheme': parsed_url.scheme,
481 'netloc': parsed_url.netloc+decoded_path,
481 'netloc': parsed_url.netloc+decoded_path,
482 'hostname': parsed_url.hostname,
482 'hostname': parsed_url.hostname,
483 }
483 }
484
484
485
485
486 def get_clone_url(request, uri_tmpl, repo_name, repo_id, repo_type, **override):
486 def get_clone_url(request, uri_tmpl, repo_name, repo_id, repo_type, **override):
487 qualified_home_url = request.route_url('home')
487 qualified_home_url = request.route_url('home')
488 parsed_url = urlobject.URLObject(qualified_home_url)
488 parsed_url = urlobject.URLObject(qualified_home_url)
489 decoded_path = safe_unicode(urllib.parse.unquote(parsed_url.path.rstrip('/')))
489 decoded_path = safe_unicode(urllib.parse.unquote(parsed_url.path.rstrip('/')))
490
490
491 args = {
491 args = {
492 'scheme': parsed_url.scheme,
492 'scheme': parsed_url.scheme,
493 'user': '',
493 'user': '',
494 'sys_user': getpass.getuser(),
494 'sys_user': getpass.getuser(),
495 # path if we use proxy-prefix
495 # path if we use proxy-prefix
496 'netloc': parsed_url.netloc+decoded_path,
496 'netloc': parsed_url.netloc+decoded_path,
497 'hostname': parsed_url.hostname,
497 'hostname': parsed_url.hostname,
498 'prefix': decoded_path,
498 'prefix': decoded_path,
499 'repo': repo_name,
499 'repo': repo_name,
500 'repoid': str(repo_id),
500 'repoid': str(repo_id),
501 'repo_type': repo_type
501 'repo_type': repo_type
502 }
502 }
503 args.update(override)
503 args.update(override)
504 args['user'] = urllib.parse.quote(safe_str(args['user']))
504 args['user'] = urllib.parse.quote(safe_str(args['user']))
505
505
506 for k, v in args.items():
506 for k, v in args.items():
507 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
507 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
508
508
509 # special case for SVN clone url
509 # special case for SVN clone url
510 if repo_type == 'svn':
510 if repo_type == 'svn':
511 uri_tmpl = uri_tmpl.replace('ssh://', 'svn+ssh://')
511 uri_tmpl = uri_tmpl.replace('ssh://', 'svn+ssh://')
512
512
513 # remove leading @ sign if it's present. Case of empty user
513 # remove leading @ sign if it's present. Case of empty user
514 url_obj = urlobject.URLObject(uri_tmpl)
514 url_obj = urlobject.URLObject(uri_tmpl)
515 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
515 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
516
516
517 return safe_unicode(url)
517 return safe_unicode(url)
518
518
519
519
520 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None,
520 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None,
521 maybe_unreachable=False, reference_obj=None):
521 maybe_unreachable=False, reference_obj=None):
522 """
522 """
523 Safe version of get_commit if this commit doesn't exists for a
523 Safe version of get_commit if this commit doesn't exists for a
524 repository it returns a Dummy one instead
524 repository it returns a Dummy one instead
525
525
526 :param repo: repository instance
526 :param repo: repository instance
527 :param commit_id: commit id as str
527 :param commit_id: commit id as str
528 :param commit_idx: numeric commit index
528 :param commit_idx: numeric commit index
529 :param pre_load: optional list of commit attributes to load
529 :param pre_load: optional list of commit attributes to load
530 :param maybe_unreachable: translate unreachable commits on git repos
530 :param maybe_unreachable: translate unreachable commits on git repos
531 :param reference_obj: explicitly search via a reference obj in git. E.g "branch:123" would mean branch "123"
531 :param reference_obj: explicitly search via a reference obj in git. E.g "branch:123" would mean branch "123"
532 """
532 """
533 # TODO(skreft): remove these circular imports
533 # TODO(skreft): remove these circular imports
534 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
534 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
535 from rhodecode.lib.vcs.exceptions import RepositoryError
535 from rhodecode.lib.vcs.exceptions import RepositoryError
536 if not isinstance(repo, BaseRepository):
536 if not isinstance(repo, BaseRepository):
537 raise Exception('You must pass an Repository '
537 raise Exception('You must pass an Repository '
538 'object as first argument got %s', type(repo))
538 'object as first argument got %s', type(repo))
539
539
540 try:
540 try:
541 commit = repo.get_commit(
541 commit = repo.get_commit(
542 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load,
542 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load,
543 maybe_unreachable=maybe_unreachable, reference_obj=reference_obj)
543 maybe_unreachable=maybe_unreachable, reference_obj=reference_obj)
544 except (RepositoryError, LookupError):
544 except (RepositoryError, LookupError):
545 commit = EmptyCommit()
545 commit = EmptyCommit()
546 return commit
546 return commit
547
547
548
548
549 def datetime_to_time(dt):
549 def datetime_to_time(dt):
550 if dt:
550 if dt:
551 return time.mktime(dt.timetuple())
551 return time.mktime(dt.timetuple())
552
552
553
553
554 def time_to_datetime(tm):
554 def time_to_datetime(tm):
555 if tm:
555 if tm:
556 if isinstance(tm, str):
556 if isinstance(tm, str):
557 try:
557 try:
558 tm = float(tm)
558 tm = float(tm)
559 except ValueError:
559 except ValueError:
560 return
560 return
561 return datetime.datetime.fromtimestamp(tm)
561 return datetime.datetime.fromtimestamp(tm)
562
562
563
563
564 def time_to_utcdatetime(tm):
564 def time_to_utcdatetime(tm):
565 if tm:
565 if tm:
566 if isinstance(tm, str):
566 if isinstance(tm, str):
567 try:
567 try:
568 tm = float(tm)
568 tm = float(tm)
569 except ValueError:
569 except ValueError:
570 return
570 return
571 return datetime.datetime.utcfromtimestamp(tm)
571 return datetime.datetime.utcfromtimestamp(tm)
572
572
573
573
574 MENTIONS_REGEX = re.compile(
574 MENTIONS_REGEX = re.compile(
575 # ^@ or @ without any special chars in front
575 # ^@ or @ without any special chars in front
576 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
576 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
577 # main body starts with letter, then can be . - _
577 # main body starts with letter, then can be . - _
578 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
578 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
579 re.VERBOSE | re.MULTILINE)
579 re.VERBOSE | re.MULTILINE)
580
580
581
581
582 def extract_mentioned_users(s):
582 def extract_mentioned_users(s):
583 """
583 """
584 Returns unique usernames from given string s that have @mention
584 Returns unique usernames from given string s that have @mention
585
585
586 :param s: string to get mentions
586 :param s: string to get mentions
587 """
587 """
588 usrs = set()
588 usrs = set()
589 for username in MENTIONS_REGEX.findall(s):
589 for username in MENTIONS_REGEX.findall(s):
590 usrs.add(username)
590 usrs.add(username)
591
591
592 return sorted(list(usrs), key=lambda k: k.lower())
592 return sorted(list(usrs), key=lambda k: k.lower())
593
593
594
594
595 class AttributeDictBase(dict):
595 class AttributeDictBase(dict):
596 def __getstate__(self):
596 def __getstate__(self):
597 odict = self.__dict__ # get attribute dictionary
597 odict = self.__dict__ # get attribute dictionary
598 return odict
598 return odict
599
599
600 def __setstate__(self, dict):
600 def __setstate__(self, dict):
601 self.__dict__ = dict
601 self.__dict__ = dict
602
602
603 __setattr__ = dict.__setitem__
603 __setattr__ = dict.__setitem__
604 __delattr__ = dict.__delitem__
604 __delattr__ = dict.__delitem__
605
605
606
606
607 class StrictAttributeDict(AttributeDictBase):
607 class StrictAttributeDict(AttributeDictBase):
608 """
608 """
609 Strict Version of Attribute dict which raises an Attribute error when
609 Strict Version of Attribute dict which raises an Attribute error when
610 requested attribute is not set
610 requested attribute is not set
611 """
611 """
612 def __getattr__(self, attr):
612 def __getattr__(self, attr):
613 try:
613 try:
614 return self[attr]
614 return self[attr]
615 except KeyError:
615 except KeyError:
616 raise AttributeError('%s object has no attribute %s' % (
616 raise AttributeError('%s object has no attribute %s' % (
617 self.__class__, attr))
617 self.__class__, attr))
618
618
619
619
620 class AttributeDict(AttributeDictBase):
620 class AttributeDict(AttributeDictBase):
621 def __getattr__(self, attr):
621 def __getattr__(self, attr):
622 return self.get(attr, None)
622 return self.get(attr, None)
623
623
624
624
625 def fix_PATH(os_=None):
625 def fix_PATH(os_=None):
626 """
626 """
627 Get current active python path, and append it to PATH variable to fix
627 Get current active python path, and append it to PATH variable to fix
628 issues of subprocess calls and different python versions
628 issues of subprocess calls and different python versions
629 """
629 """
630 if os_ is None:
630 if os_ is None:
631 import os
631 import os
632 else:
632 else:
633 os = os_
633 os = os_
634
634
635 cur_path = os.path.split(sys.executable)[0]
635 cur_path = os.path.split(sys.executable)[0]
636 if not os.environ['PATH'].startswith(cur_path):
636 if not os.environ['PATH'].startswith(cur_path):
637 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
637 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
638
638
639
639
640 def obfuscate_url_pw(engine):
640 def obfuscate_url_pw(engine):
641 _url = engine or ''
641 _url = engine or ''
642 try:
642 try:
643 _url = sqlalchemy.engine.url.make_url(engine)
643 _url = sqlalchemy.engine.url.make_url(engine)
644 if _url.password:
644 if _url.password:
645 _url.password = 'XXXXX'
645 _url.password = 'XXXXX'
646 except Exception:
646 except Exception:
647 pass
647 pass
648 return str(_url)
648 return str(_url)
649
649
650
650
651 def get_server_url(environ):
651 def get_server_url(environ):
652 req = webob.Request(environ)
652 req = webob.Request(environ)
653 return req.host_url + req.script_name
653 return req.host_url + req.script_name
654
654
655
655
656 def unique_id(hexlen=32):
656 def unique_id(hexlen=32):
657 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
657 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
658 return suuid(truncate_to=hexlen, alphabet=alphabet)
658 return suuid(truncate_to=hexlen, alphabet=alphabet)
659
659
660
660
661 def suuid(url=None, truncate_to=22, alphabet=None):
661 def suuid(url=None, truncate_to=22, alphabet=None):
662 """
662 """
663 Generate and return a short URL safe UUID.
663 Generate and return a short URL safe UUID.
664
664
665 If the url parameter is provided, set the namespace to the provided
665 If the url parameter is provided, set the namespace to the provided
666 URL and generate a UUID.
666 URL and generate a UUID.
667
667
668 :param url to get the uuid for
668 :param url to get the uuid for
669 :truncate_to: truncate the basic 22 UUID to shorter version
669 :truncate_to: truncate the basic 22 UUID to shorter version
670
670
671 The IDs won't be universally unique any longer, but the probability of
671 The IDs won't be universally unique any longer, but the probability of
672 a collision will still be very low.
672 a collision will still be very low.
673 """
673 """
674 # Define our alphabet.
674 # Define our alphabet.
675 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
675 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
676
676
677 # If no URL is given, generate a random UUID.
677 # If no URL is given, generate a random UUID.
678 if url is None:
678 if url is None:
679 unique_id = uuid.uuid4().int
679 unique_id = uuid.uuid4().int
680 else:
680 else:
681 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
681 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
682
682
683 alphabet_length = len(_ALPHABET)
683 alphabet_length = len(_ALPHABET)
684 output = []
684 output = []
685 while unique_id > 0:
685 while unique_id > 0:
686 digit = unique_id % alphabet_length
686 digit = unique_id % alphabet_length
687 output.append(_ALPHABET[digit])
687 output.append(_ALPHABET[digit])
688 unique_id = int(unique_id / alphabet_length)
688 unique_id = int(unique_id / alphabet_length)
689 return "".join(output)[:truncate_to]
689 return "".join(output)[:truncate_to]
690
690
691
691
692 def get_current_rhodecode_user(request=None):
692 def get_current_rhodecode_user(request=None):
693 """
693 """
694 Gets rhodecode user from request
694 Gets rhodecode user from request
695 """
695 """
696 pyramid_request = request or pyramid.threadlocal.get_current_request()
696 pyramid_request = request or pyramid.threadlocal.get_current_request()
697
697
698 # web case
698 # web case
699 if pyramid_request and hasattr(pyramid_request, 'user'):
699 if pyramid_request and hasattr(pyramid_request, 'user'):
700 return pyramid_request.user
700 return pyramid_request.user
701
701
702 # api case
702 # api case
703 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
703 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
704 return pyramid_request.rpc_user
704 return pyramid_request.rpc_user
705
705
706 return None
706 return None
707
707
708
708
709 def action_logger_generic(action, namespace=''):
709 def action_logger_generic(action, namespace=''):
710 """
710 """
711 A generic logger for actions useful to the system overview, tries to find
711 A generic logger for actions useful to the system overview, tries to find
712 an acting user for the context of the call otherwise reports unknown user
712 an acting user for the context of the call otherwise reports unknown user
713
713
714 :param action: logging message eg 'comment 5 deleted'
714 :param action: logging message eg 'comment 5 deleted'
715 :param type: string
715 :param type: string
716
716
717 :param namespace: namespace of the logging message eg. 'repo.comments'
717 :param namespace: namespace of the logging message eg. 'repo.comments'
718 :param type: string
718 :param type: string
719
719
720 """
720 """
721
721
722 logger_name = 'rhodecode.actions'
722 logger_name = 'rhodecode.actions'
723
723
724 if namespace:
724 if namespace:
725 logger_name += '.' + namespace
725 logger_name += '.' + namespace
726
726
727 log = logging.getLogger(logger_name)
727 log = logging.getLogger(logger_name)
728
728
729 # get a user if we can
729 # get a user if we can
730 user = get_current_rhodecode_user()
730 user = get_current_rhodecode_user()
731
731
732 logfunc = log.info
732 logfunc = log.info
733
733
734 if not user:
734 if not user:
735 user = '<unknown user>'
735 user = '<unknown user>'
736 logfunc = log.warning
736 logfunc = log.warning
737
737
738 logfunc('Logging action by {}: {}'.format(user, action))
738 logfunc('Logging action by {}: {}'.format(user, action))
739
739
740
740
741 def escape_split(text, sep=',', maxsplit=-1):
741 def escape_split(text, sep=',', maxsplit=-1):
742 r"""
742 r"""
743 Allows for escaping of the separator: e.g. arg='foo\, bar'
743 Allows for escaping of the separator: e.g. arg='foo\, bar'
744
744
745 It should be noted that the way bash et. al. do command line parsing, those
745 It should be noted that the way bash et. al. do command line parsing, those
746 single quotes are required.
746 single quotes are required.
747 """
747 """
748 escaped_sep = r'\%s' % sep
748 escaped_sep = r'\%s' % sep
749
749
750 if escaped_sep not in text:
750 if escaped_sep not in text:
751 return text.split(sep, maxsplit)
751 return text.split(sep, maxsplit)
752
752
753 before, _mid, after = text.partition(escaped_sep)
753 before, _mid, after = text.partition(escaped_sep)
754 startlist = before.split(sep, maxsplit) # a regular split is fine here
754 startlist = before.split(sep, maxsplit) # a regular split is fine here
755 unfinished = startlist[-1]
755 unfinished = startlist[-1]
756 startlist = startlist[:-1]
756 startlist = startlist[:-1]
757
757
758 # recurse because there may be more escaped separators
758 # recurse because there may be more escaped separators
759 endlist = escape_split(after, sep, maxsplit)
759 endlist = escape_split(after, sep, maxsplit)
760
760
761 # finish building the escaped value. we use endlist[0] becaue the first
761 # finish building the escaped value. we use endlist[0] becaue the first
762 # part of the string sent in recursion is the rest of the escaped value.
762 # part of the string sent in recursion is the rest of the escaped value.
763 unfinished += sep + endlist[0]
763 unfinished += sep + endlist[0]
764
764
765 return startlist + [unfinished] + endlist[1:] # put together all the parts
765 return startlist + [unfinished] + endlist[1:] # put together all the parts
766
766
767
767
768 class OptionalAttr(object):
768 class OptionalAttr(object):
769 """
769 """
770 Special Optional Option that defines other attribute. Example::
770 Special Optional Option that defines other attribute. Example::
771
771
772 def test(apiuser, userid=Optional(OAttr('apiuser')):
772 def test(apiuser, userid=Optional(OAttr('apiuser')):
773 user = Optional.extract(userid)
773 user = Optional.extract(userid)
774 # calls
774 # calls
775
775
776 """
776 """
777
777
778 def __init__(self, attr_name):
778 def __init__(self, attr_name):
779 self.attr_name = attr_name
779 self.attr_name = attr_name
780
780
781 def __repr__(self):
781 def __repr__(self):
782 return '<OptionalAttr:%s>' % self.attr_name
782 return '<OptionalAttr:%s>' % self.attr_name
783
783
784 def __call__(self):
784 def __call__(self):
785 return self
785 return self
786
786
787
787
788 # alias
788 # alias
789 OAttr = OptionalAttr
789 OAttr = OptionalAttr
790
790
791
791
792 class Optional(object):
792 class Optional(object):
793 """
793 """
794 Defines an optional parameter::
794 Defines an optional parameter::
795
795
796 param = param.getval() if isinstance(param, Optional) else param
796 param = param.getval() if isinstance(param, Optional) else param
797 param = param() if isinstance(param, Optional) else param
797 param = param() if isinstance(param, Optional) else param
798
798
799 is equivalent of::
799 is equivalent of::
800
800
801 param = Optional.extract(param)
801 param = Optional.extract(param)
802
802
803 """
803 """
804
804
805 def __init__(self, type_):
805 def __init__(self, type_):
806 self.type_ = type_
806 self.type_ = type_
807
807
808 def __repr__(self):
808 def __repr__(self):
809 return '<Optional:%s>' % self.type_.__repr__()
809 return '<Optional:%s>' % self.type_.__repr__()
810
810
811 def __call__(self):
811 def __call__(self):
812 return self.getval()
812 return self.getval()
813
813
814 def getval(self):
814 def getval(self):
815 """
815 """
816 returns value from this Optional instance
816 returns value from this Optional instance
817 """
817 """
818 if isinstance(self.type_, OAttr):
818 if isinstance(self.type_, OAttr):
819 # use params name
819 # use params name
820 return self.type_.attr_name
820 return self.type_.attr_name
821 return self.type_
821 return self.type_
822
822
823 @classmethod
823 @classmethod
824 def extract(cls, val):
824 def extract(cls, val):
825 """
825 """
826 Extracts value from Optional() instance
826 Extracts value from Optional() instance
827
827
828 :param val:
828 :param val:
829 :return: original value if it's not Optional instance else
829 :return: original value if it's not Optional instance else
830 value of instance
830 value of instance
831 """
831 """
832 if isinstance(val, cls):
832 if isinstance(val, cls):
833 return val.getval()
833 return val.getval()
834 return val
834 return val
835
835
836
836
837 def glob2re(pat):
837 def glob2re(pat):
838 """
838 """
839 Translate a shell PATTERN to a regular expression.
839 Translate a shell PATTERN to a regular expression.
840
840
841 There is no way to quote meta-characters.
841 There is no way to quote meta-characters.
842 """
842 """
843
843
844 i, n = 0, len(pat)
844 i, n = 0, len(pat)
845 res = ''
845 res = ''
846 while i < n:
846 while i < n:
847 c = pat[i]
847 c = pat[i]
848 i = i+1
848 i = i+1
849 if c == '*':
849 if c == '*':
850 #res = res + '.*'
850 #res = res + '.*'
851 res = res + '[^/]*'
851 res = res + '[^/]*'
852 elif c == '?':
852 elif c == '?':
853 #res = res + '.'
853 #res = res + '.'
854 res = res + '[^/]'
854 res = res + '[^/]'
855 elif c == '[':
855 elif c == '[':
856 j = i
856 j = i
857 if j < n and pat[j] == '!':
857 if j < n and pat[j] == '!':
858 j = j+1
858 j = j+1
859 if j < n and pat[j] == ']':
859 if j < n and pat[j] == ']':
860 j = j+1
860 j = j+1
861 while j < n and pat[j] != ']':
861 while j < n and pat[j] != ']':
862 j = j+1
862 j = j+1
863 if j >= n:
863 if j >= n:
864 res = res + '\\['
864 res = res + '\\['
865 else:
865 else:
866 stuff = pat[i:j].replace('\\','\\\\')
866 stuff = pat[i:j].replace('\\','\\\\')
867 i = j+1
867 i = j+1
868 if stuff[0] == '!':
868 if stuff[0] == '!':
869 stuff = '^' + stuff[1:]
869 stuff = '^' + stuff[1:]
870 elif stuff[0] == '^':
870 elif stuff[0] == '^':
871 stuff = '\\' + stuff
871 stuff = '\\' + stuff
872 res = '%s[%s]' % (res, stuff)
872 res = '%s[%s]' % (res, stuff)
873 else:
873 else:
874 res = res + re.escape(c)
874 res = res + re.escape(c)
875 return res + '\Z(?ms)'
875 return res + '\Z(?ms)'
876
876
877
877
878 def parse_byte_string(size_str):
878 def parse_byte_string(size_str):
879 match = re.match(r'(\d+)(MB|KB)', size_str, re.IGNORECASE)
879 match = re.match(r'(\d+)(MB|KB)', size_str, re.IGNORECASE)
880 if not match:
880 if not match:
881 raise ValueError('Given size:%s is invalid, please make sure '
881 raise ValueError('Given size:%s is invalid, please make sure '
882 'to use format of <num>(MB|KB)' % size_str)
882 'to use format of <num>(MB|KB)' % size_str)
883
883
884 _parts = match.groups()
884 _parts = match.groups()
885 num, type_ = _parts
885 num, type_ = _parts
886 return int(num) * {'mb': 1024*1024, 'kb': 1024}[type_.lower()]
886 return int(num) * {'mb': 1024*1024, 'kb': 1024}[type_.lower()]
887
887
888
888
889 class CachedProperty(object):
889 class CachedProperty(object):
890 """
890 """
891 Lazy Attributes. With option to invalidate the cache by running a method
891 Lazy Attributes. With option to invalidate the cache by running a method
892
892
893 >>> class Foo(object):
893 >>> class Foo(object):
894 ...
894 ...
895 ... @CachedProperty
895 ... @CachedProperty
896 ... def heavy_func(self):
896 ... def heavy_func(self):
897 ... return 'super-calculation'
897 ... return 'super-calculation'
898 ...
898 ...
899 ... foo = Foo()
899 ... foo = Foo()
900 ... foo.heavy_func() # first computation
900 ... foo.heavy_func() # first computation
901 ... foo.heavy_func() # fetch from cache
901 ... foo.heavy_func() # fetch from cache
902 ... foo._invalidate_prop_cache('heavy_func')
902 ... foo._invalidate_prop_cache('heavy_func')
903
903
904 # at this point calling foo.heavy_func() will be re-computed
904 # at this point calling foo.heavy_func() will be re-computed
905 """
905 """
906
906
907 def __init__(self, func, func_name=None):
907 def __init__(self, func, func_name=None):
908
908
909 if func_name is None:
909 if func_name is None:
910 func_name = func.__name__
910 func_name = func.__name__
911 self.data = (func, func_name)
911 self.data = (func, func_name)
912 update_wrapper(self, func)
912 update_wrapper(self, func)
913
913
914 def __get__(self, inst, class_):
914 def __get__(self, inst, class_):
915 if inst is None:
915 if inst is None:
916 return self
916 return self
917
917
918 func, func_name = self.data
918 func, func_name = self.data
919 value = func(inst)
919 value = func(inst)
920 inst.__dict__[func_name] = value
920 inst.__dict__[func_name] = value
921 if '_invalidate_prop_cache' not in inst.__dict__:
921 if '_invalidate_prop_cache' not in inst.__dict__:
922 inst.__dict__['_invalidate_prop_cache'] = partial(
922 inst.__dict__['_invalidate_prop_cache'] = partial(
923 self._invalidate_prop_cache, inst)
923 self._invalidate_prop_cache, inst)
924 return value
924 return value
925
925
926 def _invalidate_prop_cache(self, inst, name):
926 def _invalidate_prop_cache(self, inst, name):
927 inst.__dict__.pop(name, None)
927 inst.__dict__.pop(name, None)
928
928
929
929
930 def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=True):
930 def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=True):
931 """
931 """
932 Retry decorator with exponential backoff.
932 Retry decorator with exponential backoff.
933
933
934 Parameters
934 Parameters
935 ----------
935 ----------
936 func : typing.Callable, optional
936 func : typing.Callable, optional
937 Callable on which the decorator is applied, by default None
937 Callable on which the decorator is applied, by default None
938 exception : Exception or tuple of Exceptions, optional
938 exception : Exception or tuple of Exceptions, optional
939 Exception(s) that invoke retry, by default Exception
939 Exception(s) that invoke retry, by default Exception
940 n_tries : int, optional
940 n_tries : int, optional
941 Number of tries before giving up, by default 5
941 Number of tries before giving up, by default 5
942 delay : int, optional
942 delay : int, optional
943 Initial delay between retries in seconds, by default 5
943 Initial delay between retries in seconds, by default 5
944 backoff : int, optional
944 backoff : int, optional
945 Backoff multiplier e.g. value of 2 will double the delay, by default 1
945 Backoff multiplier e.g. value of 2 will double the delay, by default 1
946 logger : bool, optional
946 logger : bool, optional
947 Option to log or print, by default False
947 Option to log or print, by default False
948
948
949 Returns
949 Returns
950 -------
950 -------
951 typing.Callable
951 typing.Callable
952 Decorated callable that calls itself when exception(s) occur.
952 Decorated callable that calls itself when exception(s) occur.
953
953
954 Examples
954 Examples
955 --------
955 --------
956 >>> import random
956 >>> import random
957 >>> @retry(exception=Exception, n_tries=3)
957 >>> @retry(exception=Exception, n_tries=3)
958 ... def test_random(text):
958 ... def test_random(text):
959 ... x = random.random()
959 ... x = random.random()
960 ... if x < 0.5:
960 ... if x < 0.5:
961 ... raise Exception("Fail")
961 ... raise Exception("Fail")
962 ... else:
962 ... else:
963 ... print("Success: ", text)
963 ... print("Success: ", text)
964 >>> test_random("It works!")
964 >>> test_random("It works!")
965 """
965 """
966
966
967 if func is None:
967 if func is None:
968 return partial(
968 return partial(
969 retry,
969 retry,
970 exception=exception,
970 exception=exception,
971 n_tries=n_tries,
971 n_tries=n_tries,
972 delay=delay,
972 delay=delay,
973 backoff=backoff,
973 backoff=backoff,
974 logger=logger,
974 logger=logger,
975 )
975 )
976
976
977 @wraps(func)
977 @wraps(func)
978 def wrapper(*args, **kwargs):
978 def wrapper(*args, **kwargs):
979 _n_tries, n_delay = n_tries, delay
979 _n_tries, n_delay = n_tries, delay
980 log = logging.getLogger('rhodecode.retry')
980 log = logging.getLogger('rhodecode.retry')
981
981
982 while _n_tries > 1:
982 while _n_tries > 1:
983 try:
983 try:
984 return func(*args, **kwargs)
984 return func(*args, **kwargs)
985 except exception as e:
985 except exception as e:
986 e_details = repr(e)
986 e_details = repr(e)
987 msg = "Exception on calling func {func}: {e}, " \
987 msg = "Exception on calling func {func}: {e}, " \
988 "Retrying in {n_delay} seconds..."\
988 "Retrying in {n_delay} seconds..."\
989 .format(func=func, e=e_details, n_delay=n_delay)
989 .format(func=func, e=e_details, n_delay=n_delay)
990 if logger:
990 if logger:
991 log.warning(msg)
991 log.warning(msg)
992 else:
992 else:
993 print(msg)
993 print(msg)
994 time.sleep(n_delay)
994 time.sleep(n_delay)
995 _n_tries -= 1
995 _n_tries -= 1
996 n_delay *= backoff
996 n_delay *= backoff
997
997
998 return func(*args, **kwargs)
998 return func(*args, **kwargs)
999
999
1000 return wrapper
1000 return wrapper
1001
1001
1002
1002
1003 def user_agent_normalizer(user_agent_raw, safe=True):
1003 def user_agent_normalizer(user_agent_raw, safe=True):
1004 log = logging.getLogger('rhodecode.user_agent_normalizer')
1004 log = logging.getLogger('rhodecode.user_agent_normalizer')
1005 ua = (user_agent_raw or '').strip().lower()
1005 ua = (user_agent_raw or '').strip().lower()
1006 ua = ua.replace('"', '')
1006 ua = ua.replace('"', '')
1007
1007
1008 try:
1008 try:
1009 if 'mercurial/proto-1.0' in ua:
1009 if 'mercurial/proto-1.0' in ua:
1010 ua = ua.replace('mercurial/proto-1.0', '')
1010 ua = ua.replace('mercurial/proto-1.0', '')
1011 ua = ua.replace('(', '').replace(')', '').strip()
1011 ua = ua.replace('(', '').replace(')', '').strip()
1012 ua = ua.replace('mercurial ', 'mercurial/')
1012 ua = ua.replace('mercurial ', 'mercurial/')
1013 elif ua.startswith('git'):
1013 elif ua.startswith('git'):
1014 parts = ua.split(' ')
1014 parts = ua.split(' ')
1015 if parts:
1015 if parts:
1016 ua = parts[0]
1016 ua = parts[0]
1017 ua = re.sub('\.windows\.\d', '', ua).strip()
1017 ua = re.sub('\.windows\.\d', '', ua).strip()
1018
1018
1019 return ua
1019 return ua
1020 except Exception:
1020 except Exception:
1021 log.exception('Failed to parse scm user-agent')
1021 log.exception('Failed to parse scm user-agent')
1022 if not safe:
1022 if not safe:
1023 raise
1023 raise
1024
1024
1025 return ua
1025 return ua
1026
1026
1027
1027
1028 def get_available_port(min_port=40000, max_port=55555, use_range=False):
1028 def get_available_port(min_port=40000, max_port=55555, use_range=False):
1029 hostname = ''
1029 hostname = ''
1030 for _ in range(min_port, max_port):
1030 for _ in range(min_port, max_port):
1031 pick_port = 0
1031 pick_port = 0
1032 if use_range:
1032 if use_range:
1033 pick_port = random.randint(min_port, max_port)
1033 pick_port = random.randint(min_port, max_port)
1034
1034
1035 with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
1035 with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
1036 try:
1036 try:
1037 s.bind((hostname, pick_port))
1037 s.bind((hostname, pick_port))
1038 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1038 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1039 return s.getsockname()[1]
1039 return s.getsockname()[1]
1040 except OSError:
1040 except OSError:
1041 continue
1041 continue
1042 except socket.error as e:
1042 except socket.error as e:
1043 if e.args[0] in [errno.EADDRINUSE, errno.ECONNREFUSED]:
1043 if e.args[0] in [errno.EADDRINUSE, errno.ECONNREFUSED]:
1044 continue
1044 continue
1045 raise
1045 raise
@@ -1,494 +1,494 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2014-2020 RhodeCode GmbH
3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 GIT commit module
22 GIT commit module
23 """
23 """
24
24
25 import re
25 import re
26 import io
26 import stat
27 import stat
27 import configparser
28 import configparser
28 from itertools import chain
29 from itertools import chain
29 from io import StringIO
30
30
31 from zope.cachedescriptors.property import Lazy as LazyProperty
31 from zope.cachedescriptors.property import Lazy as LazyProperty
32
32
33 from rhodecode.lib.datelib import utcdate_fromtimestamp
33 from rhodecode.lib.datelib import utcdate_fromtimestamp
34 from rhodecode.lib.utils import safe_unicode, safe_str
34 from rhodecode.lib.utils import safe_unicode, safe_str
35 from rhodecode.lib.utils2 import safe_int
35 from rhodecode.lib.utils2 import safe_int
36 from rhodecode.lib.vcs.conf import settings
36 from rhodecode.lib.vcs.conf import settings
37 from rhodecode.lib.vcs.backends import base
37 from rhodecode.lib.vcs.backends import base
38 from rhodecode.lib.vcs.exceptions import CommitError, NodeDoesNotExistError
38 from rhodecode.lib.vcs.exceptions import CommitError, NodeDoesNotExistError
39 from rhodecode.lib.vcs.nodes import (
39 from rhodecode.lib.vcs.nodes import (
40 FileNode, DirNode, NodeKind, RootNode, SubModuleNode,
40 FileNode, DirNode, NodeKind, RootNode, SubModuleNode,
41 ChangedFileNodesGenerator, AddedFileNodesGenerator,
41 ChangedFileNodesGenerator, AddedFileNodesGenerator,
42 RemovedFileNodesGenerator, LargeFileNode)
42 RemovedFileNodesGenerator, LargeFileNode)
43
43
44
44
45 class GitCommit(base.BaseCommit):
45 class GitCommit(base.BaseCommit):
46 """
46 """
47 Represents state of the repository at single commit id.
47 Represents state of the repository at single commit id.
48 """
48 """
49
49
50 _filter_pre_load = [
50 _filter_pre_load = [
51 # done through a more complex tree walk on parents
51 # done through a more complex tree walk on parents
52 "affected_files",
52 "affected_files",
53 # done through subprocess not remote call
53 # done through subprocess not remote call
54 "children",
54 "children",
55 # done through a more complex tree walk on parents
55 # done through a more complex tree walk on parents
56 "status",
56 "status",
57 # mercurial specific property not supported here
57 # mercurial specific property not supported here
58 "_file_paths",
58 "_file_paths",
59 # mercurial specific property not supported here
59 # mercurial specific property not supported here
60 'obsolete',
60 'obsolete',
61 # mercurial specific property not supported here
61 # mercurial specific property not supported here
62 'phase',
62 'phase',
63 # mercurial specific property not supported here
63 # mercurial specific property not supported here
64 'hidden'
64 'hidden'
65 ]
65 ]
66
66
67 def __init__(self, repository, raw_id, idx, pre_load=None):
67 def __init__(self, repository, raw_id, idx, pre_load=None):
68 self.repository = repository
68 self.repository = repository
69 self._remote = repository._remote
69 self._remote = repository._remote
70 # TODO: johbo: Tweak of raw_id should not be necessary
70 # TODO: johbo: Tweak of raw_id should not be necessary
71 self.raw_id = safe_str(raw_id)
71 self.raw_id = safe_str(raw_id)
72 self.idx = idx
72 self.idx = idx
73
73
74 self._set_bulk_properties(pre_load)
74 self._set_bulk_properties(pre_load)
75
75
76 # caches
76 # caches
77 self._stat_modes = {} # stat info for paths
77 self._stat_modes = {} # stat info for paths
78 self._paths = {} # path processed with parse_tree
78 self._paths = {} # path processed with parse_tree
79 self.nodes = {}
79 self.nodes = {}
80 self._submodules = None
80 self._submodules = None
81
81
82 def _set_bulk_properties(self, pre_load):
82 def _set_bulk_properties(self, pre_load):
83
83
84 if not pre_load:
84 if not pre_load:
85 return
85 return
86 pre_load = [entry for entry in pre_load
86 pre_load = [entry for entry in pre_load
87 if entry not in self._filter_pre_load]
87 if entry not in self._filter_pre_load]
88 if not pre_load:
88 if not pre_load:
89 return
89 return
90
90
91 result = self._remote.bulk_request(self.raw_id, pre_load)
91 result = self._remote.bulk_request(self.raw_id, pre_load)
92 for attr, value in result.items():
92 for attr, value in result.items():
93 if attr in ["author", "message"]:
93 if attr in ["author", "message"]:
94 if value:
94 if value:
95 value = safe_unicode(value)
95 value = safe_unicode(value)
96 elif attr == "date":
96 elif attr == "date":
97 value = utcdate_fromtimestamp(*value)
97 value = utcdate_fromtimestamp(*value)
98 elif attr == "parents":
98 elif attr == "parents":
99 value = self._make_commits(value)
99 value = self._make_commits(value)
100 elif attr == "branch":
100 elif attr == "branch":
101 value = self._set_branch(value)
101 value = self._set_branch(value)
102 self.__dict__[attr] = value
102 self.__dict__[attr] = value
103
103
104 @LazyProperty
104 @LazyProperty
105 def _commit(self):
105 def _commit(self):
106 return self._remote[self.raw_id]
106 return self._remote[self.raw_id]
107
107
108 @LazyProperty
108 @LazyProperty
109 def _tree_id(self):
109 def _tree_id(self):
110 return self._remote[self._commit['tree']]['id']
110 return self._remote[self._commit['tree']]['id']
111
111
112 @LazyProperty
112 @LazyProperty
113 def id(self):
113 def id(self):
114 return self.raw_id
114 return self.raw_id
115
115
116 @LazyProperty
116 @LazyProperty
117 def short_id(self):
117 def short_id(self):
118 return self.raw_id[:12]
118 return self.raw_id[:12]
119
119
120 @LazyProperty
120 @LazyProperty
121 def message(self):
121 def message(self):
122 return safe_unicode(self._remote.message(self.id))
122 return safe_unicode(self._remote.message(self.id))
123
123
124 @LazyProperty
124 @LazyProperty
125 def committer(self):
125 def committer(self):
126 return safe_unicode(self._remote.author(self.id))
126 return safe_unicode(self._remote.author(self.id))
127
127
128 @LazyProperty
128 @LazyProperty
129 def author(self):
129 def author(self):
130 return safe_unicode(self._remote.author(self.id))
130 return safe_unicode(self._remote.author(self.id))
131
131
132 @LazyProperty
132 @LazyProperty
133 def date(self):
133 def date(self):
134 unix_ts, tz = self._remote.date(self.raw_id)
134 unix_ts, tz = self._remote.date(self.raw_id)
135 return utcdate_fromtimestamp(unix_ts, tz)
135 return utcdate_fromtimestamp(unix_ts, tz)
136
136
137 @LazyProperty
137 @LazyProperty
138 def status(self):
138 def status(self):
139 """
139 """
140 Returns modified, added, removed, deleted files for current commit
140 Returns modified, added, removed, deleted files for current commit
141 """
141 """
142 return self.changed, self.added, self.removed
142 return self.changed, self.added, self.removed
143
143
144 @LazyProperty
144 @LazyProperty
145 def tags(self):
145 def tags(self):
146 tags = [safe_unicode(name) for name,
146 tags = [safe_unicode(name) for name,
147 commit_id in self.repository.tags.items()
147 commit_id in self.repository.tags.items()
148 if commit_id == self.raw_id]
148 if commit_id == self.raw_id]
149 return tags
149 return tags
150
150
151 @LazyProperty
151 @LazyProperty
152 def commit_branches(self):
152 def commit_branches(self):
153 branches = []
153 branches = []
154 for name, commit_id in self.repository.branches.items():
154 for name, commit_id in self.repository.branches.items():
155 if commit_id == self.raw_id:
155 if commit_id == self.raw_id:
156 branches.append(name)
156 branches.append(name)
157 return branches
157 return branches
158
158
159 def _set_branch(self, branches):
159 def _set_branch(self, branches):
160 if branches:
160 if branches:
161 # actually commit can have multiple branches in git
161 # actually commit can have multiple branches in git
162 return safe_unicode(branches[0])
162 return safe_unicode(branches[0])
163
163
164 @LazyProperty
164 @LazyProperty
165 def branch(self):
165 def branch(self):
166 branches = self._remote.branch(self.raw_id)
166 branches = self._remote.branch(self.raw_id)
167 return self._set_branch(branches)
167 return self._set_branch(branches)
168
168
169 def _get_tree_id_for_path(self, path):
169 def _get_tree_id_for_path(self, path):
170 path = safe_str(path)
170 path = safe_str(path)
171 if path in self._paths:
171 if path in self._paths:
172 return self._paths[path]
172 return self._paths[path]
173
173
174 tree_id = self._tree_id
174 tree_id = self._tree_id
175
175
176 path = path.strip('/')
176 path = path.strip('/')
177 if path == '':
177 if path == '':
178 data = [tree_id, "tree"]
178 data = [tree_id, "tree"]
179 self._paths[''] = data
179 self._paths[''] = data
180 return data
180 return data
181
181
182 tree_id, tree_type, tree_mode = \
182 tree_id, tree_type, tree_mode = \
183 self._remote.tree_and_type_for_path(self.raw_id, path)
183 self._remote.tree_and_type_for_path(self.raw_id, path)
184 if tree_id is None:
184 if tree_id is None:
185 raise self.no_node_at_path(path)
185 raise self.no_node_at_path(path)
186
186
187 self._paths[path] = [tree_id, tree_type]
187 self._paths[path] = [tree_id, tree_type]
188 self._stat_modes[path] = tree_mode
188 self._stat_modes[path] = tree_mode
189
189
190 if path not in self._paths:
190 if path not in self._paths:
191 raise self.no_node_at_path(path)
191 raise self.no_node_at_path(path)
192
192
193 return self._paths[path]
193 return self._paths[path]
194
194
195 def _get_kind(self, path):
195 def _get_kind(self, path):
196 tree_id, type_ = self._get_tree_id_for_path(path)
196 tree_id, type_ = self._get_tree_id_for_path(path)
197 if type_ == 'blob':
197 if type_ == 'blob':
198 return NodeKind.FILE
198 return NodeKind.FILE
199 elif type_ == 'tree':
199 elif type_ == 'tree':
200 return NodeKind.DIR
200 return NodeKind.DIR
201 elif type_ == 'link':
201 elif type_ == 'link':
202 return NodeKind.SUBMODULE
202 return NodeKind.SUBMODULE
203 return None
203 return None
204
204
205 def _get_filectx(self, path):
205 def _get_filectx(self, path):
206 path = self._fix_path(path)
206 path = self._fix_path(path)
207 if self._get_kind(path) != NodeKind.FILE:
207 if self._get_kind(path) != NodeKind.FILE:
208 raise CommitError(
208 raise CommitError(
209 "File does not exist for commit %s at '%s'" % (self.raw_id, path))
209 "File does not exist for commit %s at '%s'" % (self.raw_id, path))
210 return path
210 return path
211
211
212 def _get_file_nodes(self):
212 def _get_file_nodes(self):
213 return chain(*(t[2] for t in self.walk()))
213 return chain(*(t[2] for t in self.walk()))
214
214
215 @LazyProperty
215 @LazyProperty
216 def parents(self):
216 def parents(self):
217 """
217 """
218 Returns list of parent commits.
218 Returns list of parent commits.
219 """
219 """
220 parent_ids = self._remote.parents(self.id)
220 parent_ids = self._remote.parents(self.id)
221 return self._make_commits(parent_ids)
221 return self._make_commits(parent_ids)
222
222
223 @LazyProperty
223 @LazyProperty
224 def children(self):
224 def children(self):
225 """
225 """
226 Returns list of child commits.
226 Returns list of child commits.
227 """
227 """
228
228
229 children = self._remote.children(self.raw_id)
229 children = self._remote.children(self.raw_id)
230 return self._make_commits(children)
230 return self._make_commits(children)
231
231
232 def _make_commits(self, commit_ids):
232 def _make_commits(self, commit_ids):
233 def commit_maker(_commit_id):
233 def commit_maker(_commit_id):
234 return self.repository.get_commit(commit_id=commit_id)
234 return self.repository.get_commit(commit_id=commit_id)
235
235
236 return [commit_maker(commit_id) for commit_id in commit_ids]
236 return [commit_maker(commit_id) for commit_id in commit_ids]
237
237
238 def get_file_mode(self, path):
238 def get_file_mode(self, path):
239 """
239 """
240 Returns stat mode of the file at the given `path`.
240 Returns stat mode of the file at the given `path`.
241 """
241 """
242 path = safe_str(path)
242 path = safe_str(path)
243 # ensure path is traversed
243 # ensure path is traversed
244 self._get_tree_id_for_path(path)
244 self._get_tree_id_for_path(path)
245 return self._stat_modes[path]
245 return self._stat_modes[path]
246
246
247 def is_link(self, path):
247 def is_link(self, path):
248 return stat.S_ISLNK(self.get_file_mode(path))
248 return stat.S_ISLNK(self.get_file_mode(path))
249
249
250 def is_node_binary(self, path):
250 def is_node_binary(self, path):
251 tree_id, _ = self._get_tree_id_for_path(path)
251 tree_id, _ = self._get_tree_id_for_path(path)
252 return self._remote.is_binary(tree_id)
252 return self._remote.is_binary(tree_id)
253
253
254 def get_file_content(self, path):
254 def get_file_content(self, path):
255 """
255 """
256 Returns content of the file at given `path`.
256 Returns content of the file at given `path`.
257 """
257 """
258 tree_id, _ = self._get_tree_id_for_path(path)
258 tree_id, _ = self._get_tree_id_for_path(path)
259 return self._remote.blob_as_pretty_string(tree_id)
259 return self._remote.blob_as_pretty_string(tree_id)
260
260
261 def get_file_content_streamed(self, path):
261 def get_file_content_streamed(self, path):
262 tree_id, _ = self._get_tree_id_for_path(path)
262 tree_id, _ = self._get_tree_id_for_path(path)
263 stream_method = getattr(self._remote, 'stream:blob_as_pretty_string')
263 stream_method = getattr(self._remote, 'stream:blob_as_pretty_string')
264 return stream_method(tree_id)
264 return stream_method(tree_id)
265
265
266 def get_file_size(self, path):
266 def get_file_size(self, path):
267 """
267 """
268 Returns size of the file at given `path`.
268 Returns size of the file at given `path`.
269 """
269 """
270 tree_id, _ = self._get_tree_id_for_path(path)
270 tree_id, _ = self._get_tree_id_for_path(path)
271 return self._remote.blob_raw_length(tree_id)
271 return self._remote.blob_raw_length(tree_id)
272
272
273 def get_path_history(self, path, limit=None, pre_load=None):
273 def get_path_history(self, path, limit=None, pre_load=None):
274 """
274 """
275 Returns history of file as reversed list of `GitCommit` objects for
275 Returns history of file as reversed list of `GitCommit` objects for
276 which file at given `path` has been modified.
276 which file at given `path` has been modified.
277 """
277 """
278
278
279 path = self._get_filectx(path)
279 path = self._get_filectx(path)
280 hist = self._remote.node_history(self.raw_id, path, limit)
280 hist = self._remote.node_history(self.raw_id, path, limit)
281 return [
281 return [
282 self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
282 self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
283 for commit_id in hist]
283 for commit_id in hist]
284
284
285 def get_file_annotate(self, path, pre_load=None):
285 def get_file_annotate(self, path, pre_load=None):
286 """
286 """
287 Returns a generator of four element tuples with
287 Returns a generator of four element tuples with
288 lineno, commit_id, commit lazy loader and line
288 lineno, commit_id, commit lazy loader and line
289 """
289 """
290
290
291 result = self._remote.node_annotate(self.raw_id, path)
291 result = self._remote.node_annotate(self.raw_id, path)
292
292
293 for ln_no, commit_id, content in result:
293 for ln_no, commit_id, content in result:
294 yield (
294 yield (
295 ln_no, commit_id,
295 ln_no, commit_id,
296 lambda: self.repository.get_commit(commit_id=commit_id, pre_load=pre_load),
296 lambda: self.repository.get_commit(commit_id=commit_id, pre_load=pre_load),
297 content)
297 content)
298
298
299 def get_nodes(self, path):
299 def get_nodes(self, path):
300
300
301 if self._get_kind(path) != NodeKind.DIR:
301 if self._get_kind(path) != NodeKind.DIR:
302 raise CommitError(
302 raise CommitError(
303 "Directory does not exist for commit %s at '%s'" % (self.raw_id, path))
303 "Directory does not exist for commit %s at '%s'" % (self.raw_id, path))
304 path = self._fix_path(path)
304 path = self._fix_path(path)
305
305
306 tree_id, _ = self._get_tree_id_for_path(path)
306 tree_id, _ = self._get_tree_id_for_path(path)
307
307
308 dirnodes = []
308 dirnodes = []
309 filenodes = []
309 filenodes = []
310
310
311 # extracted tree ID gives us our files...
311 # extracted tree ID gives us our files...
312 bytes_path = safe_str(path) # libgit operates on bytes
312 bytes_path = safe_str(path) # libgit operates on bytes
313 for name, stat_, id_, type_ in self._remote.tree_items(tree_id):
313 for name, stat_, id_, type_ in self._remote.tree_items(tree_id):
314 if type_ == 'link':
314 if type_ == 'link':
315 url = self._get_submodule_url('/'.join((bytes_path, name)))
315 url = self._get_submodule_url('/'.join((bytes_path, name)))
316 dirnodes.append(SubModuleNode(
316 dirnodes.append(SubModuleNode(
317 name, url=url, commit=id_, alias=self.repository.alias))
317 name, url=url, commit=id_, alias=self.repository.alias))
318 continue
318 continue
319
319
320 if bytes_path != '':
320 if bytes_path != '':
321 obj_path = '/'.join((bytes_path, name))
321 obj_path = '/'.join((bytes_path, name))
322 else:
322 else:
323 obj_path = name
323 obj_path = name
324 if obj_path not in self._stat_modes:
324 if obj_path not in self._stat_modes:
325 self._stat_modes[obj_path] = stat_
325 self._stat_modes[obj_path] = stat_
326
326
327 if type_ == 'tree':
327 if type_ == 'tree':
328 dirnodes.append(DirNode(obj_path, commit=self))
328 dirnodes.append(DirNode(obj_path, commit=self))
329 elif type_ == 'blob':
329 elif type_ == 'blob':
330 filenodes.append(FileNode(obj_path, commit=self, mode=stat_))
330 filenodes.append(FileNode(obj_path, commit=self, mode=stat_))
331 else:
331 else:
332 raise CommitError(
332 raise CommitError(
333 "Requested object should be Tree or Blob, is %s", type_)
333 "Requested object should be Tree or Blob, is %s", type_)
334
334
335 nodes = dirnodes + filenodes
335 nodes = dirnodes + filenodes
336 for node in nodes:
336 for node in nodes:
337 if node.path not in self.nodes:
337 if node.path not in self.nodes:
338 self.nodes[node.path] = node
338 self.nodes[node.path] = node
339 nodes.sort()
339 nodes.sort()
340 return nodes
340 return nodes
341
341
342 def get_node(self, path, pre_load=None):
342 def get_node(self, path, pre_load=None):
343 path = self._fix_path(path)
343 path = self._fix_path(path)
344 if path not in self.nodes:
344 if path not in self.nodes:
345 try:
345 try:
346 tree_id, type_ = self._get_tree_id_for_path(path)
346 tree_id, type_ = self._get_tree_id_for_path(path)
347 except CommitError:
347 except CommitError:
348 raise NodeDoesNotExistError(
348 raise NodeDoesNotExistError(
349 "Cannot find one of parents' directories for a given "
349 "Cannot find one of parents' directories for a given "
350 "path: %s" % path)
350 "path: %s" % path)
351
351
352 if type_ in ['link', 'commit']:
352 if type_ in ['link', 'commit']:
353 url = self._get_submodule_url(path)
353 url = self._get_submodule_url(path)
354 node = SubModuleNode(path, url=url, commit=tree_id,
354 node = SubModuleNode(path, url=url, commit=tree_id,
355 alias=self.repository.alias)
355 alias=self.repository.alias)
356 elif type_ == 'tree':
356 elif type_ == 'tree':
357 if path == '':
357 if path == '':
358 node = RootNode(commit=self)
358 node = RootNode(commit=self)
359 else:
359 else:
360 node = DirNode(path, commit=self)
360 node = DirNode(path, commit=self)
361 elif type_ == 'blob':
361 elif type_ == 'blob':
362 node = FileNode(path, commit=self, pre_load=pre_load)
362 node = FileNode(path, commit=self, pre_load=pre_load)
363 self._stat_modes[path] = node.mode
363 self._stat_modes[path] = node.mode
364 else:
364 else:
365 raise self.no_node_at_path(path)
365 raise self.no_node_at_path(path)
366
366
367 # cache node
367 # cache node
368 self.nodes[path] = node
368 self.nodes[path] = node
369
369
370 return self.nodes[path]
370 return self.nodes[path]
371
371
372 def get_largefile_node(self, path):
372 def get_largefile_node(self, path):
373 tree_id, _ = self._get_tree_id_for_path(path)
373 tree_id, _ = self._get_tree_id_for_path(path)
374 pointer_spec = self._remote.is_large_file(tree_id)
374 pointer_spec = self._remote.is_large_file(tree_id)
375
375
376 if pointer_spec:
376 if pointer_spec:
377 # content of that file regular FileNode is the hash of largefile
377 # content of that file regular FileNode is the hash of largefile
378 file_id = pointer_spec.get('oid_hash')
378 file_id = pointer_spec.get('oid_hash')
379 if self._remote.in_largefiles_store(file_id):
379 if self._remote.in_largefiles_store(file_id):
380 lf_path = self._remote.store_path(file_id)
380 lf_path = self._remote.store_path(file_id)
381 return LargeFileNode(lf_path, commit=self, org_path=path)
381 return LargeFileNode(lf_path, commit=self, org_path=path)
382
382
383 @LazyProperty
383 @LazyProperty
384 def affected_files(self):
384 def affected_files(self):
385 """
385 """
386 Gets a fast accessible file changes for given commit
386 Gets a fast accessible file changes for given commit
387 """
387 """
388 added, modified, deleted = self._changes_cache
388 added, modified, deleted = self._changes_cache
389 return list(added.union(modified).union(deleted))
389 return list(added.union(modified).union(deleted))
390
390
391 @LazyProperty
391 @LazyProperty
392 def _changes_cache(self):
392 def _changes_cache(self):
393 added = set()
393 added = set()
394 modified = set()
394 modified = set()
395 deleted = set()
395 deleted = set()
396 _r = self._remote
396 _r = self._remote
397
397
398 parents = self.parents
398 parents = self.parents
399 if not self.parents:
399 if not self.parents:
400 parents = [base.EmptyCommit()]
400 parents = [base.EmptyCommit()]
401 for parent in parents:
401 for parent in parents:
402 if isinstance(parent, base.EmptyCommit):
402 if isinstance(parent, base.EmptyCommit):
403 oid = None
403 oid = None
404 else:
404 else:
405 oid = parent.raw_id
405 oid = parent.raw_id
406 changes = _r.tree_changes(oid, self.raw_id)
406 changes = _r.tree_changes(oid, self.raw_id)
407 for (oldpath, newpath), (_, _), (_, _) in changes:
407 for (oldpath, newpath), (_, _), (_, _) in changes:
408 if newpath and oldpath:
408 if newpath and oldpath:
409 modified.add(newpath)
409 modified.add(newpath)
410 elif newpath and not oldpath:
410 elif newpath and not oldpath:
411 added.add(newpath)
411 added.add(newpath)
412 elif not newpath and oldpath:
412 elif not newpath and oldpath:
413 deleted.add(oldpath)
413 deleted.add(oldpath)
414 return added, modified, deleted
414 return added, modified, deleted
415
415
416 def _get_paths_for_status(self, status):
416 def _get_paths_for_status(self, status):
417 """
417 """
418 Returns sorted list of paths for given ``status``.
418 Returns sorted list of paths for given ``status``.
419
419
420 :param status: one of: *added*, *modified* or *deleted*
420 :param status: one of: *added*, *modified* or *deleted*
421 """
421 """
422 added, modified, deleted = self._changes_cache
422 added, modified, deleted = self._changes_cache
423 return sorted({
423 return sorted({
424 'added': list(added),
424 'added': list(added),
425 'modified': list(modified),
425 'modified': list(modified),
426 'deleted': list(deleted)}[status]
426 'deleted': list(deleted)}[status]
427 )
427 )
428
428
429 @LazyProperty
429 @LazyProperty
430 def added(self):
430 def added(self):
431 """
431 """
432 Returns list of added ``FileNode`` objects.
432 Returns list of added ``FileNode`` objects.
433 """
433 """
434 if not self.parents:
434 if not self.parents:
435 return list(self._get_file_nodes())
435 return list(self._get_file_nodes())
436 return AddedFileNodesGenerator(self.added_paths, self)
436 return AddedFileNodesGenerator(self.added_paths, self)
437
437
438 @LazyProperty
438 @LazyProperty
439 def added_paths(self):
439 def added_paths(self):
440 return [n for n in self._get_paths_for_status('added')]
440 return [n for n in self._get_paths_for_status('added')]
441
441
442 @LazyProperty
442 @LazyProperty
443 def changed(self):
443 def changed(self):
444 """
444 """
445 Returns list of modified ``FileNode`` objects.
445 Returns list of modified ``FileNode`` objects.
446 """
446 """
447 if not self.parents:
447 if not self.parents:
448 return []
448 return []
449 return ChangedFileNodesGenerator(self.changed_paths, self)
449 return ChangedFileNodesGenerator(self.changed_paths, self)
450
450
451 @LazyProperty
451 @LazyProperty
452 def changed_paths(self):
452 def changed_paths(self):
453 return [n for n in self._get_paths_for_status('modified')]
453 return [n for n in self._get_paths_for_status('modified')]
454
454
455 @LazyProperty
455 @LazyProperty
456 def removed(self):
456 def removed(self):
457 """
457 """
458 Returns list of removed ``FileNode`` objects.
458 Returns list of removed ``FileNode`` objects.
459 """
459 """
460 if not self.parents:
460 if not self.parents:
461 return []
461 return []
462 return RemovedFileNodesGenerator(self.removed_paths, self)
462 return RemovedFileNodesGenerator(self.removed_paths, self)
463
463
464 @LazyProperty
464 @LazyProperty
465 def removed_paths(self):
465 def removed_paths(self):
466 return [n for n in self._get_paths_for_status('deleted')]
466 return [n for n in self._get_paths_for_status('deleted')]
467
467
468 def _get_submodule_url(self, submodule_path):
468 def _get_submodule_url(self, submodule_path):
469 git_modules_path = '.gitmodules'
469 git_modules_path = '.gitmodules'
470
470
471 if self._submodules is None:
471 if self._submodules is None:
472 self._submodules = {}
472 self._submodules = {}
473
473
474 try:
474 try:
475 submodules_node = self.get_node(git_modules_path)
475 submodules_node = self.get_node(git_modules_path)
476 except NodeDoesNotExistError:
476 except NodeDoesNotExistError:
477 return None
477 return None
478
478
479 # ConfigParser fails if there are whitespaces, also it needs an iterable
479 # ConfigParser fails if there are whitespaces, also it needs an iterable
480 # file like content
480 # file like content
481 def iter_content(_content):
481 def iter_content(_content):
482 for line in _content.splitlines():
482 for line in _content.splitlines():
483 yield line
483 yield line
484
484
485 parser = configparser.RawConfigParser()
485 parser = configparser.RawConfigParser()
486 parser.read_file(iter_content(submodules_node.content))
486 parser.read_file(iter_content(submodules_node.content))
487
487
488 for section in parser.sections():
488 for section in parser.sections():
489 path = parser.get(section, 'path')
489 path = parser.get(section, 'path')
490 url = parser.get(section, 'url')
490 url = parser.get(section, 'url')
491 if path and url:
491 if path and url:
492 self._submodules[path.strip('/')] = url
492 self._submodules[path.strip('/')] = url
493
493
494 return self._submodules.get(submodule_path.strip('/'))
494 return self._submodules.get(submodule_path.strip('/'))
@@ -1,401 +1,402 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2014-2020 RhodeCode GmbH
3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 HG commit module
22 HG commit module
23 """
23 """
24
24
25 import os
25 import os
26
26
27 from zope.cachedescriptors.property import Lazy as LazyProperty
27 from zope.cachedescriptors.property import Lazy as LazyProperty
28
28
29 from rhodecode.lib.datelib import utcdate_fromtimestamp
29 from rhodecode.lib.datelib import utcdate_fromtimestamp
30 from rhodecode.lib.utils import safe_str, safe_unicode
30 from rhodecode.lib.utils import safe_str, safe_unicode
31 from rhodecode.lib.vcs import path as vcspath
31 from rhodecode.lib.vcs import path as vcspath
32 from rhodecode.lib.vcs.backends import base
32 from rhodecode.lib.vcs.backends import base
33 from rhodecode.lib.vcs.backends.hg.diff import MercurialDiff
33 from rhodecode.lib.vcs.backends.hg.diff import MercurialDiff
34 from rhodecode.lib.vcs.exceptions import CommitError
34 from rhodecode.lib.vcs.exceptions import CommitError
35 from rhodecode.lib.vcs.nodes import (
35 from rhodecode.lib.vcs.nodes import (
36 AddedFileNodesGenerator, ChangedFileNodesGenerator, DirNode, FileNode,
36 AddedFileNodesGenerator, ChangedFileNodesGenerator, DirNode, FileNode,
37 NodeKind, RemovedFileNodesGenerator, RootNode, SubModuleNode,
37 NodeKind, RemovedFileNodesGenerator, RootNode, SubModuleNode,
38 LargeFileNode, LARGEFILE_PREFIX)
38 LargeFileNode, LARGEFILE_PREFIX)
39 from rhodecode.lib.vcs.utils.paths import get_dirs_for_path
39 from rhodecode.lib.vcs.utils.paths import get_dirs_for_path
40
40
41
41
42 class MercurialCommit(base.BaseCommit):
42 class MercurialCommit(base.BaseCommit):
43 """
43 """
44 Represents state of the repository at the single commit.
44 Represents state of the repository at the single commit.
45 """
45 """
46
46
47 _filter_pre_load = [
47 _filter_pre_load = [
48 # git specific property not supported here
48 # git specific property not supported here
49 "_commit",
49 "_commit",
50 ]
50 ]
51
51
52 def __init__(self, repository, raw_id, idx, pre_load=None):
52 def __init__(self, repository, raw_id, idx, pre_load=None):
53 raw_id = safe_str(raw_id)
53 raw_id = safe_str(raw_id)
54
54
55 self.repository = repository
55 self.repository = repository
56 self._remote = repository._remote
56 self._remote = repository._remote
57
57
58 self.raw_id = raw_id
58 self.raw_id = raw_id
59 self.idx = idx
59 self.idx = idx
60
60
61 self._set_bulk_properties(pre_load)
61 self._set_bulk_properties(pre_load)
62
62
63 # caches
63 # caches
64 self.nodes = {}
64 self.nodes = {}
65
65
66 def _set_bulk_properties(self, pre_load):
66 def _set_bulk_properties(self, pre_load):
67 if not pre_load:
67 if not pre_load:
68 return
68 return
69 pre_load = [entry for entry in pre_load
69 pre_load = [entry for entry in pre_load
70 if entry not in self._filter_pre_load]
70 if entry not in self._filter_pre_load]
71 if not pre_load:
71 if not pre_load:
72 return
72 return
73
73
74 result = self._remote.bulk_request(self.raw_id, pre_load)
74 result = self._remote.bulk_request(self.raw_id, pre_load)
75
75 for attr, value in result.items():
76 for attr, value in result.items():
76 if attr in ["author", "branch", "message"]:
77 if attr in ["author", "branch", "message"]:
77 value = safe_unicode(value)
78 value = safe_unicode(value)
78 elif attr == "affected_files":
79 elif attr == "affected_files":
79 value = map(safe_unicode, value)
80 value = map(safe_unicode, value)
80 elif attr == "date":
81 elif attr == "date":
81 value = utcdate_fromtimestamp(*value)
82 value = utcdate_fromtimestamp(*value)
82 elif attr in ["children", "parents"]:
83 elif attr in ["children", "parents"]:
83 value = self._make_commits(value)
84 value = self._make_commits(value)
84 elif attr in ["phase"]:
85 elif attr in ["phase"]:
85 value = self._get_phase_text(value)
86 value = self._get_phase_text(value)
86 self.__dict__[attr] = value
87 self.__dict__[attr] = value
87
88
88 @LazyProperty
89 @LazyProperty
89 def tags(self):
90 def tags(self):
90 tags = [name for name, commit_id in self.repository.tags.items()
91 tags = [name for name, commit_id in self.repository.tags.items()
91 if commit_id == self.raw_id]
92 if commit_id == self.raw_id]
92 return tags
93 return tags
93
94
94 @LazyProperty
95 @LazyProperty
95 def branch(self):
96 def branch(self):
96 return safe_unicode(self._remote.ctx_branch(self.raw_id))
97 return safe_unicode(self._remote.ctx_branch(self.raw_id))
97
98
98 @LazyProperty
99 @LazyProperty
99 def bookmarks(self):
100 def bookmarks(self):
100 bookmarks = [
101 bookmarks = [
101 name for name, commit_id in self.repository.bookmarks.items()
102 name for name, commit_id in self.repository.bookmarks.items()
102 if commit_id == self.raw_id]
103 if commit_id == self.raw_id]
103 return bookmarks
104 return bookmarks
104
105
105 @LazyProperty
106 @LazyProperty
106 def message(self):
107 def message(self):
107 return safe_unicode(self._remote.ctx_description(self.raw_id))
108 return safe_unicode(self._remote.ctx_description(self.raw_id))
108
109
109 @LazyProperty
110 @LazyProperty
110 def committer(self):
111 def committer(self):
111 return safe_unicode(self.author)
112 return safe_unicode(self.author)
112
113
113 @LazyProperty
114 @LazyProperty
114 def author(self):
115 def author(self):
115 return safe_unicode(self._remote.ctx_user(self.raw_id))
116 return safe_unicode(self._remote.ctx_user(self.raw_id))
116
117
117 @LazyProperty
118 @LazyProperty
118 def date(self):
119 def date(self):
119 return utcdate_fromtimestamp(*self._remote.ctx_date(self.raw_id))
120 return utcdate_fromtimestamp(*self._remote.ctx_date(self.raw_id))
120
121
121 @LazyProperty
122 @LazyProperty
122 def status(self):
123 def status(self):
123 """
124 """
124 Returns modified, added, removed, deleted files for current commit
125 Returns modified, added, removed, deleted files for current commit
125 """
126 """
126 return self._remote.ctx_status(self.raw_id)
127 return self._remote.ctx_status(self.raw_id)
127
128
128 @LazyProperty
129 @LazyProperty
129 def _file_paths(self):
130 def _file_paths(self):
130 return self._remote.ctx_list(self.raw_id)
131 return self._remote.ctx_list(self.raw_id)
131
132
132 @LazyProperty
133 @LazyProperty
133 def _dir_paths(self):
134 def _dir_paths(self):
134 p = list(set(get_dirs_for_path(*self._file_paths)))
135 p = list(set(get_dirs_for_path(*self._file_paths)))
135 p.insert(0, '')
136 p.insert(0, '')
136 return p
137 return p
137
138
138 @LazyProperty
139 @LazyProperty
139 def _paths(self):
140 def _paths(self):
140 return self._dir_paths + self._file_paths
141 return self._dir_paths + self._file_paths
141
142
142 @LazyProperty
143 @LazyProperty
143 def id(self):
144 def id(self):
144 if self.last:
145 if self.last:
145 return u'tip'
146 return u'tip'
146 return self.short_id
147 return self.short_id
147
148
148 @LazyProperty
149 @LazyProperty
149 def short_id(self):
150 def short_id(self):
150 return self.raw_id[:12]
151 return self.raw_id[:12]
151
152
152 def _make_commits(self, commit_ids, pre_load=None):
153 def _make_commits(self, commit_ids, pre_load=None):
153 return [self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
154 return [self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
154 for commit_id in commit_ids]
155 for commit_id in commit_ids]
155
156
156 @LazyProperty
157 @LazyProperty
157 def parents(self):
158 def parents(self):
158 """
159 """
159 Returns list of parent commits.
160 Returns list of parent commits.
160 """
161 """
161 parents = self._remote.ctx_parents(self.raw_id)
162 parents = self._remote.ctx_parents(self.raw_id)
162 return self._make_commits(parents)
163 return self._make_commits(parents)
163
164
164 def _get_phase_text(self, phase_id):
165 def _get_phase_text(self, phase_id):
165 return {
166 return {
166 0: 'public',
167 0: 'public',
167 1: 'draft',
168 1: 'draft',
168 2: 'secret',
169 2: 'secret',
169 }.get(phase_id) or ''
170 }.get(phase_id) or ''
170
171
171 @LazyProperty
172 @LazyProperty
172 def phase(self):
173 def phase(self):
173 phase_id = self._remote.ctx_phase(self.raw_id)
174 phase_id = self._remote.ctx_phase(self.raw_id)
174 phase_text = self._get_phase_text(phase_id)
175 phase_text = self._get_phase_text(phase_id)
175
176
176 return safe_unicode(phase_text)
177 return safe_unicode(phase_text)
177
178
178 @LazyProperty
179 @LazyProperty
179 def obsolete(self):
180 def obsolete(self):
180 obsolete = self._remote.ctx_obsolete(self.raw_id)
181 obsolete = self._remote.ctx_obsolete(self.raw_id)
181 return obsolete
182 return obsolete
182
183
183 @LazyProperty
184 @LazyProperty
184 def hidden(self):
185 def hidden(self):
185 hidden = self._remote.ctx_hidden(self.raw_id)
186 hidden = self._remote.ctx_hidden(self.raw_id)
186 return hidden
187 return hidden
187
188
188 @LazyProperty
189 @LazyProperty
189 def children(self):
190 def children(self):
190 """
191 """
191 Returns list of child commits.
192 Returns list of child commits.
192 """
193 """
193 children = self._remote.ctx_children(self.raw_id)
194 children = self._remote.ctx_children(self.raw_id)
194 return self._make_commits(children)
195 return self._make_commits(children)
195
196
196 def _fix_path(self, path):
197 def _fix_path(self, path):
197 """
198 """
198 Mercurial keeps filenodes as str so we need to encode from unicode
199 Mercurial keeps filenodes as str so we need to encode from unicode
199 to str.
200 to str.
200 """
201 """
201 return safe_str(super(MercurialCommit, self)._fix_path(path))
202 return safe_str(super(MercurialCommit, self)._fix_path(path))
202
203
203 def _get_kind(self, path):
204 def _get_kind(self, path):
204 path = self._fix_path(path)
205 path = self._fix_path(path)
205 if path in self._file_paths:
206 if path in self._file_paths:
206 return NodeKind.FILE
207 return NodeKind.FILE
207 elif path in self._dir_paths:
208 elif path in self._dir_paths:
208 return NodeKind.DIR
209 return NodeKind.DIR
209 else:
210 else:
210 raise CommitError(
211 raise CommitError(
211 "Node does not exist at the given path '%s'" % (path, ))
212 "Node does not exist at the given path '%s'" % (path, ))
212
213
213 def _get_filectx(self, path):
214 def _get_filectx(self, path):
214 path = self._fix_path(path)
215 path = self._fix_path(path)
215 if self._get_kind(path) != NodeKind.FILE:
216 if self._get_kind(path) != NodeKind.FILE:
216 raise CommitError(
217 raise CommitError(
217 "File does not exist for idx %s at '%s'" % (self.raw_id, path))
218 "File does not exist for idx %s at '%s'" % (self.raw_id, path))
218 return path
219 return path
219
220
220 def get_file_mode(self, path):
221 def get_file_mode(self, path):
221 """
222 """
222 Returns stat mode of the file at the given ``path``.
223 Returns stat mode of the file at the given ``path``.
223 """
224 """
224 path = self._get_filectx(path)
225 path = self._get_filectx(path)
225 if 'x' in self._remote.fctx_flags(self.raw_id, path):
226 if 'x' in self._remote.fctx_flags(self.raw_id, path):
226 return base.FILEMODE_EXECUTABLE
227 return base.FILEMODE_EXECUTABLE
227 else:
228 else:
228 return base.FILEMODE_DEFAULT
229 return base.FILEMODE_DEFAULT
229
230
230 def is_link(self, path):
231 def is_link(self, path):
231 path = self._get_filectx(path)
232 path = self._get_filectx(path)
232 return 'l' in self._remote.fctx_flags(self.raw_id, path)
233 return 'l' in self._remote.fctx_flags(self.raw_id, path)
233
234
234 def is_node_binary(self, path):
235 def is_node_binary(self, path):
235 path = self._get_filectx(path)
236 path = self._get_filectx(path)
236 return self._remote.is_binary(self.raw_id, path)
237 return self._remote.is_binary(self.raw_id, path)
237
238
238 def get_file_content(self, path):
239 def get_file_content(self, path):
239 """
240 """
240 Returns content of the file at given ``path``.
241 Returns content of the file at given ``path``.
241 """
242 """
242 path = self._get_filectx(path)
243 path = self._get_filectx(path)
243 return self._remote.fctx_node_data(self.raw_id, path)
244 return self._remote.fctx_node_data(self.raw_id, path)
244
245
245 def get_file_content_streamed(self, path):
246 def get_file_content_streamed(self, path):
246 path = self._get_filectx(path)
247 path = self._get_filectx(path)
247 stream_method = getattr(self._remote, 'stream:fctx_node_data')
248 stream_method = getattr(self._remote, 'stream:fctx_node_data')
248 return stream_method(self.raw_id, path)
249 return stream_method(self.raw_id, path)
249
250
250 def get_file_size(self, path):
251 def get_file_size(self, path):
251 """
252 """
252 Returns size of the file at given ``path``.
253 Returns size of the file at given ``path``.
253 """
254 """
254 path = self._get_filectx(path)
255 path = self._get_filectx(path)
255 return self._remote.fctx_size(self.raw_id, path)
256 return self._remote.fctx_size(self.raw_id, path)
256
257
257 def get_path_history(self, path, limit=None, pre_load=None):
258 def get_path_history(self, path, limit=None, pre_load=None):
258 """
259 """
259 Returns history of file as reversed list of `MercurialCommit` objects
260 Returns history of file as reversed list of `MercurialCommit` objects
260 for which file at given ``path`` has been modified.
261 for which file at given ``path`` has been modified.
261 """
262 """
262 path = self._get_filectx(path)
263 path = self._get_filectx(path)
263 hist = self._remote.node_history(self.raw_id, path, limit)
264 hist = self._remote.node_history(self.raw_id, path, limit)
264 return [
265 return [
265 self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
266 self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
266 for commit_id in hist]
267 for commit_id in hist]
267
268
268 def get_file_annotate(self, path, pre_load=None):
269 def get_file_annotate(self, path, pre_load=None):
269 """
270 """
270 Returns a generator of four element tuples with
271 Returns a generator of four element tuples with
271 lineno, commit_id, commit lazy loader and line
272 lineno, commit_id, commit lazy loader and line
272 """
273 """
273 result = self._remote.fctx_annotate(self.raw_id, path)
274 result = self._remote.fctx_annotate(self.raw_id, path)
274
275
275 for ln_no, commit_id, content in result:
276 for ln_no, commit_id, content in result:
276 yield (
277 yield (
277 ln_no, commit_id,
278 ln_no, commit_id,
278 lambda: self.repository.get_commit(commit_id=commit_id, pre_load=pre_load),
279 lambda: self.repository.get_commit(commit_id=commit_id, pre_load=pre_load),
279 content)
280 content)
280
281
281 def get_nodes(self, path):
282 def get_nodes(self, path):
282 """
283 """
283 Returns combined ``DirNode`` and ``FileNode`` objects list representing
284 Returns combined ``DirNode`` and ``FileNode`` objects list representing
284 state of commit at the given ``path``. If node at the given ``path``
285 state of commit at the given ``path``. If node at the given ``path``
285 is not instance of ``DirNode``, CommitError would be raised.
286 is not instance of ``DirNode``, CommitError would be raised.
286 """
287 """
287
288
288 if self._get_kind(path) != NodeKind.DIR:
289 if self._get_kind(path) != NodeKind.DIR:
289 raise CommitError(
290 raise CommitError(
290 "Directory does not exist for idx %s at '%s'" % (self.raw_id, path))
291 "Directory does not exist for idx %s at '%s'" % (self.raw_id, path))
291 path = self._fix_path(path)
292 path = self._fix_path(path)
292
293
293 filenodes = [
294 filenodes = [
294 FileNode(f, commit=self) for f in self._file_paths
295 FileNode(f, commit=self) for f in self._file_paths
295 if os.path.dirname(f) == path]
296 if os.path.dirname(f) == path]
296 # TODO: johbo: Check if this can be done in a more obvious way
297 # TODO: johbo: Check if this can be done in a more obvious way
297 dirs = path == '' and '' or [
298 dirs = path == '' and '' or [
298 d for d in self._dir_paths
299 d for d in self._dir_paths
299 if d and vcspath.dirname(d) == path]
300 if d and vcspath.dirname(d) == path]
300 dirnodes = [
301 dirnodes = [
301 DirNode(d, commit=self) for d in dirs
302 DirNode(d, commit=self) for d in dirs
302 if os.path.dirname(d) == path]
303 if os.path.dirname(d) == path]
303
304
304 alias = self.repository.alias
305 alias = self.repository.alias
305 for k, vals in self._submodules.items():
306 for k, vals in self._submodules.items():
306 if vcspath.dirname(k) == path:
307 if vcspath.dirname(k) == path:
307 loc = vals[0]
308 loc = vals[0]
308 commit = vals[1]
309 commit = vals[1]
309 dirnodes.append(SubModuleNode(k, url=loc, commit=commit, alias=alias))
310 dirnodes.append(SubModuleNode(k, url=loc, commit=commit, alias=alias))
310
311
311 nodes = dirnodes + filenodes
312 nodes = dirnodes + filenodes
312 for node in nodes:
313 for node in nodes:
313 if node.path not in self.nodes:
314 if node.path not in self.nodes:
314 self.nodes[node.path] = node
315 self.nodes[node.path] = node
315 nodes.sort()
316 nodes.sort()
316
317
317 return nodes
318 return nodes
318
319
319 def get_node(self, path, pre_load=None):
320 def get_node(self, path, pre_load=None):
320 """
321 """
321 Returns `Node` object from the given `path`. If there is no node at
322 Returns `Node` object from the given `path`. If there is no node at
322 the given `path`, `NodeDoesNotExistError` would be raised.
323 the given `path`, `NodeDoesNotExistError` would be raised.
323 """
324 """
324 path = self._fix_path(path)
325 path = self._fix_path(path)
325
326
326 if path not in self.nodes:
327 if path not in self.nodes:
327 if path in self._file_paths:
328 if path in self._file_paths:
328 node = FileNode(path, commit=self, pre_load=pre_load)
329 node = FileNode(path, commit=self, pre_load=pre_load)
329 elif path in self._dir_paths:
330 elif path in self._dir_paths:
330 if path == '':
331 if path == '':
331 node = RootNode(commit=self)
332 node = RootNode(commit=self)
332 else:
333 else:
333 node = DirNode(path, commit=self)
334 node = DirNode(path, commit=self)
334 else:
335 else:
335 raise self.no_node_at_path(path)
336 raise self.no_node_at_path(path)
336
337
337 # cache node
338 # cache node
338 self.nodes[path] = node
339 self.nodes[path] = node
339 return self.nodes[path]
340 return self.nodes[path]
340
341
341 def get_largefile_node(self, path):
342 def get_largefile_node(self, path):
342 pointer_spec = self._remote.is_large_file(self.raw_id, path)
343 pointer_spec = self._remote.is_large_file(self.raw_id, path)
343 if pointer_spec:
344 if pointer_spec:
344 # content of that file regular FileNode is the hash of largefile
345 # content of that file regular FileNode is the hash of largefile
345 file_id = self.get_file_content(path).strip()
346 file_id = self.get_file_content(path).strip()
346
347
347 if self._remote.in_largefiles_store(file_id):
348 if self._remote.in_largefiles_store(file_id):
348 lf_path = self._remote.store_path(file_id)
349 lf_path = self._remote.store_path(file_id)
349 return LargeFileNode(lf_path, commit=self, org_path=path)
350 return LargeFileNode(lf_path, commit=self, org_path=path)
350 elif self._remote.in_user_cache(file_id):
351 elif self._remote.in_user_cache(file_id):
351 lf_path = self._remote.store_path(file_id)
352 lf_path = self._remote.store_path(file_id)
352 self._remote.link(file_id, path)
353 self._remote.link(file_id, path)
353 return LargeFileNode(lf_path, commit=self, org_path=path)
354 return LargeFileNode(lf_path, commit=self, org_path=path)
354
355
355 @LazyProperty
356 @LazyProperty
356 def _submodules(self):
357 def _submodules(self):
357 """
358 """
358 Returns a dictionary with submodule information from substate file
359 Returns a dictionary with submodule information from substate file
359 of hg repository.
360 of hg repository.
360 """
361 """
361 return self._remote.ctx_substate(self.raw_id)
362 return self._remote.ctx_substate(self.raw_id)
362
363
363 @LazyProperty
364 @LazyProperty
364 def affected_files(self):
365 def affected_files(self):
365 """
366 """
366 Gets a fast accessible file changes for given commit
367 Gets a fast accessible file changes for given commit
367 """
368 """
368 return self._remote.ctx_files(self.raw_id)
369 return self._remote.ctx_files(self.raw_id)
369
370
370 @property
371 @property
371 def added(self):
372 def added(self):
372 """
373 """
373 Returns list of added ``FileNode`` objects.
374 Returns list of added ``FileNode`` objects.
374 """
375 """
375 return AddedFileNodesGenerator(self.added_paths, self)
376 return AddedFileNodesGenerator(self.added_paths, self)
376
377
377 @LazyProperty
378 @LazyProperty
378 def added_paths(self):
379 def added_paths(self):
379 return [n for n in self.status[1]]
380 return [n for n in self.status[1]]
380
381
381 @property
382 @property
382 def changed(self):
383 def changed(self):
383 """
384 """
384 Returns list of modified ``FileNode`` objects.
385 Returns list of modified ``FileNode`` objects.
385 """
386 """
386 return ChangedFileNodesGenerator(self.changed_paths, self)
387 return ChangedFileNodesGenerator(self.changed_paths, self)
387
388
388 @LazyProperty
389 @LazyProperty
389 def changed_paths(self):
390 def changed_paths(self):
390 return [n for n in self.status[0]]
391 return [n for n in self.status[0]]
391
392
392 @property
393 @property
393 def removed(self):
394 def removed(self):
394 """
395 """
395 Returns list of removed ``FileNode`` objects.
396 Returns list of removed ``FileNode`` objects.
396 """
397 """
397 return RemovedFileNodesGenerator(self.removed_paths, self)
398 return RemovedFileNodesGenerator(self.removed_paths, self)
398
399
399 @LazyProperty
400 @LazyProperty
400 def removed_paths(self):
401 def removed_paths(self):
401 return [n for n in self.status[2]]
402 return [n for n in self.status[2]]
@@ -1,600 +1,600 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 permissions model for RhodeCode
22 permissions model for RhodeCode
23 """
23 """
24 import collections
24 import collections
25 import logging
25 import logging
26 import traceback
26 import traceback
27
27
28 from sqlalchemy.exc import DatabaseError
28 from sqlalchemy.exc import DatabaseError
29
29
30 from rhodecode import events
30 from rhodecode import events
31 from rhodecode.model import BaseModel
31 from rhodecode.model import BaseModel
32 from rhodecode.model.db import (
32 from rhodecode.model.db import (
33 User, Permission, UserToPerm, UserRepoToPerm, UserRepoGroupToPerm,
33 User, Permission, UserToPerm, UserRepoToPerm, UserRepoGroupToPerm,
34 UserUserGroupToPerm, UserGroup, UserGroupToPerm, UserToRepoBranchPermission)
34 UserUserGroupToPerm, UserGroup, UserGroupToPerm, UserToRepoBranchPermission)
35 from rhodecode.lib.utils2 import str2bool, safe_int
35 from rhodecode.lib.utils2 import str2bool, safe_int
36
36
37 log = logging.getLogger(__name__)
37 log = logging.getLogger(__name__)
38
38
39
39
40 class PermissionModel(BaseModel):
40 class PermissionModel(BaseModel):
41 """
41 """
42 Permissions model for RhodeCode
42 Permissions model for RhodeCode
43 """
43 """
44 FORKING_DISABLED = 'hg.fork.none'
44 FORKING_DISABLED = 'hg.fork.none'
45 FORKING_ENABLED = 'hg.fork.repository'
45 FORKING_ENABLED = 'hg.fork.repository'
46
46
47 cls = Permission
47 cls = Permission
48 global_perms = {
48 global_perms = {
49 'default_repo_create': None,
49 'default_repo_create': None,
50 # special case for create repos on write access to group
50 # special case for create repos on write access to group
51 'default_repo_create_on_write': None,
51 'default_repo_create_on_write': None,
52 'default_repo_group_create': None,
52 'default_repo_group_create': None,
53 'default_user_group_create': None,
53 'default_user_group_create': None,
54 'default_fork_create': None,
54 'default_fork_create': None,
55 'default_inherit_default_permissions': None,
55 'default_inherit_default_permissions': None,
56 'default_register': None,
56 'default_register': None,
57 'default_password_reset': None,
57 'default_password_reset': None,
58 'default_extern_activate': None,
58 'default_extern_activate': None,
59
59
60 # object permissions below
60 # object permissions below
61 'default_repo_perm': None,
61 'default_repo_perm': None,
62 'default_group_perm': None,
62 'default_group_perm': None,
63 'default_user_group_perm': None,
63 'default_user_group_perm': None,
64
64
65 # branch
65 # branch
66 'default_branch_perm': None,
66 'default_branch_perm': None,
67 }
67 }
68
68
69 def set_global_permission_choices(self, c_obj, gettext_translator):
69 def set_global_permission_choices(self, c_obj, gettext_translator):
70 _ = gettext_translator
70 _ = gettext_translator
71
71
72 c_obj.repo_perms_choices = [
72 c_obj.repo_perms_choices = [
73 ('repository.none', _('None'),),
73 ('repository.none', _('None'),),
74 ('repository.read', _('Read'),),
74 ('repository.read', _('Read'),),
75 ('repository.write', _('Write'),),
75 ('repository.write', _('Write'),),
76 ('repository.admin', _('Admin'),)]
76 ('repository.admin', _('Admin'),)]
77
77
78 c_obj.group_perms_choices = [
78 c_obj.group_perms_choices = [
79 ('group.none', _('None'),),
79 ('group.none', _('None'),),
80 ('group.read', _('Read'),),
80 ('group.read', _('Read'),),
81 ('group.write', _('Write'),),
81 ('group.write', _('Write'),),
82 ('group.admin', _('Admin'),)]
82 ('group.admin', _('Admin'),)]
83
83
84 c_obj.user_group_perms_choices = [
84 c_obj.user_group_perms_choices = [
85 ('usergroup.none', _('None'),),
85 ('usergroup.none', _('None'),),
86 ('usergroup.read', _('Read'),),
86 ('usergroup.read', _('Read'),),
87 ('usergroup.write', _('Write'),),
87 ('usergroup.write', _('Write'),),
88 ('usergroup.admin', _('Admin'),)]
88 ('usergroup.admin', _('Admin'),)]
89
89
90 c_obj.branch_perms_choices = [
90 c_obj.branch_perms_choices = [
91 ('branch.none', _('Protected/No Access'),),
91 ('branch.none', _('Protected/No Access'),),
92 ('branch.merge', _('Web merge'),),
92 ('branch.merge', _('Web merge'),),
93 ('branch.push', _('Push'),),
93 ('branch.push', _('Push'),),
94 ('branch.push_force', _('Force Push'),)]
94 ('branch.push_force', _('Force Push'),)]
95
95
96 c_obj.register_choices = [
96 c_obj.register_choices = [
97 ('hg.register.none', _('Disabled')),
97 ('hg.register.none', _('Disabled')),
98 ('hg.register.manual_activate', _('Allowed with manual account activation')),
98 ('hg.register.manual_activate', _('Allowed with manual account activation')),
99 ('hg.register.auto_activate', _('Allowed with automatic account activation')),]
99 ('hg.register.auto_activate', _('Allowed with automatic account activation'))]
100
100
101 c_obj.password_reset_choices = [
101 c_obj.password_reset_choices = [
102 ('hg.password_reset.enabled', _('Allow password recovery')),
102 ('hg.password_reset.enabled', _('Allow password recovery')),
103 ('hg.password_reset.hidden', _('Hide password recovery link')),
103 ('hg.password_reset.hidden', _('Hide password recovery link')),
104 ('hg.password_reset.disabled', _('Disable password recovery')),]
104 ('hg.password_reset.disabled', _('Disable password recovery'))]
105
105
106 c_obj.extern_activate_choices = [
106 c_obj.extern_activate_choices = [
107 ('hg.extern_activate.manual', _('Manual activation of external account')),
107 ('hg.extern_activate.manual', _('Manual activation of external account')),
108 ('hg.extern_activate.auto', _('Automatic activation of external account')),]
108 ('hg.extern_activate.auto', _('Automatic activation of external account'))]
109
109
110 c_obj.repo_create_choices = [
110 c_obj.repo_create_choices = [
111 ('hg.create.none', _('Disabled')),
111 ('hg.create.none', _('Disabled')),
112 ('hg.create.repository', _('Enabled'))]
112 ('hg.create.repository', _('Enabled'))]
113
113
114 c_obj.repo_create_on_write_choices = [
114 c_obj.repo_create_on_write_choices = [
115 ('hg.create.write_on_repogroup.false', _('Disabled')),
115 ('hg.create.write_on_repogroup.false', _('Disabled')),
116 ('hg.create.write_on_repogroup.true', _('Enabled'))]
116 ('hg.create.write_on_repogroup.true', _('Enabled'))]
117
117
118 c_obj.user_group_create_choices = [
118 c_obj.user_group_create_choices = [
119 ('hg.usergroup.create.false', _('Disabled')),
119 ('hg.usergroup.create.false', _('Disabled')),
120 ('hg.usergroup.create.true', _('Enabled'))]
120 ('hg.usergroup.create.true', _('Enabled'))]
121
121
122 c_obj.repo_group_create_choices = [
122 c_obj.repo_group_create_choices = [
123 ('hg.repogroup.create.false', _('Disabled')),
123 ('hg.repogroup.create.false', _('Disabled')),
124 ('hg.repogroup.create.true', _('Enabled'))]
124 ('hg.repogroup.create.true', _('Enabled'))]
125
125
126 c_obj.fork_choices = [
126 c_obj.fork_choices = [
127 (self.FORKING_DISABLED, _('Disabled')),
127 (self.FORKING_DISABLED, _('Disabled')),
128 (self.FORKING_ENABLED, _('Enabled'))]
128 (self.FORKING_ENABLED, _('Enabled'))]
129
129
130 c_obj.inherit_default_permission_choices = [
130 c_obj.inherit_default_permission_choices = [
131 ('hg.inherit_default_perms.false', _('Disabled')),
131 ('hg.inherit_default_perms.false', _('Disabled')),
132 ('hg.inherit_default_perms.true', _('Enabled'))]
132 ('hg.inherit_default_perms.true', _('Enabled'))]
133
133
134 def get_default_perms(self, object_perms, suffix):
134 def get_default_perms(self, object_perms, suffix):
135 defaults = {}
135 defaults = {}
136 for perm in object_perms:
136 for perm in object_perms:
137 # perms
137 # perms
138 if perm.permission.permission_name.startswith('repository.'):
138 if perm.permission.permission_name.startswith('repository.'):
139 defaults['default_repo_perm' + suffix] = perm.permission.permission_name
139 defaults['default_repo_perm' + suffix] = perm.permission.permission_name
140
140
141 if perm.permission.permission_name.startswith('group.'):
141 if perm.permission.permission_name.startswith('group.'):
142 defaults['default_group_perm' + suffix] = perm.permission.permission_name
142 defaults['default_group_perm' + suffix] = perm.permission.permission_name
143
143
144 if perm.permission.permission_name.startswith('usergroup.'):
144 if perm.permission.permission_name.startswith('usergroup.'):
145 defaults['default_user_group_perm' + suffix] = perm.permission.permission_name
145 defaults['default_user_group_perm' + suffix] = perm.permission.permission_name
146
146
147 # branch
147 # branch
148 if perm.permission.permission_name.startswith('branch.'):
148 if perm.permission.permission_name.startswith('branch.'):
149 defaults['default_branch_perm' + suffix] = perm.permission.permission_name
149 defaults['default_branch_perm' + suffix] = perm.permission.permission_name
150
150
151 # creation of objects
151 # creation of objects
152 if perm.permission.permission_name.startswith('hg.create.write_on_repogroup'):
152 if perm.permission.permission_name.startswith('hg.create.write_on_repogroup'):
153 defaults['default_repo_create_on_write' + suffix] = perm.permission.permission_name
153 defaults['default_repo_create_on_write' + suffix] = perm.permission.permission_name
154
154
155 elif perm.permission.permission_name.startswith('hg.create.'):
155 elif perm.permission.permission_name.startswith('hg.create.'):
156 defaults['default_repo_create' + suffix] = perm.permission.permission_name
156 defaults['default_repo_create' + suffix] = perm.permission.permission_name
157
157
158 if perm.permission.permission_name.startswith('hg.fork.'):
158 if perm.permission.permission_name.startswith('hg.fork.'):
159 defaults['default_fork_create' + suffix] = perm.permission.permission_name
159 defaults['default_fork_create' + suffix] = perm.permission.permission_name
160
160
161 if perm.permission.permission_name.startswith('hg.inherit_default_perms.'):
161 if perm.permission.permission_name.startswith('hg.inherit_default_perms.'):
162 defaults['default_inherit_default_permissions' + suffix] = perm.permission.permission_name
162 defaults['default_inherit_default_permissions' + suffix] = perm.permission.permission_name
163
163
164 if perm.permission.permission_name.startswith('hg.repogroup.'):
164 if perm.permission.permission_name.startswith('hg.repogroup.'):
165 defaults['default_repo_group_create' + suffix] = perm.permission.permission_name
165 defaults['default_repo_group_create' + suffix] = perm.permission.permission_name
166
166
167 if perm.permission.permission_name.startswith('hg.usergroup.'):
167 if perm.permission.permission_name.startswith('hg.usergroup.'):
168 defaults['default_user_group_create' + suffix] = perm.permission.permission_name
168 defaults['default_user_group_create' + suffix] = perm.permission.permission_name
169
169
170 # registration and external account activation
170 # registration and external account activation
171 if perm.permission.permission_name.startswith('hg.register.'):
171 if perm.permission.permission_name.startswith('hg.register.'):
172 defaults['default_register' + suffix] = perm.permission.permission_name
172 defaults['default_register' + suffix] = perm.permission.permission_name
173
173
174 if perm.permission.permission_name.startswith('hg.password_reset.'):
174 if perm.permission.permission_name.startswith('hg.password_reset.'):
175 defaults['default_password_reset' + suffix] = perm.permission.permission_name
175 defaults['default_password_reset' + suffix] = perm.permission.permission_name
176
176
177 if perm.permission.permission_name.startswith('hg.extern_activate.'):
177 if perm.permission.permission_name.startswith('hg.extern_activate.'):
178 defaults['default_extern_activate' + suffix] = perm.permission.permission_name
178 defaults['default_extern_activate' + suffix] = perm.permission.permission_name
179
179
180 return defaults
180 return defaults
181
181
182 def _make_new_user_perm(self, user, perm_name):
182 def _make_new_user_perm(self, user, perm_name):
183 log.debug('Creating new user permission:%s', perm_name)
183 log.debug('Creating new user permission:%s', perm_name)
184 new = UserToPerm()
184 new = UserToPerm()
185 new.user = user
185 new.user = user
186 new.permission = Permission.get_by_key(perm_name)
186 new.permission = Permission.get_by_key(perm_name)
187 return new
187 return new
188
188
189 def _make_new_user_group_perm(self, user_group, perm_name):
189 def _make_new_user_group_perm(self, user_group, perm_name):
190 log.debug('Creating new user group permission:%s', perm_name)
190 log.debug('Creating new user group permission:%s', perm_name)
191 new = UserGroupToPerm()
191 new = UserGroupToPerm()
192 new.users_group = user_group
192 new.users_group = user_group
193 new.permission = Permission.get_by_key(perm_name)
193 new.permission = Permission.get_by_key(perm_name)
194 return new
194 return new
195
195
196 def _keep_perm(self, perm_name, keep_fields):
196 def _keep_perm(self, perm_name, keep_fields):
197 def get_pat(field_name):
197 def get_pat(field_name):
198 return {
198 return {
199 # global perms
199 # global perms
200 'default_repo_create': 'hg.create.',
200 'default_repo_create': 'hg.create.',
201 # special case for create repos on write access to group
201 # special case for create repos on write access to group
202 'default_repo_create_on_write': 'hg.create.write_on_repogroup.',
202 'default_repo_create_on_write': 'hg.create.write_on_repogroup.',
203 'default_repo_group_create': 'hg.repogroup.create.',
203 'default_repo_group_create': 'hg.repogroup.create.',
204 'default_user_group_create': 'hg.usergroup.create.',
204 'default_user_group_create': 'hg.usergroup.create.',
205 'default_fork_create': 'hg.fork.',
205 'default_fork_create': 'hg.fork.',
206 'default_inherit_default_permissions': 'hg.inherit_default_perms.',
206 'default_inherit_default_permissions': 'hg.inherit_default_perms.',
207
207
208 # application perms
208 # application perms
209 'default_register': 'hg.register.',
209 'default_register': 'hg.register.',
210 'default_password_reset': 'hg.password_reset.',
210 'default_password_reset': 'hg.password_reset.',
211 'default_extern_activate': 'hg.extern_activate.',
211 'default_extern_activate': 'hg.extern_activate.',
212
212
213 # object permissions below
213 # object permissions below
214 'default_repo_perm': 'repository.',
214 'default_repo_perm': 'repository.',
215 'default_group_perm': 'group.',
215 'default_group_perm': 'group.',
216 'default_user_group_perm': 'usergroup.',
216 'default_user_group_perm': 'usergroup.',
217 # branch
217 # branch
218 'default_branch_perm': 'branch.',
218 'default_branch_perm': 'branch.',
219
219
220 }[field_name]
220 }[field_name]
221 for field in keep_fields:
221 for field in keep_fields:
222 pat = get_pat(field)
222 pat = get_pat(field)
223 if perm_name.startswith(pat):
223 if perm_name.startswith(pat):
224 return True
224 return True
225 return False
225 return False
226
226
227 def _clear_object_perm(self, object_perms, preserve=None):
227 def _clear_object_perm(self, object_perms, preserve=None):
228 preserve = preserve or []
228 preserve = preserve or []
229 _deleted = []
229 _deleted = []
230 for perm in object_perms:
230 for perm in object_perms:
231 perm_name = perm.permission.permission_name
231 perm_name = perm.permission.permission_name
232 if not self._keep_perm(perm_name, keep_fields=preserve):
232 if not self._keep_perm(perm_name, keep_fields=preserve):
233 _deleted.append(perm_name)
233 _deleted.append(perm_name)
234 self.sa.delete(perm)
234 self.sa.delete(perm)
235 return _deleted
235 return _deleted
236
236
237 def _clear_user_perms(self, user_id, preserve=None):
237 def _clear_user_perms(self, user_id, preserve=None):
238 perms = self.sa.query(UserToPerm)\
238 perms = self.sa.query(UserToPerm)\
239 .filter(UserToPerm.user_id == user_id)\
239 .filter(UserToPerm.user_id == user_id)\
240 .all()
240 .all()
241 return self._clear_object_perm(perms, preserve=preserve)
241 return self._clear_object_perm(perms, preserve=preserve)
242
242
243 def _clear_user_group_perms(self, user_group_id, preserve=None):
243 def _clear_user_group_perms(self, user_group_id, preserve=None):
244 perms = self.sa.query(UserGroupToPerm)\
244 perms = self.sa.query(UserGroupToPerm)\
245 .filter(UserGroupToPerm.users_group_id == user_group_id)\
245 .filter(UserGroupToPerm.users_group_id == user_group_id)\
246 .all()
246 .all()
247 return self._clear_object_perm(perms, preserve=preserve)
247 return self._clear_object_perm(perms, preserve=preserve)
248
248
249 def _set_new_object_perms(self, obj_type, object, form_result, preserve=None):
249 def _set_new_object_perms(self, obj_type, to_object, form_result, preserve=None):
250 # clear current entries, to make this function idempotent
250 # clear current entries, to make this function idempotent
251 # it will fix even if we define more permissions or permissions
251 # it will fix even if we define more permissions or permissions
252 # are somehow missing
252 # are somehow missing
253 preserve = preserve or []
253 preserve = preserve or []
254 _global_perms = self.global_perms.copy()
254 _global_perms = self.global_perms.copy()
255 if obj_type not in ['user', 'user_group']:
255 if obj_type not in ['user', 'user_group']:
256 raise ValueError("obj_type must be on of 'user' or 'user_group'")
256 raise ValueError("obj_type must be on of 'user' or 'user_group'")
257 global_perms = len(_global_perms)
257 global_perms = len(_global_perms)
258 default_user_perms = len(Permission.DEFAULT_USER_PERMISSIONS)
258 default_user_perms = len(Permission.DEFAULT_USER_PERMISSIONS)
259 if global_perms != default_user_perms:
259 if global_perms != default_user_perms:
260 raise Exception(
260 raise Exception(
261 'Inconsistent permissions definition. Got {} vs {}'.format(
261 'Inconsistent permissions definition. Got {} vs {}'.format(
262 global_perms, default_user_perms))
262 global_perms, default_user_perms))
263
263
264 if obj_type == 'user':
264 if obj_type == 'user':
265 self._clear_user_perms(object.user_id, preserve)
265 self._clear_user_perms(to_object.user_id, preserve)
266 if obj_type == 'user_group':
266 if obj_type == 'user_group':
267 self._clear_user_group_perms(object.users_group_id, preserve)
267 self._clear_user_group_perms(to_object.users_group_id, preserve)
268
268
269 # now kill the keys that we want to preserve from the form.
269 # now kill the keys that we want to preserve from the form.
270 for key in preserve:
270 for key in preserve:
271 del _global_perms[key]
271 del _global_perms[key]
272
272
273 for k in _global_perms.copy():
273 for k in _global_perms.copy():
274 _global_perms[k] = form_result[k]
274 _global_perms[k] = form_result[k]
275
275
276 # at that stage we validate all are passed inside form_result
276 # at that stage we validate all are passed inside form_result
277 for _perm_key, perm_value in _global_perms.items():
277 for _perm_key, perm_value in _global_perms.items():
278 if perm_value is None:
278 if perm_value is None:
279 raise ValueError('Missing permission for %s' % (_perm_key,))
279 raise ValueError('Missing permission for %s' % (_perm_key,))
280
280
281 if obj_type == 'user':
281 if obj_type == 'user':
282 p = self._make_new_user_perm(object, perm_value)
282 p = self._make_new_user_perm(object, perm_value)
283 self.sa.add(p)
283 self.sa.add(p)
284 if obj_type == 'user_group':
284 if obj_type == 'user_group':
285 p = self._make_new_user_group_perm(object, perm_value)
285 p = self._make_new_user_group_perm(object, perm_value)
286 self.sa.add(p)
286 self.sa.add(p)
287
287
288 def _set_new_user_perms(self, user, form_result, preserve=None):
288 def _set_new_user_perms(self, user, form_result, preserve=None):
289 return self._set_new_object_perms(
289 return self._set_new_object_perms(
290 'user', user, form_result, preserve)
290 'user', user, form_result, preserve)
291
291
292 def _set_new_user_group_perms(self, user_group, form_result, preserve=None):
292 def _set_new_user_group_perms(self, user_group, form_result, preserve=None):
293 return self._set_new_object_perms(
293 return self._set_new_object_perms(
294 'user_group', user_group, form_result, preserve)
294 'user_group', user_group, form_result, preserve)
295
295
296 def set_new_user_perms(self, user, form_result):
296 def set_new_user_perms(self, user, form_result):
297 # calculate what to preserve from what is given in form_result
297 # calculate what to preserve from what is given in form_result
298 preserve = set(self.global_perms.keys()).difference(set(form_result.keys()))
298 preserve = set(self.global_perms.keys()).difference(set(form_result.keys()))
299 return self._set_new_user_perms(user, form_result, preserve)
299 return self._set_new_user_perms(user, form_result, preserve)
300
300
301 def set_new_user_group_perms(self, user_group, form_result):
301 def set_new_user_group_perms(self, user_group, form_result):
302 # calculate what to preserve from what is given in form_result
302 # calculate what to preserve from what is given in form_result
303 preserve = set(self.global_perms.keys()).difference(set(form_result.keys()))
303 preserve = set(self.global_perms.keys()).difference(set(form_result.keys()))
304 return self._set_new_user_group_perms(user_group, form_result, preserve)
304 return self._set_new_user_group_perms(user_group, form_result, preserve)
305
305
306 def create_permissions(self):
306 def create_permissions(self):
307 """
307 """
308 Create permissions for whole system
308 Create permissions for whole system
309 """
309 """
310 for p in Permission.PERMS:
310 for p in Permission.PERMS:
311 if not Permission.get_by_key(p[0]):
311 if not Permission.get_by_key(p[0]):
312 new_perm = Permission()
312 new_perm = Permission()
313 new_perm.permission_name = p[0]
313 new_perm.permission_name = p[0]
314 new_perm.permission_longname = p[0] # translation err with p[1]
314 new_perm.permission_longname = p[0] # translation err with p[1]
315 self.sa.add(new_perm)
315 self.sa.add(new_perm)
316
316
317 def _create_default_object_permission(self, obj_type, obj, obj_perms,
317 def _create_default_object_permission(self, obj_type, obj, obj_perms,
318 force=False):
318 force=False):
319 if obj_type not in ['user', 'user_group']:
319 if obj_type not in ['user', 'user_group']:
320 raise ValueError("obj_type must be on of 'user' or 'user_group'")
320 raise ValueError("obj_type must be on of 'user' or 'user_group'")
321
321
322 def _get_group(perm_name):
322 def _get_group(perm_name):
323 return '.'.join(perm_name.split('.')[:1])
323 return '.'.join(perm_name.split('.')[:1])
324
324
325 defined_perms_groups = map(
325 defined_perms_groups = map(
326 _get_group, (x.permission.permission_name for x in obj_perms))
326 _get_group, (x.permission.permission_name for x in obj_perms))
327 log.debug('GOT ALREADY DEFINED:%s', obj_perms)
327 log.debug('GOT ALREADY DEFINED:%s', obj_perms)
328
328
329 if force:
329 if force:
330 self._clear_object_perm(obj_perms)
330 self._clear_object_perm(obj_perms)
331 self.sa.commit()
331 self.sa.commit()
332 defined_perms_groups = []
332 defined_perms_groups = []
333 # for every default permission that needs to be created, we check if
333 # for every default permission that needs to be created, we check if
334 # it's group is already defined, if it's not we create default perm
334 # it's group is already defined, if it's not we create default perm
335 for perm_name in Permission.DEFAULT_USER_PERMISSIONS:
335 for perm_name in Permission.DEFAULT_USER_PERMISSIONS:
336 gr = _get_group(perm_name)
336 gr = _get_group(perm_name)
337 if gr not in defined_perms_groups:
337 if gr not in defined_perms_groups:
338 log.debug('GR:%s not found, creating permission %s',
338 log.debug('GR:%s not found, creating permission %s',
339 gr, perm_name)
339 gr, perm_name)
340 if obj_type == 'user':
340 if obj_type == 'user':
341 new_perm = self._make_new_user_perm(obj, perm_name)
341 new_perm = self._make_new_user_perm(obj, perm_name)
342 self.sa.add(new_perm)
342 self.sa.add(new_perm)
343 if obj_type == 'user_group':
343 if obj_type == 'user_group':
344 new_perm = self._make_new_user_group_perm(obj, perm_name)
344 new_perm = self._make_new_user_group_perm(obj, perm_name)
345 self.sa.add(new_perm)
345 self.sa.add(new_perm)
346
346
347 def create_default_user_permissions(self, user, force=False):
347 def create_default_user_permissions(self, user, force=False):
348 """
348 """
349 Creates only missing default permissions for user, if force is set it
349 Creates only missing default permissions for user, if force is set it
350 resets the default permissions for that user
350 resets the default permissions for that user
351
351
352 :param user:
352 :param user:
353 :param force:
353 :param force:
354 """
354 """
355 user = self._get_user(user)
355 user = self._get_user(user)
356 obj_perms = UserToPerm.query().filter(UserToPerm.user == user).all()
356 obj_perms = UserToPerm.query().filter(UserToPerm.user == user).all()
357 return self._create_default_object_permission(
357 return self._create_default_object_permission(
358 'user', user, obj_perms, force)
358 'user', user, obj_perms, force)
359
359
360 def create_default_user_group_permissions(self, user_group, force=False):
360 def create_default_user_group_permissions(self, user_group, force=False):
361 """
361 """
362 Creates only missing default permissions for user group, if force is
362 Creates only missing default permissions for user group, if force is
363 set it resets the default permissions for that user group
363 set it resets the default permissions for that user group
364
364
365 :param user_group:
365 :param user_group:
366 :param force:
366 :param force:
367 """
367 """
368 user_group = self._get_user_group(user_group)
368 user_group = self._get_user_group(user_group)
369 obj_perms = UserToPerm.query().filter(UserGroupToPerm.users_group == user_group).all()
369 obj_perms = UserToPerm.query().filter(UserGroupToPerm.users_group == user_group).all()
370 return self._create_default_object_permission(
370 return self._create_default_object_permission(
371 'user_group', user_group, obj_perms, force)
371 'user_group', user_group, obj_perms, force)
372
372
373 def update_application_permissions(self, form_result):
373 def update_application_permissions(self, form_result):
374 if 'perm_user_id' in form_result:
374 if 'perm_user_id' in form_result:
375 perm_user = User.get(safe_int(form_result['perm_user_id']))
375 perm_user = User.get(safe_int(form_result['perm_user_id']))
376 else:
376 else:
377 # used mostly to do lookup for default user
377 # used mostly to do lookup for default user
378 perm_user = User.get_by_username(form_result['perm_user_name'])
378 perm_user = User.get_by_username(form_result['perm_user_name'])
379
379
380 try:
380 try:
381 # stage 1 set anonymous access
381 # stage 1 set anonymous access
382 if perm_user.username == User.DEFAULT_USER:
382 if perm_user.username == User.DEFAULT_USER:
383 perm_user.active = str2bool(form_result['anonymous'])
383 perm_user.active = str2bool(form_result['anonymous'])
384 self.sa.add(perm_user)
384 self.sa.add(perm_user)
385
385
386 # stage 2 reset defaults and set them from form data
386 # stage 2 reset defaults and set them from form data
387 self._set_new_user_perms(perm_user, form_result, preserve=[
387 self._set_new_user_perms(perm_user, form_result, preserve=[
388 'default_repo_perm',
388 'default_repo_perm',
389 'default_group_perm',
389 'default_group_perm',
390 'default_user_group_perm',
390 'default_user_group_perm',
391 'default_branch_perm',
391 'default_branch_perm',
392
392
393 'default_repo_group_create',
393 'default_repo_group_create',
394 'default_user_group_create',
394 'default_user_group_create',
395 'default_repo_create_on_write',
395 'default_repo_create_on_write',
396 'default_repo_create',
396 'default_repo_create',
397 'default_fork_create',
397 'default_fork_create',
398 'default_inherit_default_permissions',])
398 'default_inherit_default_permissions'])
399
399
400 self.sa.commit()
400 self.sa.commit()
401 except (DatabaseError,):
401 except (DatabaseError,):
402 log.error(traceback.format_exc())
402 log.error(traceback.format_exc())
403 self.sa.rollback()
403 self.sa.rollback()
404 raise
404 raise
405
405
406 def update_user_permissions(self, form_result):
406 def update_user_permissions(self, form_result):
407 if 'perm_user_id' in form_result:
407 if 'perm_user_id' in form_result:
408 perm_user = User.get(safe_int(form_result['perm_user_id']))
408 perm_user = User.get(safe_int(form_result['perm_user_id']))
409 else:
409 else:
410 # used mostly to do lookup for default user
410 # used mostly to do lookup for default user
411 perm_user = User.get_by_username(form_result['perm_user_name'])
411 perm_user = User.get_by_username(form_result['perm_user_name'])
412 try:
412 try:
413 # stage 2 reset defaults and set them from form data
413 # stage 2 reset defaults and set them from form data
414 self._set_new_user_perms(perm_user, form_result, preserve=[
414 self._set_new_user_perms(perm_user, form_result, preserve=[
415 'default_repo_perm',
415 'default_repo_perm',
416 'default_group_perm',
416 'default_group_perm',
417 'default_user_group_perm',
417 'default_user_group_perm',
418 'default_branch_perm',
418 'default_branch_perm',
419
419
420 'default_register',
420 'default_register',
421 'default_password_reset',
421 'default_password_reset',
422 'default_extern_activate'])
422 'default_extern_activate'])
423 self.sa.commit()
423 self.sa.commit()
424 except (DatabaseError,):
424 except (DatabaseError,):
425 log.error(traceback.format_exc())
425 log.error(traceback.format_exc())
426 self.sa.rollback()
426 self.sa.rollback()
427 raise
427 raise
428
428
429 def update_user_group_permissions(self, form_result):
429 def update_user_group_permissions(self, form_result):
430 if 'perm_user_group_id' in form_result:
430 if 'perm_user_group_id' in form_result:
431 perm_user_group = UserGroup.get(safe_int(form_result['perm_user_group_id']))
431 perm_user_group = UserGroup.get(safe_int(form_result['perm_user_group_id']))
432 else:
432 else:
433 # used mostly to do lookup for default user
433 # used mostly to do lookup for default user
434 perm_user_group = UserGroup.get_by_group_name(form_result['perm_user_group_name'])
434 perm_user_group = UserGroup.get_by_group_name(form_result['perm_user_group_name'])
435 try:
435 try:
436 # stage 2 reset defaults and set them from form data
436 # stage 2 reset defaults and set them from form data
437 self._set_new_user_group_perms(perm_user_group, form_result, preserve=[
437 self._set_new_user_group_perms(perm_user_group, form_result, preserve=[
438 'default_repo_perm',
438 'default_repo_perm',
439 'default_group_perm',
439 'default_group_perm',
440 'default_user_group_perm',
440 'default_user_group_perm',
441 'default_branch_perm',
441 'default_branch_perm',
442
442
443 'default_register',
443 'default_register',
444 'default_password_reset',
444 'default_password_reset',
445 'default_extern_activate'])
445 'default_extern_activate'])
446 self.sa.commit()
446 self.sa.commit()
447 except (DatabaseError,):
447 except (DatabaseError,):
448 log.error(traceback.format_exc())
448 log.error(traceback.format_exc())
449 self.sa.rollback()
449 self.sa.rollback()
450 raise
450 raise
451
451
452 def update_object_permissions(self, form_result):
452 def update_object_permissions(self, form_result):
453 if 'perm_user_id' in form_result:
453 if 'perm_user_id' in form_result:
454 perm_user = User.get(safe_int(form_result['perm_user_id']))
454 perm_user = User.get(safe_int(form_result['perm_user_id']))
455 else:
455 else:
456 # used mostly to do lookup for default user
456 # used mostly to do lookup for default user
457 perm_user = User.get_by_username(form_result['perm_user_name'])
457 perm_user = User.get_by_username(form_result['perm_user_name'])
458 try:
458 try:
459
459
460 # stage 2 reset defaults and set them from form data
460 # stage 2 reset defaults and set them from form data
461 self._set_new_user_perms(perm_user, form_result, preserve=[
461 self._set_new_user_perms(perm_user, form_result, preserve=[
462 'default_repo_group_create',
462 'default_repo_group_create',
463 'default_user_group_create',
463 'default_user_group_create',
464 'default_repo_create_on_write',
464 'default_repo_create_on_write',
465 'default_repo_create',
465 'default_repo_create',
466 'default_fork_create',
466 'default_fork_create',
467 'default_inherit_default_permissions',
467 'default_inherit_default_permissions',
468 'default_branch_perm',
468 'default_branch_perm',
469
469
470 'default_register',
470 'default_register',
471 'default_password_reset',
471 'default_password_reset',
472 'default_extern_activate'])
472 'default_extern_activate'])
473
473
474 # overwrite default repo permissions
474 # overwrite default repo permissions
475 if form_result['overwrite_default_repo']:
475 if form_result['overwrite_default_repo']:
476 _def_name = form_result['default_repo_perm'].split('repository.')[-1]
476 _def_name = form_result['default_repo_perm'].split('repository.')[-1]
477 _def = Permission.get_by_key('repository.' + _def_name)
477 _def = Permission.get_by_key('repository.' + _def_name)
478 for r2p in self.sa.query(UserRepoToPerm)\
478 for r2p in self.sa.query(UserRepoToPerm)\
479 .filter(UserRepoToPerm.user == perm_user)\
479 .filter(UserRepoToPerm.user == perm_user)\
480 .all():
480 .all():
481 # don't reset PRIVATE repositories
481 # don't reset PRIVATE repositories
482 if not r2p.repository.private:
482 if not r2p.repository.private:
483 r2p.permission = _def
483 r2p.permission = _def
484 self.sa.add(r2p)
484 self.sa.add(r2p)
485
485
486 # overwrite default repo group permissions
486 # overwrite default repo group permissions
487 if form_result['overwrite_default_group']:
487 if form_result['overwrite_default_group']:
488 _def_name = form_result['default_group_perm'].split('group.')[-1]
488 _def_name = form_result['default_group_perm'].split('group.')[-1]
489 _def = Permission.get_by_key('group.' + _def_name)
489 _def = Permission.get_by_key('group.' + _def_name)
490 for g2p in self.sa.query(UserRepoGroupToPerm)\
490 for g2p in self.sa.query(UserRepoGroupToPerm)\
491 .filter(UserRepoGroupToPerm.user == perm_user)\
491 .filter(UserRepoGroupToPerm.user == perm_user)\
492 .all():
492 .all():
493 g2p.permission = _def
493 g2p.permission = _def
494 self.sa.add(g2p)
494 self.sa.add(g2p)
495
495
496 # overwrite default user group permissions
496 # overwrite default user group permissions
497 if form_result['overwrite_default_user_group']:
497 if form_result['overwrite_default_user_group']:
498 _def_name = form_result['default_user_group_perm'].split('usergroup.')[-1]
498 _def_name = form_result['default_user_group_perm'].split('usergroup.')[-1]
499 # user groups
499 # user groups
500 _def = Permission.get_by_key('usergroup.' + _def_name)
500 _def = Permission.get_by_key('usergroup.' + _def_name)
501 for g2p in self.sa.query(UserUserGroupToPerm)\
501 for g2p in self.sa.query(UserUserGroupToPerm)\
502 .filter(UserUserGroupToPerm.user == perm_user)\
502 .filter(UserUserGroupToPerm.user == perm_user)\
503 .all():
503 .all():
504 g2p.permission = _def
504 g2p.permission = _def
505 self.sa.add(g2p)
505 self.sa.add(g2p)
506
506
507 # COMMIT
507 # COMMIT
508 self.sa.commit()
508 self.sa.commit()
509 except (DatabaseError,):
509 except (DatabaseError,):
510 log.exception('Failed to set default object permissions')
510 log.exception('Failed to set default object permissions')
511 self.sa.rollback()
511 self.sa.rollback()
512 raise
512 raise
513
513
514 def update_branch_permissions(self, form_result):
514 def update_branch_permissions(self, form_result):
515 if 'perm_user_id' in form_result:
515 if 'perm_user_id' in form_result:
516 perm_user = User.get(safe_int(form_result['perm_user_id']))
516 perm_user = User.get(safe_int(form_result['perm_user_id']))
517 else:
517 else:
518 # used mostly to do lookup for default user
518 # used mostly to do lookup for default user
519 perm_user = User.get_by_username(form_result['perm_user_name'])
519 perm_user = User.get_by_username(form_result['perm_user_name'])
520 try:
520 try:
521
521
522 # stage 2 reset defaults and set them from form data
522 # stage 2 reset defaults and set them from form data
523 self._set_new_user_perms(perm_user, form_result, preserve=[
523 self._set_new_user_perms(perm_user, form_result, preserve=[
524 'default_repo_perm',
524 'default_repo_perm',
525 'default_group_perm',
525 'default_group_perm',
526 'default_user_group_perm',
526 'default_user_group_perm',
527
527
528 'default_repo_group_create',
528 'default_repo_group_create',
529 'default_user_group_create',
529 'default_user_group_create',
530 'default_repo_create_on_write',
530 'default_repo_create_on_write',
531 'default_repo_create',
531 'default_repo_create',
532 'default_fork_create',
532 'default_fork_create',
533 'default_inherit_default_permissions',
533 'default_inherit_default_permissions',
534
534
535 'default_register',
535 'default_register',
536 'default_password_reset',
536 'default_password_reset',
537 'default_extern_activate'])
537 'default_extern_activate'])
538
538
539 # overwrite default branch permissions
539 # overwrite default branch permissions
540 if form_result['overwrite_default_branch']:
540 if form_result['overwrite_default_branch']:
541 _def_name = \
541 _def_name = \
542 form_result['default_branch_perm'].split('branch.')[-1]
542 form_result['default_branch_perm'].split('branch.')[-1]
543
543
544 _def = Permission.get_by_key('branch.' + _def_name)
544 _def = Permission.get_by_key('branch.' + _def_name)
545
545
546 user_perms = UserToRepoBranchPermission.query()\
546 user_perms = UserToRepoBranchPermission.query()\
547 .join(UserToRepoBranchPermission.user_repo_to_perm)\
547 .join(UserToRepoBranchPermission.user_repo_to_perm)\
548 .filter(UserRepoToPerm.user == perm_user).all()
548 .filter(UserRepoToPerm.user == perm_user).all()
549
549
550 for g2p in user_perms:
550 for g2p in user_perms:
551 g2p.permission = _def
551 g2p.permission = _def
552 self.sa.add(g2p)
552 self.sa.add(g2p)
553
553
554 # COMMIT
554 # COMMIT
555 self.sa.commit()
555 self.sa.commit()
556 except (DatabaseError,):
556 except (DatabaseError,):
557 log.exception('Failed to set default branch permissions')
557 log.exception('Failed to set default branch permissions')
558 self.sa.rollback()
558 self.sa.rollback()
559 raise
559 raise
560
560
561 def get_users_with_repo_write(self, db_repo):
561 def get_users_with_repo_write(self, db_repo):
562 write_plus = ['repository.write', 'repository.admin']
562 write_plus = ['repository.write', 'repository.admin']
563 default_user_id = User.get_default_user_id()
563 default_user_id = User.get_default_user_id()
564 user_write_permissions = collections.OrderedDict()
564 user_write_permissions = collections.OrderedDict()
565
565
566 # write or higher and DEFAULT user for inheritance
566 # write or higher and DEFAULT user for inheritance
567 for perm in db_repo.permissions():
567 for perm in db_repo.permissions():
568 if perm.permission in write_plus or perm.user_id == default_user_id:
568 if perm.permission in write_plus or perm.user_id == default_user_id:
569 user_write_permissions[perm.user_id] = perm
569 user_write_permissions[perm.user_id] = perm
570 return user_write_permissions
570 return user_write_permissions
571
571
572 def get_user_groups_with_repo_write(self, db_repo):
572 def get_user_groups_with_repo_write(self, db_repo):
573 write_plus = ['repository.write', 'repository.admin']
573 write_plus = ['repository.write', 'repository.admin']
574 user_group_write_permissions = collections.OrderedDict()
574 user_group_write_permissions = collections.OrderedDict()
575
575
576 # write or higher and DEFAULT user for inheritance
576 # write or higher and DEFAULT user for inheritance
577 for p in db_repo.permission_user_groups():
577 for p in db_repo.permission_user_groups():
578 if p.permission in write_plus:
578 if p.permission in write_plus:
579 user_group_write_permissions[p.users_group_id] = p
579 user_group_write_permissions[p.users_group_id] = p
580 return user_group_write_permissions
580 return user_group_write_permissions
581
581
582 def trigger_permission_flush(self, affected_user_ids=None):
582 def trigger_permission_flush(self, affected_user_ids=None):
583 affected_user_ids = affected_user_ids or User.get_all_user_ids()
583 affected_user_ids = affected_user_ids or User.get_all_user_ids()
584 events.trigger(events.UserPermissionsChange(affected_user_ids))
584 events.trigger(events.UserPermissionsChange(affected_user_ids))
585
585
586 def flush_user_permission_caches(self, changes, affected_user_ids=None):
586 def flush_user_permission_caches(self, changes, affected_user_ids=None):
587 affected_user_ids = affected_user_ids or []
587 affected_user_ids = affected_user_ids or []
588
588
589 for change in changes['added'] + changes['updated'] + changes['deleted']:
589 for change in changes['added'] + changes['updated'] + changes['deleted']:
590 if change['type'] == 'user':
590 if change['type'] == 'user':
591 affected_user_ids.append(change['id'])
591 affected_user_ids.append(change['id'])
592 if change['type'] == 'user_group':
592 if change['type'] == 'user_group':
593 user_group = UserGroup.get(safe_int(change['id']))
593 user_group = UserGroup.get(safe_int(change['id']))
594 if user_group:
594 if user_group:
595 group_members_ids = [x.user_id for x in user_group.members]
595 group_members_ids = [x.user_id for x in user_group.members]
596 affected_user_ids.extend(group_members_ids)
596 affected_user_ids.extend(group_members_ids)
597
597
598 self.trigger_permission_flush(affected_user_ids)
598 self.trigger_permission_flush(affected_user_ids)
599
599
600 return affected_user_ids
600 return affected_user_ids
@@ -1,1028 +1,1028 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 Scm model for RhodeCode
22 Scm model for RhodeCode
23 """
23 """
24
24
25 import os.path
25 import os.path
26 import traceback
26 import traceback
27 import logging
27 import logging
28 from io import StringIO
28 import io
29
29
30 from sqlalchemy import func
30 from sqlalchemy import func
31 from zope.cachedescriptors.property import Lazy as LazyProperty
31 from zope.cachedescriptors.property import Lazy as LazyProperty
32
32
33 import rhodecode
33 import rhodecode
34 from rhodecode.lib.vcs import get_backend
34 from rhodecode.lib.vcs import get_backend
35 from rhodecode.lib.vcs.exceptions import RepositoryError, NodeNotChangedError
35 from rhodecode.lib.vcs.exceptions import RepositoryError, NodeNotChangedError
36 from rhodecode.lib.vcs.nodes import FileNode
36 from rhodecode.lib.vcs.nodes import FileNode
37 from rhodecode.lib.vcs.backends.base import EmptyCommit
37 from rhodecode.lib.vcs.backends.base import EmptyCommit
38 from rhodecode.lib import helpers as h, rc_cache
38 from rhodecode.lib import helpers as h, rc_cache
39 from rhodecode.lib.auth import (
39 from rhodecode.lib.auth import (
40 HasRepoPermissionAny, HasRepoGroupPermissionAny,
40 HasRepoPermissionAny, HasRepoGroupPermissionAny,
41 HasUserGroupPermissionAny)
41 HasUserGroupPermissionAny)
42 from rhodecode.lib.exceptions import NonRelativePathError, IMCCommitError
42 from rhodecode.lib.exceptions import NonRelativePathError, IMCCommitError
43 from rhodecode.lib import hooks_utils
43 from rhodecode.lib import hooks_utils
44 from rhodecode.lib.utils import (
44 from rhodecode.lib.utils import (
45 get_filesystem_repos, make_db_config)
45 get_filesystem_repos, make_db_config)
46 from rhodecode.lib.utils2 import (safe_str, safe_unicode)
46 from rhodecode.lib.utils2 import (safe_str, safe_unicode)
47 from rhodecode.lib.system_info import get_system_info
47 from rhodecode.lib.system_info import get_system_info
48 from rhodecode.model import BaseModel
48 from rhodecode.model import BaseModel
49 from rhodecode.model.db import (
49 from rhodecode.model.db import (
50 or_, false,
50 or_, false,
51 Repository, CacheKey, UserFollowing, UserLog, User, RepoGroup,
51 Repository, CacheKey, UserFollowing, UserLog, User, RepoGroup,
52 PullRequest, FileStore)
52 PullRequest, FileStore)
53 from rhodecode.model.settings import VcsSettingsModel
53 from rhodecode.model.settings import VcsSettingsModel
54 from rhodecode.model.validation_schema.validators import url_validator, InvalidCloneUrl
54 from rhodecode.model.validation_schema.validators import url_validator, InvalidCloneUrl
55
55
56 log = logging.getLogger(__name__)
56 log = logging.getLogger(__name__)
57
57
58
58
59 class UserTemp(object):
59 class UserTemp(object):
60 def __init__(self, user_id):
60 def __init__(self, user_id):
61 self.user_id = user_id
61 self.user_id = user_id
62
62
63 def __repr__(self):
63 def __repr__(self):
64 return "<%s('id:%s')>" % (self.__class__.__name__, self.user_id)
64 return "<%s('id:%s')>" % (self.__class__.__name__, self.user_id)
65
65
66
66
67 class RepoTemp(object):
67 class RepoTemp(object):
68 def __init__(self, repo_id):
68 def __init__(self, repo_id):
69 self.repo_id = repo_id
69 self.repo_id = repo_id
70
70
71 def __repr__(self):
71 def __repr__(self):
72 return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id)
72 return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id)
73
73
74
74
75 class SimpleCachedRepoList(object):
75 class SimpleCachedRepoList(object):
76 """
76 """
77 Lighter version of of iteration of repos without the scm initialisation,
77 Lighter version of of iteration of repos without the scm initialisation,
78 and with cache usage
78 and with cache usage
79 """
79 """
80 def __init__(self, db_repo_list, repos_path, order_by=None, perm_set=None):
80 def __init__(self, db_repo_list, repos_path, order_by=None, perm_set=None):
81 self.db_repo_list = db_repo_list
81 self.db_repo_list = db_repo_list
82 self.repos_path = repos_path
82 self.repos_path = repos_path
83 self.order_by = order_by
83 self.order_by = order_by
84 self.reversed = (order_by or '').startswith('-')
84 self.reversed = (order_by or '').startswith('-')
85 if not perm_set:
85 if not perm_set:
86 perm_set = ['repository.read', 'repository.write',
86 perm_set = ['repository.read', 'repository.write',
87 'repository.admin']
87 'repository.admin']
88 self.perm_set = perm_set
88 self.perm_set = perm_set
89
89
90 def __len__(self):
90 def __len__(self):
91 return len(self.db_repo_list)
91 return len(self.db_repo_list)
92
92
93 def __repr__(self):
93 def __repr__(self):
94 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
94 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
95
95
96 def __iter__(self):
96 def __iter__(self):
97 for dbr in self.db_repo_list:
97 for dbr in self.db_repo_list:
98 # check permission at this level
98 # check permission at this level
99 has_perm = HasRepoPermissionAny(*self.perm_set)(
99 has_perm = HasRepoPermissionAny(*self.perm_set)(
100 dbr.repo_name, 'SimpleCachedRepoList check')
100 dbr.repo_name, 'SimpleCachedRepoList check')
101 if not has_perm:
101 if not has_perm:
102 continue
102 continue
103
103
104 tmp_d = {
104 tmp_d = {
105 'name': dbr.repo_name,
105 'name': dbr.repo_name,
106 'dbrepo': dbr.get_dict(),
106 'dbrepo': dbr.get_dict(),
107 'dbrepo_fork': dbr.fork.get_dict() if dbr.fork else {}
107 'dbrepo_fork': dbr.fork.get_dict() if dbr.fork else {}
108 }
108 }
109 yield tmp_d
109 yield tmp_d
110
110
111
111
112 class _PermCheckIterator(object):
112 class _PermCheckIterator(object):
113
113
114 def __init__(
114 def __init__(
115 self, obj_list, obj_attr, perm_set, perm_checker,
115 self, obj_list, obj_attr, perm_set, perm_checker,
116 extra_kwargs=None):
116 extra_kwargs=None):
117 """
117 """
118 Creates iterator from given list of objects, additionally
118 Creates iterator from given list of objects, additionally
119 checking permission for them from perm_set var
119 checking permission for them from perm_set var
120
120
121 :param obj_list: list of db objects
121 :param obj_list: list of db objects
122 :param obj_attr: attribute of object to pass into perm_checker
122 :param obj_attr: attribute of object to pass into perm_checker
123 :param perm_set: list of permissions to check
123 :param perm_set: list of permissions to check
124 :param perm_checker: callable to check permissions against
124 :param perm_checker: callable to check permissions against
125 """
125 """
126 self.obj_list = obj_list
126 self.obj_list = obj_list
127 self.obj_attr = obj_attr
127 self.obj_attr = obj_attr
128 self.perm_set = perm_set
128 self.perm_set = perm_set
129 self.perm_checker = perm_checker(*self.perm_set)
129 self.perm_checker = perm_checker(*self.perm_set)
130 self.extra_kwargs = extra_kwargs or {}
130 self.extra_kwargs = extra_kwargs or {}
131
131
132 def __len__(self):
132 def __len__(self):
133 return len(self.obj_list)
133 return len(self.obj_list)
134
134
135 def __repr__(self):
135 def __repr__(self):
136 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
136 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
137
137
138 def __iter__(self):
138 def __iter__(self):
139 for db_obj in self.obj_list:
139 for db_obj in self.obj_list:
140 # check permission at this level
140 # check permission at this level
141 # NOTE(marcink): the __dict__.get() is ~4x faster then getattr()
141 # NOTE(marcink): the __dict__.get() is ~4x faster then getattr()
142 name = db_obj.__dict__.get(self.obj_attr, None)
142 name = db_obj.__dict__.get(self.obj_attr, None)
143 if not self.perm_checker(name, self.__class__.__name__, **self.extra_kwargs):
143 if not self.perm_checker(name, self.__class__.__name__, **self.extra_kwargs):
144 continue
144 continue
145
145
146 yield db_obj
146 yield db_obj
147
147
148
148
149 class RepoList(_PermCheckIterator):
149 class RepoList(_PermCheckIterator):
150
150
151 def __init__(self, db_repo_list, perm_set=None, extra_kwargs=None):
151 def __init__(self, db_repo_list, perm_set=None, extra_kwargs=None):
152 if not perm_set:
152 if not perm_set:
153 perm_set = ['repository.read', 'repository.write', 'repository.admin']
153 perm_set = ['repository.read', 'repository.write', 'repository.admin']
154
154
155 super(RepoList, self).__init__(
155 super(RepoList, self).__init__(
156 obj_list=db_repo_list,
156 obj_list=db_repo_list,
157 obj_attr='_repo_name', perm_set=perm_set,
157 obj_attr='_repo_name', perm_set=perm_set,
158 perm_checker=HasRepoPermissionAny,
158 perm_checker=HasRepoPermissionAny,
159 extra_kwargs=extra_kwargs)
159 extra_kwargs=extra_kwargs)
160
160
161
161
162 class RepoGroupList(_PermCheckIterator):
162 class RepoGroupList(_PermCheckIterator):
163
163
164 def __init__(self, db_repo_group_list, perm_set=None, extra_kwargs=None):
164 def __init__(self, db_repo_group_list, perm_set=None, extra_kwargs=None):
165 if not perm_set:
165 if not perm_set:
166 perm_set = ['group.read', 'group.write', 'group.admin']
166 perm_set = ['group.read', 'group.write', 'group.admin']
167
167
168 super(RepoGroupList, self).__init__(
168 super(RepoGroupList, self).__init__(
169 obj_list=db_repo_group_list,
169 obj_list=db_repo_group_list,
170 obj_attr='_group_name', perm_set=perm_set,
170 obj_attr='_group_name', perm_set=perm_set,
171 perm_checker=HasRepoGroupPermissionAny,
171 perm_checker=HasRepoGroupPermissionAny,
172 extra_kwargs=extra_kwargs)
172 extra_kwargs=extra_kwargs)
173
173
174
174
175 class UserGroupList(_PermCheckIterator):
175 class UserGroupList(_PermCheckIterator):
176
176
177 def __init__(self, db_user_group_list, perm_set=None, extra_kwargs=None):
177 def __init__(self, db_user_group_list, perm_set=None, extra_kwargs=None):
178 if not perm_set:
178 if not perm_set:
179 perm_set = ['usergroup.read', 'usergroup.write', 'usergroup.admin']
179 perm_set = ['usergroup.read', 'usergroup.write', 'usergroup.admin']
180
180
181 super(UserGroupList, self).__init__(
181 super(UserGroupList, self).__init__(
182 obj_list=db_user_group_list,
182 obj_list=db_user_group_list,
183 obj_attr='users_group_name', perm_set=perm_set,
183 obj_attr='users_group_name', perm_set=perm_set,
184 perm_checker=HasUserGroupPermissionAny,
184 perm_checker=HasUserGroupPermissionAny,
185 extra_kwargs=extra_kwargs)
185 extra_kwargs=extra_kwargs)
186
186
187
187
188 class ScmModel(BaseModel):
188 class ScmModel(BaseModel):
189 """
189 """
190 Generic Scm Model
190 Generic Scm Model
191 """
191 """
192
192
193 @LazyProperty
193 @LazyProperty
194 def repos_path(self):
194 def repos_path(self):
195 """
195 """
196 Gets the repositories root path from database
196 Gets the repositories root path from database
197 """
197 """
198
198
199 settings_model = VcsSettingsModel(sa=self.sa)
199 settings_model = VcsSettingsModel(sa=self.sa)
200 return settings_model.get_repos_location()
200 return settings_model.get_repos_location()
201
201
202 def repo_scan(self, repos_path=None):
202 def repo_scan(self, repos_path=None):
203 """
203 """
204 Listing of repositories in given path. This path should not be a
204 Listing of repositories in given path. This path should not be a
205 repository itself. Return a dictionary of repository objects
205 repository itself. Return a dictionary of repository objects
206
206
207 :param repos_path: path to directory containing repositories
207 :param repos_path: path to directory containing repositories
208 """
208 """
209
209
210 if repos_path is None:
210 if repos_path is None:
211 repos_path = self.repos_path
211 repos_path = self.repos_path
212
212
213 log.info('scanning for repositories in %s', repos_path)
213 log.info('scanning for repositories in %s', repos_path)
214
214
215 config = make_db_config()
215 config = make_db_config()
216 config.set('extensions', 'largefiles', '')
216 config.set('extensions', 'largefiles', '')
217 repos = {}
217 repos = {}
218
218
219 for name, path in get_filesystem_repos(repos_path, recursive=True):
219 for name, path in get_filesystem_repos(repos_path, recursive=True):
220 # name need to be decomposed and put back together using the /
220 # name need to be decomposed and put back together using the /
221 # since this is internal storage separator for rhodecode
221 # since this is internal storage separator for rhodecode
222 name = Repository.normalize_repo_name(name)
222 name = Repository.normalize_repo_name(name)
223
223
224 try:
224 try:
225 if name in repos:
225 if name in repos:
226 raise RepositoryError('Duplicate repository name %s '
226 raise RepositoryError('Duplicate repository name %s '
227 'found in %s' % (name, path))
227 'found in %s' % (name, path))
228 elif path[0] in rhodecode.BACKENDS:
228 elif path[0] in rhodecode.BACKENDS:
229 backend = get_backend(path[0])
229 backend = get_backend(path[0])
230 repos[name] = backend(path[1], config=config,
230 repos[name] = backend(path[1], config=config,
231 with_wire={"cache": False})
231 with_wire={"cache": False})
232 except OSError:
232 except OSError:
233 continue
233 continue
234 except RepositoryError:
234 except RepositoryError:
235 log.exception('Failed to create a repo')
235 log.exception('Failed to create a repo')
236 continue
236 continue
237
237
238 log.debug('found %s paths with repositories', len(repos))
238 log.debug('found %s paths with repositories', len(repos))
239 return repos
239 return repos
240
240
241 def get_repos(self, all_repos=None, sort_key=None):
241 def get_repos(self, all_repos=None, sort_key=None):
242 """
242 """
243 Get all repositories from db and for each repo create it's
243 Get all repositories from db and for each repo create it's
244 backend instance and fill that backed with information from database
244 backend instance and fill that backed with information from database
245
245
246 :param all_repos: list of repository names as strings
246 :param all_repos: list of repository names as strings
247 give specific repositories list, good for filtering
247 give specific repositories list, good for filtering
248
248
249 :param sort_key: initial sorting of repositories
249 :param sort_key: initial sorting of repositories
250 """
250 """
251 if all_repos is None:
251 if all_repos is None:
252 all_repos = self.sa.query(Repository)\
252 all_repos = self.sa.query(Repository)\
253 .filter(Repository.group_id == None)\
253 .filter(Repository.group_id == None)\
254 .order_by(func.lower(Repository.repo_name)).all()
254 .order_by(func.lower(Repository.repo_name)).all()
255 repo_iter = SimpleCachedRepoList(
255 repo_iter = SimpleCachedRepoList(
256 all_repos, repos_path=self.repos_path, order_by=sort_key)
256 all_repos, repos_path=self.repos_path, order_by=sort_key)
257 return repo_iter
257 return repo_iter
258
258
259 def get_repo_groups(self, all_groups=None):
259 def get_repo_groups(self, all_groups=None):
260 if all_groups is None:
260 if all_groups is None:
261 all_groups = RepoGroup.query()\
261 all_groups = RepoGroup.query()\
262 .filter(RepoGroup.group_parent_id == None).all()
262 .filter(RepoGroup.group_parent_id == None).all()
263 return [x for x in RepoGroupList(all_groups)]
263 return [x for x in RepoGroupList(all_groups)]
264
264
265 def mark_for_invalidation(self, repo_name, delete=False):
265 def mark_for_invalidation(self, repo_name, delete=False):
266 """
266 """
267 Mark caches of this repo invalid in the database. `delete` flag
267 Mark caches of this repo invalid in the database. `delete` flag
268 removes the cache entries
268 removes the cache entries
269
269
270 :param repo_name: the repo_name for which caches should be marked
270 :param repo_name: the repo_name for which caches should be marked
271 invalid, or deleted
271 invalid, or deleted
272 :param delete: delete the entry keys instead of setting bool
272 :param delete: delete the entry keys instead of setting bool
273 flag on them, and also purge caches used by the dogpile
273 flag on them, and also purge caches used by the dogpile
274 """
274 """
275 repo = Repository.get_by_repo_name(repo_name)
275 repo = Repository.get_by_repo_name(repo_name)
276
276
277 if repo:
277 if repo:
278 invalidation_namespace = CacheKey.REPO_INVALIDATION_NAMESPACE.format(
278 invalidation_namespace = CacheKey.REPO_INVALIDATION_NAMESPACE.format(
279 repo_id=repo.repo_id)
279 repo_id=repo.repo_id)
280 CacheKey.set_invalidate(invalidation_namespace, delete=delete)
280 CacheKey.set_invalidate(invalidation_namespace, delete=delete)
281
281
282 repo_id = repo.repo_id
282 repo_id = repo.repo_id
283 config = repo._config
283 config = repo._config
284 config.set('extensions', 'largefiles', '')
284 config.set('extensions', 'largefiles', '')
285 repo.update_commit_cache(config=config, cs_cache=None)
285 repo.update_commit_cache(config=config, cs_cache=None)
286 if delete:
286 if delete:
287 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
287 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
288 rc_cache.clear_cache_namespace(
288 rc_cache.clear_cache_namespace(
289 'cache_repo', cache_namespace_uid, invalidate=True)
289 'cache_repo', cache_namespace_uid, invalidate=True)
290
290
291 def toggle_following_repo(self, follow_repo_id, user_id):
291 def toggle_following_repo(self, follow_repo_id, user_id):
292
292
293 f = self.sa.query(UserFollowing)\
293 f = self.sa.query(UserFollowing)\
294 .filter(UserFollowing.follows_repo_id == follow_repo_id)\
294 .filter(UserFollowing.follows_repo_id == follow_repo_id)\
295 .filter(UserFollowing.user_id == user_id).scalar()
295 .filter(UserFollowing.user_id == user_id).scalar()
296
296
297 if f is not None:
297 if f is not None:
298 try:
298 try:
299 self.sa.delete(f)
299 self.sa.delete(f)
300 return
300 return
301 except Exception:
301 except Exception:
302 log.error(traceback.format_exc())
302 log.error(traceback.format_exc())
303 raise
303 raise
304
304
305 try:
305 try:
306 f = UserFollowing()
306 f = UserFollowing()
307 f.user_id = user_id
307 f.user_id = user_id
308 f.follows_repo_id = follow_repo_id
308 f.follows_repo_id = follow_repo_id
309 self.sa.add(f)
309 self.sa.add(f)
310 except Exception:
310 except Exception:
311 log.error(traceback.format_exc())
311 log.error(traceback.format_exc())
312 raise
312 raise
313
313
314 def toggle_following_user(self, follow_user_id, user_id):
314 def toggle_following_user(self, follow_user_id, user_id):
315 f = self.sa.query(UserFollowing)\
315 f = self.sa.query(UserFollowing)\
316 .filter(UserFollowing.follows_user_id == follow_user_id)\
316 .filter(UserFollowing.follows_user_id == follow_user_id)\
317 .filter(UserFollowing.user_id == user_id).scalar()
317 .filter(UserFollowing.user_id == user_id).scalar()
318
318
319 if f is not None:
319 if f is not None:
320 try:
320 try:
321 self.sa.delete(f)
321 self.sa.delete(f)
322 return
322 return
323 except Exception:
323 except Exception:
324 log.error(traceback.format_exc())
324 log.error(traceback.format_exc())
325 raise
325 raise
326
326
327 try:
327 try:
328 f = UserFollowing()
328 f = UserFollowing()
329 f.user_id = user_id
329 f.user_id = user_id
330 f.follows_user_id = follow_user_id
330 f.follows_user_id = follow_user_id
331 self.sa.add(f)
331 self.sa.add(f)
332 except Exception:
332 except Exception:
333 log.error(traceback.format_exc())
333 log.error(traceback.format_exc())
334 raise
334 raise
335
335
336 def is_following_repo(self, repo_name, user_id, cache=False):
336 def is_following_repo(self, repo_name, user_id, cache=False):
337 r = self.sa.query(Repository)\
337 r = self.sa.query(Repository)\
338 .filter(Repository.repo_name == repo_name).scalar()
338 .filter(Repository.repo_name == repo_name).scalar()
339
339
340 f = self.sa.query(UserFollowing)\
340 f = self.sa.query(UserFollowing)\
341 .filter(UserFollowing.follows_repository == r)\
341 .filter(UserFollowing.follows_repository == r)\
342 .filter(UserFollowing.user_id == user_id).scalar()
342 .filter(UserFollowing.user_id == user_id).scalar()
343
343
344 return f is not None
344 return f is not None
345
345
346 def is_following_user(self, username, user_id, cache=False):
346 def is_following_user(self, username, user_id, cache=False):
347 u = User.get_by_username(username)
347 u = User.get_by_username(username)
348
348
349 f = self.sa.query(UserFollowing)\
349 f = self.sa.query(UserFollowing)\
350 .filter(UserFollowing.follows_user == u)\
350 .filter(UserFollowing.follows_user == u)\
351 .filter(UserFollowing.user_id == user_id).scalar()
351 .filter(UserFollowing.user_id == user_id).scalar()
352
352
353 return f is not None
353 return f is not None
354
354
355 def get_followers(self, repo):
355 def get_followers(self, repo):
356 repo = self._get_repo(repo)
356 repo = self._get_repo(repo)
357
357
358 return self.sa.query(UserFollowing)\
358 return self.sa.query(UserFollowing)\
359 .filter(UserFollowing.follows_repository == repo).count()
359 .filter(UserFollowing.follows_repository == repo).count()
360
360
361 def get_forks(self, repo):
361 def get_forks(self, repo):
362 repo = self._get_repo(repo)
362 repo = self._get_repo(repo)
363 return self.sa.query(Repository)\
363 return self.sa.query(Repository)\
364 .filter(Repository.fork == repo).count()
364 .filter(Repository.fork == repo).count()
365
365
366 def get_pull_requests(self, repo):
366 def get_pull_requests(self, repo):
367 repo = self._get_repo(repo)
367 repo = self._get_repo(repo)
368 return self.sa.query(PullRequest)\
368 return self.sa.query(PullRequest)\
369 .filter(PullRequest.target_repo == repo)\
369 .filter(PullRequest.target_repo == repo)\
370 .filter(PullRequest.status != PullRequest.STATUS_CLOSED).count()
370 .filter(PullRequest.status != PullRequest.STATUS_CLOSED).count()
371
371
372 def get_artifacts(self, repo):
372 def get_artifacts(self, repo):
373 repo = self._get_repo(repo)
373 repo = self._get_repo(repo)
374 return self.sa.query(FileStore)\
374 return self.sa.query(FileStore)\
375 .filter(FileStore.repo == repo)\
375 .filter(FileStore.repo == repo)\
376 .filter(or_(FileStore.hidden == None, FileStore.hidden == false())).count()
376 .filter(or_(FileStore.hidden == None, FileStore.hidden == false())).count()
377
377
378 def mark_as_fork(self, repo, fork, user):
378 def mark_as_fork(self, repo, fork, user):
379 repo = self._get_repo(repo)
379 repo = self._get_repo(repo)
380 fork = self._get_repo(fork)
380 fork = self._get_repo(fork)
381 if fork and repo.repo_id == fork.repo_id:
381 if fork and repo.repo_id == fork.repo_id:
382 raise Exception("Cannot set repository as fork of itself")
382 raise Exception("Cannot set repository as fork of itself")
383
383
384 if fork and repo.repo_type != fork.repo_type:
384 if fork and repo.repo_type != fork.repo_type:
385 raise RepositoryError(
385 raise RepositoryError(
386 "Cannot set repository as fork of repository with other type")
386 "Cannot set repository as fork of repository with other type")
387
387
388 repo.fork = fork
388 repo.fork = fork
389 self.sa.add(repo)
389 self.sa.add(repo)
390 return repo
390 return repo
391
391
392 def pull_changes(self, repo, username, remote_uri=None, validate_uri=True):
392 def pull_changes(self, repo, username, remote_uri=None, validate_uri=True):
393 dbrepo = self._get_repo(repo)
393 dbrepo = self._get_repo(repo)
394 remote_uri = remote_uri or dbrepo.clone_uri
394 remote_uri = remote_uri or dbrepo.clone_uri
395 if not remote_uri:
395 if not remote_uri:
396 raise Exception("This repository doesn't have a clone uri")
396 raise Exception("This repository doesn't have a clone uri")
397
397
398 repo = dbrepo.scm_instance(cache=False)
398 repo = dbrepo.scm_instance(cache=False)
399 repo.config.clear_section('hooks')
399 repo.config.clear_section('hooks')
400
400
401 try:
401 try:
402 # NOTE(marcink): add extra validation so we skip invalid urls
402 # NOTE(marcink): add extra validation so we skip invalid urls
403 # this is due this tasks can be executed via scheduler without
403 # this is due this tasks can be executed via scheduler without
404 # proper validation of remote_uri
404 # proper validation of remote_uri
405 if validate_uri:
405 if validate_uri:
406 config = make_db_config(clear_session=False)
406 config = make_db_config(clear_session=False)
407 url_validator(remote_uri, dbrepo.repo_type, config)
407 url_validator(remote_uri, dbrepo.repo_type, config)
408 except InvalidCloneUrl:
408 except InvalidCloneUrl:
409 raise
409 raise
410
410
411 repo_name = dbrepo.repo_name
411 repo_name = dbrepo.repo_name
412 try:
412 try:
413 # TODO: we need to make sure those operations call proper hooks !
413 # TODO: we need to make sure those operations call proper hooks !
414 repo.fetch(remote_uri)
414 repo.fetch(remote_uri)
415
415
416 self.mark_for_invalidation(repo_name)
416 self.mark_for_invalidation(repo_name)
417 except Exception:
417 except Exception:
418 log.error(traceback.format_exc())
418 log.error(traceback.format_exc())
419 raise
419 raise
420
420
421 def push_changes(self, repo, username, remote_uri=None, validate_uri=True):
421 def push_changes(self, repo, username, remote_uri=None, validate_uri=True):
422 dbrepo = self._get_repo(repo)
422 dbrepo = self._get_repo(repo)
423 remote_uri = remote_uri or dbrepo.push_uri
423 remote_uri = remote_uri or dbrepo.push_uri
424 if not remote_uri:
424 if not remote_uri:
425 raise Exception("This repository doesn't have a clone uri")
425 raise Exception("This repository doesn't have a clone uri")
426
426
427 repo = dbrepo.scm_instance(cache=False)
427 repo = dbrepo.scm_instance(cache=False)
428 repo.config.clear_section('hooks')
428 repo.config.clear_section('hooks')
429
429
430 try:
430 try:
431 # NOTE(marcink): add extra validation so we skip invalid urls
431 # NOTE(marcink): add extra validation so we skip invalid urls
432 # this is due this tasks can be executed via scheduler without
432 # this is due this tasks can be executed via scheduler without
433 # proper validation of remote_uri
433 # proper validation of remote_uri
434 if validate_uri:
434 if validate_uri:
435 config = make_db_config(clear_session=False)
435 config = make_db_config(clear_session=False)
436 url_validator(remote_uri, dbrepo.repo_type, config)
436 url_validator(remote_uri, dbrepo.repo_type, config)
437 except InvalidCloneUrl:
437 except InvalidCloneUrl:
438 raise
438 raise
439
439
440 try:
440 try:
441 repo.push(remote_uri)
441 repo.push(remote_uri)
442 except Exception:
442 except Exception:
443 log.error(traceback.format_exc())
443 log.error(traceback.format_exc())
444 raise
444 raise
445
445
446 def commit_change(self, repo, repo_name, commit, user, author, message,
446 def commit_change(self, repo, repo_name, commit, user, author, message,
447 content, f_path):
447 content, f_path):
448 """
448 """
449 Commits changes
449 Commits changes
450
450
451 :param repo: SCM instance
451 :param repo: SCM instance
452
452
453 """
453 """
454 user = self._get_user(user)
454 user = self._get_user(user)
455
455
456 # decoding here will force that we have proper encoded values
456 # decoding here will force that we have proper encoded values
457 # in any other case this will throw exceptions and deny commit
457 # in any other case this will throw exceptions and deny commit
458 content = safe_str(content)
458 content = safe_str(content)
459 path = safe_str(f_path)
459 path = safe_str(f_path)
460 # message and author needs to be unicode
460 # message and author needs to be unicode
461 # proper backend should then translate that into required type
461 # proper backend should then translate that into required type
462 message = safe_unicode(message)
462 message = safe_unicode(message)
463 author = safe_unicode(author)
463 author = safe_unicode(author)
464 imc = repo.in_memory_commit
464 imc = repo.in_memory_commit
465 imc.change(FileNode(path, content, mode=commit.get_file_mode(f_path)))
465 imc.change(FileNode(path, content, mode=commit.get_file_mode(f_path)))
466 try:
466 try:
467 # TODO: handle pre-push action !
467 # TODO: handle pre-push action !
468 tip = imc.commit(
468 tip = imc.commit(
469 message=message, author=author, parents=[commit],
469 message=message, author=author, parents=[commit],
470 branch=commit.branch)
470 branch=commit.branch)
471 except Exception as e:
471 except Exception as e:
472 log.error(traceback.format_exc())
472 log.error(traceback.format_exc())
473 raise IMCCommitError(str(e))
473 raise IMCCommitError(str(e))
474 finally:
474 finally:
475 # always clear caches, if commit fails we want fresh object also
475 # always clear caches, if commit fails we want fresh object also
476 self.mark_for_invalidation(repo_name)
476 self.mark_for_invalidation(repo_name)
477
477
478 # We trigger the post-push action
478 # We trigger the post-push action
479 hooks_utils.trigger_post_push_hook(
479 hooks_utils.trigger_post_push_hook(
480 username=user.username, action='push_local', hook_type='post_push',
480 username=user.username, action='push_local', hook_type='post_push',
481 repo_name=repo_name, repo_type=repo.alias, commit_ids=[tip.raw_id])
481 repo_name=repo_name, repo_type=repo.alias, commit_ids=[tip.raw_id])
482 return tip
482 return tip
483
483
484 def _sanitize_path(self, f_path):
484 def _sanitize_path(self, f_path):
485 if f_path.startswith('/') or f_path.startswith('./') or '../' in f_path:
485 if f_path.startswith('/') or f_path.startswith('./') or '../' in f_path:
486 raise NonRelativePathError('%s is not an relative path' % f_path)
486 raise NonRelativePathError('%s is not an relative path' % f_path)
487 if f_path:
487 if f_path:
488 f_path = os.path.normpath(f_path)
488 f_path = os.path.normpath(f_path)
489 return f_path
489 return f_path
490
490
491 def get_dirnode_metadata(self, request, commit, dir_node):
491 def get_dirnode_metadata(self, request, commit, dir_node):
492 if not dir_node.is_dir():
492 if not dir_node.is_dir():
493 return []
493 return []
494
494
495 data = []
495 data = []
496 for node in dir_node:
496 for node in dir_node:
497 if not node.is_file():
497 if not node.is_file():
498 # we skip file-nodes
498 # we skip file-nodes
499 continue
499 continue
500
500
501 last_commit = node.last_commit
501 last_commit = node.last_commit
502 last_commit_date = last_commit.date
502 last_commit_date = last_commit.date
503 data.append({
503 data.append({
504 'name': node.name,
504 'name': node.name,
505 'size': h.format_byte_size_binary(node.size),
505 'size': h.format_byte_size_binary(node.size),
506 'modified_at': h.format_date(last_commit_date),
506 'modified_at': h.format_date(last_commit_date),
507 'modified_ts': last_commit_date.isoformat(),
507 'modified_ts': last_commit_date.isoformat(),
508 'revision': last_commit.revision,
508 'revision': last_commit.revision,
509 'short_id': last_commit.short_id,
509 'short_id': last_commit.short_id,
510 'message': h.escape(last_commit.message),
510 'message': h.escape(last_commit.message),
511 'author': h.escape(last_commit.author),
511 'author': h.escape(last_commit.author),
512 'user_profile': h.gravatar_with_user(
512 'user_profile': h.gravatar_with_user(
513 request, last_commit.author),
513 request, last_commit.author),
514 })
514 })
515
515
516 return data
516 return data
517
517
518 def get_nodes(self, repo_name, commit_id, root_path='/', flat=True,
518 def get_nodes(self, repo_name, commit_id, root_path='/', flat=True,
519 extended_info=False, content=False, max_file_bytes=None):
519 extended_info=False, content=False, max_file_bytes=None):
520 """
520 """
521 recursive walk in root dir and return a set of all path in that dir
521 recursive walk in root dir and return a set of all path in that dir
522 based on repository walk function
522 based on repository walk function
523
523
524 :param repo_name: name of repository
524 :param repo_name: name of repository
525 :param commit_id: commit id for which to list nodes
525 :param commit_id: commit id for which to list nodes
526 :param root_path: root path to list
526 :param root_path: root path to list
527 :param flat: return as a list, if False returns a dict with description
527 :param flat: return as a list, if False returns a dict with description
528 :param extended_info: show additional info such as md5, binary, size etc
528 :param extended_info: show additional info such as md5, binary, size etc
529 :param content: add nodes content to the return data
529 :param content: add nodes content to the return data
530 :param max_file_bytes: will not return file contents over this limit
530 :param max_file_bytes: will not return file contents over this limit
531
531
532 """
532 """
533 _files = list()
533 _files = list()
534 _dirs = list()
534 _dirs = list()
535 try:
535 try:
536 _repo = self._get_repo(repo_name)
536 _repo = self._get_repo(repo_name)
537 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
537 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
538 root_path = root_path.lstrip('/')
538 root_path = root_path.lstrip('/')
539 for __, dirs, files in commit.walk(root_path):
539 for __, dirs, files in commit.walk(root_path):
540
540
541 for f in files:
541 for f in files:
542 _content = None
542 _content = None
543 _data = f_name = f.unicode_path
543 _data = f_name = f.unicode_path
544
544
545 if not flat:
545 if not flat:
546 _data = {
546 _data = {
547 "name": h.escape(f_name),
547 "name": h.escape(f_name),
548 "type": "file",
548 "type": "file",
549 }
549 }
550 if extended_info:
550 if extended_info:
551 _data.update({
551 _data.update({
552 "md5": f.md5,
552 "md5": f.md5,
553 "binary": f.is_binary,
553 "binary": f.is_binary,
554 "size": f.size,
554 "size": f.size,
555 "extension": f.extension,
555 "extension": f.extension,
556 "mimetype": f.mimetype,
556 "mimetype": f.mimetype,
557 "lines": f.lines()[0]
557 "lines": f.lines()[0]
558 })
558 })
559
559
560 if content:
560 if content:
561 over_size_limit = (max_file_bytes is not None
561 over_size_limit = (max_file_bytes is not None
562 and f.size > max_file_bytes)
562 and f.size > max_file_bytes)
563 full_content = None
563 full_content = None
564 if not f.is_binary and not over_size_limit:
564 if not f.is_binary and not over_size_limit:
565 full_content = safe_str(f.content)
565 full_content = safe_str(f.content)
566
566
567 _data.update({
567 _data.update({
568 "content": full_content,
568 "content": full_content,
569 })
569 })
570 _files.append(_data)
570 _files.append(_data)
571
571
572 for d in dirs:
572 for d in dirs:
573 _data = d_name = d.unicode_path
573 _data = d_name = d.unicode_path
574 if not flat:
574 if not flat:
575 _data = {
575 _data = {
576 "name": h.escape(d_name),
576 "name": h.escape(d_name),
577 "type": "dir",
577 "type": "dir",
578 }
578 }
579 if extended_info:
579 if extended_info:
580 _data.update({
580 _data.update({
581 "md5": None,
581 "md5": None,
582 "binary": None,
582 "binary": None,
583 "size": None,
583 "size": None,
584 "extension": None,
584 "extension": None,
585 })
585 })
586 if content:
586 if content:
587 _data.update({
587 _data.update({
588 "content": None
588 "content": None
589 })
589 })
590 _dirs.append(_data)
590 _dirs.append(_data)
591 except RepositoryError:
591 except RepositoryError:
592 log.exception("Exception in get_nodes")
592 log.exception("Exception in get_nodes")
593 raise
593 raise
594
594
595 return _dirs, _files
595 return _dirs, _files
596
596
597 def get_quick_filter_nodes(self, repo_name, commit_id, root_path='/'):
597 def get_quick_filter_nodes(self, repo_name, commit_id, root_path='/'):
598 """
598 """
599 Generate files for quick filter in files view
599 Generate files for quick filter in files view
600 """
600 """
601
601
602 _files = list()
602 _files = list()
603 _dirs = list()
603 _dirs = list()
604 try:
604 try:
605 _repo = self._get_repo(repo_name)
605 _repo = self._get_repo(repo_name)
606 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
606 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
607 root_path = root_path.lstrip('/')
607 root_path = root_path.lstrip('/')
608 for __, dirs, files in commit.walk(root_path):
608 for __, dirs, files in commit.walk(root_path):
609
609
610 for f in files:
610 for f in files:
611
611
612 _data = {
612 _data = {
613 "name": h.escape(f.unicode_path),
613 "name": h.escape(f.unicode_path),
614 "type": "file",
614 "type": "file",
615 }
615 }
616
616
617 _files.append(_data)
617 _files.append(_data)
618
618
619 for d in dirs:
619 for d in dirs:
620
620
621 _data = {
621 _data = {
622 "name": h.escape(d.unicode_path),
622 "name": h.escape(d.unicode_path),
623 "type": "dir",
623 "type": "dir",
624 }
624 }
625
625
626 _dirs.append(_data)
626 _dirs.append(_data)
627 except RepositoryError:
627 except RepositoryError:
628 log.exception("Exception in get_quick_filter_nodes")
628 log.exception("Exception in get_quick_filter_nodes")
629 raise
629 raise
630
630
631 return _dirs, _files
631 return _dirs, _files
632
632
633 def get_node(self, repo_name, commit_id, file_path,
633 def get_node(self, repo_name, commit_id, file_path,
634 extended_info=False, content=False, max_file_bytes=None, cache=True):
634 extended_info=False, content=False, max_file_bytes=None, cache=True):
635 """
635 """
636 retrieve single node from commit
636 retrieve single node from commit
637 """
637 """
638 try:
638 try:
639
639
640 _repo = self._get_repo(repo_name)
640 _repo = self._get_repo(repo_name)
641 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
641 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
642
642
643 file_node = commit.get_node(file_path)
643 file_node = commit.get_node(file_path)
644 if file_node.is_dir():
644 if file_node.is_dir():
645 raise RepositoryError('The given path is a directory')
645 raise RepositoryError('The given path is a directory')
646
646
647 _content = None
647 _content = None
648 f_name = file_node.unicode_path
648 f_name = file_node.unicode_path
649
649
650 file_data = {
650 file_data = {
651 "name": h.escape(f_name),
651 "name": h.escape(f_name),
652 "type": "file",
652 "type": "file",
653 }
653 }
654
654
655 if extended_info:
655 if extended_info:
656 file_data.update({
656 file_data.update({
657 "extension": file_node.extension,
657 "extension": file_node.extension,
658 "mimetype": file_node.mimetype,
658 "mimetype": file_node.mimetype,
659 })
659 })
660
660
661 if cache:
661 if cache:
662 md5 = file_node.md5
662 md5 = file_node.md5
663 is_binary = file_node.is_binary
663 is_binary = file_node.is_binary
664 size = file_node.size
664 size = file_node.size
665 else:
665 else:
666 is_binary, md5, size, _content = file_node.metadata_uncached()
666 is_binary, md5, size, _content = file_node.metadata_uncached()
667
667
668 file_data.update({
668 file_data.update({
669 "md5": md5,
669 "md5": md5,
670 "binary": is_binary,
670 "binary": is_binary,
671 "size": size,
671 "size": size,
672 })
672 })
673
673
674 if content and cache:
674 if content and cache:
675 # get content + cache
675 # get content + cache
676 size = file_node.size
676 size = file_node.size
677 over_size_limit = (max_file_bytes is not None and size > max_file_bytes)
677 over_size_limit = (max_file_bytes is not None and size > max_file_bytes)
678 full_content = None
678 full_content = None
679 all_lines = 0
679 all_lines = 0
680 if not file_node.is_binary and not over_size_limit:
680 if not file_node.is_binary and not over_size_limit:
681 full_content = safe_unicode(file_node.content)
681 full_content = safe_unicode(file_node.content)
682 all_lines, empty_lines = file_node.count_lines(full_content)
682 all_lines, empty_lines = file_node.count_lines(full_content)
683
683
684 file_data.update({
684 file_data.update({
685 "content": full_content,
685 "content": full_content,
686 "lines": all_lines
686 "lines": all_lines
687 })
687 })
688 elif content:
688 elif content:
689 # get content *without* cache
689 # get content *without* cache
690 if _content is None:
690 if _content is None:
691 is_binary, md5, size, _content = file_node.metadata_uncached()
691 is_binary, md5, size, _content = file_node.metadata_uncached()
692
692
693 over_size_limit = (max_file_bytes is not None and size > max_file_bytes)
693 over_size_limit = (max_file_bytes is not None and size > max_file_bytes)
694 full_content = None
694 full_content = None
695 all_lines = 0
695 all_lines = 0
696 if not is_binary and not over_size_limit:
696 if not is_binary and not over_size_limit:
697 full_content = safe_unicode(_content)
697 full_content = safe_unicode(_content)
698 all_lines, empty_lines = file_node.count_lines(full_content)
698 all_lines, empty_lines = file_node.count_lines(full_content)
699
699
700 file_data.update({
700 file_data.update({
701 "content": full_content,
701 "content": full_content,
702 "lines": all_lines
702 "lines": all_lines
703 })
703 })
704
704
705 except RepositoryError:
705 except RepositoryError:
706 log.exception("Exception in get_node")
706 log.exception("Exception in get_node")
707 raise
707 raise
708
708
709 return file_data
709 return file_data
710
710
711 def get_fts_data(self, repo_name, commit_id, root_path='/'):
711 def get_fts_data(self, repo_name, commit_id, root_path='/'):
712 """
712 """
713 Fetch node tree for usage in full text search
713 Fetch node tree for usage in full text search
714 """
714 """
715
715
716 tree_info = list()
716 tree_info = list()
717
717
718 try:
718 try:
719 _repo = self._get_repo(repo_name)
719 _repo = self._get_repo(repo_name)
720 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
720 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
721 root_path = root_path.lstrip('/')
721 root_path = root_path.lstrip('/')
722 for __, dirs, files in commit.walk(root_path):
722 for __, dirs, files in commit.walk(root_path):
723
723
724 for f in files:
724 for f in files:
725 is_binary, md5, size, _content = f.metadata_uncached()
725 is_binary, md5, size, _content = f.metadata_uncached()
726 _data = {
726 _data = {
727 "name": f.unicode_path,
727 "name": f.unicode_path,
728 "md5": md5,
728 "md5": md5,
729 "extension": f.extension,
729 "extension": f.extension,
730 "binary": is_binary,
730 "binary": is_binary,
731 "size": size
731 "size": size
732 }
732 }
733
733
734 tree_info.append(_data)
734 tree_info.append(_data)
735
735
736 except RepositoryError:
736 except RepositoryError:
737 log.exception("Exception in get_nodes")
737 log.exception("Exception in get_nodes")
738 raise
738 raise
739
739
740 return tree_info
740 return tree_info
741
741
742 def create_nodes(self, user, repo, message, nodes, parent_commit=None,
742 def create_nodes(self, user, repo, message, nodes, parent_commit=None,
743 author=None, trigger_push_hook=True):
743 author=None, trigger_push_hook=True):
744 """
744 """
745 Commits given multiple nodes into repo
745 Commits given multiple nodes into repo
746
746
747 :param user: RhodeCode User object or user_id, the commiter
747 :param user: RhodeCode User object or user_id, the commiter
748 :param repo: RhodeCode Repository object
748 :param repo: RhodeCode Repository object
749 :param message: commit message
749 :param message: commit message
750 :param nodes: mapping {filename:{'content':content},...}
750 :param nodes: mapping {filename:{'content':content},...}
751 :param parent_commit: parent commit, can be empty than it's
751 :param parent_commit: parent commit, can be empty than it's
752 initial commit
752 initial commit
753 :param author: author of commit, cna be different that commiter
753 :param author: author of commit, cna be different that commiter
754 only for git
754 only for git
755 :param trigger_push_hook: trigger push hooks
755 :param trigger_push_hook: trigger push hooks
756
756
757 :returns: new committed commit
757 :returns: new committed commit
758 """
758 """
759
759
760 user = self._get_user(user)
760 user = self._get_user(user)
761 scm_instance = repo.scm_instance(cache=False)
761 scm_instance = repo.scm_instance(cache=False)
762
762
763 processed_nodes = []
763 processed_nodes = []
764 for f_path in nodes:
764 for f_path in nodes:
765 f_path = self._sanitize_path(f_path)
765 f_path = self._sanitize_path(f_path)
766 content = nodes[f_path]['content']
766 content = nodes[f_path]['content']
767 f_path = safe_str(f_path)
767 f_path = safe_str(f_path)
768 # decoding here will force that we have proper encoded values
768 # decoding here will force that we have proper encoded values
769 # in any other case this will throw exceptions and deny commit
769 # in any other case this will throw exceptions and deny commit
770 if isinstance(content, (str,)):
770 if isinstance(content, (str,)):
771 content = safe_str(content)
771 content = safe_str(content)
772 elif isinstance(content, (file, cStringIO.OutputType,)):
772 elif isinstance(content, (file, cStringIO.OutputType,)):
773 content = content.read()
773 content = content.read()
774 else:
774 else:
775 raise Exception('Content is of unrecognized type %s' % (
775 raise Exception('Content is of unrecognized type %s' % (
776 type(content)
776 type(content)
777 ))
777 ))
778 processed_nodes.append((f_path, content))
778 processed_nodes.append((f_path, content))
779
779
780 message = safe_unicode(message)
780 message = safe_unicode(message)
781 commiter = user.full_contact
781 commiter = user.full_contact
782 author = safe_unicode(author) if author else commiter
782 author = safe_unicode(author) if author else commiter
783
783
784 imc = scm_instance.in_memory_commit
784 imc = scm_instance.in_memory_commit
785
785
786 if not parent_commit:
786 if not parent_commit:
787 parent_commit = EmptyCommit(alias=scm_instance.alias)
787 parent_commit = EmptyCommit(alias=scm_instance.alias)
788
788
789 if isinstance(parent_commit, EmptyCommit):
789 if isinstance(parent_commit, EmptyCommit):
790 # EmptyCommit means we we're editing empty repository
790 # EmptyCommit means we we're editing empty repository
791 parents = None
791 parents = None
792 else:
792 else:
793 parents = [parent_commit]
793 parents = [parent_commit]
794 # add multiple nodes
794 # add multiple nodes
795 for path, content in processed_nodes:
795 for path, content in processed_nodes:
796 imc.add(FileNode(path, content=content))
796 imc.add(FileNode(path, content=content))
797 # TODO: handle pre push scenario
797 # TODO: handle pre push scenario
798 tip = imc.commit(message=message,
798 tip = imc.commit(message=message,
799 author=author,
799 author=author,
800 parents=parents,
800 parents=parents,
801 branch=parent_commit.branch)
801 branch=parent_commit.branch)
802
802
803 self.mark_for_invalidation(repo.repo_name)
803 self.mark_for_invalidation(repo.repo_name)
804 if trigger_push_hook:
804 if trigger_push_hook:
805 hooks_utils.trigger_post_push_hook(
805 hooks_utils.trigger_post_push_hook(
806 username=user.username, action='push_local',
806 username=user.username, action='push_local',
807 repo_name=repo.repo_name, repo_type=scm_instance.alias,
807 repo_name=repo.repo_name, repo_type=scm_instance.alias,
808 hook_type='post_push',
808 hook_type='post_push',
809 commit_ids=[tip.raw_id])
809 commit_ids=[tip.raw_id])
810 return tip
810 return tip
811
811
812 def update_nodes(self, user, repo, message, nodes, parent_commit=None,
812 def update_nodes(self, user, repo, message, nodes, parent_commit=None,
813 author=None, trigger_push_hook=True):
813 author=None, trigger_push_hook=True):
814 user = self._get_user(user)
814 user = self._get_user(user)
815 scm_instance = repo.scm_instance(cache=False)
815 scm_instance = repo.scm_instance(cache=False)
816
816
817 message = safe_unicode(message)
817 message = safe_unicode(message)
818 commiter = user.full_contact
818 commiter = user.full_contact
819 author = safe_unicode(author) if author else commiter
819 author = safe_unicode(author) if author else commiter
820
820
821 imc = scm_instance.in_memory_commit
821 imc = scm_instance.in_memory_commit
822
822
823 if not parent_commit:
823 if not parent_commit:
824 parent_commit = EmptyCommit(alias=scm_instance.alias)
824 parent_commit = EmptyCommit(alias=scm_instance.alias)
825
825
826 if isinstance(parent_commit, EmptyCommit):
826 if isinstance(parent_commit, EmptyCommit):
827 # EmptyCommit means we we're editing empty repository
827 # EmptyCommit means we we're editing empty repository
828 parents = None
828 parents = None
829 else:
829 else:
830 parents = [parent_commit]
830 parents = [parent_commit]
831
831
832 # add multiple nodes
832 # add multiple nodes
833 for _filename, data in nodes.items():
833 for _filename, data in nodes.items():
834 # new filename, can be renamed from the old one, also sanitaze
834 # new filename, can be renamed from the old one, also sanitaze
835 # the path for any hack around relative paths like ../../ etc.
835 # the path for any hack around relative paths like ../../ etc.
836 filename = self._sanitize_path(data['filename'])
836 filename = self._sanitize_path(data['filename'])
837 old_filename = self._sanitize_path(_filename)
837 old_filename = self._sanitize_path(_filename)
838 content = data['content']
838 content = data['content']
839 file_mode = data.get('mode')
839 file_mode = data.get('mode')
840 filenode = FileNode(old_filename, content=content, mode=file_mode)
840 filenode = FileNode(old_filename, content=content, mode=file_mode)
841 op = data['op']
841 op = data['op']
842 if op == 'add':
842 if op == 'add':
843 imc.add(filenode)
843 imc.add(filenode)
844 elif op == 'del':
844 elif op == 'del':
845 imc.remove(filenode)
845 imc.remove(filenode)
846 elif op == 'mod':
846 elif op == 'mod':
847 if filename != old_filename:
847 if filename != old_filename:
848 # TODO: handle renames more efficient, needs vcs lib changes
848 # TODO: handle renames more efficient, needs vcs lib changes
849 imc.remove(filenode)
849 imc.remove(filenode)
850 imc.add(FileNode(filename, content=content, mode=file_mode))
850 imc.add(FileNode(filename, content=content, mode=file_mode))
851 else:
851 else:
852 imc.change(filenode)
852 imc.change(filenode)
853
853
854 try:
854 try:
855 # TODO: handle pre push scenario commit changes
855 # TODO: handle pre push scenario commit changes
856 tip = imc.commit(message=message,
856 tip = imc.commit(message=message,
857 author=author,
857 author=author,
858 parents=parents,
858 parents=parents,
859 branch=parent_commit.branch)
859 branch=parent_commit.branch)
860 except NodeNotChangedError:
860 except NodeNotChangedError:
861 raise
861 raise
862 except Exception as e:
862 except Exception as e:
863 log.exception("Unexpected exception during call to imc.commit")
863 log.exception("Unexpected exception during call to imc.commit")
864 raise IMCCommitError(str(e))
864 raise IMCCommitError(str(e))
865 finally:
865 finally:
866 # always clear caches, if commit fails we want fresh object also
866 # always clear caches, if commit fails we want fresh object also
867 self.mark_for_invalidation(repo.repo_name)
867 self.mark_for_invalidation(repo.repo_name)
868
868
869 if trigger_push_hook:
869 if trigger_push_hook:
870 hooks_utils.trigger_post_push_hook(
870 hooks_utils.trigger_post_push_hook(
871 username=user.username, action='push_local', hook_type='post_push',
871 username=user.username, action='push_local', hook_type='post_push',
872 repo_name=repo.repo_name, repo_type=scm_instance.alias,
872 repo_name=repo.repo_name, repo_type=scm_instance.alias,
873 commit_ids=[tip.raw_id])
873 commit_ids=[tip.raw_id])
874
874
875 return tip
875 return tip
876
876
877 def delete_nodes(self, user, repo, message, nodes, parent_commit=None,
877 def delete_nodes(self, user, repo, message, nodes, parent_commit=None,
878 author=None, trigger_push_hook=True):
878 author=None, trigger_push_hook=True):
879 """
879 """
880 Deletes given multiple nodes into `repo`
880 Deletes given multiple nodes into `repo`
881
881
882 :param user: RhodeCode User object or user_id, the committer
882 :param user: RhodeCode User object or user_id, the committer
883 :param repo: RhodeCode Repository object
883 :param repo: RhodeCode Repository object
884 :param message: commit message
884 :param message: commit message
885 :param nodes: mapping {filename:{'content':content},...}
885 :param nodes: mapping {filename:{'content':content},...}
886 :param parent_commit: parent commit, can be empty than it's initial
886 :param parent_commit: parent commit, can be empty than it's initial
887 commit
887 commit
888 :param author: author of commit, cna be different that commiter only
888 :param author: author of commit, cna be different that commiter only
889 for git
889 for git
890 :param trigger_push_hook: trigger push hooks
890 :param trigger_push_hook: trigger push hooks
891
891
892 :returns: new commit after deletion
892 :returns: new commit after deletion
893 """
893 """
894
894
895 user = self._get_user(user)
895 user = self._get_user(user)
896 scm_instance = repo.scm_instance(cache=False)
896 scm_instance = repo.scm_instance(cache=False)
897
897
898 processed_nodes = []
898 processed_nodes = []
899 for f_path in nodes:
899 for f_path in nodes:
900 f_path = self._sanitize_path(f_path)
900 f_path = self._sanitize_path(f_path)
901 # content can be empty but for compatabilty it allows same dicts
901 # content can be empty but for compatabilty it allows same dicts
902 # structure as add_nodes
902 # structure as add_nodes
903 content = nodes[f_path].get('content')
903 content = nodes[f_path].get('content')
904 processed_nodes.append((f_path, content))
904 processed_nodes.append((f_path, content))
905
905
906 message = safe_unicode(message)
906 message = safe_unicode(message)
907 commiter = user.full_contact
907 commiter = user.full_contact
908 author = safe_unicode(author) if author else commiter
908 author = safe_unicode(author) if author else commiter
909
909
910 imc = scm_instance.in_memory_commit
910 imc = scm_instance.in_memory_commit
911
911
912 if not parent_commit:
912 if not parent_commit:
913 parent_commit = EmptyCommit(alias=scm_instance.alias)
913 parent_commit = EmptyCommit(alias=scm_instance.alias)
914
914
915 if isinstance(parent_commit, EmptyCommit):
915 if isinstance(parent_commit, EmptyCommit):
916 # EmptyCommit means we we're editing empty repository
916 # EmptyCommit means we we're editing empty repository
917 parents = None
917 parents = None
918 else:
918 else:
919 parents = [parent_commit]
919 parents = [parent_commit]
920 # add multiple nodes
920 # add multiple nodes
921 for path, content in processed_nodes:
921 for path, content in processed_nodes:
922 imc.remove(FileNode(path, content=content))
922 imc.remove(FileNode(path, content=content))
923
923
924 # TODO: handle pre push scenario
924 # TODO: handle pre push scenario
925 tip = imc.commit(message=message,
925 tip = imc.commit(message=message,
926 author=author,
926 author=author,
927 parents=parents,
927 parents=parents,
928 branch=parent_commit.branch)
928 branch=parent_commit.branch)
929
929
930 self.mark_for_invalidation(repo.repo_name)
930 self.mark_for_invalidation(repo.repo_name)
931 if trigger_push_hook:
931 if trigger_push_hook:
932 hooks_utils.trigger_post_push_hook(
932 hooks_utils.trigger_post_push_hook(
933 username=user.username, action='push_local', hook_type='post_push',
933 username=user.username, action='push_local', hook_type='post_push',
934 repo_name=repo.repo_name, repo_type=scm_instance.alias,
934 repo_name=repo.repo_name, repo_type=scm_instance.alias,
935 commit_ids=[tip.raw_id])
935 commit_ids=[tip.raw_id])
936 return tip
936 return tip
937
937
938 def strip(self, repo, commit_id, branch):
938 def strip(self, repo, commit_id, branch):
939 scm_instance = repo.scm_instance(cache=False)
939 scm_instance = repo.scm_instance(cache=False)
940 scm_instance.config.clear_section('hooks')
940 scm_instance.config.clear_section('hooks')
941 scm_instance.strip(commit_id, branch)
941 scm_instance.strip(commit_id, branch)
942 self.mark_for_invalidation(repo.repo_name)
942 self.mark_for_invalidation(repo.repo_name)
943
943
944 def get_unread_journal(self):
944 def get_unread_journal(self):
945 return self.sa.query(UserLog).count()
945 return self.sa.query(UserLog).count()
946
946
947 @classmethod
947 @classmethod
948 def backend_landing_ref(cls, repo_type):
948 def backend_landing_ref(cls, repo_type):
949 """
949 """
950 Return a default landing ref based on a repository type.
950 Return a default landing ref based on a repository type.
951 """
951 """
952
952
953 landing_ref = {
953 landing_ref = {
954 'hg': ('branch:default', 'default'),
954 'hg': ('branch:default', 'default'),
955 'git': ('branch:master', 'master'),
955 'git': ('branch:master', 'master'),
956 'svn': ('rev:tip', 'latest tip'),
956 'svn': ('rev:tip', 'latest tip'),
957 'default': ('rev:tip', 'latest tip'),
957 'default': ('rev:tip', 'latest tip'),
958 }
958 }
959
959
960 return landing_ref.get(repo_type) or landing_ref['default']
960 return landing_ref.get(repo_type) or landing_ref['default']
961
961
962 def get_repo_landing_revs(self, translator, repo=None):
962 def get_repo_landing_revs(self, translator, repo=None):
963 """
963 """
964 Generates select option with tags branches and bookmarks (for hg only)
964 Generates select option with tags branches and bookmarks (for hg only)
965 grouped by type
965 grouped by type
966
966
967 :param repo:
967 :param repo:
968 """
968 """
969 from rhodecode.lib.vcs.backends.git import GitRepository
969 from rhodecode.lib.vcs.backends.git import GitRepository
970
970
971 _ = translator
971 _ = translator
972 repo = self._get_repo(repo)
972 repo = self._get_repo(repo)
973
973
974 if repo:
974 if repo:
975 repo_type = repo.repo_type
975 repo_type = repo.repo_type
976 else:
976 else:
977 repo_type = 'default'
977 repo_type = 'default'
978
978
979 default_landing_ref, landing_ref_lbl = self.backend_landing_ref(repo_type)
979 default_landing_ref, landing_ref_lbl = self.backend_landing_ref(repo_type)
980
980
981 default_ref_options = [
981 default_ref_options = [
982 [default_landing_ref, landing_ref_lbl]
982 [default_landing_ref, landing_ref_lbl]
983 ]
983 ]
984 default_choices = [
984 default_choices = [
985 default_landing_ref
985 default_landing_ref
986 ]
986 ]
987
987
988 if not repo:
988 if not repo:
989 # presented at NEW repo creation
989 # presented at NEW repo creation
990 return default_choices, default_ref_options
990 return default_choices, default_ref_options
991
991
992 repo = repo.scm_instance()
992 repo = repo.scm_instance()
993
993
994 ref_options = [(default_landing_ref, landing_ref_lbl)]
994 ref_options = [(default_landing_ref, landing_ref_lbl)]
995 choices = [default_landing_ref]
995 choices = [default_landing_ref]
996
996
997 # branches
997 # branches
998 branch_group = [(u'branch:%s' % safe_unicode(b), safe_unicode(b)) for b in repo.branches]
998 branch_group = [(u'branch:%s' % safe_unicode(b), safe_unicode(b)) for b in repo.branches]
999 if not branch_group:
999 if not branch_group:
1000 # new repo, or without maybe a branch?
1000 # new repo, or without maybe a branch?
1001 branch_group = default_ref_options
1001 branch_group = default_ref_options
1002
1002
1003 branches_group = (branch_group, _("Branches"))
1003 branches_group = (branch_group, _("Branches"))
1004 ref_options.append(branches_group)
1004 ref_options.append(branches_group)
1005 choices.extend([x[0] for x in branches_group[0]])
1005 choices.extend([x[0] for x in branches_group[0]])
1006
1006
1007 # bookmarks for HG
1007 # bookmarks for HG
1008 if repo.alias == 'hg':
1008 if repo.alias == 'hg':
1009 bookmarks_group = (
1009 bookmarks_group = (
1010 [(u'book:%s' % safe_unicode(b), safe_unicode(b))
1010 [(u'book:%s' % safe_unicode(b), safe_unicode(b))
1011 for b in repo.bookmarks],
1011 for b in repo.bookmarks],
1012 _("Bookmarks"))
1012 _("Bookmarks"))
1013 ref_options.append(bookmarks_group)
1013 ref_options.append(bookmarks_group)
1014 choices.extend([x[0] for x in bookmarks_group[0]])
1014 choices.extend([x[0] for x in bookmarks_group[0]])
1015
1015
1016 # tags
1016 # tags
1017 tags_group = (
1017 tags_group = (
1018 [(u'tag:%s' % safe_unicode(t), safe_unicode(t))
1018 [(u'tag:%s' % safe_unicode(t), safe_unicode(t))
1019 for t in repo.tags],
1019 for t in repo.tags],
1020 _("Tags"))
1020 _("Tags"))
1021 ref_options.append(tags_group)
1021 ref_options.append(tags_group)
1022 choices.extend([x[0] for x in tags_group[0]])
1022 choices.extend([x[0] for x in tags_group[0]])
1023
1023
1024 return choices, ref_options
1024 return choices, ref_options
1025
1025
1026 def get_server_info(self, environ=None):
1026 def get_server_info(self, environ=None):
1027 server_info = get_system_info(environ)
1027 server_info = get_system_info(environ)
1028 return server_info
1028 return server_info
@@ -1,1115 +1,1115 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 Set of generic validators
22 Set of generic validators
23 """
23 """
24
24
25
25
26 import os
26 import os
27 import re
27 import re
28 import logging
28 import logging
29 import collections
29 import collections
30
30
31 import formencode
31 import formencode
32 import ipaddress
32 import ipaddress
33 from formencode.validators import (
33 from formencode.validators import (
34 UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set,
34 UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set,
35 NotEmpty, IPAddress, CIDR, String, FancyValidator
35 NotEmpty, IPAddress, CIDR, String, FancyValidator
36 )
36 )
37
37
38 from sqlalchemy.sql.expression import true
38 from sqlalchemy.sql.expression import true
39 from sqlalchemy.util import OrderedSet
39 from sqlalchemy.util import OrderedSet
40
40
41 from rhodecode.authentication import (
41 from rhodecode.authentication import (
42 legacy_plugin_prefix, _import_legacy_plugin)
42 legacy_plugin_prefix, _import_legacy_plugin)
43 from rhodecode.authentication.base import loadplugin
43 from rhodecode.authentication.base import loadplugin
44 from rhodecode.apps._base import ADMIN_PREFIX
44 from rhodecode.apps._base import ADMIN_PREFIX
45 from rhodecode.lib.auth import HasRepoGroupPermissionAny, HasPermissionAny
45 from rhodecode.lib.auth import HasRepoGroupPermissionAny, HasPermissionAny
46 from rhodecode.lib.utils import repo_name_slug, make_db_config
46 from rhodecode.lib.utils import repo_name_slug, make_db_config
47 from rhodecode.lib.utils2 import safe_int, str2bool, aslist, md5, safe_unicode
47 from rhodecode.lib.utils2 import safe_int, str2bool, aslist, md5, safe_unicode
48 from rhodecode.lib.vcs.backends.git.repository import GitRepository
48 from rhodecode.lib.vcs.backends.git.repository import GitRepository
49 from rhodecode.lib.vcs.backends.hg.repository import MercurialRepository
49 from rhodecode.lib.vcs.backends.hg.repository import MercurialRepository
50 from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository
50 from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository
51 from rhodecode.model.db import (
51 from rhodecode.model.db import (
52 RepoGroup, Repository, UserGroup, User, ChangesetStatus, Gist)
52 RepoGroup, Repository, UserGroup, User, ChangesetStatus, Gist)
53 from rhodecode.model.settings import VcsSettingsModel
53 from rhodecode.model.settings import VcsSettingsModel
54
54
55 # silence warnings and pylint
55 # silence warnings and pylint
56 UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set, \
56 UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set, \
57 NotEmpty, IPAddress, CIDR, String, FancyValidator
57 NotEmpty, IPAddress, CIDR, String, FancyValidator
58
58
59 log = logging.getLogger(__name__)
59 log = logging.getLogger(__name__)
60
60
61
61
62 class _Missing(object):
62 class _Missing(object):
63 pass
63 pass
64
64
65
65
66 Missing = _Missing()
66 Missing = _Missing()
67
67
68
68
69 def M(self, key, state, **kwargs):
69 def M(self, key, state, **kwargs):
70 """
70 """
71 returns string from self.message based on given key,
71 returns string from self.message based on given key,
72 passed kw params are used to substitute %(named)s params inside
72 passed kw params are used to substitute %(named)s params inside
73 translated strings
73 translated strings
74
74
75 :param msg:
75 :param msg:
76 :param state:
76 :param state:
77 """
77 """
78
78
79 #state._ = staticmethod(_)
79 #state._ = staticmethod(_)
80 # inject validator into state object
80 # inject validator into state object
81 return self.message(key, state, **kwargs)
81 return self.message(key, state, **kwargs)
82
82
83
83
84 def UniqueList(localizer, convert=None):
84 def UniqueList(localizer, convert=None):
85 _ = localizer
85 _ = localizer
86
86
87 class _validator(formencode.FancyValidator):
87 class _validator(formencode.FancyValidator):
88 """
88 """
89 Unique List !
89 Unique List !
90 """
90 """
91 messages = {
91 messages = {
92 'empty': _(u'Value cannot be an empty list'),
92 'empty': _('Value cannot be an empty list'),
93 'missing_value': _(u'Value cannot be an empty list'),
93 'missing_value': _('Value cannot be an empty list'),
94 }
94 }
95
95
96 def _to_python(self, value, state):
96 def _to_python(self, value, state):
97 ret_val = []
97 ret_val = []
98
98
99 def make_unique(value):
99 def make_unique(value):
100 seen = []
100 seen = []
101 return [c for c in value if not (c in seen or seen.append(c))]
101 return [c for c in value if not (c in seen or seen.append(c))]
102
102
103 if isinstance(value, list):
103 if isinstance(value, list):
104 ret_val = make_unique(value)
104 ret_val = make_unique(value)
105 elif isinstance(value, set):
105 elif isinstance(value, set):
106 ret_val = make_unique(list(value))
106 ret_val = make_unique(list(value))
107 elif isinstance(value, tuple):
107 elif isinstance(value, tuple):
108 ret_val = make_unique(list(value))
108 ret_val = make_unique(list(value))
109 elif value is None:
109 elif value is None:
110 ret_val = []
110 ret_val = []
111 else:
111 else:
112 ret_val = [value]
112 ret_val = [value]
113
113
114 if convert:
114 if convert:
115 ret_val = map(convert, ret_val)
115 ret_val = map(convert, ret_val)
116 return ret_val
116 return ret_val
117
117
118 def empty_value(self, value):
118 def empty_value(self, value):
119 return []
119 return []
120 return _validator
120 return _validator
121
121
122
122
123 def UniqueListFromString(localizer):
123 def UniqueListFromString(localizer):
124 _ = localizer
124 _ = localizer
125
125
126 class _validator(UniqueList(localizer)):
126 class _validator(UniqueList(localizer)):
127 def _to_python(self, value, state):
127 def _to_python(self, value, state):
128 if isinstance(value, str):
128 if isinstance(value, str):
129 value = aslist(value, ',')
129 value = aslist(value, ',')
130 return super(_validator, self)._to_python(value, state)
130 return super(_validator, self)._to_python(value, state)
131 return _validator
131 return _validator
132
132
133
133
134 def ValidSvnPattern(localizer, section, repo_name=None):
134 def ValidSvnPattern(localizer, section, repo_name=None):
135 _ = localizer
135 _ = localizer
136
136
137 class _validator(formencode.validators.FancyValidator):
137 class _validator(formencode.validators.FancyValidator):
138 messages = {
138 messages = {
139 'pattern_exists': _(u'Pattern already exists'),
139 'pattern_exists': _('Pattern already exists'),
140 }
140 }
141
141
142 def validate_python(self, value, state):
142 def validate_python(self, value, state):
143 if not value:
143 if not value:
144 return
144 return
145 model = VcsSettingsModel(repo=repo_name)
145 model = VcsSettingsModel(repo=repo_name)
146 ui_settings = model.get_svn_patterns(section=section)
146 ui_settings = model.get_svn_patterns(section=section)
147 for entry in ui_settings:
147 for entry in ui_settings:
148 if value == entry.value:
148 if value == entry.value:
149 msg = M(self, 'pattern_exists', state)
149 msg = M(self, 'pattern_exists', state)
150 raise formencode.Invalid(msg, value, state)
150 raise formencode.Invalid(msg, value, state)
151 return _validator
151 return _validator
152
152
153
153
154 def ValidUsername(localizer, edit=False, old_data=None):
154 def ValidUsername(localizer, edit=False, old_data=None):
155 _ = localizer
155 _ = localizer
156 old_data = old_data or {}
156 old_data = old_data or {}
157
157
158 class _validator(formencode.validators.FancyValidator):
158 class _validator(formencode.validators.FancyValidator):
159 messages = {
159 messages = {
160 'username_exists': _(u'Username "%(username)s" already exists'),
160 'username_exists': _('Username "%(username)s" already exists'),
161 'system_invalid_username':
161 'system_invalid_username':
162 _(u'Username "%(username)s" is forbidden'),
162 _('Username "%(username)s" is forbidden'),
163 'invalid_username':
163 'invalid_username':
164 _(u'Username may only contain alphanumeric characters '
164 _('Username may only contain alphanumeric characters '
165 u'underscores, periods or dashes and must begin with '
165 'underscores, periods or dashes and must begin with '
166 u'alphanumeric character or underscore')
166 'alphanumeric character or underscore')
167 }
167 }
168
168
169 def validate_python(self, value, state):
169 def validate_python(self, value, state):
170 if value in ['default', 'new_user']:
170 if value in ['default', 'new_user']:
171 msg = M(self, 'system_invalid_username', state, username=value)
171 msg = M(self, 'system_invalid_username', state, username=value)
172 raise formencode.Invalid(msg, value, state)
172 raise formencode.Invalid(msg, value, state)
173 # check if user is unique
173 # check if user is unique
174 old_un = None
174 old_un = None
175 if edit:
175 if edit:
176 old_un = User.get(old_data.get('user_id')).username
176 old_un = User.get(old_data.get('user_id')).username
177
177
178 if old_un != value or not edit:
178 if old_un != value or not edit:
179 if User.get_by_username(value, case_insensitive=True):
179 if User.get_by_username(value, case_insensitive=True):
180 msg = M(self, 'username_exists', state, username=value)
180 msg = M(self, 'username_exists', state, username=value)
181 raise formencode.Invalid(msg, value, state)
181 raise formencode.Invalid(msg, value, state)
182
182
183 if (re.match(r'^[\w]{1}[\w\-\.]{0,254}$', value)
183 if (re.match(r'^[\w]{1}[\w\-\.]{0,254}$', value)
184 is None):
184 is None):
185 msg = M(self, 'invalid_username', state)
185 msg = M(self, 'invalid_username', state)
186 raise formencode.Invalid(msg, value, state)
186 raise formencode.Invalid(msg, value, state)
187 return _validator
187 return _validator
188
188
189
189
190 def ValidRepoUser(localizer, allow_disabled=False):
190 def ValidRepoUser(localizer, allow_disabled=False):
191 _ = localizer
191 _ = localizer
192
192
193 class _validator(formencode.validators.FancyValidator):
193 class _validator(formencode.validators.FancyValidator):
194 messages = {
194 messages = {
195 'invalid_username': _(u'Username %(username)s is not valid'),
195 'invalid_username': _('Username %(username)s is not valid'),
196 'disabled_username': _(u'Username %(username)s is disabled')
196 'disabled_username': _('Username %(username)s is disabled')
197 }
197 }
198
198
199 def validate_python(self, value, state):
199 def validate_python(self, value, state):
200 try:
200 try:
201 user = User.query().filter(User.username == value).one()
201 user = User.query().filter(User.username == value).one()
202 except Exception:
202 except Exception:
203 msg = M(self, 'invalid_username', state, username=value)
203 msg = M(self, 'invalid_username', state, username=value)
204 raise formencode.Invalid(
204 raise formencode.Invalid(
205 msg, value, state, error_dict={'username': msg}
205 msg, value, state, error_dict={'username': msg}
206 )
206 )
207 if user and (not allow_disabled and not user.active):
207 if user and (not allow_disabled and not user.active):
208 msg = M(self, 'disabled_username', state, username=value)
208 msg = M(self, 'disabled_username', state, username=value)
209 raise formencode.Invalid(
209 raise formencode.Invalid(
210 msg, value, state, error_dict={'username': msg}
210 msg, value, state, error_dict={'username': msg}
211 )
211 )
212 return _validator
212 return _validator
213
213
214
214
215 def ValidUserGroup(localizer, edit=False, old_data=None):
215 def ValidUserGroup(localizer, edit=False, old_data=None):
216 _ = localizer
216 _ = localizer
217 old_data = old_data or {}
217 old_data = old_data or {}
218
218
219 class _validator(formencode.validators.FancyValidator):
219 class _validator(formencode.validators.FancyValidator):
220 messages = {
220 messages = {
221 'invalid_group': _(u'Invalid user group name'),
221 'invalid_group': _('Invalid user group name'),
222 'group_exist': _(u'User group `%(usergroup)s` already exists'),
222 'group_exist': _('User group `%(usergroup)s` already exists'),
223 'invalid_usergroup_name':
223 'invalid_usergroup_name':
224 _(u'user group name may only contain alphanumeric '
224 _('user group name may only contain alphanumeric '
225 u'characters underscores, periods or dashes and must begin '
225 'characters underscores, periods or dashes and must begin '
226 u'with alphanumeric character')
226 'with alphanumeric character')
227 }
227 }
228
228
229 def validate_python(self, value, state):
229 def validate_python(self, value, state):
230 if value in ['default']:
230 if value in ['default']:
231 msg = M(self, 'invalid_group', state)
231 msg = M(self, 'invalid_group', state)
232 raise formencode.Invalid(
232 raise formencode.Invalid(
233 msg, value, state, error_dict={'users_group_name': msg}
233 msg, value, state, error_dict={'users_group_name': msg}
234 )
234 )
235 # check if group is unique
235 # check if group is unique
236 old_ugname = None
236 old_ugname = None
237 if edit:
237 if edit:
238 old_id = old_data.get('users_group_id')
238 old_id = old_data.get('users_group_id')
239 old_ugname = UserGroup.get(old_id).users_group_name
239 old_ugname = UserGroup.get(old_id).users_group_name
240
240
241 if old_ugname != value or not edit:
241 if old_ugname != value or not edit:
242 is_existing_group = UserGroup.get_by_group_name(
242 is_existing_group = UserGroup.get_by_group_name(
243 value, case_insensitive=True)
243 value, case_insensitive=True)
244 if is_existing_group:
244 if is_existing_group:
245 msg = M(self, 'group_exist', state, usergroup=value)
245 msg = M(self, 'group_exist', state, usergroup=value)
246 raise formencode.Invalid(
246 raise formencode.Invalid(
247 msg, value, state, error_dict={'users_group_name': msg}
247 msg, value, state, error_dict={'users_group_name': msg}
248 )
248 )
249
249
250 if re.match(r'^[a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+$', value) is None:
250 if re.match(r'^[a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+$', value) is None:
251 msg = M(self, 'invalid_usergroup_name', state)
251 msg = M(self, 'invalid_usergroup_name', state)
252 raise formencode.Invalid(
252 raise formencode.Invalid(
253 msg, value, state, error_dict={'users_group_name': msg}
253 msg, value, state, error_dict={'users_group_name': msg}
254 )
254 )
255 return _validator
255 return _validator
256
256
257
257
258 def ValidRepoGroup(localizer, edit=False, old_data=None, can_create_in_root=False):
258 def ValidRepoGroup(localizer, edit=False, old_data=None, can_create_in_root=False):
259 _ = localizer
259 _ = localizer
260 old_data = old_data or {}
260 old_data = old_data or {}
261
261
262 class _validator(formencode.validators.FancyValidator):
262 class _validator(formencode.validators.FancyValidator):
263 messages = {
263 messages = {
264 'group_parent_id': _(u'Cannot assign this group as parent'),
264 'group_parent_id': _('Cannot assign this group as parent'),
265 'group_exists': _(u'Group "%(group_name)s" already exists'),
265 'group_exists': _('Group "%(group_name)s" already exists'),
266 'repo_exists': _(u'Repository with name "%(group_name)s" '
266 'repo_exists': _('Repository with name "%(group_name)s" '
267 u'already exists'),
267 'already exists'),
268 'permission_denied': _(u"no permission to store repository group"
268 'permission_denied': _("no permission to store repository group"
269 u"in this location"),
269 "in this location"),
270 'permission_denied_root': _(
270 'permission_denied_root': _(
271 u"no permission to store repository group "
271 "no permission to store repository group "
272 u"in root location")
272 "in root location")
273 }
273 }
274
274
275 def _to_python(self, value, state):
275 def _to_python(self, value, state):
276 group_name = repo_name_slug(value.get('group_name', ''))
276 group_name = repo_name_slug(value.get('group_name', ''))
277 group_parent_id = safe_int(value.get('group_parent_id'))
277 group_parent_id = safe_int(value.get('group_parent_id'))
278 gr = RepoGroup.get(group_parent_id)
278 gr = RepoGroup.get(group_parent_id)
279 if gr:
279 if gr:
280 parent_group_path = gr.full_path
280 parent_group_path = gr.full_path
281 # value needs to be aware of group name in order to check
281 # value needs to be aware of group name in order to check
282 # db key This is an actual just the name to store in the
282 # db key This is an actual just the name to store in the
283 # database
283 # database
284 group_name_full = (
284 group_name_full = (
285 parent_group_path + RepoGroup.url_sep() + group_name)
285 parent_group_path + RepoGroup.url_sep() + group_name)
286 else:
286 else:
287 group_name_full = group_name
287 group_name_full = group_name
288
288
289 value['group_name'] = group_name
289 value['group_name'] = group_name
290 value['group_name_full'] = group_name_full
290 value['group_name_full'] = group_name_full
291 value['group_parent_id'] = group_parent_id
291 value['group_parent_id'] = group_parent_id
292 return value
292 return value
293
293
294 def validate_python(self, value, state):
294 def validate_python(self, value, state):
295
295
296 old_group_name = None
296 old_group_name = None
297 group_name = value.get('group_name')
297 group_name = value.get('group_name')
298 group_name_full = value.get('group_name_full')
298 group_name_full = value.get('group_name_full')
299 group_parent_id = safe_int(value.get('group_parent_id'))
299 group_parent_id = safe_int(value.get('group_parent_id'))
300 if group_parent_id == -1:
300 if group_parent_id == -1:
301 group_parent_id = None
301 group_parent_id = None
302
302
303 group_obj = RepoGroup.get(old_data.get('group_id'))
303 group_obj = RepoGroup.get(old_data.get('group_id'))
304 parent_group_changed = False
304 parent_group_changed = False
305 if edit:
305 if edit:
306 old_group_name = group_obj.group_name
306 old_group_name = group_obj.group_name
307 old_group_parent_id = group_obj.group_parent_id
307 old_group_parent_id = group_obj.group_parent_id
308
308
309 if group_parent_id != old_group_parent_id:
309 if group_parent_id != old_group_parent_id:
310 parent_group_changed = True
310 parent_group_changed = True
311
311
312 # TODO: mikhail: the following if statement is not reached
312 # TODO: mikhail: the following if statement is not reached
313 # since group_parent_id's OneOf validation fails before.
313 # since group_parent_id's OneOf validation fails before.
314 # Can be removed.
314 # Can be removed.
315
315
316 # check against setting a parent of self
316 # check against setting a parent of self
317 parent_of_self = (
317 parent_of_self = (
318 old_data['group_id'] == group_parent_id
318 old_data['group_id'] == group_parent_id
319 if group_parent_id else False
319 if group_parent_id else False
320 )
320 )
321 if parent_of_self:
321 if parent_of_self:
322 msg = M(self, 'group_parent_id', state)
322 msg = M(self, 'group_parent_id', state)
323 raise formencode.Invalid(
323 raise formencode.Invalid(
324 msg, value, state, error_dict={'group_parent_id': msg}
324 msg, value, state, error_dict={'group_parent_id': msg}
325 )
325 )
326
326
327 # group we're moving current group inside
327 # group we're moving current group inside
328 child_group = None
328 child_group = None
329 if group_parent_id:
329 if group_parent_id:
330 child_group = RepoGroup.query().filter(
330 child_group = RepoGroup.query().filter(
331 RepoGroup.group_id == group_parent_id).scalar()
331 RepoGroup.group_id == group_parent_id).scalar()
332
332
333 # do a special check that we cannot move a group to one of
333 # do a special check that we cannot move a group to one of
334 # it's children
334 # it's children
335 if edit and child_group:
335 if edit and child_group:
336 parents = [x.group_id for x in child_group.parents]
336 parents = [x.group_id for x in child_group.parents]
337 move_to_children = old_data['group_id'] in parents
337 move_to_children = old_data['group_id'] in parents
338 if move_to_children:
338 if move_to_children:
339 msg = M(self, 'group_parent_id', state)
339 msg = M(self, 'group_parent_id', state)
340 raise formencode.Invalid(
340 raise formencode.Invalid(
341 msg, value, state, error_dict={'group_parent_id': msg})
341 msg, value, state, error_dict={'group_parent_id': msg})
342
342
343 # Check if we have permission to store in the parent.
343 # Check if we have permission to store in the parent.
344 # Only check if the parent group changed.
344 # Only check if the parent group changed.
345 if parent_group_changed:
345 if parent_group_changed:
346 if child_group is None:
346 if child_group is None:
347 if not can_create_in_root:
347 if not can_create_in_root:
348 msg = M(self, 'permission_denied_root', state)
348 msg = M(self, 'permission_denied_root', state)
349 raise formencode.Invalid(
349 raise formencode.Invalid(
350 msg, value, state,
350 msg, value, state,
351 error_dict={'group_parent_id': msg})
351 error_dict={'group_parent_id': msg})
352 else:
352 else:
353 valid = HasRepoGroupPermissionAny('group.admin')
353 valid = HasRepoGroupPermissionAny('group.admin')
354 forbidden = not valid(
354 forbidden = not valid(
355 child_group.group_name, 'can create group validator')
355 child_group.group_name, 'can create group validator')
356 if forbidden:
356 if forbidden:
357 msg = M(self, 'permission_denied', state)
357 msg = M(self, 'permission_denied', state)
358 raise formencode.Invalid(
358 raise formencode.Invalid(
359 msg, value, state,
359 msg, value, state,
360 error_dict={'group_parent_id': msg})
360 error_dict={'group_parent_id': msg})
361
361
362 # if we change the name or it's new group, check for existing names
362 # if we change the name or it's new group, check for existing names
363 # or repositories with the same name
363 # or repositories with the same name
364 if old_group_name != group_name_full or not edit:
364 if old_group_name != group_name_full or not edit:
365 # check group
365 # check group
366 gr = RepoGroup.get_by_group_name(group_name_full)
366 gr = RepoGroup.get_by_group_name(group_name_full)
367 if gr:
367 if gr:
368 msg = M(self, 'group_exists', state, group_name=group_name)
368 msg = M(self, 'group_exists', state, group_name=group_name)
369 raise formencode.Invalid(
369 raise formencode.Invalid(
370 msg, value, state, error_dict={'group_name': msg})
370 msg, value, state, error_dict={'group_name': msg})
371
371
372 # check for same repo
372 # check for same repo
373 repo = Repository.get_by_repo_name(group_name_full)
373 repo = Repository.get_by_repo_name(group_name_full)
374 if repo:
374 if repo:
375 msg = M(self, 'repo_exists', state, group_name=group_name)
375 msg = M(self, 'repo_exists', state, group_name=group_name)
376 raise formencode.Invalid(
376 raise formencode.Invalid(
377 msg, value, state, error_dict={'group_name': msg})
377 msg, value, state, error_dict={'group_name': msg})
378 return _validator
378 return _validator
379
379
380
380
381 def ValidPassword(localizer):
381 def ValidPassword(localizer):
382 _ = localizer
382 _ = localizer
383
383
384 class _validator(formencode.validators.FancyValidator):
384 class _validator(formencode.validators.FancyValidator):
385 messages = {
385 messages = {
386 'invalid_password':
386 'invalid_password':
387 _(u'Invalid characters (non-ascii) in password')
387 _('Invalid characters (non-ascii) in password')
388 }
388 }
389
389
390 def validate_python(self, value, state):
390 def validate_python(self, value, state):
391 try:
391 try:
392 (value or '').decode('ascii')
392 (value or '').decode('ascii')
393 except UnicodeError:
393 except UnicodeError:
394 msg = M(self, 'invalid_password', state)
394 msg = M(self, 'invalid_password', state)
395 raise formencode.Invalid(msg, value, state,)
395 raise formencode.Invalid(msg, value, state,)
396 return _validator
396 return _validator
397
397
398
398
399 def ValidPasswordsMatch(
399 def ValidPasswordsMatch(
400 localizer, passwd='new_password',
400 localizer, passwd='new_password',
401 passwd_confirmation='password_confirmation'):
401 passwd_confirmation='password_confirmation'):
402 _ = localizer
402 _ = localizer
403
403
404 class _validator(formencode.validators.FancyValidator):
404 class _validator(formencode.validators.FancyValidator):
405 messages = {
405 messages = {
406 'password_mismatch': _(u'Passwords do not match'),
406 'password_mismatch': _('Passwords do not match'),
407 }
407 }
408
408
409 def validate_python(self, value, state):
409 def validate_python(self, value, state):
410
410
411 pass_val = value.get('password') or value.get(passwd)
411 pass_val = value.get('password') or value.get(passwd)
412 if pass_val != value[passwd_confirmation]:
412 if pass_val != value[passwd_confirmation]:
413 msg = M(self, 'password_mismatch', state)
413 msg = M(self, 'password_mismatch', state)
414 raise formencode.Invalid(
414 raise formencode.Invalid(
415 msg, value, state,
415 msg, value, state,
416 error_dict={passwd: msg, passwd_confirmation: msg}
416 error_dict={passwd: msg, passwd_confirmation: msg}
417 )
417 )
418 return _validator
418 return _validator
419
419
420
420
421 def ValidAuth(localizer):
421 def ValidAuth(localizer):
422 _ = localizer
422 _ = localizer
423
423
424 class _validator(formencode.validators.FancyValidator):
424 class _validator(formencode.validators.FancyValidator):
425 messages = {
425 messages = {
426 'invalid_password': _(u'invalid password'),
426 'invalid_password': _('invalid password'),
427 'invalid_username': _(u'invalid user name'),
427 'invalid_username': _('invalid user name'),
428 'disabled_account': _(u'Your account is disabled')
428 'disabled_account': _('Your account is disabled')
429 }
429 }
430
430
431 def validate_python(self, value, state):
431 def validate_python(self, value, state):
432 from rhodecode.authentication.base import authenticate, HTTP_TYPE
432 from rhodecode.authentication.base import authenticate, HTTP_TYPE
433
433
434 password = value['password']
434 password = value['password']
435 username = value['username']
435 username = value['username']
436
436
437 if not authenticate(username, password, '', HTTP_TYPE,
437 if not authenticate(username, password, '', HTTP_TYPE,
438 skip_missing=True):
438 skip_missing=True):
439 user = User.get_by_username(username)
439 user = User.get_by_username(username)
440 if user and not user.active:
440 if user and not user.active:
441 log.warning('user %s is disabled', username)
441 log.warning('user %s is disabled', username)
442 msg = M(self, 'disabled_account', state)
442 msg = M(self, 'disabled_account', state)
443 raise formencode.Invalid(
443 raise formencode.Invalid(
444 msg, value, state, error_dict={'username': msg}
444 msg, value, state, error_dict={'username': msg}
445 )
445 )
446 else:
446 else:
447 log.warning('user `%s` failed to authenticate', username)
447 log.warning('user `%s` failed to authenticate', username)
448 msg = M(self, 'invalid_username', state)
448 msg = M(self, 'invalid_username', state)
449 msg2 = M(self, 'invalid_password', state)
449 msg2 = M(self, 'invalid_password', state)
450 raise formencode.Invalid(
450 raise formencode.Invalid(
451 msg, value, state,
451 msg, value, state,
452 error_dict={'username': msg, 'password': msg2}
452 error_dict={'username': msg, 'password': msg2}
453 )
453 )
454 return _validator
454 return _validator
455
455
456
456
457 def ValidRepoName(localizer, edit=False, old_data=None):
457 def ValidRepoName(localizer, edit=False, old_data=None):
458 old_data = old_data or {}
458 old_data = old_data or {}
459 _ = localizer
459 _ = localizer
460
460
461 class _validator(formencode.validators.FancyValidator):
461 class _validator(formencode.validators.FancyValidator):
462 messages = {
462 messages = {
463 'invalid_repo_name':
463 'invalid_repo_name':
464 _(u'Repository name %(repo)s is disallowed'),
464 _('Repository name %(repo)s is disallowed'),
465 # top level
465 # top level
466 'repository_exists': _(u'Repository with name %(repo)s '
466 'repository_exists': _('Repository with name %(repo)s '
467 u'already exists'),
467 'already exists'),
468 'group_exists': _(u'Repository group with name "%(repo)s" '
468 'group_exists': _('Repository group with name "%(repo)s" '
469 u'already exists'),
469 'already exists'),
470 # inside a group
470 # inside a group
471 'repository_in_group_exists': _(u'Repository with name %(repo)s '
471 'repository_in_group_exists': _('Repository with name %(repo)s '
472 u'exists in group "%(group)s"'),
472 'exists in group "%(group)s"'),
473 'group_in_group_exists': _(
473 'group_in_group_exists': _(
474 u'Repository group with name "%(repo)s" '
474 'Repository group with name "%(repo)s" '
475 u'exists in group "%(group)s"'),
475 'exists in group "%(group)s"'),
476 }
476 }
477
477
478 def _to_python(self, value, state):
478 def _to_python(self, value, state):
479 repo_name = repo_name_slug(value.get('repo_name', ''))
479 repo_name = repo_name_slug(value.get('repo_name', ''))
480 repo_group = value.get('repo_group')
480 repo_group = value.get('repo_group')
481 if repo_group:
481 if repo_group:
482 gr = RepoGroup.get(repo_group)
482 gr = RepoGroup.get(repo_group)
483 group_path = gr.full_path
483 group_path = gr.full_path
484 group_name = gr.group_name
484 group_name = gr.group_name
485 # value needs to be aware of group name in order to check
485 # value needs to be aware of group name in order to check
486 # db key This is an actual just the name to store in the
486 # db key This is an actual just the name to store in the
487 # database
487 # database
488 repo_name_full = group_path + RepoGroup.url_sep() + repo_name
488 repo_name_full = group_path + RepoGroup.url_sep() + repo_name
489 else:
489 else:
490 group_name = group_path = ''
490 group_name = group_path = ''
491 repo_name_full = repo_name
491 repo_name_full = repo_name
492
492
493 value['repo_name'] = repo_name
493 value['repo_name'] = repo_name
494 value['repo_name_full'] = repo_name_full
494 value['repo_name_full'] = repo_name_full
495 value['group_path'] = group_path
495 value['group_path'] = group_path
496 value['group_name'] = group_name
496 value['group_name'] = group_name
497 return value
497 return value
498
498
499 def validate_python(self, value, state):
499 def validate_python(self, value, state):
500
500
501 repo_name = value.get('repo_name')
501 repo_name = value.get('repo_name')
502 repo_name_full = value.get('repo_name_full')
502 repo_name_full = value.get('repo_name_full')
503 group_path = value.get('group_path')
503 group_path = value.get('group_path')
504 group_name = value.get('group_name')
504 group_name = value.get('group_name')
505
505
506 if repo_name in [ADMIN_PREFIX, '']:
506 if repo_name in [ADMIN_PREFIX, '']:
507 msg = M(self, 'invalid_repo_name', state, repo=repo_name)
507 msg = M(self, 'invalid_repo_name', state, repo=repo_name)
508 raise formencode.Invalid(
508 raise formencode.Invalid(
509 msg, value, state, error_dict={'repo_name': msg})
509 msg, value, state, error_dict={'repo_name': msg})
510
510
511 rename = old_data.get('repo_name') != repo_name_full
511 rename = old_data.get('repo_name') != repo_name_full
512 create = not edit
512 create = not edit
513 if rename or create:
513 if rename or create:
514
514
515 if group_path:
515 if group_path:
516 if Repository.get_by_repo_name(repo_name_full):
516 if Repository.get_by_repo_name(repo_name_full):
517 msg = M(self, 'repository_in_group_exists', state,
517 msg = M(self, 'repository_in_group_exists', state,
518 repo=repo_name, group=group_name)
518 repo=repo_name, group=group_name)
519 raise formencode.Invalid(
519 raise formencode.Invalid(
520 msg, value, state, error_dict={'repo_name': msg})
520 msg, value, state, error_dict={'repo_name': msg})
521 if RepoGroup.get_by_group_name(repo_name_full):
521 if RepoGroup.get_by_group_name(repo_name_full):
522 msg = M(self, 'group_in_group_exists', state,
522 msg = M(self, 'group_in_group_exists', state,
523 repo=repo_name, group=group_name)
523 repo=repo_name, group=group_name)
524 raise formencode.Invalid(
524 raise formencode.Invalid(
525 msg, value, state, error_dict={'repo_name': msg})
525 msg, value, state, error_dict={'repo_name': msg})
526 else:
526 else:
527 if RepoGroup.get_by_group_name(repo_name_full):
527 if RepoGroup.get_by_group_name(repo_name_full):
528 msg = M(self, 'group_exists', state, repo=repo_name)
528 msg = M(self, 'group_exists', state, repo=repo_name)
529 raise formencode.Invalid(
529 raise formencode.Invalid(
530 msg, value, state, error_dict={'repo_name': msg})
530 msg, value, state, error_dict={'repo_name': msg})
531
531
532 if Repository.get_by_repo_name(repo_name_full):
532 if Repository.get_by_repo_name(repo_name_full):
533 msg = M(
533 msg = M(
534 self, 'repository_exists', state, repo=repo_name)
534 self, 'repository_exists', state, repo=repo_name)
535 raise formencode.Invalid(
535 raise formencode.Invalid(
536 msg, value, state, error_dict={'repo_name': msg})
536 msg, value, state, error_dict={'repo_name': msg})
537 return value
537 return value
538 return _validator
538 return _validator
539
539
540
540
541 def ValidForkName(localizer, *args, **kwargs):
541 def ValidForkName(localizer, *args, **kwargs):
542 _ = localizer
542 _ = localizer
543
543
544 return ValidRepoName(localizer, *args, **kwargs)
544 return ValidRepoName(localizer, *args, **kwargs)
545
545
546
546
547 def SlugifyName(localizer):
547 def SlugifyName(localizer):
548 _ = localizer
548 _ = localizer
549
549
550 class _validator(formencode.validators.FancyValidator):
550 class _validator(formencode.validators.FancyValidator):
551
551
552 def _to_python(self, value, state):
552 def _to_python(self, value, state):
553 return repo_name_slug(value)
553 return repo_name_slug(value)
554
554
555 def validate_python(self, value, state):
555 def validate_python(self, value, state):
556 pass
556 pass
557 return _validator
557 return _validator
558
558
559
559
560 def CannotHaveGitSuffix(localizer):
560 def CannotHaveGitSuffix(localizer):
561 _ = localizer
561 _ = localizer
562
562
563 class _validator(formencode.validators.FancyValidator):
563 class _validator(formencode.validators.FancyValidator):
564 messages = {
564 messages = {
565 'has_git_suffix':
565 'has_git_suffix':
566 _(u'Repository name cannot end with .git'),
566 _('Repository name cannot end with .git'),
567 }
567 }
568
568
569 def _to_python(self, value, state):
569 def _to_python(self, value, state):
570 return value
570 return value
571
571
572 def validate_python(self, value, state):
572 def validate_python(self, value, state):
573 if value and value.endswith('.git'):
573 if value and value.endswith('.git'):
574 msg = M(
574 msg = M(
575 self, 'has_git_suffix', state)
575 self, 'has_git_suffix', state)
576 raise formencode.Invalid(
576 raise formencode.Invalid(
577 msg, value, state, error_dict={'repo_name': msg})
577 msg, value, state, error_dict={'repo_name': msg})
578 return _validator
578 return _validator
579
579
580
580
581 def ValidCloneUri(localizer):
581 def ValidCloneUri(localizer):
582 _ = localizer
582 _ = localizer
583
583
584 class InvalidCloneUrl(Exception):
584 class InvalidCloneUrl(Exception):
585 allowed_prefixes = ()
585 allowed_prefixes = ()
586
586
587 def url_handler(repo_type, url):
587 def url_handler(repo_type, url):
588 config = make_db_config(clear_session=False)
588 config = make_db_config(clear_session=False)
589 if repo_type == 'hg':
589 if repo_type == 'hg':
590 allowed_prefixes = ('http', 'svn+http', 'git+http')
590 allowed_prefixes = ('http', 'svn+http', 'git+http')
591
591
592 if 'http' in url[:4]:
592 if 'http' in url[:4]:
593 # initially check if it's at least the proper URL
593 # initially check if it's at least the proper URL
594 # or does it pass basic auth
594 # or does it pass basic auth
595 MercurialRepository.check_url(url, config)
595 MercurialRepository.check_url(url, config)
596 elif 'svn+http' in url[:8]: # svn->hg import
596 elif 'svn+http' in url[:8]: # svn->hg import
597 SubversionRepository.check_url(url, config)
597 SubversionRepository.check_url(url, config)
598 elif 'git+http' in url[:8]: # git->hg import
598 elif 'git+http' in url[:8]: # git->hg import
599 raise NotImplementedError()
599 raise NotImplementedError()
600 else:
600 else:
601 exc = InvalidCloneUrl('Clone from URI %s not allowed. '
601 exc = InvalidCloneUrl('Clone from URI %s not allowed. '
602 'Allowed url must start with one of %s'
602 'Allowed url must start with one of %s'
603 % (url, ','.join(allowed_prefixes)))
603 % (url, ','.join(allowed_prefixes)))
604 exc.allowed_prefixes = allowed_prefixes
604 exc.allowed_prefixes = allowed_prefixes
605 raise exc
605 raise exc
606
606
607 elif repo_type == 'git':
607 elif repo_type == 'git':
608 allowed_prefixes = ('http', 'svn+http', 'hg+http')
608 allowed_prefixes = ('http', 'svn+http', 'hg+http')
609 if 'http' in url[:4]:
609 if 'http' in url[:4]:
610 # initially check if it's at least the proper URL
610 # initially check if it's at least the proper URL
611 # or does it pass basic auth
611 # or does it pass basic auth
612 GitRepository.check_url(url, config)
612 GitRepository.check_url(url, config)
613 elif 'svn+http' in url[:8]: # svn->git import
613 elif 'svn+http' in url[:8]: # svn->git import
614 raise NotImplementedError()
614 raise NotImplementedError()
615 elif 'hg+http' in url[:8]: # hg->git import
615 elif 'hg+http' in url[:8]: # hg->git import
616 raise NotImplementedError()
616 raise NotImplementedError()
617 else:
617 else:
618 exc = InvalidCloneUrl('Clone from URI %s not allowed. '
618 exc = InvalidCloneUrl('Clone from URI %s not allowed. '
619 'Allowed url must start with one of %s'
619 'Allowed url must start with one of %s'
620 % (url, ','.join(allowed_prefixes)))
620 % (url, ','.join(allowed_prefixes)))
621 exc.allowed_prefixes = allowed_prefixes
621 exc.allowed_prefixes = allowed_prefixes
622 raise exc
622 raise exc
623
623
624 class _validator(formencode.validators.FancyValidator):
624 class _validator(formencode.validators.FancyValidator):
625 messages = {
625 messages = {
626 'clone_uri': _(u'invalid clone url or credentials for %(rtype)s repository'),
626 'clone_uri': _('invalid clone url or credentials for %(rtype)s repository'),
627 'invalid_clone_uri': _(
627 'invalid_clone_uri': _(
628 u'Invalid clone url, provide a valid clone '
628 'Invalid clone url, provide a valid clone '
629 u'url starting with one of %(allowed_prefixes)s')
629 'url starting with one of %(allowed_prefixes)s')
630 }
630 }
631
631
632 def validate_python(self, value, state):
632 def validate_python(self, value, state):
633 repo_type = value.get('repo_type')
633 repo_type = value.get('repo_type')
634 url = value.get('clone_uri')
634 url = value.get('clone_uri')
635
635
636 if url:
636 if url:
637 try:
637 try:
638 url_handler(repo_type, url)
638 url_handler(repo_type, url)
639 except InvalidCloneUrl as e:
639 except InvalidCloneUrl as e:
640 log.warning(e)
640 log.warning(e)
641 msg = M(self, 'invalid_clone_uri', state, rtype=repo_type,
641 msg = M(self, 'invalid_clone_uri', state, rtype=repo_type,
642 allowed_prefixes=','.join(e.allowed_prefixes))
642 allowed_prefixes=','.join(e.allowed_prefixes))
643 raise formencode.Invalid(msg, value, state,
643 raise formencode.Invalid(msg, value, state,
644 error_dict={'clone_uri': msg})
644 error_dict={'clone_uri': msg})
645 except Exception:
645 except Exception:
646 log.exception('Url validation failed')
646 log.exception('Url validation failed')
647 msg = M(self, 'clone_uri', state, rtype=repo_type)
647 msg = M(self, 'clone_uri', state, rtype=repo_type)
648 raise formencode.Invalid(msg, value, state,
648 raise formencode.Invalid(msg, value, state,
649 error_dict={'clone_uri': msg})
649 error_dict={'clone_uri': msg})
650 return _validator
650 return _validator
651
651
652
652
653 def ValidForkType(localizer, old_data=None):
653 def ValidForkType(localizer, old_data=None):
654 _ = localizer
654 _ = localizer
655 old_data = old_data or {}
655 old_data = old_data or {}
656
656
657 class _validator(formencode.validators.FancyValidator):
657 class _validator(formencode.validators.FancyValidator):
658 messages = {
658 messages = {
659 'invalid_fork_type': _(u'Fork have to be the same type as parent')
659 'invalid_fork_type': _('Fork have to be the same type as parent')
660 }
660 }
661
661
662 def validate_python(self, value, state):
662 def validate_python(self, value, state):
663 if old_data['repo_type'] != value:
663 if old_data['repo_type'] != value:
664 msg = M(self, 'invalid_fork_type', state)
664 msg = M(self, 'invalid_fork_type', state)
665 raise formencode.Invalid(
665 raise formencode.Invalid(
666 msg, value, state, error_dict={'repo_type': msg}
666 msg, value, state, error_dict={'repo_type': msg}
667 )
667 )
668 return _validator
668 return _validator
669
669
670
670
671 def CanWriteGroup(localizer, old_data=None):
671 def CanWriteGroup(localizer, old_data=None):
672 _ = localizer
672 _ = localizer
673
673
674 class _validator(formencode.validators.FancyValidator):
674 class _validator(formencode.validators.FancyValidator):
675 messages = {
675 messages = {
676 'permission_denied': _(
676 'permission_denied': _(
677 u"You do not have the permission "
677 "You do not have the permission "
678 u"to create repositories in this group."),
678 "to create repositories in this group."),
679 'permission_denied_root': _(
679 'permission_denied_root': _(
680 u"You do not have the permission to store repositories in "
680 "You do not have the permission to store repositories in "
681 u"the root location.")
681 "the root location.")
682 }
682 }
683
683
684 def _to_python(self, value, state):
684 def _to_python(self, value, state):
685 # root location
685 # root location
686 if value in [-1, "-1"]:
686 if value in [-1, "-1"]:
687 return None
687 return None
688 return value
688 return value
689
689
690 def validate_python(self, value, state):
690 def validate_python(self, value, state):
691 gr = RepoGroup.get(value)
691 gr = RepoGroup.get(value)
692 gr_name = gr.group_name if gr else None # None means ROOT location
692 gr_name = gr.group_name if gr else None # None means ROOT location
693 # create repositories with write permission on group is set to true
693 # create repositories with write permission on group is set to true
694 create_on_write = HasPermissionAny(
694 create_on_write = HasPermissionAny(
695 'hg.create.write_on_repogroup.true')()
695 'hg.create.write_on_repogroup.true')()
696 group_admin = HasRepoGroupPermissionAny('group.admin')(
696 group_admin = HasRepoGroupPermissionAny('group.admin')(
697 gr_name, 'can write into group validator')
697 gr_name, 'can write into group validator')
698 group_write = HasRepoGroupPermissionAny('group.write')(
698 group_write = HasRepoGroupPermissionAny('group.write')(
699 gr_name, 'can write into group validator')
699 gr_name, 'can write into group validator')
700 forbidden = not (group_admin or (group_write and create_on_write))
700 forbidden = not (group_admin or (group_write and create_on_write))
701 can_create_repos = HasPermissionAny(
701 can_create_repos = HasPermissionAny(
702 'hg.admin', 'hg.create.repository')
702 'hg.admin', 'hg.create.repository')
703 gid = (old_data['repo_group'].get('group_id')
703 gid = (old_data['repo_group'].get('group_id')
704 if (old_data and 'repo_group' in old_data) else None)
704 if (old_data and 'repo_group' in old_data) else None)
705 value_changed = gid != safe_int(value)
705 value_changed = gid != safe_int(value)
706 new = not old_data
706 new = not old_data
707 # do check if we changed the value, there's a case that someone got
707 # do check if we changed the value, there's a case that someone got
708 # revoked write permissions to a repository, he still created, we
708 # revoked write permissions to a repository, he still created, we
709 # don't need to check permission if he didn't change the value of
709 # don't need to check permission if he didn't change the value of
710 # groups in form box
710 # groups in form box
711 if value_changed or new:
711 if value_changed or new:
712 # parent group need to be existing
712 # parent group need to be existing
713 if gr and forbidden:
713 if gr and forbidden:
714 msg = M(self, 'permission_denied', state)
714 msg = M(self, 'permission_denied', state)
715 raise formencode.Invalid(
715 raise formencode.Invalid(
716 msg, value, state, error_dict={'repo_type': msg}
716 msg, value, state, error_dict={'repo_type': msg}
717 )
717 )
718 # check if we can write to root location !
718 # check if we can write to root location !
719 elif gr is None and not can_create_repos():
719 elif gr is None and not can_create_repos():
720 msg = M(self, 'permission_denied_root', state)
720 msg = M(self, 'permission_denied_root', state)
721 raise formencode.Invalid(
721 raise formencode.Invalid(
722 msg, value, state, error_dict={'repo_type': msg}
722 msg, value, state, error_dict={'repo_type': msg}
723 )
723 )
724 return _validator
724 return _validator
725
725
726
726
727 def ValidPerms(localizer, type_='repo'):
727 def ValidPerms(localizer, type_='repo'):
728 _ = localizer
728 _ = localizer
729 if type_ == 'repo_group':
729 if type_ == 'repo_group':
730 EMPTY_PERM = 'group.none'
730 EMPTY_PERM = 'group.none'
731 elif type_ == 'repo':
731 elif type_ == 'repo':
732 EMPTY_PERM = 'repository.none'
732 EMPTY_PERM = 'repository.none'
733 elif type_ == 'user_group':
733 elif type_ == 'user_group':
734 EMPTY_PERM = 'usergroup.none'
734 EMPTY_PERM = 'usergroup.none'
735
735
736 class _validator(formencode.validators.FancyValidator):
736 class _validator(formencode.validators.FancyValidator):
737 messages = {
737 messages = {
738 'perm_new_member_name':
738 'perm_new_member_name':
739 _(u'This username or user group name is not valid')
739 _('This username or user group name is not valid')
740 }
740 }
741
741
742 def _to_python(self, value, state):
742 def _to_python(self, value, state):
743 perm_updates = OrderedSet()
743 perm_updates = OrderedSet()
744 perm_additions = OrderedSet()
744 perm_additions = OrderedSet()
745 perm_deletions = OrderedSet()
745 perm_deletions = OrderedSet()
746 # build a list of permission to update/delete and new permission
746 # build a list of permission to update/delete and new permission
747
747
748 # Read the perm_new_member/perm_del_member attributes and group
748 # Read the perm_new_member/perm_del_member attributes and group
749 # them by they IDs
749 # them by they IDs
750 new_perms_group = collections.defaultdict(dict)
750 new_perms_group = collections.defaultdict(dict)
751 del_perms_group = collections.defaultdict(dict)
751 del_perms_group = collections.defaultdict(dict)
752 for k, v in value.copy().items():
752 for k, v in value.copy().items():
753 if k.startswith('perm_del_member'):
753 if k.startswith('perm_del_member'):
754 # delete from org storage so we don't process that later
754 # delete from org storage so we don't process that later
755 del value[k]
755 del value[k]
756 # part is `id`, `type`
756 # part is `id`, `type`
757 _type, part = k.split('perm_del_member_')
757 _type, part = k.split('perm_del_member_')
758 args = part.split('_')
758 args = part.split('_')
759 if len(args) == 2:
759 if len(args) == 2:
760 _key, pos = args
760 _key, pos = args
761 del_perms_group[pos][_key] = v
761 del_perms_group[pos][_key] = v
762 if k.startswith('perm_new_member'):
762 if k.startswith('perm_new_member'):
763 # delete from org storage so we don't process that later
763 # delete from org storage so we don't process that later
764 del value[k]
764 del value[k]
765 # part is `id`, `type`, `perm`
765 # part is `id`, `type`, `perm`
766 _type, part = k.split('perm_new_member_')
766 _type, part = k.split('perm_new_member_')
767 args = part.split('_')
767 args = part.split('_')
768 if len(args) == 2:
768 if len(args) == 2:
769 _key, pos = args
769 _key, pos = args
770 new_perms_group[pos][_key] = v
770 new_perms_group[pos][_key] = v
771
771
772 # store the deletes
772 # store the deletes
773 for k in sorted(del_perms_group.keys()):
773 for k in sorted(del_perms_group.keys()):
774 perm_dict = del_perms_group[k]
774 perm_dict = del_perms_group[k]
775 del_member = perm_dict.get('id')
775 del_member = perm_dict.get('id')
776 del_type = perm_dict.get('type')
776 del_type = perm_dict.get('type')
777 if del_member and del_type:
777 if del_member and del_type:
778 perm_deletions.add(
778 perm_deletions.add(
779 (del_member, None, del_type))
779 (del_member, None, del_type))
780
780
781 # store additions in order of how they were added in web form
781 # store additions in order of how they were added in web form
782 for k in sorted(new_perms_group.keys()):
782 for k in sorted(new_perms_group.keys()):
783 perm_dict = new_perms_group[k]
783 perm_dict = new_perms_group[k]
784 new_member = perm_dict.get('id')
784 new_member = perm_dict.get('id')
785 new_type = perm_dict.get('type')
785 new_type = perm_dict.get('type')
786 new_perm = perm_dict.get('perm')
786 new_perm = perm_dict.get('perm')
787 if new_member and new_perm and new_type:
787 if new_member and new_perm and new_type:
788 perm_additions.add(
788 perm_additions.add(
789 (new_member, new_perm, new_type))
789 (new_member, new_perm, new_type))
790
790
791 # get updates of permissions
791 # get updates of permissions
792 # (read the existing radio button states)
792 # (read the existing radio button states)
793 default_user_id = User.get_default_user_id()
793 default_user_id = User.get_default_user_id()
794
794
795 for k, update_value in value.items():
795 for k, update_value in value.items():
796 if k.startswith('u_perm_') or k.startswith('g_perm_'):
796 if k.startswith('u_perm_') or k.startswith('g_perm_'):
797 obj_type = k[0]
797 obj_type = k[0]
798 obj_id = k[7:]
798 obj_id = k[7:]
799 update_type = {'u': 'user',
799 update_type = {'u': 'user',
800 'g': 'user_group'}[obj_type]
800 'g': 'user_group'}[obj_type]
801
801
802 if obj_type == 'u' and safe_int(obj_id) == default_user_id:
802 if obj_type == 'u' and safe_int(obj_id) == default_user_id:
803 if str2bool(value.get('repo_private')):
803 if str2bool(value.get('repo_private')):
804 # prevent from updating default user permissions
804 # prevent from updating default user permissions
805 # when this repository is marked as private
805 # when this repository is marked as private
806 update_value = EMPTY_PERM
806 update_value = EMPTY_PERM
807
807
808 perm_updates.add(
808 perm_updates.add(
809 (obj_id, update_value, update_type))
809 (obj_id, update_value, update_type))
810
810
811 value['perm_additions'] = [] # propagated later
811 value['perm_additions'] = [] # propagated later
812 value['perm_updates'] = list(perm_updates)
812 value['perm_updates'] = list(perm_updates)
813 value['perm_deletions'] = list(perm_deletions)
813 value['perm_deletions'] = list(perm_deletions)
814
814
815 updates_map = dict(
815 updates_map = dict(
816 (x[0], (x[1], x[2])) for x in value['perm_updates'])
816 (x[0], (x[1], x[2])) for x in value['perm_updates'])
817 # make sure Additions don't override updates.
817 # make sure Additions don't override updates.
818 for member_id, perm, member_type in list(perm_additions):
818 for member_id, perm, member_type in list(perm_additions):
819 if member_id in updates_map:
819 if member_id in updates_map:
820 perm = updates_map[member_id][0]
820 perm = updates_map[member_id][0]
821 value['perm_additions'].append((member_id, perm, member_type))
821 value['perm_additions'].append((member_id, perm, member_type))
822
822
823 # on new entries validate users they exist and they are active !
823 # on new entries validate users they exist and they are active !
824 # this leaves feedback to the form
824 # this leaves feedback to the form
825 try:
825 try:
826 if member_type == 'user':
826 if member_type == 'user':
827 User.query()\
827 User.query()\
828 .filter(User.active == true())\
828 .filter(User.active == true())\
829 .filter(User.user_id == member_id).one()
829 .filter(User.user_id == member_id).one()
830 if member_type == 'user_group':
830 if member_type == 'user_group':
831 UserGroup.query()\
831 UserGroup.query()\
832 .filter(UserGroup.users_group_active == true())\
832 .filter(UserGroup.users_group_active == true())\
833 .filter(UserGroup.users_group_id == member_id)\
833 .filter(UserGroup.users_group_id == member_id)\
834 .one()
834 .one()
835
835
836 except Exception:
836 except Exception:
837 log.exception('Updated permission failed: org_exc:')
837 log.exception('Updated permission failed: org_exc:')
838 msg = M(self, 'perm_new_member_type', state)
838 msg = M(self, 'perm_new_member_type', state)
839 raise formencode.Invalid(
839 raise formencode.Invalid(
840 msg, value, state, error_dict={
840 msg, value, state, error_dict={
841 'perm_new_member_name': msg}
841 'perm_new_member_name': msg}
842 )
842 )
843 return value
843 return value
844 return _validator
844 return _validator
845
845
846
846
847 def ValidPath(localizer):
847 def ValidPath(localizer):
848 _ = localizer
848 _ = localizer
849
849
850 class _validator(formencode.validators.FancyValidator):
850 class _validator(formencode.validators.FancyValidator):
851 messages = {
851 messages = {
852 'invalid_path': _(u'This is not a valid path')
852 'invalid_path': _('This is not a valid path')
853 }
853 }
854
854
855 def validate_python(self, value, state):
855 def validate_python(self, value, state):
856 if not os.path.isdir(value):
856 if not os.path.isdir(value):
857 msg = M(self, 'invalid_path', state)
857 msg = M(self, 'invalid_path', state)
858 raise formencode.Invalid(
858 raise formencode.Invalid(
859 msg, value, state, error_dict={'paths_root_path': msg}
859 msg, value, state, error_dict={'paths_root_path': msg}
860 )
860 )
861 return _validator
861 return _validator
862
862
863
863
864 def UniqSystemEmail(localizer, old_data=None):
864 def UniqSystemEmail(localizer, old_data=None):
865 _ = localizer
865 _ = localizer
866 old_data = old_data or {}
866 old_data = old_data or {}
867
867
868 class _validator(formencode.validators.FancyValidator):
868 class _validator(formencode.validators.FancyValidator):
869 messages = {
869 messages = {
870 'email_taken': _(u'This e-mail address is already taken')
870 'email_taken': _('This e-mail address is already taken')
871 }
871 }
872
872
873 def _to_python(self, value, state):
873 def _to_python(self, value, state):
874 return value.lower()
874 return value.lower()
875
875
876 def validate_python(self, value, state):
876 def validate_python(self, value, state):
877 if (old_data.get('email') or '').lower() != value:
877 if (old_data.get('email') or '').lower() != value:
878 user = User.get_by_email(value, case_insensitive=True)
878 user = User.get_by_email(value, case_insensitive=True)
879 if user:
879 if user:
880 msg = M(self, 'email_taken', state)
880 msg = M(self, 'email_taken', state)
881 raise formencode.Invalid(
881 raise formencode.Invalid(
882 msg, value, state, error_dict={'email': msg}
882 msg, value, state, error_dict={'email': msg}
883 )
883 )
884 return _validator
884 return _validator
885
885
886
886
887 def ValidSystemEmail(localizer):
887 def ValidSystemEmail(localizer):
888 _ = localizer
888 _ = localizer
889
889
890 class _validator(formencode.validators.FancyValidator):
890 class _validator(formencode.validators.FancyValidator):
891 messages = {
891 messages = {
892 'non_existing_email': _(u'e-mail "%(email)s" does not exist.')
892 'non_existing_email': _('e-mail "%(email)s" does not exist.')
893 }
893 }
894
894
895 def _to_python(self, value, state):
895 def _to_python(self, value, state):
896 return value.lower()
896 return value.lower()
897
897
898 def validate_python(self, value, state):
898 def validate_python(self, value, state):
899 user = User.get_by_email(value, case_insensitive=True)
899 user = User.get_by_email(value, case_insensitive=True)
900 if user is None:
900 if user is None:
901 msg = M(self, 'non_existing_email', state, email=value)
901 msg = M(self, 'non_existing_email', state, email=value)
902 raise formencode.Invalid(
902 raise formencode.Invalid(
903 msg, value, state, error_dict={'email': msg}
903 msg, value, state, error_dict={'email': msg}
904 )
904 )
905 return _validator
905 return _validator
906
906
907
907
908 def NotReviewedRevisions(localizer, repo_id):
908 def NotReviewedRevisions(localizer, repo_id):
909 _ = localizer
909 _ = localizer
910 class _validator(formencode.validators.FancyValidator):
910 class _validator(formencode.validators.FancyValidator):
911 messages = {
911 messages = {
912 'rev_already_reviewed':
912 'rev_already_reviewed':
913 _(u'Revisions %(revs)s are already part of pull request '
913 _('Revisions %(revs)s are already part of pull request '
914 u'or have set status'),
914 'or have set status'),
915 }
915 }
916
916
917 def validate_python(self, value, state):
917 def validate_python(self, value, state):
918 # check revisions if they are not reviewed, or a part of another
918 # check revisions if they are not reviewed, or a part of another
919 # pull request
919 # pull request
920 statuses = ChangesetStatus.query()\
920 statuses = ChangesetStatus.query()\
921 .filter(ChangesetStatus.revision.in_(value))\
921 .filter(ChangesetStatus.revision.in_(value))\
922 .filter(ChangesetStatus.repo_id == repo_id)\
922 .filter(ChangesetStatus.repo_id == repo_id)\
923 .all()
923 .all()
924
924
925 errors = []
925 errors = []
926 for status in statuses:
926 for status in statuses:
927 if status.pull_request_id:
927 if status.pull_request_id:
928 errors.append(['pull_req', status.revision[:12]])
928 errors.append(['pull_req', status.revision[:12]])
929 elif status.status:
929 elif status.status:
930 errors.append(['status', status.revision[:12]])
930 errors.append(['status', status.revision[:12]])
931
931
932 if errors:
932 if errors:
933 revs = ','.join([x[1] for x in errors])
933 revs = ','.join([x[1] for x in errors])
934 msg = M(self, 'rev_already_reviewed', state, revs=revs)
934 msg = M(self, 'rev_already_reviewed', state, revs=revs)
935 raise formencode.Invalid(
935 raise formencode.Invalid(
936 msg, value, state, error_dict={'revisions': revs})
936 msg, value, state, error_dict={'revisions': revs})
937
937
938 return _validator
938 return _validator
939
939
940
940
941 def ValidIp(localizer):
941 def ValidIp(localizer):
942 _ = localizer
942 _ = localizer
943
943
944 class _validator(CIDR):
944 class _validator(CIDR):
945 messages = {
945 messages = {
946 'badFormat': _(u'Please enter a valid IPv4 or IpV6 address'),
946 'badFormat': _('Please enter a valid IPv4 or IpV6 address'),
947 'illegalBits': _(
947 'illegalBits': _(
948 u'The network size (bits) must be within the range '
948 'The network size (bits) must be within the range '
949 u'of 0-32 (not %(bits)r)'),
949 'of 0-32 (not %(bits)r)'),
950 }
950 }
951
951
952 # we ovveride the default to_python() call
952 # we ovveride the default to_python() call
953 def to_python(self, value, state):
953 def to_python(self, value, state):
954 v = super(_validator, self).to_python(value, state)
954 v = super(_validator, self).to_python(value, state)
955 v = safe_unicode(v.strip())
955 v = safe_unicode(v.strip())
956 net = ipaddress.ip_network(address=v, strict=False)
956 net = ipaddress.ip_network(address=v, strict=False)
957 return str(net)
957 return str(net)
958
958
959 def validate_python(self, value, state):
959 def validate_python(self, value, state):
960 try:
960 try:
961 addr = safe_unicode(value.strip())
961 addr = safe_unicode(value.strip())
962 # this raises an ValueError if address is not IpV4 or IpV6
962 # this raises an ValueError if address is not IpV4 or IpV6
963 ipaddress.ip_network(addr, strict=False)
963 ipaddress.ip_network(addr, strict=False)
964 except ValueError:
964 except ValueError:
965 raise formencode.Invalid(self.message('badFormat', state),
965 raise formencode.Invalid(self.message('badFormat', state),
966 value, state)
966 value, state)
967 return _validator
967 return _validator
968
968
969
969
970 def FieldKey(localizer):
970 def FieldKey(localizer):
971 _ = localizer
971 _ = localizer
972
972
973 class _validator(formencode.validators.FancyValidator):
973 class _validator(formencode.validators.FancyValidator):
974 messages = {
974 messages = {
975 'badFormat': _(
975 'badFormat': _(
976 u'Key name can only consist of letters, '
976 'Key name can only consist of letters, '
977 u'underscore, dash or numbers'),
977 'underscore, dash or numbers'),
978 }
978 }
979
979
980 def validate_python(self, value, state):
980 def validate_python(self, value, state):
981 if not re.match('[a-zA-Z0-9_-]+$', value):
981 if not re.match('[a-zA-Z0-9_-]+$', value):
982 raise formencode.Invalid(self.message('badFormat', state),
982 raise formencode.Invalid(self.message('badFormat', state),
983 value, state)
983 value, state)
984 return _validator
984 return _validator
985
985
986
986
987 def ValidAuthPlugins(localizer):
987 def ValidAuthPlugins(localizer):
988 _ = localizer
988 _ = localizer
989
989
990 class _validator(formencode.validators.FancyValidator):
990 class _validator(formencode.validators.FancyValidator):
991 messages = {
991 messages = {
992 'import_duplicate': _(
992 'import_duplicate': _(
993 u'Plugins %(loaded)s and %(next_to_load)s '
993 'Plugins %(loaded)s and %(next_to_load)s '
994 u'both export the same name'),
994 'both export the same name'),
995 'missing_includeme': _(
995 'missing_includeme': _(
996 u'The plugin "%(plugin_id)s" is missing an includeme '
996 'The plugin "%(plugin_id)s" is missing an includeme '
997 u'function.'),
997 'function.'),
998 'import_error': _(
998 'import_error': _(
999 u'Can not load plugin "%(plugin_id)s"'),
999 'Can not load plugin "%(plugin_id)s"'),
1000 'no_plugin': _(
1000 'no_plugin': _(
1001 u'No plugin available with ID "%(plugin_id)s"'),
1001 'No plugin available with ID "%(plugin_id)s"'),
1002 }
1002 }
1003
1003
1004 def _to_python(self, value, state):
1004 def _to_python(self, value, state):
1005 # filter empty values
1005 # filter empty values
1006 return filter(lambda s: s not in [None, ''], value)
1006 return filter(lambda s: s not in [None, ''], value)
1007
1007
1008 def _validate_legacy_plugin_id(self, plugin_id, value, state):
1008 def _validate_legacy_plugin_id(self, plugin_id, value, state):
1009 """
1009 """
1010 Validates that the plugin import works. It also checks that the
1010 Validates that the plugin import works. It also checks that the
1011 plugin has an includeme attribute.
1011 plugin has an includeme attribute.
1012 """
1012 """
1013 try:
1013 try:
1014 plugin = _import_legacy_plugin(plugin_id)
1014 plugin = _import_legacy_plugin(plugin_id)
1015 except Exception as e:
1015 except Exception as e:
1016 log.exception(
1016 log.exception(
1017 'Exception during import of auth legacy plugin "{}"'
1017 'Exception during import of auth legacy plugin "{}"'
1018 .format(plugin_id))
1018 .format(plugin_id))
1019 msg = M(self, 'import_error', state, plugin_id=plugin_id)
1019 msg = M(self, 'import_error', state, plugin_id=plugin_id)
1020 raise formencode.Invalid(msg, value, state)
1020 raise formencode.Invalid(msg, value, state)
1021
1021
1022 if not hasattr(plugin, 'includeme'):
1022 if not hasattr(plugin, 'includeme'):
1023 msg = M(self, 'missing_includeme', state, plugin_id=plugin_id)
1023 msg = M(self, 'missing_includeme', state, plugin_id=plugin_id)
1024 raise formencode.Invalid(msg, value, state)
1024 raise formencode.Invalid(msg, value, state)
1025
1025
1026 return plugin
1026 return plugin
1027
1027
1028 def _validate_plugin_id(self, plugin_id, value, state):
1028 def _validate_plugin_id(self, plugin_id, value, state):
1029 """
1029 """
1030 Plugins are already imported during app start up. Therefore this
1030 Plugins are already imported during app start up. Therefore this
1031 validation only retrieves the plugin from the plugin registry and
1031 validation only retrieves the plugin from the plugin registry and
1032 if it returns something not None everything is OK.
1032 if it returns something not None everything is OK.
1033 """
1033 """
1034 plugin = loadplugin(plugin_id)
1034 plugin = loadplugin(plugin_id)
1035
1035
1036 if plugin is None:
1036 if plugin is None:
1037 msg = M(self, 'no_plugin', state, plugin_id=plugin_id)
1037 msg = M(self, 'no_plugin', state, plugin_id=plugin_id)
1038 raise formencode.Invalid(msg, value, state)
1038 raise formencode.Invalid(msg, value, state)
1039
1039
1040 return plugin
1040 return plugin
1041
1041
1042 def validate_python(self, value, state):
1042 def validate_python(self, value, state):
1043 unique_names = {}
1043 unique_names = {}
1044 for plugin_id in value:
1044 for plugin_id in value:
1045
1045
1046 # Validate legacy or normal plugin.
1046 # Validate legacy or normal plugin.
1047 if plugin_id.startswith(legacy_plugin_prefix):
1047 if plugin_id.startswith(legacy_plugin_prefix):
1048 plugin = self._validate_legacy_plugin_id(
1048 plugin = self._validate_legacy_plugin_id(
1049 plugin_id, value, state)
1049 plugin_id, value, state)
1050 else:
1050 else:
1051 plugin = self._validate_plugin_id(plugin_id, value, state)
1051 plugin = self._validate_plugin_id(plugin_id, value, state)
1052
1052
1053 # Only allow unique plugin names.
1053 # Only allow unique plugin names.
1054 if plugin.name in unique_names:
1054 if plugin.name in unique_names:
1055 msg = M(self, 'import_duplicate', state,
1055 msg = M(self, 'import_duplicate', state,
1056 loaded=unique_names[plugin.name],
1056 loaded=unique_names[plugin.name],
1057 next_to_load=plugin)
1057 next_to_load=plugin)
1058 raise formencode.Invalid(msg, value, state)
1058 raise formencode.Invalid(msg, value, state)
1059 unique_names[plugin.name] = plugin
1059 unique_names[plugin.name] = plugin
1060 return _validator
1060 return _validator
1061
1061
1062
1062
1063 def ValidPattern(localizer):
1063 def ValidPattern(localizer):
1064 _ = localizer
1064 _ = localizer
1065
1065
1066 class _validator(formencode.validators.FancyValidator):
1066 class _validator(formencode.validators.FancyValidator):
1067 messages = {
1067 messages = {
1068 'bad_format': _(u'Url must start with http or /'),
1068 'bad_format': _('Url must start with http or /'),
1069 }
1069 }
1070
1070
1071 def _to_python(self, value, state):
1071 def _to_python(self, value, state):
1072 patterns = []
1072 patterns = []
1073
1073
1074 prefix = 'new_pattern'
1074 prefix = 'new_pattern'
1075 for name, v in value.items():
1075 for name, v in value.items():
1076 pattern_name = '_'.join((prefix, 'pattern'))
1076 pattern_name = '_'.join((prefix, 'pattern'))
1077 if name.startswith(pattern_name):
1077 if name.startswith(pattern_name):
1078 new_item_id = name[len(pattern_name)+1:]
1078 new_item_id = name[len(pattern_name)+1:]
1079
1079
1080 def _field(name):
1080 def _field(name):
1081 return '%s_%s_%s' % (prefix, name, new_item_id)
1081 return '%s_%s_%s' % (prefix, name, new_item_id)
1082
1082
1083 values = {
1083 values = {
1084 'issuetracker_pat': value.get(_field('pattern')),
1084 'issuetracker_pat': value.get(_field('pattern')),
1085 'issuetracker_url': value.get(_field('url')),
1085 'issuetracker_url': value.get(_field('url')),
1086 'issuetracker_pref': value.get(_field('prefix')),
1086 'issuetracker_pref': value.get(_field('prefix')),
1087 'issuetracker_desc': value.get(_field('description'))
1087 'issuetracker_desc': value.get(_field('description'))
1088 }
1088 }
1089 new_uid = md5(values['issuetracker_pat'])
1089 new_uid = md5(values['issuetracker_pat'])
1090
1090
1091 has_required_fields = (
1091 has_required_fields = (
1092 values['issuetracker_pat']
1092 values['issuetracker_pat']
1093 and values['issuetracker_url'])
1093 and values['issuetracker_url'])
1094
1094
1095 if has_required_fields:
1095 if has_required_fields:
1096 # validate url that it starts with http or /
1096 # validate url that it starts with http or /
1097 # otherwise it can lead to JS injections
1097 # otherwise it can lead to JS injections
1098 # e.g specifig javascript:<malicios code>
1098 # e.g specifig javascript:<malicios code>
1099 if not values['issuetracker_url'].startswith(('http', '/')):
1099 if not values['issuetracker_url'].startswith(('http', '/')):
1100 raise formencode.Invalid(
1100 raise formencode.Invalid(
1101 self.message('bad_format', state),
1101 self.message('bad_format', state),
1102 value, state)
1102 value, state)
1103
1103
1104 settings = [
1104 settings = [
1105 ('_'.join((key, new_uid)), values[key], 'unicode')
1105 ('_'.join((key, new_uid)), values[key], 'unicode')
1106 for key in values]
1106 for key in values]
1107 patterns.append(settings)
1107 patterns.append(settings)
1108
1108
1109 value['patterns'] = patterns
1109 value['patterns'] = patterns
1110 delete_patterns = value.get('uid') or []
1110 delete_patterns = value.get('uid') or []
1111 if not isinstance(delete_patterns, (list, tuple)):
1111 if not isinstance(delete_patterns, (list, tuple)):
1112 delete_patterns = [delete_patterns]
1112 delete_patterns = [delete_patterns]
1113 value['delete_patterns'] = delete_patterns
1113 value['delete_patterns'] = delete_patterns
1114 return value
1114 return value
1115 return _validator
1115 return _validator
@@ -1,398 +1,398 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 import io
20 import io
21 import shlex
21 import shlex
22
22
23 import math
23 import math
24 import re
24 import re
25 import os
25 import os
26 import datetime
26 import datetime
27 import logging
27 import logging
28 import queue
28 import queue
29 import subprocess
29 import subprocess
30
30
31
31
32 from dateutil.parser import parse
32 from dateutil.parser import parse
33 from pyramid.threadlocal import get_current_request
33 from pyramid.threadlocal import get_current_request
34 from pyramid.interfaces import IRoutesMapper
34 from pyramid.interfaces import IRoutesMapper
35 from pyramid.settings import asbool
35 from pyramid.settings import asbool
36 from pyramid.path import AssetResolver
36 from pyramid.path import AssetResolver
37 from threading import Thread
37 from threading import Thread
38
38
39 from rhodecode.config.jsroutes import generate_jsroutes_content
39 from rhodecode.config.jsroutes import generate_jsroutes_content
40 from rhodecode.lib.base import get_auth_user
40 from rhodecode.lib.base import get_auth_user
41
41
42 import rhodecode
42 import rhodecode
43
43
44
44
45 log = logging.getLogger(__name__)
45 log = logging.getLogger(__name__)
46
46
47
47
48 def add_renderer_globals(event):
48 def add_renderer_globals(event):
49 from rhodecode.lib import helpers
49 from rhodecode.lib import helpers
50
50
51 # TODO: When executed in pyramid view context the request is not available
51 # TODO: When executed in pyramid view context the request is not available
52 # in the event. Find a better solution to get the request.
52 # in the event. Find a better solution to get the request.
53 request = event['request'] or get_current_request()
53 request = event['request'] or get_current_request()
54
54
55 # Add Pyramid translation as '_' to context
55 # Add Pyramid translation as '_' to context
56 event['_'] = request.translate
56 event['_'] = request.translate
57 event['_ungettext'] = request.plularize
57 event['_ungettext'] = request.plularize
58 event['h'] = helpers
58 event['h'] = helpers
59
59
60
60
61 def set_user_lang(event):
61 def set_user_lang(event):
62 request = event.request
62 request = event.request
63 cur_user = getattr(request, 'user', None)
63 cur_user = getattr(request, 'user', None)
64
64
65 if cur_user:
65 if cur_user:
66 user_lang = cur_user.get_instance().user_data.get('language')
66 user_lang = cur_user.get_instance().user_data.get('language')
67 if user_lang:
67 if user_lang:
68 log.debug('lang: setting current user:%s language to: %s', cur_user, user_lang)
68 log.debug('lang: setting current user:%s language to: %s', cur_user, user_lang)
69 event.request._LOCALE_ = user_lang
69 event.request._LOCALE_ = user_lang
70
70
71
71
72 def update_celery_conf(event):
72 def update_celery_conf(event):
73 from rhodecode.lib.celerylib.loader import set_celery_conf
73 from rhodecode.lib.celerylib.loader import set_celery_conf
74 log.debug('Setting celery config from new request')
74 log.debug('Setting celery config from new request')
75 set_celery_conf(request=event.request, registry=event.request.registry)
75 set_celery_conf(request=event.request, registry=event.request.registry)
76
76
77
77
78 def add_request_user_context(event):
78 def add_request_user_context(event):
79 """
79 """
80 Adds auth user into request context
80 Adds auth user into request context
81 """
81 """
82 request = event.request
82 request = event.request
83 # access req_id as soon as possible
83 # access req_id as soon as possible
84 req_id = request.req_id
84 req_id = request.req_id
85
85
86 if hasattr(request, 'vcs_call'):
86 if hasattr(request, 'vcs_call'):
87 # skip vcs calls
87 # skip vcs calls
88 return
88 return
89
89
90 if hasattr(request, 'rpc_method'):
90 if hasattr(request, 'rpc_method'):
91 # skip api calls
91 # skip api calls
92 return
92 return
93
93
94 auth_user, auth_token = get_auth_user(request)
94 auth_user, auth_token = get_auth_user(request)
95 request.user = auth_user
95 request.user = auth_user
96 request.user_auth_token = auth_token
96 request.user_auth_token = auth_token
97 request.environ['rc_auth_user'] = auth_user
97 request.environ['rc_auth_user'] = auth_user
98 request.environ['rc_auth_user_id'] = auth_user.user_id
98 request.environ['rc_auth_user_id'] = auth_user.user_id
99 request.environ['rc_req_id'] = req_id
99 request.environ['rc_req_id'] = req_id
100
100
101
101
102 def reset_log_bucket(event):
102 def reset_log_bucket(event):
103 """
103 """
104 reset the log bucket on new request
104 reset the log bucket on new request
105 """
105 """
106 request = event.request
106 request = event.request
107 request.req_id_records_init()
107 request.req_id_records_init()
108
108
109
109
110 def scan_repositories_if_enabled(event):
110 def scan_repositories_if_enabled(event):
111 """
111 """
112 This is subscribed to the `pyramid.events.ApplicationCreated` event. It
112 This is subscribed to the `pyramid.events.ApplicationCreated` event. It
113 does a repository scan if enabled in the settings.
113 does a repository scan if enabled in the settings.
114 """
114 """
115 settings = event.app.registry.settings
115 settings = event.app.registry.settings
116 vcs_server_enabled = settings['vcs.server.enable']
116 vcs_server_enabled = settings['vcs.server.enable']
117 import_on_startup = settings['startup.import_repos']
117 import_on_startup = settings['startup.import_repos']
118 if vcs_server_enabled and import_on_startup:
118 if vcs_server_enabled and import_on_startup:
119 from rhodecode.model.scm import ScmModel
119 from rhodecode.model.scm import ScmModel
120 from rhodecode.lib.utils import repo2db_mapper, get_rhodecode_base_path
120 from rhodecode.lib.utils import repo2db_mapper, get_rhodecode_base_path
121 repositories = ScmModel().repo_scan(get_rhodecode_base_path())
121 repositories = ScmModel().repo_scan(get_rhodecode_base_path())
122 repo2db_mapper(repositories, remove_obsolete=False)
122 repo2db_mapper(repositories, remove_obsolete=False)
123
123
124
124
125 def write_metadata_if_needed(event):
125 def write_metadata_if_needed(event):
126 """
126 """
127 Writes upgrade metadata
127 Writes upgrade metadata
128 """
128 """
129 import rhodecode
129 import rhodecode
130 from rhodecode.lib import system_info
130 from rhodecode.lib import system_info
131 from rhodecode.lib import ext_json
131 from rhodecode.lib import ext_json
132
132
133 fname = '.rcmetadata.json'
133 fname = '.rcmetadata.json'
134 ini_loc = os.path.dirname(rhodecode.CONFIG.get('__file__'))
134 ini_loc = os.path.dirname(rhodecode.CONFIG.get('__file__'))
135 metadata_destination = os.path.join(ini_loc, fname)
135 metadata_destination = os.path.join(ini_loc, fname)
136
136
137 def get_update_age():
137 def get_update_age():
138 now = datetime.datetime.utcnow()
138 now = datetime.datetime.utcnow()
139
139
140 with open(metadata_destination, 'rb') as f:
140 with open(metadata_destination, 'rb') as f:
141 data = ext_json.json.loads(f.read())
141 data = ext_json.json.loads(f.read())
142 if 'created_on' in data:
142 if 'created_on' in data:
143 update_date = parse(data['created_on'])
143 update_date = parse(data['created_on'])
144 diff = now - update_date
144 diff = now - update_date
145 return diff.total_seconds() / 60.0
145 return diff.total_seconds() / 60.0
146
146
147 return 0
147 return 0
148
148
149 def write():
149 def write():
150 configuration = system_info.SysInfo(
150 configuration = system_info.SysInfo(
151 system_info.rhodecode_config)()['value']
151 system_info.rhodecode_config)()['value']
152 license_token = configuration['config']['license_token']
152 license_token = configuration['config']['license_token']
153
153
154 setup = dict(
154 setup = dict(
155 workers=configuration['config']['server:main'].get(
155 workers=configuration['config']['server:main'].get(
156 'workers', '?'),
156 'workers', '?'),
157 worker_type=configuration['config']['server:main'].get(
157 worker_type=configuration['config']['server:main'].get(
158 'worker_class', 'sync'),
158 'worker_class', 'sync'),
159 )
159 )
160 dbinfo = system_info.SysInfo(system_info.database_info)()['value']
160 dbinfo = system_info.SysInfo(system_info.database_info)()['value']
161 del dbinfo['url']
161 del dbinfo['url']
162
162
163 metadata = dict(
163 metadata = dict(
164 desc='upgrade metadata info',
164 desc='upgrade metadata info',
165 license_token=license_token,
165 license_token=license_token,
166 created_on=datetime.datetime.utcnow().isoformat(),
166 created_on=datetime.datetime.utcnow().isoformat(),
167 usage=system_info.SysInfo(system_info.usage_info)()['value'],
167 usage=system_info.SysInfo(system_info.usage_info)()['value'],
168 platform=system_info.SysInfo(system_info.platform_type)()['value'],
168 platform=system_info.SysInfo(system_info.platform_type)()['value'],
169 database=dbinfo,
169 database=dbinfo,
170 cpu=system_info.SysInfo(system_info.cpu)()['value'],
170 cpu=system_info.SysInfo(system_info.cpu)()['value'],
171 memory=system_info.SysInfo(system_info.memory)()['value'],
171 memory=system_info.SysInfo(system_info.memory)()['value'],
172 setup=setup
172 setup=setup
173 )
173 )
174
174
175 with open(metadata_destination, 'wb') as f:
175 with open(metadata_destination, 'wb') as f:
176 f.write(ext_json.json.dumps(metadata))
176 f.write(ext_json.json.dumps(metadata))
177
177
178 settings = event.app.registry.settings
178 settings = event.app.registry.settings
179 if settings.get('metadata.skip'):
179 if settings.get('metadata.skip'):
180 return
180 return
181
181
182 # only write this every 24h, workers restart caused unwanted delays
182 # only write this every 24h, workers restart caused unwanted delays
183 try:
183 try:
184 age_in_min = get_update_age()
184 age_in_min = get_update_age()
185 except Exception:
185 except Exception:
186 age_in_min = 0
186 age_in_min = 0
187
187
188 if age_in_min > 60 * 60 * 24:
188 if age_in_min > 60 * 60 * 24:
189 return
189 return
190
190
191 try:
191 try:
192 write()
192 write()
193 except Exception:
193 except Exception:
194 pass
194 pass
195
195
196
196
197 def write_usage_data(event):
197 def write_usage_data(event):
198 import rhodecode
198 import rhodecode
199 from rhodecode.lib import system_info
199 from rhodecode.lib import system_info
200 from rhodecode.lib import ext_json
200 from rhodecode.lib import ext_json
201
201
202 settings = event.app.registry.settings
202 settings = event.app.registry.settings
203 instance_tag = settings.get('metadata.write_usage_tag')
203 instance_tag = settings.get('metadata.write_usage_tag')
204 if not settings.get('metadata.write_usage'):
204 if not settings.get('metadata.write_usage'):
205 return
205 return
206
206
207 def get_update_age(dest_file):
207 def get_update_age(dest_file):
208 now = datetime.datetime.utcnow()
208 now = datetime.datetime.utcnow()
209
209
210 with open(dest_file, 'rb') as f:
210 with open(dest_file, 'rb') as f:
211 data = ext_json.json.loads(f.read())
211 data = ext_json.json.loads(f.read())
212 if 'created_on' in data:
212 if 'created_on' in data:
213 update_date = parse(data['created_on'])
213 update_date = parse(data['created_on'])
214 diff = now - update_date
214 diff = now - update_date
215 return math.ceil(diff.total_seconds() / 60.0)
215 return math.ceil(diff.total_seconds() / 60.0)
216
216
217 return 0
217 return 0
218
218
219 utc_date = datetime.datetime.utcnow()
219 utc_date = datetime.datetime.utcnow()
220 hour_quarter = int(math.ceil((utc_date.hour + utc_date.minute/60.0) / 6.))
220 hour_quarter = int(math.ceil((utc_date.hour + utc_date.minute/60.0) / 6.))
221 fname = '.rc_usage_{date.year}{date.month:02d}{date.day:02d}_{hour}.json'.format(
221 fname = '.rc_usage_{date.year}{date.month:02d}{date.day:02d}_{hour}.json'.format(
222 date=utc_date, hour=hour_quarter)
222 date=utc_date, hour=hour_quarter)
223 ini_loc = os.path.dirname(rhodecode.CONFIG.get('__file__'))
223 ini_loc = os.path.dirname(rhodecode.CONFIG.get('__file__'))
224
224
225 usage_dir = os.path.join(ini_loc, '.rcusage')
225 usage_dir = os.path.join(ini_loc, '.rcusage')
226 if not os.path.isdir(usage_dir):
226 if not os.path.isdir(usage_dir):
227 os.makedirs(usage_dir)
227 os.makedirs(usage_dir)
228 usage_metadata_destination = os.path.join(usage_dir, fname)
228 usage_metadata_destination = os.path.join(usage_dir, fname)
229
229
230 try:
230 try:
231 age_in_min = get_update_age(usage_metadata_destination)
231 age_in_min = get_update_age(usage_metadata_destination)
232 except Exception:
232 except Exception:
233 age_in_min = 0
233 age_in_min = 0
234
234
235 # write every 6th hour
235 # write every 6th hour
236 if age_in_min and age_in_min < 60 * 6:
236 if age_in_min and age_in_min < 60 * 6:
237 log.debug('Usage file created %s minutes ago, skipping (threshold: %s minutes)...',
237 log.debug('Usage file created %s minutes ago, skipping (threshold: %s minutes)...',
238 age_in_min, 60 * 6)
238 age_in_min, 60 * 6)
239 return
239 return
240
240
241 def write(dest_file):
241 def write(dest_file):
242 configuration = system_info.SysInfo(system_info.rhodecode_config)()['value']
242 configuration = system_info.SysInfo(system_info.rhodecode_config)()['value']
243 license_token = configuration['config']['license_token']
243 license_token = configuration['config']['license_token']
244
244
245 metadata = dict(
245 metadata = dict(
246 desc='Usage data',
246 desc='Usage data',
247 instance_tag=instance_tag,
247 instance_tag=instance_tag,
248 license_token=license_token,
248 license_token=license_token,
249 created_on=datetime.datetime.utcnow().isoformat(),
249 created_on=datetime.datetime.utcnow().isoformat(),
250 usage=system_info.SysInfo(system_info.usage_info)()['value'],
250 usage=system_info.SysInfo(system_info.usage_info)()['value'],
251 )
251 )
252
252
253 with open(dest_file, 'wb') as f:
253 with open(dest_file, 'wb') as f:
254 f.write(ext_json.json.dumps(metadata, indent=2, sort_keys=True))
254 f.write(ext_json.json.dumps(metadata, indent=2, sort_keys=True))
255
255
256 try:
256 try:
257 log.debug('Writing usage file at: %s', usage_metadata_destination)
257 log.debug('Writing usage file at: %s', usage_metadata_destination)
258 write(usage_metadata_destination)
258 write(usage_metadata_destination)
259 except Exception:
259 except Exception:
260 pass
260 pass
261
261
262
262
263 def write_js_routes_if_enabled(event):
263 def write_js_routes_if_enabled(event):
264 registry = event.app.registry
264 registry = event.app.registry
265
265
266 mapper = registry.queryUtility(IRoutesMapper)
266 mapper = registry.queryUtility(IRoutesMapper)
267 _argument_prog = re.compile('\{(.*?)\}|:\((.*)\)')
267 _argument_prog = re.compile(r'\{(.*?)\}|:\((.*)\)')
268
268
269 def _extract_route_information(route):
269 def _extract_route_information(route):
270 """
270 """
271 Convert a route into tuple(name, path, args), eg:
271 Convert a route into tuple(name, path, args), eg:
272 ('show_user', '/profile/%(username)s', ['username'])
272 ('show_user', '/profile/%(username)s', ['username'])
273 """
273 """
274
274
275 routepath = route.pattern
275 routepath = route.pattern
276 pattern = route.pattern
276 pattern = route.pattern
277
277
278 def replace(matchobj):
278 def replace(matchobj):
279 if matchobj.group(1):
279 if matchobj.group(1):
280 return "%%(%s)s" % matchobj.group(1).split(':')[0]
280 return "%%(%s)s" % matchobj.group(1).split(':')[0]
281 else:
281 else:
282 return "%%(%s)s" % matchobj.group(2)
282 return "%%(%s)s" % matchobj.group(2)
283
283
284 routepath = _argument_prog.sub(replace, routepath)
284 routepath = _argument_prog.sub(replace, routepath)
285
285
286 if not routepath.startswith('/'):
286 if not routepath.startswith('/'):
287 routepath = '/'+routepath
287 routepath = '/'+routepath
288
288
289 return (
289 return (
290 route.name,
290 route.name,
291 routepath,
291 routepath,
292 [(arg[0].split(':')[0] if arg[0] != '' else arg[1])
292 [(arg[0].split(':')[0] if arg[0] != '' else arg[1])
293 for arg in _argument_prog.findall(pattern)]
293 for arg in _argument_prog.findall(pattern)]
294 )
294 )
295
295
296 def get_routes():
296 def get_routes():
297 # pyramid routes
297 # pyramid routes
298 for route in mapper.get_routes():
298 for route in mapper.get_routes():
299 if not route.name.startswith('__'):
299 if not route.name.startswith('__'):
300 yield _extract_route_information(route)
300 yield _extract_route_information(route)
301
301
302 if asbool(registry.settings.get('generate_js_files', 'false')):
302 if asbool(registry.settings.get('generate_js_files', 'false')):
303 static_path = AssetResolver().resolve('rhodecode:public').abspath()
303 static_path = AssetResolver().resolve('rhodecode:public').abspath()
304 jsroutes = get_routes()
304 jsroutes = get_routes()
305 jsroutes_file_content = generate_jsroutes_content(jsroutes)
305 jsroutes_file_content = generate_jsroutes_content(jsroutes)
306 jsroutes_file_path = os.path.join(
306 jsroutes_file_path = os.path.join(
307 static_path, 'js', 'rhodecode', 'routes.js')
307 static_path, 'js', 'rhodecode', 'routes.js')
308
308
309 try:
309 try:
310 with io.open(jsroutes_file_path, 'w', encoding='utf-8') as f:
310 with io.open(jsroutes_file_path, 'w', encoding='utf-8') as f:
311 f.write(jsroutes_file_content)
311 f.write(jsroutes_file_content)
312 except Exception:
312 except Exception:
313 log.exception('Failed to write routes.js into %s', jsroutes_file_path)
313 log.exception('Failed to write routes.js into %s', jsroutes_file_path)
314
314
315
315
316 class Subscriber(object):
316 class Subscriber(object):
317 """
317 """
318 Base class for subscribers to the pyramid event system.
318 Base class for subscribers to the pyramid event system.
319 """
319 """
320 def __call__(self, event):
320 def __call__(self, event):
321 self.run(event)
321 self.run(event)
322
322
323 def run(self, event):
323 def run(self, event):
324 raise NotImplementedError('Subclass has to implement this.')
324 raise NotImplementedError('Subclass has to implement this.')
325
325
326
326
327 class AsyncSubscriber(Subscriber):
327 class AsyncSubscriber(Subscriber):
328 """
328 """
329 Subscriber that handles the execution of events in a separate task to not
329 Subscriber that handles the execution of events in a separate task to not
330 block the execution of the code which triggers the event. It puts the
330 block the execution of the code which triggers the event. It puts the
331 received events into a queue from which the worker process takes them in
331 received events into a queue from which the worker process takes them in
332 order.
332 order.
333 """
333 """
334 def __init__(self):
334 def __init__(self):
335 self._stop = False
335 self._stop = False
336 self._eventq = queue.Queue()
336 self._eventq = queue.Queue()
337 self._worker = self.create_worker()
337 self._worker = self.create_worker()
338 self._worker.start()
338 self._worker.start()
339
339
340 def __call__(self, event):
340 def __call__(self, event):
341 self._eventq.put(event)
341 self._eventq.put(event)
342
342
343 def create_worker(self):
343 def create_worker(self):
344 worker = Thread(target=self.do_work)
344 worker = Thread(target=self.do_work)
345 worker.daemon = True
345 worker.daemon = True
346 return worker
346 return worker
347
347
348 def stop_worker(self):
348 def stop_worker(self):
349 self._stop = False
349 self._stop = False
350 self._eventq.put(None)
350 self._eventq.put(None)
351 self._worker.join()
351 self._worker.join()
352
352
353 def do_work(self):
353 def do_work(self):
354 while not self._stop:
354 while not self._stop:
355 event = self._eventq.get()
355 event = self._eventq.get()
356 if event is not None:
356 if event is not None:
357 self.run(event)
357 self.run(event)
358
358
359
359
360 class AsyncSubprocessSubscriber(AsyncSubscriber):
360 class AsyncSubprocessSubscriber(AsyncSubscriber):
361 """
361 """
362 Subscriber that uses the subprocess module to execute a command if an
362 Subscriber that uses the subprocess module to execute a command if an
363 event is received. Events are handled asynchronously::
363 event is received. Events are handled asynchronously::
364
364
365 subscriber = AsyncSubprocessSubscriber('ls -la', timeout=10)
365 subscriber = AsyncSubprocessSubscriber('ls -la', timeout=10)
366 subscriber(dummyEvent) # running __call__(event)
366 subscriber(dummyEvent) # running __call__(event)
367
367
368 """
368 """
369
369
370 def __init__(self, cmd, timeout=None):
370 def __init__(self, cmd, timeout=None):
371 if not isinstance(cmd, (list, tuple)):
371 if not isinstance(cmd, (list, tuple)):
372 cmd = shlex.split(cmd)
372 cmd = shlex.split(cmd)
373 super(AsyncSubprocessSubscriber, self).__init__()
373 super(AsyncSubprocessSubscriber, self).__init__()
374 self._cmd = cmd
374 self._cmd = cmd
375 self._timeout = timeout
375 self._timeout = timeout
376
376
377 def run(self, event):
377 def run(self, event):
378 cmd = self._cmd
378 cmd = self._cmd
379 timeout = self._timeout
379 timeout = self._timeout
380 log.debug('Executing command %s.', cmd)
380 log.debug('Executing command %s.', cmd)
381
381
382 try:
382 try:
383 output = subprocess.check_output(
383 output = subprocess.check_output(
384 cmd, timeout=timeout, stderr=subprocess.STDOUT)
384 cmd, timeout=timeout, stderr=subprocess.STDOUT)
385 log.debug('Command finished %s', cmd)
385 log.debug('Command finished %s', cmd)
386 if output:
386 if output:
387 log.debug('Command output: %s', output)
387 log.debug('Command output: %s', output)
388 except subprocess.TimeoutExpired as e:
388 except subprocess.TimeoutExpired as e:
389 log.exception('Timeout while executing command.')
389 log.exception('Timeout while executing command.')
390 if e.output:
390 if e.output:
391 log.error('Command output: %s', e.output)
391 log.error('Command output: %s', e.output)
392 except subprocess.CalledProcessError as e:
392 except subprocess.CalledProcessError as e:
393 log.exception('Error while executing command.')
393 log.exception('Error while executing command.')
394 if e.output:
394 if e.output:
395 log.error('Command output: %s', e.output)
395 log.error('Command output: %s', e.output)
396 except Exception:
396 except Exception:
397 log.exception(
397 log.exception(
398 'Exception while executing command %s.', cmd)
398 'Exception while executing command %s.', cmd)
@@ -1,1404 +1,1404 b''
1 <%namespace name="base" file="/base/base.mako"/>
1 <%namespace name="base" file="/base/base.mako"/>
2 <%namespace name="commentblock" file="/changeset/changeset_file_comment.mako"/>
2 <%namespace name="commentblock" file="/changeset/changeset_file_comment.mako"/>
3
3
4 <%def name="diff_line_anchor(commit, filename, line, type)"><%
4 <%def name="diff_line_anchor(commit, filename, line, type)"><%
5 return '%s_%s_%i' % (h.md5_safe(commit+filename), type, line)
5 return '%s_%s_%i' % (h.md5_safe(commit+filename), type, line)
6 %></%def>
6 %></%def>
7
7
8 <%def name="action_class(action)">
8 <%def name="action_class(action)">
9 <%
9 <%
10 return {
10 return {
11 '-': 'cb-deletion',
11 '-': 'cb-deletion',
12 '+': 'cb-addition',
12 '+': 'cb-addition',
13 ' ': 'cb-context',
13 ' ': 'cb-context',
14 }.get(action, 'cb-empty')
14 }.get(action, 'cb-empty')
15 %>
15 %>
16 </%def>
16 </%def>
17
17
18 <%def name="op_class(op_id)">
18 <%def name="op_class(op_id)">
19 <%
19 <%
20 return {
20 return {
21 DEL_FILENODE: 'deletion', # file deleted
21 DEL_FILENODE: 'deletion', # file deleted
22 BIN_FILENODE: 'warning' # binary diff hidden
22 BIN_FILENODE: 'warning' # binary diff hidden
23 }.get(op_id, 'addition')
23 }.get(op_id, 'addition')
24 %>
24 %>
25 </%def>
25 </%def>
26
26
27
27
28
28
29 <%def name="render_diffset(diffset, commit=None,
29 <%def name="render_diffset(diffset, commit=None,
30
30
31 # collapse all file diff entries when there are more than this amount of files in the diff
31 # collapse all file diff entries when there are more than this amount of files in the diff
32 collapse_when_files_over=20,
32 collapse_when_files_over=20,
33
33
34 # collapse lines in the diff when more than this amount of lines changed in the file diff
34 # collapse lines in the diff when more than this amount of lines changed in the file diff
35 lines_changed_limit=500,
35 lines_changed_limit=500,
36
36
37 # add a ruler at to the output
37 # add a ruler at to the output
38 ruler_at_chars=0,
38 ruler_at_chars=0,
39
39
40 # show inline comments
40 # show inline comments
41 use_comments=False,
41 use_comments=False,
42
42
43 # disable new comments
43 # disable new comments
44 disable_new_comments=False,
44 disable_new_comments=False,
45
45
46 # special file-comments that were deleted in previous versions
46 # special file-comments that were deleted in previous versions
47 # it's used for showing outdated comments for deleted files in a PR
47 # it's used for showing outdated comments for deleted files in a PR
48 deleted_files_comments=None,
48 deleted_files_comments=None,
49
49
50 # for cache purpose
50 # for cache purpose
51 inline_comments=None,
51 inline_comments=None,
52
52
53 # additional menu for PRs
53 # additional menu for PRs
54 pull_request_menu=None,
54 pull_request_menu=None,
55
55
56 # show/hide todo next to comments
56 # show/hide todo next to comments
57 show_todos=True,
57 show_todos=True,
58
58
59 )">
59 )">
60
60
61 <%
61 <%
62 diffset_container_id = h.md5(diffset.target_ref)
62 diffset_container_id = h.md5(diffset.target_ref)
63 collapse_all = len(diffset.files) > collapse_when_files_over
63 collapse_all = len(diffset.files) > collapse_when_files_over
64 active_pattern_entries = h.get_active_pattern_entries(getattr(c, 'repo_name', None))
64 active_pattern_entries = h.get_active_pattern_entries(getattr(c, 'repo_name', None))
65 from rhodecode.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
65 from rhodecode.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
66 MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE, COPIED_FILENODE
66 MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE, COPIED_FILENODE
67 %>
67 %>
68
68
69 %if use_comments:
69 %if use_comments:
70
70
71 ## Template for injecting comments
71 ## Template for injecting comments
72 <div id="cb-comments-inline-container-template" class="js-template">
72 <div id="cb-comments-inline-container-template" class="js-template">
73 ${inline_comments_container([])}
73 ${inline_comments_container([])}
74 </div>
74 </div>
75
75
76 <div class="js-template" id="cb-comment-inline-form-template">
76 <div class="js-template" id="cb-comment-inline-form-template">
77 <div class="comment-inline-form ac">
77 <div class="comment-inline-form ac">
78 %if not c.rhodecode_user.is_default:
78 %if not c.rhodecode_user.is_default:
79 ## render template for inline comments
79 ## render template for inline comments
80 ${commentblock.comment_form(form_type='inline')}
80 ${commentblock.comment_form(form_type='inline')}
81 %endif
81 %endif
82 </div>
82 </div>
83 </div>
83 </div>
84
84
85 %endif
85 %endif
86
86
87 %if c.user_session_attrs["diffmode"] == 'sideside':
87 %if c.user_session_attrs["diffmode"] == 'sideside':
88 <style>
88 <style>
89 .wrapper {
89 .wrapper {
90 max-width: 1600px !important;
90 max-width: 1600px !important;
91 }
91 }
92 </style>
92 </style>
93 %endif
93 %endif
94
94
95 %if ruler_at_chars:
95 %if ruler_at_chars:
96 <style>
96 <style>
97 .diff table.cb .cb-content:after {
97 .diff table.cb .cb-content:after {
98 content: "";
98 content: "";
99 border-left: 1px solid blue;
99 border-left: 1px solid blue;
100 position: absolute;
100 position: absolute;
101 top: 0;
101 top: 0;
102 height: 18px;
102 height: 18px;
103 opacity: .2;
103 opacity: .2;
104 z-index: 10;
104 z-index: 10;
105 //## +5 to account for diff action (+/-)
105 //## +5 to account for diff action (+/-)
106 left: ${ruler_at_chars + 5}ch;
106 left: ${ruler_at_chars + 5}ch;
107 </style>
107 </style>
108 %endif
108 %endif
109
109
110 <div class="diffset ${disable_new_comments and 'diffset-comments-disabled'}">
110 <div class="diffset ${disable_new_comments and 'diffset-comments-disabled'}">
111
111
112 <div style="height: 20px; line-height: 20px">
112 <div style="height: 20px; line-height: 20px">
113 ## expand/collapse action
113 ## expand/collapse action
114 <div class="pull-left">
114 <div class="pull-left">
115 <a class="${'collapsed' if collapse_all else ''}" href="#expand-files" onclick="toggleExpand(this, '${diffset_container_id}'); return false">
115 <a class="${'collapsed' if collapse_all else ''}" href="#expand-files" onclick="toggleExpand(this, '${diffset_container_id}'); return false">
116 % if collapse_all:
116 % if collapse_all:
117 <i class="icon-plus-squared-alt icon-no-margin"></i>${_('Expand all files')}
117 <i class="icon-plus-squared-alt icon-no-margin"></i>${_('Expand all files')}
118 % else:
118 % else:
119 <i class="icon-minus-squared-alt icon-no-margin"></i>${_('Collapse all files')}
119 <i class="icon-minus-squared-alt icon-no-margin"></i>${_('Collapse all files')}
120 % endif
120 % endif
121 </a>
121 </a>
122
122
123 </div>
123 </div>
124
124
125 ## todos
125 ## todos
126 % if show_todos and getattr(c, 'at_version', None):
126 % if show_todos and getattr(c, 'at_version', None):
127 <div class="pull-right">
127 <div class="pull-right">
128 <i class="icon-flag-filled" style="color: #949494">TODOs:</i>
128 <i class="icon-flag-filled" style="color: #949494">TODOs:</i>
129 ${_('not available in this view')}
129 ${_('not available in this view')}
130 </div>
130 </div>
131 % elif show_todos:
131 % elif show_todos:
132 <div class="pull-right">
132 <div class="pull-right">
133 <div class="comments-number" style="padding-left: 10px">
133 <div class="comments-number" style="padding-left: 10px">
134 % if hasattr(c, 'unresolved_comments') and hasattr(c, 'resolved_comments'):
134 % if hasattr(c, 'unresolved_comments') and hasattr(c, 'resolved_comments'):
135 <i class="icon-flag-filled" style="color: #949494">TODOs:</i>
135 <i class="icon-flag-filled" style="color: #949494">TODOs:</i>
136 % if c.unresolved_comments:
136 % if c.unresolved_comments:
137 <a href="#show-todos" onclick="$('#todo-box').toggle(); return false">
137 <a href="#show-todos" onclick="$('#todo-box').toggle(); return false">
138 ${_('{} unresolved').format(len(c.unresolved_comments))}
138 ${_('{} unresolved').format(len(c.unresolved_comments))}
139 </a>
139 </a>
140 % else:
140 % else:
141 ${_('0 unresolved')}
141 ${_('0 unresolved')}
142 % endif
142 % endif
143
143
144 ${_('{} Resolved').format(len(c.resolved_comments))}
144 ${_('{} Resolved').format(len(c.resolved_comments))}
145 % endif
145 % endif
146 </div>
146 </div>
147 </div>
147 </div>
148 % endif
148 % endif
149
149
150 ## ## comments
150 ## ## comments
151 ## <div class="pull-right">
151 ## <div class="pull-right">
152 ## <div class="comments-number" style="padding-left: 10px">
152 ## <div class="comments-number" style="padding-left: 10px">
153 ## % if hasattr(c, 'comments') and hasattr(c, 'inline_cnt'):
153 ## % if hasattr(c, 'comments') and hasattr(c, 'inline_cnt'):
154 ## <i class="icon-comment" style="color: #949494">COMMENTS:</i>
154 ## <i class="icon-comment" style="color: #949494">COMMENTS:</i>
155 ## % if c.comments:
155 ## % if c.comments:
156 ## <a href="#comments">${_ungettext("{} General", "{} General", len(c.comments)).format(len(c.comments))}</a>,
156 ## <a href="#comments">${_ungettext("{} General", "{} General", len(c.comments)).format(len(c.comments))}</a>,
157 ## % else:
157 ## % else:
158 ## ${_('0 General')}
158 ## ${_('0 General')}
159 ## % endif
159 ## % endif
160 ##
160 ##
161 ## % if c.inline_cnt:
161 ## % if c.inline_cnt:
162 ## <a href="#" onclick="return Rhodecode.comments.nextComment();"
162 ## <a href="#" onclick="return Rhodecode.comments.nextComment();"
163 ## id="inline-comments-counter">${_ungettext("{} Inline", "{} Inline", c.inline_cnt).format(c.inline_cnt)}
163 ## id="inline-comments-counter">${_ungettext("{} Inline", "{} Inline", c.inline_cnt).format(c.inline_cnt)}
164 ## </a>
164 ## </a>
165 ## % else:
165 ## % else:
166 ## ${_('0 Inline')}
166 ## ${_('0 Inline')}
167 ## % endif
167 ## % endif
168 ## % endif
168 ## % endif
169 ##
169 ##
170 ## % if pull_request_menu:
170 ## % if pull_request_menu:
171 ## <%
171 ## <%
172 ## outdated_comm_count_ver = pull_request_menu['outdated_comm_count_ver']
172 ## outdated_comm_count_ver = pull_request_menu['outdated_comm_count_ver']
173 ## %>
173 ## %>
174 ##
174 ##
175 ## % if outdated_comm_count_ver:
175 ## % if outdated_comm_count_ver:
176 ## <a href="#" onclick="showOutdated(); Rhodecode.comments.nextOutdatedComment(); return false;">
176 ## <a href="#" onclick="showOutdated(); Rhodecode.comments.nextOutdatedComment(); return false;">
177 ## (${_("{} Outdated").format(outdated_comm_count_ver)})
177 ## (${_("{} Outdated").format(outdated_comm_count_ver)})
178 ## </a>
178 ## </a>
179 ## <a href="#" class="showOutdatedComments" onclick="showOutdated(this); return false;"> | ${_('show outdated')}</a>
179 ## <a href="#" class="showOutdatedComments" onclick="showOutdated(this); return false;"> | ${_('show outdated')}</a>
180 ## <a href="#" class="hideOutdatedComments" style="display: none" onclick="hideOutdated(this); return false;"> | ${_('hide outdated')}</a>
180 ## <a href="#" class="hideOutdatedComments" style="display: none" onclick="hideOutdated(this); return false;"> | ${_('hide outdated')}</a>
181 ## % else:
181 ## % else:
182 ## (${_("{} Outdated").format(outdated_comm_count_ver)})
182 ## (${_("{} Outdated").format(outdated_comm_count_ver)})
183 ## % endif
183 ## % endif
184 ##
184 ##
185 ## % endif
185 ## % endif
186 ##
186 ##
187 ## </div>
187 ## </div>
188 ## </div>
188 ## </div>
189
189
190 </div>
190 </div>
191
191
192 % if diffset.limited_diff:
192 % if diffset.limited_diff:
193 <div class="diffset-heading ${(diffset.limited_diff and 'diffset-heading-warning' or '')}">
193 <div class="diffset-heading ${(diffset.limited_diff and 'diffset-heading-warning' or '')}">
194 <h2 class="clearinner">
194 <h2 class="clearinner">
195 ${_('The requested changes are too big and content was truncated.')}
195 ${_('The requested changes are too big and content was truncated.')}
196 <a href="${h.current_route_path(request, fulldiff=1)}" onclick="return confirm('${_("Showing a big diff might take some time and resources, continue?")}')">${_('Show full diff')}</a>
196 <a href="${h.current_route_path(request, fulldiff=1)}" onclick="return confirm('${_("Showing a big diff might take some time and resources, continue?")}')">${_('Show full diff')}</a>
197 </h2>
197 </h2>
198 </div>
198 </div>
199 % endif
199 % endif
200
200
201 <div id="todo-box">
201 <div id="todo-box">
202 % if hasattr(c, 'unresolved_comments') and c.unresolved_comments:
202 % if hasattr(c, 'unresolved_comments') and c.unresolved_comments:
203 % for co in c.unresolved_comments:
203 % for co in c.unresolved_comments:
204 <a class="permalink" href="#comment-${co.comment_id}"
204 <a class="permalink" href="#comment-${co.comment_id}"
205 onclick="Rhodecode.comments.scrollToComment($('#comment-${co.comment_id}'))">
205 onclick="Rhodecode.comments.scrollToComment($('#comment-${co.comment_id}'))">
206 <i class="icon-flag-filled-red"></i>
206 <i class="icon-flag-filled-red"></i>
207 ${co.comment_id}</a>${('' if loop.last else ',')}
207 ${co.comment_id}</a>${('' if loop.last else ',')}
208 % endfor
208 % endfor
209 % endif
209 % endif
210 </div>
210 </div>
211 %if diffset.has_hidden_changes:
211 %if diffset.has_hidden_changes:
212 <p class="empty_data">${_('Some changes may be hidden')}</p>
212 <p class="empty_data">${_('Some changes may be hidden')}</p>
213 %elif not diffset.files:
213 %elif not diffset.files:
214 <p class="empty_data">${_('No files')}</p>
214 <p class="empty_data">${_('No files')}</p>
215 %endif
215 %endif
216
216
217 <div class="filediffs">
217 <div class="filediffs">
218
218
219 ## initial value could be marked as False later on
219 ## initial value could be marked as False later on
220 <% over_lines_changed_limit = False %>
220 <% over_lines_changed_limit = False %>
221 %for i, filediff in enumerate(diffset.files):
221 %for i, filediff in enumerate(diffset.files):
222
222
223 %if filediff.source_file_path and filediff.target_file_path:
223 %if filediff.source_file_path and filediff.target_file_path:
224 %if filediff.source_file_path != filediff.target_file_path:
224 %if filediff.source_file_path != filediff.target_file_path:
225 ## file was renamed, or copied
225 ## file was renamed, or copied
226 %if RENAMED_FILENODE in filediff.patch['stats']['ops']:
226 %if RENAMED_FILENODE in filediff.patch['stats']['ops']:
227 <%
227 <%
228 final_file_name = h.literal(u'{} <i class="icon-angle-left"></i> <del>{}</del>'.format(filediff.target_file_path, filediff.source_file_path))
228 final_file_name = h.literal(u'{} <i class="icon-angle-left"></i> <del>{}</del>'.format(filediff.target_file_path, filediff.source_file_path))
229 final_path = filediff.target_file_path
229 final_path = filediff.target_file_path
230 %>
230 %>
231 %elif COPIED_FILENODE in filediff.patch['stats']['ops']:
231 %elif COPIED_FILENODE in filediff.patch['stats']['ops']:
232 <%
232 <%
233 final_file_name = h.literal(u'{} <i class="icon-angle-left"></i> {}'.format(filediff.target_file_path, filediff.source_file_path))
233 final_file_name = h.literal(u'{} <i class="icon-angle-left"></i> {}'.format(filediff.target_file_path, filediff.source_file_path))
234 final_path = filediff.target_file_path
234 final_path = filediff.target_file_path
235 %>
235 %>
236 %endif
236 %endif
237 %else:
237 %else:
238 ## file was modified
238 ## file was modified
239 <%
239 <%
240 final_file_name = filediff.source_file_path
240 final_file_name = filediff.source_file_path
241 final_path = final_file_name
241 final_path = final_file_name
242 %>
242 %>
243 %endif
243 %endif
244 %else:
244 %else:
245 %if filediff.source_file_path:
245 %if filediff.source_file_path:
246 ## file was deleted
246 ## file was deleted
247 <%
247 <%
248 final_file_name = filediff.source_file_path
248 final_file_name = filediff.source_file_path
249 final_path = final_file_name
249 final_path = final_file_name
250 %>
250 %>
251 %else:
251 %else:
252 ## file was added
252 ## file was added
253 <%
253 <%
254 final_file_name = filediff.target_file_path
254 final_file_name = filediff.target_file_path
255 final_path = final_file_name
255 final_path = final_file_name
256 %>
256 %>
257 %endif
257 %endif
258 %endif
258 %endif
259
259
260 <%
260 <%
261 lines_changed = filediff.patch['stats']['added'] + filediff.patch['stats']['deleted']
261 lines_changed = filediff.patch['stats']['added'] + filediff.patch['stats']['deleted']
262 over_lines_changed_limit = lines_changed > lines_changed_limit
262 over_lines_changed_limit = lines_changed > lines_changed_limit
263 %>
263 %>
264 ## anchor with support of sticky header
264 ## anchor with support of sticky header
265 <div class="anchor" id="a_${h.FID(filediff.raw_id, filediff.patch['filename'])}"></div>
265 <div class="anchor" id="a_${h.FID(filediff.raw_id, filediff.patch['filename'])}"></div>
266
266
267 <input ${(collapse_all and 'checked' or '')} class="filediff-collapse-state collapse-${diffset_container_id}" id="filediff-collapse-${id(filediff)}" type="checkbox" onchange="updateSticky();">
267 <input ${(collapse_all and 'checked' or '')} class="filediff-collapse-state collapse-${diffset_container_id}" id="filediff-collapse-${id(filediff)}" type="checkbox" onchange="updateSticky();">
268 <div
268 <div
269 class="filediff"
269 class="filediff"
270 data-f-path="${filediff.patch['filename']}"
270 data-f-path="${filediff.patch['filename']}"
271 data-anchor-id="${h.FID(filediff.raw_id, filediff.patch['filename'])}"
271 data-anchor-id="${h.FID(filediff.raw_id, filediff.patch['filename'])}"
272 >
272 >
273 <label for="filediff-collapse-${id(filediff)}" class="filediff-heading">
273 <label for="filediff-collapse-${id(filediff)}" class="filediff-heading">
274 <%
274 <%
275 file_comments = (get_inline_comments(inline_comments, filediff.patch['filename']) or {}).values()
275 file_comments = (get_inline_comments(inline_comments, filediff.patch['filename']) or {}).values()
276 total_file_comments = [_c for _c in h.itertools.chain.from_iterable(file_comments) if not (_c.outdated or _c.draft)]
276 total_file_comments = [_c for _c in h.itertools.chain.from_iterable(file_comments) if not (_c.outdated or _c.draft)]
277 %>
277 %>
278 <div class="filediff-collapse-indicator icon-"></div>
278 <div class="filediff-collapse-indicator icon-"></div>
279
279
280 ## Comments/Options PILL
280 ## Comments/Options PILL
281 <span class="pill-group pull-right">
281 <span class="pill-group pull-right">
282 <span class="pill" op="comments">
282 <span class="pill" op="comments">
283 <i class="icon-comment"></i> ${len(total_file_comments)}
283 <i class="icon-comment"></i> ${len(total_file_comments)}
284 </span>
284 </span>
285
285
286 <details class="details-reset details-inline-block">
286 <details class="details-reset details-inline-block">
287 <summary class="noselect">
287 <summary class="noselect">
288 <i class="pill icon-options cursor-pointer" op="options"></i>
288 <i class="pill icon-options cursor-pointer" op="options"></i>
289 </summary>
289 </summary>
290 <details-menu class="details-dropdown">
290 <details-menu class="details-dropdown">
291
291
292 <div class="dropdown-item">
292 <div class="dropdown-item">
293 <span>${final_path}</span>
293 <span>${final_path}</span>
294 <span class="pull-right icon-clipboard clipboard-action" data-clipboard-text="${final_path}" title="Copy file path"></span>
294 <span class="pull-right icon-clipboard clipboard-action" data-clipboard-text="${final_path}" title="Copy file path"></span>
295 </div>
295 </div>
296
296
297 <div class="dropdown-divider"></div>
297 <div class="dropdown-divider"></div>
298
298
299 <div class="dropdown-item">
299 <div class="dropdown-item">
300 <% permalink = request.current_route_url(_anchor='a_{}'.format(h.FID(filediff.raw_id, filediff.patch['filename']))) %>
300 <% permalink = request.current_route_url(_anchor='a_{}'.format(h.FID(filediff.raw_id, filediff.patch['filename']))) %>
301 <a href="${permalink}">¶ permalink</a>
301 <a href="${permalink}">¶ permalink</a>
302 <span class="pull-right icon-clipboard clipboard-action" data-clipboard-text="${permalink}" title="Copy permalink"></span>
302 <span class="pull-right icon-clipboard clipboard-action" data-clipboard-text="${permalink}" title="Copy permalink"></span>
303 </div>
303 </div>
304
304
305
305
306 </details-menu>
306 </details-menu>
307 </details>
307 </details>
308
308
309 </span>
309 </span>
310
310
311 ${diff_ops(final_file_name, filediff)}
311 ${diff_ops(final_file_name, filediff)}
312
312
313 </label>
313 </label>
314
314
315 ${diff_menu(filediff, use_comments=use_comments)}
315 ${diff_menu(filediff, use_comments=use_comments)}
316 <table id="file-${h.safeid(h.safe_unicode(filediff.patch['filename']))}" data-f-path="${filediff.patch['filename']}" data-anchor-id="${h.FID(filediff.raw_id, filediff.patch['filename'])}" class="code-visible-block cb cb-diff-${c.user_session_attrs["diffmode"]} code-highlight ${(over_lines_changed_limit and 'cb-collapsed' or '')}">
316 <table id="file-${h.safeid(h.safe_unicode(filediff.patch['filename']))}" data-f-path="${filediff.patch['filename']}" data-anchor-id="${h.FID(filediff.raw_id, filediff.patch['filename'])}" class="code-visible-block cb cb-diff-${c.user_session_attrs["diffmode"]} code-highlight ${(over_lines_changed_limit and 'cb-collapsed' or '')}">
317
317
318 ## new/deleted/empty content case
318 ## new/deleted/empty content case
319 % if not filediff.hunks:
319 % if not filediff.hunks:
320 ## Comment container, on "fakes" hunk that contains all data to render comments
320 ## Comment container, on "fakes" hunk that contains all data to render comments
321 ${render_hunk_lines(filediff, c.user_session_attrs["diffmode"], filediff.hunk_ops, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
321 ${render_hunk_lines(filediff, c.user_session_attrs["diffmode"], filediff.hunk_ops, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
322 % endif
322 % endif
323
323
324 %if filediff.limited_diff:
324 %if filediff.limited_diff:
325 <tr class="cb-warning cb-collapser">
325 <tr class="cb-warning cb-collapser">
326 <td class="cb-text" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=6')}>
326 <td class="cb-text" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=6')}>
327 ${_('The requested commit or file is too big and content was truncated.')} <a href="${h.current_route_path(request, fulldiff=1)}" onclick="return confirm('${_("Showing a big diff might take some time and resources, continue?")}')">${_('Show full diff')}</a>
327 ${_('The requested commit or file is too big and content was truncated.')} <a href="${h.current_route_path(request, fulldiff=1)}" onclick="return confirm('${_("Showing a big diff might take some time and resources, continue?")}')">${_('Show full diff')}</a>
328 </td>
328 </td>
329 </tr>
329 </tr>
330 %else:
330 %else:
331 %if over_lines_changed_limit:
331 %if over_lines_changed_limit:
332 <tr class="cb-warning cb-collapser">
332 <tr class="cb-warning cb-collapser">
333 <td class="cb-text" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=6')}>
333 <td class="cb-text" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=6')}>
334 ${_('This diff has been collapsed as it changes many lines, (%i lines changed)' % lines_changed)}
334 ${_('This diff has been collapsed as it changes many lines, (%i lines changed)' % lines_changed)}
335 <a href="#" class="cb-expand"
335 <a href="#" class="cb-expand"
336 onclick="$(this).closest('table').removeClass('cb-collapsed'); updateSticky(); return false;">${_('Show them')}
336 onclick="$(this).closest('table').removeClass('cb-collapsed'); updateSticky(); return false;">${_('Show them')}
337 </a>
337 </a>
338 <a href="#" class="cb-collapse"
338 <a href="#" class="cb-collapse"
339 onclick="$(this).closest('table').addClass('cb-collapsed'); updateSticky(); return false;">${_('Hide them')}
339 onclick="$(this).closest('table').addClass('cb-collapsed'); updateSticky(); return false;">${_('Hide them')}
340 </a>
340 </a>
341 </td>
341 </td>
342 </tr>
342 </tr>
343 %endif
343 %endif
344 %endif
344 %endif
345
345
346 % for hunk in filediff.hunks:
346 % for hunk in filediff.hunks:
347 <tr class="cb-hunk">
347 <tr class="cb-hunk">
348 <td ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=3' or '')}>
348 <td ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=3' or '')}>
349 ## TODO: dan: add ajax loading of more context here
349 ## TODO: dan: add ajax loading of more context here
350 ## <a href="#">
350 ## <a href="#">
351 <i class="icon-more"></i>
351 <i class="icon-more"></i>
352 ## </a>
352 ## </a>
353 </td>
353 </td>
354 <td ${(c.user_session_attrs["diffmode"] == 'sideside' and 'colspan=5' or '')}>
354 <td ${(c.user_session_attrs["diffmode"] == 'sideside' and 'colspan=5' or '')}>
355 @@
355 @@
356 -${hunk.source_start},${hunk.source_length}
356 -${hunk.source_start},${hunk.source_length}
357 +${hunk.target_start},${hunk.target_length}
357 +${hunk.target_start},${hunk.target_length}
358 ${hunk.section_header}
358 ${hunk.section_header}
359 </td>
359 </td>
360 </tr>
360 </tr>
361
361
362 ${render_hunk_lines(filediff, c.user_session_attrs["diffmode"], hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
362 ${render_hunk_lines(filediff, c.user_session_attrs["diffmode"], hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
363 % endfor
363 % endfor
364
364
365 <% unmatched_comments = (inline_comments or {}).get(filediff.patch['filename'], {}) %>
365 <% unmatched_comments = (inline_comments or {}).get(filediff.patch['filename'], {}) %>
366
366
367 ## outdated comments that do not fit into currently displayed lines
367 ## outdated comments that do not fit into currently displayed lines
368 % for lineno, comments in unmatched_comments.items():
368 % for lineno, comments in unmatched_comments.items():
369
369
370 %if c.user_session_attrs["diffmode"] == 'unified':
370 %if c.user_session_attrs["diffmode"] == 'unified':
371 % if loop.index == 0:
371 % if loop.index == 0:
372 <tr class="cb-hunk">
372 <tr class="cb-hunk">
373 <td colspan="3"></td>
373 <td colspan="3"></td>
374 <td>
374 <td>
375 <div>
375 <div>
376 ${_('Unmatched/outdated inline comments below')}
376 ${_('Unmatched/outdated inline comments below')}
377 </div>
377 </div>
378 </td>
378 </td>
379 </tr>
379 </tr>
380 % endif
380 % endif
381 <tr class="cb-line">
381 <tr class="cb-line">
382 <td class="cb-data cb-context"></td>
382 <td class="cb-data cb-context"></td>
383 <td class="cb-lineno cb-context"></td>
383 <td class="cb-lineno cb-context"></td>
384 <td class="cb-lineno cb-context"></td>
384 <td class="cb-lineno cb-context"></td>
385 <td class="cb-content cb-context">
385 <td class="cb-content cb-context">
386 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
386 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
387 </td>
387 </td>
388 </tr>
388 </tr>
389 %elif c.user_session_attrs["diffmode"] == 'sideside':
389 %elif c.user_session_attrs["diffmode"] == 'sideside':
390 % if loop.index == 0:
390 % if loop.index == 0:
391 <tr class="cb-comment-info">
391 <tr class="cb-comment-info">
392 <td colspan="2"></td>
392 <td colspan="2"></td>
393 <td class="cb-line">
393 <td class="cb-line">
394 <div>
394 <div>
395 ${_('Unmatched/outdated inline comments below')}
395 ${_('Unmatched/outdated inline comments below')}
396 </div>
396 </div>
397 </td>
397 </td>
398 <td colspan="2"></td>
398 <td colspan="2"></td>
399 <td class="cb-line">
399 <td class="cb-line">
400 <div>
400 <div>
401 ${_('Unmatched/outdated comments below')}
401 ${_('Unmatched/outdated comments below')}
402 </div>
402 </div>
403 </td>
403 </td>
404 </tr>
404 </tr>
405 % endif
405 % endif
406 <tr class="cb-line">
406 <tr class="cb-line">
407 <td class="cb-data cb-context"></td>
407 <td class="cb-data cb-context"></td>
408 <td class="cb-lineno cb-context"></td>
408 <td class="cb-lineno cb-context"></td>
409 <td class="cb-content cb-context">
409 <td class="cb-content cb-context">
410 % if lineno.startswith('o'):
410 % if lineno.startswith('o'):
411 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
411 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
412 % endif
412 % endif
413 </td>
413 </td>
414
414
415 <td class="cb-data cb-context"></td>
415 <td class="cb-data cb-context"></td>
416 <td class="cb-lineno cb-context"></td>
416 <td class="cb-lineno cb-context"></td>
417 <td class="cb-content cb-context">
417 <td class="cb-content cb-context">
418 % if lineno.startswith('n'):
418 % if lineno.startswith('n'):
419 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
419 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
420 % endif
420 % endif
421 </td>
421 </td>
422 </tr>
422 </tr>
423 %endif
423 %endif
424
424
425 % endfor
425 % endfor
426
426
427 </table>
427 </table>
428 </div>
428 </div>
429 %endfor
429 %endfor
430
430
431 ## outdated comments that are made for a file that has been deleted
431 ## outdated comments that are made for a file that has been deleted
432 % for filename, comments_dict in (deleted_files_comments or {}).items():
432 % for filename, comments_dict in (deleted_files_comments or {}).items():
433
433
434 <%
434 <%
435 display_state = 'display: none'
435 display_state = 'display: none'
436 open_comments_in_file = [x for x in comments_dict['comments'] if x.outdated is False]
436 open_comments_in_file = [x for x in comments_dict['comments'] if x.outdated is False]
437 if open_comments_in_file:
437 if open_comments_in_file:
438 display_state = ''
438 display_state = ''
439 fid = str(id(filename))
439 fid = str(id(filename))
440 %>
440 %>
441 <div class="filediffs filediff-outdated" style="${display_state}">
441 <div class="filediffs filediff-outdated" style="${display_state}">
442 <input ${(collapse_all and 'checked' or '')} class="filediff-collapse-state collapse-${diffset_container_id}" id="filediff-collapse-${id(filename)}" type="checkbox" onchange="updateSticky();">
442 <input ${(collapse_all and 'checked' or '')} class="filediff-collapse-state collapse-${diffset_container_id}" id="filediff-collapse-${id(filename)}" type="checkbox" onchange="updateSticky();">
443 <div class="filediff" data-f-path="${filename}" id="a_${h.FID(fid, filename)}">
443 <div class="filediff" data-f-path="${filename}" id="a_${h.FID(fid, filename)}">
444 <label for="filediff-collapse-${id(filename)}" class="filediff-heading">
444 <label for="filediff-collapse-${id(filename)}" class="filediff-heading">
445 <div class="filediff-collapse-indicator icon-"></div>
445 <div class="filediff-collapse-indicator icon-"></div>
446
446
447 <span class="pill">
447 <span class="pill">
448 ## file was deleted
448 ## file was deleted
449 ${filename}
449 ${filename}
450 </span>
450 </span>
451 <span class="pill-group pull-left" >
451 <span class="pill-group pull-left" >
452 ## file op, doesn't need translation
452 ## file op, doesn't need translation
453 <span class="pill" op="removed">unresolved comments</span>
453 <span class="pill" op="removed">unresolved comments</span>
454 </span>
454 </span>
455 <a class="pill filediff-anchor" href="#a_${h.FID(fid, filename)}"></a>
455 <a class="pill filediff-anchor" href="#a_${h.FID(fid, filename)}"></a>
456 <span class="pill-group pull-right">
456 <span class="pill-group pull-right">
457 <span class="pill" op="deleted">
457 <span class="pill" op="deleted">
458 % if comments_dict['stats'] >0:
458 % if comments_dict['stats'] >0:
459 -${comments_dict['stats']}
459 -${comments_dict['stats']}
460 % else:
460 % else:
461 ${comments_dict['stats']}
461 ${comments_dict['stats']}
462 % endif
462 % endif
463 </span>
463 </span>
464 </span>
464 </span>
465 </label>
465 </label>
466
466
467 <table class="cb cb-diff-${c.user_session_attrs["diffmode"]} code-highlight ${(over_lines_changed_limit and 'cb-collapsed' or '')}">
467 <table class="cb cb-diff-${c.user_session_attrs["diffmode"]} code-highlight ${(over_lines_changed_limit and 'cb-collapsed' or '')}">
468 <tr>
468 <tr>
469 % if c.user_session_attrs["diffmode"] == 'unified':
469 % if c.user_session_attrs["diffmode"] == 'unified':
470 <td></td>
470 <td></td>
471 %endif
471 %endif
472
472
473 <td></td>
473 <td></td>
474 <td class="cb-text cb-${op_class(BIN_FILENODE)}" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=5')}>
474 <td class="cb-text cb-${op_class(BIN_FILENODE)}" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=5')}>
475 <strong>${_('This file was removed from diff during updates to this pull-request.')}</strong><br/>
475 <strong>${_('This file was removed from diff during updates to this pull-request.')}</strong><br/>
476 ${_('There are still outdated/unresolved comments attached to it.')}
476 ${_('There are still outdated/unresolved comments attached to it.')}
477 </td>
477 </td>
478 </tr>
478 </tr>
479 %if c.user_session_attrs["diffmode"] == 'unified':
479 %if c.user_session_attrs["diffmode"] == 'unified':
480 <tr class="cb-line">
480 <tr class="cb-line">
481 <td class="cb-data cb-context"></td>
481 <td class="cb-data cb-context"></td>
482 <td class="cb-lineno cb-context"></td>
482 <td class="cb-lineno cb-context"></td>
483 <td class="cb-lineno cb-context"></td>
483 <td class="cb-lineno cb-context"></td>
484 <td class="cb-content cb-context">
484 <td class="cb-content cb-context">
485 ${inline_comments_container(comments_dict['comments'], active_pattern_entries=active_pattern_entries)}
485 ${inline_comments_container(comments_dict['comments'], active_pattern_entries=active_pattern_entries)}
486 </td>
486 </td>
487 </tr>
487 </tr>
488 %elif c.user_session_attrs["diffmode"] == 'sideside':
488 %elif c.user_session_attrs["diffmode"] == 'sideside':
489 <tr class="cb-line">
489 <tr class="cb-line">
490 <td class="cb-data cb-context"></td>
490 <td class="cb-data cb-context"></td>
491 <td class="cb-lineno cb-context"></td>
491 <td class="cb-lineno cb-context"></td>
492 <td class="cb-content cb-context"></td>
492 <td class="cb-content cb-context"></td>
493
493
494 <td class="cb-data cb-context"></td>
494 <td class="cb-data cb-context"></td>
495 <td class="cb-lineno cb-context"></td>
495 <td class="cb-lineno cb-context"></td>
496 <td class="cb-content cb-context">
496 <td class="cb-content cb-context">
497 ${inline_comments_container(comments_dict['comments'], active_pattern_entries=active_pattern_entries)}
497 ${inline_comments_container(comments_dict['comments'], active_pattern_entries=active_pattern_entries)}
498 </td>
498 </td>
499 </tr>
499 </tr>
500 %endif
500 %endif
501 </table>
501 </table>
502 </div>
502 </div>
503 </div>
503 </div>
504 % endfor
504 % endfor
505
505
506 </div>
506 </div>
507 </div>
507 </div>
508 </%def>
508 </%def>
509
509
510 <%def name="diff_ops(file_name, filediff)">
510 <%def name="diff_ops(file_name, filediff)">
511 <%
511 <%
512 from rhodecode.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
512 from rhodecode.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
513 MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE, COPIED_FILENODE
513 MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE, COPIED_FILENODE
514 %>
514 %>
515 <span class="pill">
515 <span class="pill">
516 <i class="icon-file-text"></i>
516 <i class="icon-file-text"></i>
517 ${file_name}
517 ${file_name}
518 </span>
518 </span>
519
519
520 <span class="pill-group pull-right">
520 <span class="pill-group pull-right">
521
521
522 ## ops pills
522 ## ops pills
523 %if filediff.limited_diff:
523 %if filediff.limited_diff:
524 <span class="pill tooltip" op="limited" title="The stats for this diff are not complete">limited diff</span>
524 <span class="pill tooltip" op="limited" title="The stats for this diff are not complete">limited diff</span>
525 %endif
525 %endif
526
526
527 %if NEW_FILENODE in filediff.patch['stats']['ops']:
527 %if NEW_FILENODE in filediff.patch['stats']['ops']:
528 <span class="pill" op="created">created</span>
528 <span class="pill" op="created">created</span>
529 %if filediff['target_mode'].startswith('120'):
529 %if filediff['target_mode'].startswith('120'):
530 <span class="pill" op="symlink">symlink</span>
530 <span class="pill" op="symlink">symlink</span>
531 %else:
531 %else:
532 <span class="pill" op="mode">${nice_mode(filediff['target_mode'])}</span>
532 <span class="pill" op="mode">${nice_mode(filediff['target_mode'])}</span>
533 %endif
533 %endif
534 %endif
534 %endif
535
535
536 %if RENAMED_FILENODE in filediff.patch['stats']['ops']:
536 %if RENAMED_FILENODE in filediff.patch['stats']['ops']:
537 <span class="pill" op="renamed">renamed</span>
537 <span class="pill" op="renamed">renamed</span>
538 %endif
538 %endif
539
539
540 %if COPIED_FILENODE in filediff.patch['stats']['ops']:
540 %if COPIED_FILENODE in filediff.patch['stats']['ops']:
541 <span class="pill" op="copied">copied</span>
541 <span class="pill" op="copied">copied</span>
542 %endif
542 %endif
543
543
544 %if DEL_FILENODE in filediff.patch['stats']['ops']:
544 %if DEL_FILENODE in filediff.patch['stats']['ops']:
545 <span class="pill" op="removed">removed</span>
545 <span class="pill" op="removed">removed</span>
546 %endif
546 %endif
547
547
548 %if CHMOD_FILENODE in filediff.patch['stats']['ops']:
548 %if CHMOD_FILENODE in filediff.patch['stats']['ops']:
549 <span class="pill" op="mode">
549 <span class="pill" op="mode">
550 ${nice_mode(filediff['source_mode'])}${nice_mode(filediff['target_mode'])}
550 ${nice_mode(filediff['source_mode'])}${nice_mode(filediff['target_mode'])}
551 </span>
551 </span>
552 %endif
552 %endif
553
553
554 %if BIN_FILENODE in filediff.patch['stats']['ops']:
554 %if BIN_FILENODE in filediff.patch['stats']['ops']:
555 <span class="pill" op="binary">binary</span>
555 <span class="pill" op="binary">binary</span>
556 %if MOD_FILENODE in filediff.patch['stats']['ops']:
556 %if MOD_FILENODE in filediff.patch['stats']['ops']:
557 <span class="pill" op="modified">modified</span>
557 <span class="pill" op="modified">modified</span>
558 %endif
558 %endif
559 %endif
559 %endif
560
560
561 <span class="pill" op="added">${('+' if filediff.patch['stats']['added'] else '')}${filediff.patch['stats']['added']}</span>
561 <span class="pill" op="added">${('+' if filediff.patch['stats']['added'] else '')}${filediff.patch['stats']['added']}</span>
562 <span class="pill" op="deleted">${((h.safe_int(filediff.patch['stats']['deleted']) or 0) * -1)}</span>
562 <span class="pill" op="deleted">${((h.safe_int(filediff.patch['stats']['deleted']) or 0) * -1)}</span>
563
563
564 </span>
564 </span>
565
565
566 </%def>
566 </%def>
567
567
568 <%def name="nice_mode(filemode)">
568 <%def name="nice_mode(filemode)">
569 ${(filemode.startswith('100') and filemode[3:] or filemode)}
569 ${(filemode.startswith('100') and filemode[3:] or filemode)}
570 </%def>
570 </%def>
571
571
572 <%def name="diff_menu(filediff, use_comments=False)">
572 <%def name="diff_menu(filediff, use_comments=False)">
573 <div class="filediff-menu">
573 <div class="filediff-menu">
574
574
575 %if filediff.diffset.source_ref:
575 %if filediff.diffset.source_ref:
576
576
577 ## FILE BEFORE CHANGES
577 ## FILE BEFORE CHANGES
578 %if filediff.operation in ['D', 'M']:
578 %if filediff.operation in ['D', 'M']:
579 <a
579 <a
580 class="tooltip"
580 class="tooltip"
581 href="${h.route_path('repo_files',repo_name=filediff.diffset.target_repo_name,commit_id=filediff.diffset.source_ref,f_path=filediff.source_file_path)}"
581 href="${h.route_path('repo_files',repo_name=filediff.diffset.target_repo_name,commit_id=filediff.diffset.source_ref,f_path=filediff.source_file_path)}"
582 title="${h.tooltip(_('Show file at commit: %(commit_id)s') % {'commit_id': filediff.diffset.source_ref[:12]})}"
582 title="${h.tooltip(_('Show file at commit: %(commit_id)s') % {'commit_id': filediff.diffset.source_ref[:12]})}"
583 >
583 >
584 ${_('Show file before')}
584 ${_('Show file before')}
585 </a> |
585 </a> |
586 %else:
586 %else:
587 <span
587 <span
588 class="tooltip"
588 class="tooltip"
589 title="${h.tooltip(_('File not present at commit: %(commit_id)s') % {'commit_id': filediff.diffset.source_ref[:12]})}"
589 title="${h.tooltip(_('File not present at commit: %(commit_id)s') % {'commit_id': filediff.diffset.source_ref[:12]})}"
590 >
590 >
591 ${_('Show file before')}
591 ${_('Show file before')}
592 </span> |
592 </span> |
593 %endif
593 %endif
594
594
595 ## FILE AFTER CHANGES
595 ## FILE AFTER CHANGES
596 %if filediff.operation in ['A', 'M']:
596 %if filediff.operation in ['A', 'M']:
597 <a
597 <a
598 class="tooltip"
598 class="tooltip"
599 href="${h.route_path('repo_files',repo_name=filediff.diffset.source_repo_name,commit_id=filediff.diffset.target_ref,f_path=filediff.target_file_path)}"
599 href="${h.route_path('repo_files',repo_name=filediff.diffset.source_repo_name,commit_id=filediff.diffset.target_ref,f_path=filediff.target_file_path)}"
600 title="${h.tooltip(_('Show file at commit: %(commit_id)s') % {'commit_id': filediff.diffset.target_ref[:12]})}"
600 title="${h.tooltip(_('Show file at commit: %(commit_id)s') % {'commit_id': filediff.diffset.target_ref[:12]})}"
601 >
601 >
602 ${_('Show file after')}
602 ${_('Show file after')}
603 </a>
603 </a>
604 %else:
604 %else:
605 <span
605 <span
606 class="tooltip"
606 class="tooltip"
607 title="${h.tooltip(_('File not present at commit: %(commit_id)s') % {'commit_id': filediff.diffset.target_ref[:12]})}"
607 title="${h.tooltip(_('File not present at commit: %(commit_id)s') % {'commit_id': filediff.diffset.target_ref[:12]})}"
608 >
608 >
609 ${_('Show file after')}
609 ${_('Show file after')}
610 </span>
610 </span>
611 %endif
611 %endif
612
612
613 % if use_comments:
613 % if use_comments:
614 |
614 |
615 <a href="#" onclick="Rhodecode.comments.toggleDiffComments(this);return toggleElement(this)"
615 <a href="#" onclick="Rhodecode.comments.toggleDiffComments(this);return toggleElement(this)"
616 data-toggle-on="${_('Hide comments')}"
616 data-toggle-on="${_('Hide comments')}"
617 data-toggle-off="${_('Show comments')}">
617 data-toggle-off="${_('Show comments')}">
618 <span class="hide-comment-button">${_('Hide comments')}</span>
618 <span class="hide-comment-button">${_('Hide comments')}</span>
619 </a>
619 </a>
620 % endif
620 % endif
621
621
622 %endif
622 %endif
623
623
624 </div>
624 </div>
625 </%def>
625 </%def>
626
626
627
627
628 <%def name="inline_comments_container(comments, active_pattern_entries=None, line_no='', f_path='')">
628 <%def name="inline_comments_container(comments, active_pattern_entries=None, line_no='', f_path='')">
629
629
630 <div class="inline-comments">
630 <div class="inline-comments">
631 %for comment in comments:
631 %for comment in comments:
632 ${commentblock.comment_block(comment, inline=True, active_pattern_entries=active_pattern_entries)}
632 ${commentblock.comment_block(comment, inline=True, active_pattern_entries=active_pattern_entries)}
633 %endfor
633 %endfor
634
634
635 <%
635 <%
636 extra_class = ''
636 extra_class = ''
637 extra_style = ''
637 extra_style = ''
638
638
639 if comments and comments[-1].outdated_at_version(c.at_version_num):
639 if comments and comments[-1].outdated_at_version(c.at_version_num):
640 extra_class = ' comment-outdated'
640 extra_class = ' comment-outdated'
641 extra_style = 'display: none;'
641 extra_style = 'display: none;'
642
642
643 %>
643 %>
644
644
645 <div class="reply-thread-container-wrapper${extra_class}" style="${extra_style}">
645 <div class="reply-thread-container-wrapper${extra_class}" style="${extra_style}">
646 <div class="reply-thread-container${extra_class}">
646 <div class="reply-thread-container${extra_class}">
647 <div class="reply-thread-gravatar">
647 <div class="reply-thread-gravatar">
648 % if c.rhodecode_user.username != h.DEFAULT_USER:
648 % if c.rhodecode_user.username != h.DEFAULT_USER:
649 ${base.gravatar(c.rhodecode_user.email, 20, tooltip=True, user=c.rhodecode_user)}
649 ${base.gravatar(c.rhodecode_user.email, 20, tooltip=True, user=c.rhodecode_user)}
650 % endif
650 % endif
651 </div>
651 </div>
652
652
653 <div class="reply-thread-reply-button">
653 <div class="reply-thread-reply-button">
654 % if c.rhodecode_user.username != h.DEFAULT_USER:
654 % if c.rhodecode_user.username != h.DEFAULT_USER:
655 ## initial reply button, some JS logic can append here a FORM to leave a first comment.
655 ## initial reply button, some JS logic can append here a FORM to leave a first comment.
656 <button class="cb-comment-add-button" onclick="return Rhodecode.comments.createComment(this, '${f_path}', '${line_no}', null)">Reply...</button>
656 <button class="cb-comment-add-button" onclick="return Rhodecode.comments.createComment(this, '${f_path}', '${line_no}', null)">Reply...</button>
657 % endif
657 % endif
658 </div>
658 </div>
659 ##% endif
659 ##% endif
660 <div class="reply-thread-last"></div>
660 <div class="reply-thread-last"></div>
661 </div>
661 </div>
662 </div>
662 </div>
663 </div>
663 </div>
664
664
665 </%def>
665 </%def>
666
666
667 <%!
667 <%!
668
668
669 def get_inline_comments(comments, filename):
669 def get_inline_comments(comments, filename):
670 if hasattr(filename, 'unicode_path'):
670 if hasattr(filename, 'unicode_path'):
671 filename = filename.unicode_path
671 filename = filename.unicode_path
672
672
673 if not isinstance(filename, (unicode, str)):
673 if not isinstance(filename, str):
674 return None
674 return None
675
675
676 if comments and filename in comments:
676 if comments and filename in comments:
677 return comments[filename]
677 return comments[filename]
678
678
679 return None
679 return None
680
680
681 def get_comments_for(diff_type, comments, filename, line_version, line_number):
681 def get_comments_for(diff_type, comments, filename, line_version, line_number):
682 if hasattr(filename, 'unicode_path'):
682 if hasattr(filename, 'unicode_path'):
683 filename = filename.unicode_path
683 filename = filename.unicode_path
684
684
685 if not isinstance(filename, (unicode, str)):
685 if not isinstance(filename, str):
686 return None
686 return None
687
687
688 file_comments = get_inline_comments(comments, filename)
688 file_comments = get_inline_comments(comments, filename)
689 if file_comments is None:
689 if file_comments is None:
690 return None
690 return None
691
691
692 line_key = '{}{}'.format(line_version, line_number) ## e.g o37, n12
692 line_key = '{}{}'.format(line_version, line_number) ## e.g o37, n12
693 if line_key in file_comments:
693 if line_key in file_comments:
694 data = file_comments.pop(line_key)
694 data = file_comments.pop(line_key)
695 return data
695 return data
696 %>
696 %>
697
697
698 <%def name="render_hunk_lines_sideside(filediff, hunk, use_comments=False, inline_comments=None, active_pattern_entries=None)">
698 <%def name="render_hunk_lines_sideside(filediff, hunk, use_comments=False, inline_comments=None, active_pattern_entries=None)">
699
699
700 <% chunk_count = 1 %>
700 <% chunk_count = 1 %>
701 %for loop_obj, item in h.looper(hunk.sideside):
701 %for loop_obj, item in h.looper(hunk.sideside):
702 <%
702 <%
703 line = item
703 line = item
704 i = loop_obj.index
704 i = loop_obj.index
705 prev_line = loop_obj.previous
705 prev_line = loop_obj.previous
706 old_line_anchor, new_line_anchor = None, None
706 old_line_anchor, new_line_anchor = None, None
707
707
708 if line.original.lineno:
708 if line.original.lineno:
709 old_line_anchor = diff_line_anchor(filediff.raw_id, hunk.source_file_path, line.original.lineno, 'o')
709 old_line_anchor = diff_line_anchor(filediff.raw_id, hunk.source_file_path, line.original.lineno, 'o')
710 if line.modified.lineno:
710 if line.modified.lineno:
711 new_line_anchor = diff_line_anchor(filediff.raw_id, hunk.target_file_path, line.modified.lineno, 'n')
711 new_line_anchor = diff_line_anchor(filediff.raw_id, hunk.target_file_path, line.modified.lineno, 'n')
712
712
713 line_action = line.modified.action or line.original.action
713 line_action = line.modified.action or line.original.action
714 prev_line_action = prev_line and (prev_line.modified.action or prev_line.original.action)
714 prev_line_action = prev_line and (prev_line.modified.action or prev_line.original.action)
715 %>
715 %>
716
716
717 <tr class="cb-line">
717 <tr class="cb-line">
718 <td class="cb-data ${action_class(line.original.action)}"
718 <td class="cb-data ${action_class(line.original.action)}"
719 data-line-no="${line.original.lineno}"
719 data-line-no="${line.original.lineno}"
720 >
720 >
721
721
722 <% line_old_comments, line_old_comments_no_drafts = None, None %>
722 <% line_old_comments, line_old_comments_no_drafts = None, None %>
723 %if line.original.get_comment_args:
723 %if line.original.get_comment_args:
724 <%
724 <%
725 line_old_comments = get_comments_for('side-by-side', inline_comments, *line.original.get_comment_args)
725 line_old_comments = get_comments_for('side-by-side', inline_comments, *line.original.get_comment_args)
726 line_old_comments_no_drafts = [c for c in line_old_comments if not c.draft] if line_old_comments else []
726 line_old_comments_no_drafts = [c for c in line_old_comments if not c.draft] if line_old_comments else []
727 has_outdated = any([x.outdated for x in line_old_comments_no_drafts])
727 has_outdated = any([x.outdated for x in line_old_comments_no_drafts])
728 %>
728 %>
729 %endif
729 %endif
730 %if line_old_comments_no_drafts:
730 %if line_old_comments_no_drafts:
731 % if has_outdated:
731 % if has_outdated:
732 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(line_old_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
732 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(line_old_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
733 % else:
733 % else:
734 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(line_old_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
734 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(line_old_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
735 % endif
735 % endif
736 %endif
736 %endif
737 </td>
737 </td>
738 <td class="cb-lineno ${action_class(line.original.action)}"
738 <td class="cb-lineno ${action_class(line.original.action)}"
739 data-line-no="${line.original.lineno}"
739 data-line-no="${line.original.lineno}"
740 %if old_line_anchor:
740 %if old_line_anchor:
741 id="${old_line_anchor}"
741 id="${old_line_anchor}"
742 %endif
742 %endif
743 >
743 >
744 %if line.original.lineno:
744 %if line.original.lineno:
745 <a name="${old_line_anchor}" href="#${old_line_anchor}">${line.original.lineno}</a>
745 <a name="${old_line_anchor}" href="#${old_line_anchor}">${line.original.lineno}</a>
746 %endif
746 %endif
747 </td>
747 </td>
748
748
749 <% line_no = 'o{}'.format(line.original.lineno) %>
749 <% line_no = 'o{}'.format(line.original.lineno) %>
750 <td class="cb-content ${action_class(line.original.action)}"
750 <td class="cb-content ${action_class(line.original.action)}"
751 data-line-no="${line_no}"
751 data-line-no="${line_no}"
752 >
752 >
753 %if use_comments and line.original.lineno:
753 %if use_comments and line.original.lineno:
754 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
754 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
755 %endif
755 %endif
756 <span class="cb-code"><span class="cb-action ${action_class(line.original.action)}"></span>${line.original.content or '' | n}</span>
756 <span class="cb-code"><span class="cb-action ${action_class(line.original.action)}"></span>${line.original.content or '' | n}</span>
757
757
758 %if use_comments and line.original.lineno and line_old_comments:
758 %if use_comments and line.original.lineno and line_old_comments:
759 ${inline_comments_container(line_old_comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
759 ${inline_comments_container(line_old_comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
760 %endif
760 %endif
761
761
762 </td>
762 </td>
763 <td class="cb-data ${action_class(line.modified.action)}"
763 <td class="cb-data ${action_class(line.modified.action)}"
764 data-line-no="${line.modified.lineno}"
764 data-line-no="${line.modified.lineno}"
765 >
765 >
766 <div>
766 <div>
767
767
768 <% line_new_comments, line_new_comments_no_drafts = None, None %>
768 <% line_new_comments, line_new_comments_no_drafts = None, None %>
769 %if line.modified.get_comment_args:
769 %if line.modified.get_comment_args:
770 <%
770 <%
771 line_new_comments = get_comments_for('side-by-side', inline_comments, *line.modified.get_comment_args)
771 line_new_comments = get_comments_for('side-by-side', inline_comments, *line.modified.get_comment_args)
772 line_new_comments_no_drafts = [c for c in line_new_comments if not c.draft] if line_new_comments else []
772 line_new_comments_no_drafts = [c for c in line_new_comments if not c.draft] if line_new_comments else []
773 has_outdated = any([x.outdated for x in line_new_comments_no_drafts])
773 has_outdated = any([x.outdated for x in line_new_comments_no_drafts])
774 %>
774 %>
775 %endif
775 %endif
776
776
777 %if line_new_comments_no_drafts:
777 %if line_new_comments_no_drafts:
778 % if has_outdated:
778 % if has_outdated:
779 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(line_new_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
779 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(line_new_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
780 % else:
780 % else:
781 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(line_new_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
781 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(line_new_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
782 % endif
782 % endif
783 %endif
783 %endif
784 </div>
784 </div>
785 </td>
785 </td>
786 <td class="cb-lineno ${action_class(line.modified.action)}"
786 <td class="cb-lineno ${action_class(line.modified.action)}"
787 data-line-no="${line.modified.lineno}"
787 data-line-no="${line.modified.lineno}"
788 %if new_line_anchor:
788 %if new_line_anchor:
789 id="${new_line_anchor}"
789 id="${new_line_anchor}"
790 %endif
790 %endif
791 >
791 >
792 %if line.modified.lineno:
792 %if line.modified.lineno:
793 <a name="${new_line_anchor}" href="#${new_line_anchor}">${line.modified.lineno}</a>
793 <a name="${new_line_anchor}" href="#${new_line_anchor}">${line.modified.lineno}</a>
794 %endif
794 %endif
795 </td>
795 </td>
796
796
797 <% line_no = 'n{}'.format(line.modified.lineno) %>
797 <% line_no = 'n{}'.format(line.modified.lineno) %>
798 <td class="cb-content ${action_class(line.modified.action)}"
798 <td class="cb-content ${action_class(line.modified.action)}"
799 data-line-no="${line_no}"
799 data-line-no="${line_no}"
800 >
800 >
801 %if use_comments and line.modified.lineno:
801 %if use_comments and line.modified.lineno:
802 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
802 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
803 %endif
803 %endif
804 <span class="cb-code"><span class="cb-action ${action_class(line.modified.action)}"></span>${line.modified.content or '' | n}</span>
804 <span class="cb-code"><span class="cb-action ${action_class(line.modified.action)}"></span>${line.modified.content or '' | n}</span>
805 % if line_action in ['+', '-'] and prev_line_action not in ['+', '-']:
805 % if line_action in ['+', '-'] and prev_line_action not in ['+', '-']:
806 <div class="nav-chunk" style="visibility: hidden">
806 <div class="nav-chunk" style="visibility: hidden">
807 <i class="icon-eye" title="viewing diff hunk-${hunk.index}-${chunk_count}"></i>
807 <i class="icon-eye" title="viewing diff hunk-${hunk.index}-${chunk_count}"></i>
808 </div>
808 </div>
809 <% chunk_count +=1 %>
809 <% chunk_count +=1 %>
810 % endif
810 % endif
811 %if use_comments and line.modified.lineno and line_new_comments:
811 %if use_comments and line.modified.lineno and line_new_comments:
812 ${inline_comments_container(line_new_comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
812 ${inline_comments_container(line_new_comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
813 %endif
813 %endif
814
814
815 </td>
815 </td>
816 </tr>
816 </tr>
817 %endfor
817 %endfor
818 </%def>
818 </%def>
819
819
820
820
821 <%def name="render_hunk_lines_unified(filediff, hunk, use_comments=False, inline_comments=None, active_pattern_entries=None)">
821 <%def name="render_hunk_lines_unified(filediff, hunk, use_comments=False, inline_comments=None, active_pattern_entries=None)">
822 %for old_line_no, new_line_no, action, content, comments_args in hunk.unified:
822 %for old_line_no, new_line_no, action, content, comments_args in hunk.unified:
823
823
824 <%
824 <%
825 old_line_anchor, new_line_anchor = None, None
825 old_line_anchor, new_line_anchor = None, None
826 if old_line_no:
826 if old_line_no:
827 old_line_anchor = diff_line_anchor(filediff.raw_id, hunk.source_file_path, old_line_no, 'o')
827 old_line_anchor = diff_line_anchor(filediff.raw_id, hunk.source_file_path, old_line_no, 'o')
828 if new_line_no:
828 if new_line_no:
829 new_line_anchor = diff_line_anchor(filediff.raw_id, hunk.target_file_path, new_line_no, 'n')
829 new_line_anchor = diff_line_anchor(filediff.raw_id, hunk.target_file_path, new_line_no, 'n')
830 %>
830 %>
831 <tr class="cb-line">
831 <tr class="cb-line">
832 <td class="cb-data ${action_class(action)}">
832 <td class="cb-data ${action_class(action)}">
833 <div>
833 <div>
834
834
835 <% comments, comments_no_drafts = None, None %>
835 <% comments, comments_no_drafts = None, None %>
836 %if comments_args:
836 %if comments_args:
837 <%
837 <%
838 comments = get_comments_for('unified', inline_comments, *comments_args)
838 comments = get_comments_for('unified', inline_comments, *comments_args)
839 comments_no_drafts = [c for c in line_new_comments if not c.draft] if line_new_comments else []
839 comments_no_drafts = [c for c in line_new_comments if not c.draft] if line_new_comments else []
840 has_outdated = any([x.outdated for x in comments_no_drafts])
840 has_outdated = any([x.outdated for x in comments_no_drafts])
841 %>
841 %>
842 %endif
842 %endif
843
843
844 % if comments_no_drafts:
844 % if comments_no_drafts:
845 % if has_outdated:
845 % if has_outdated:
846 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
846 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
847 % else:
847 % else:
848 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
848 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
849 % endif
849 % endif
850 % endif
850 % endif
851 </div>
851 </div>
852 </td>
852 </td>
853 <td class="cb-lineno ${action_class(action)}"
853 <td class="cb-lineno ${action_class(action)}"
854 data-line-no="${old_line_no}"
854 data-line-no="${old_line_no}"
855 %if old_line_anchor:
855 %if old_line_anchor:
856 id="${old_line_anchor}"
856 id="${old_line_anchor}"
857 %endif
857 %endif
858 >
858 >
859 %if old_line_anchor:
859 %if old_line_anchor:
860 <a name="${old_line_anchor}" href="#${old_line_anchor}">${old_line_no}</a>
860 <a name="${old_line_anchor}" href="#${old_line_anchor}">${old_line_no}</a>
861 %endif
861 %endif
862 </td>
862 </td>
863 <td class="cb-lineno ${action_class(action)}"
863 <td class="cb-lineno ${action_class(action)}"
864 data-line-no="${new_line_no}"
864 data-line-no="${new_line_no}"
865 %if new_line_anchor:
865 %if new_line_anchor:
866 id="${new_line_anchor}"
866 id="${new_line_anchor}"
867 %endif
867 %endif
868 >
868 >
869 %if new_line_anchor:
869 %if new_line_anchor:
870 <a name="${new_line_anchor}" href="#${new_line_anchor}">${new_line_no}</a>
870 <a name="${new_line_anchor}" href="#${new_line_anchor}">${new_line_no}</a>
871 %endif
871 %endif
872 </td>
872 </td>
873 <% line_no = '{}{}'.format(new_line_no and 'n' or 'o', new_line_no or old_line_no) %>
873 <% line_no = '{}{}'.format(new_line_no and 'n' or 'o', new_line_no or old_line_no) %>
874 <td class="cb-content ${action_class(action)}"
874 <td class="cb-content ${action_class(action)}"
875 data-line-no="${line_no}"
875 data-line-no="${line_no}"
876 >
876 >
877 %if use_comments:
877 %if use_comments:
878 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
878 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
879 %endif
879 %endif
880 <span class="cb-code"><span class="cb-action ${action_class(action)}"></span> ${content or '' | n}</span>
880 <span class="cb-code"><span class="cb-action ${action_class(action)}"></span> ${content or '' | n}</span>
881 %if use_comments and comments:
881 %if use_comments and comments:
882 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
882 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
883 %endif
883 %endif
884 </td>
884 </td>
885 </tr>
885 </tr>
886 %endfor
886 %endfor
887 </%def>
887 </%def>
888
888
889
889
890 <%def name="render_hunk_lines(filediff, diff_mode, hunk, use_comments, inline_comments, active_pattern_entries)">
890 <%def name="render_hunk_lines(filediff, diff_mode, hunk, use_comments, inline_comments, active_pattern_entries)">
891 % if diff_mode == 'unified':
891 % if diff_mode == 'unified':
892 ${render_hunk_lines_unified(filediff, hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
892 ${render_hunk_lines_unified(filediff, hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
893 % elif diff_mode == 'sideside':
893 % elif diff_mode == 'sideside':
894 ${render_hunk_lines_sideside(filediff, hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
894 ${render_hunk_lines_sideside(filediff, hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
895 % else:
895 % else:
896 <tr class="cb-line">
896 <tr class="cb-line">
897 <td>unknown diff mode</td>
897 <td>unknown diff mode</td>
898 </tr>
898 </tr>
899 % endif
899 % endif
900 </%def>file changes
900 </%def>file changes
901
901
902
902
903 <%def name="render_add_comment_button(line_no='', f_path='')">
903 <%def name="render_add_comment_button(line_no='', f_path='')">
904 % if not c.rhodecode_user.is_default:
904 % if not c.rhodecode_user.is_default:
905 <button class="btn btn-small btn-primary cb-comment-box-opener" onclick="return Rhodecode.comments.createComment(this, '${f_path}', '${line_no}', null)">
905 <button class="btn btn-small btn-primary cb-comment-box-opener" onclick="return Rhodecode.comments.createComment(this, '${f_path}', '${line_no}', null)">
906 <span><i class="icon-comment"></i></span>
906 <span><i class="icon-comment"></i></span>
907 </button>
907 </button>
908 % endif
908 % endif
909 </%def>
909 </%def>
910
910
911 <%def name="render_diffset_menu(diffset, range_diff_on=None, commit=None, pull_request_menu=None)">
911 <%def name="render_diffset_menu(diffset, range_diff_on=None, commit=None, pull_request_menu=None)">
912 <% diffset_container_id = h.md5(diffset.target_ref) %>
912 <% diffset_container_id = h.md5(diffset.target_ref) %>
913
913
914 <div id="diff-file-sticky" class="diffset-menu clearinner">
914 <div id="diff-file-sticky" class="diffset-menu clearinner">
915 ## auto adjustable
915 ## auto adjustable
916 <div class="sidebar__inner">
916 <div class="sidebar__inner">
917 <div class="sidebar__bar">
917 <div class="sidebar__bar">
918 <div class="pull-right">
918 <div class="pull-right">
919
919
920 <div class="btn-group" style="margin-right: 5px;">
920 <div class="btn-group" style="margin-right: 5px;">
921 <a class="tooltip btn" onclick="scrollDown();return false" title="${_('Scroll to page bottom')}">
921 <a class="tooltip btn" onclick="scrollDown();return false" title="${_('Scroll to page bottom')}">
922 <i class="icon-arrow_down"></i>
922 <i class="icon-arrow_down"></i>
923 </a>
923 </a>
924 <a class="tooltip btn" onclick="scrollUp();return false" title="${_('Scroll to page top')}">
924 <a class="tooltip btn" onclick="scrollUp();return false" title="${_('Scroll to page top')}">
925 <i class="icon-arrow_up"></i>
925 <i class="icon-arrow_up"></i>
926 </a>
926 </a>
927 </div>
927 </div>
928
928
929 <div class="btn-group">
929 <div class="btn-group">
930 <a class="btn tooltip toggle-wide-diff" href="#toggle-wide-diff" onclick="toggleWideDiff(this); return false" title="${h.tooltip(_('Toggle wide diff'))}">
930 <a class="btn tooltip toggle-wide-diff" href="#toggle-wide-diff" onclick="toggleWideDiff(this); return false" title="${h.tooltip(_('Toggle wide diff'))}">
931 <i class="icon-wide-mode"></i>
931 <i class="icon-wide-mode"></i>
932 </a>
932 </a>
933 </div>
933 </div>
934 <div class="btn-group">
934 <div class="btn-group">
935
935
936 <a
936 <a
937 class="btn ${(c.user_session_attrs["diffmode"] == 'sideside' and 'btn-active')} tooltip"
937 class="btn ${(c.user_session_attrs["diffmode"] == 'sideside' and 'btn-active')} tooltip"
938 title="${h.tooltip(_('View diff as side by side'))}"
938 title="${h.tooltip(_('View diff as side by side'))}"
939 href="${h.current_route_path(request, diffmode='sideside')}">
939 href="${h.current_route_path(request, diffmode='sideside')}">
940 <span>${_('Side by Side')}</span>
940 <span>${_('Side by Side')}</span>
941 </a>
941 </a>
942
942
943 <a
943 <a
944 class="btn ${(c.user_session_attrs["diffmode"] == 'unified' and 'btn-active')} tooltip"
944 class="btn ${(c.user_session_attrs["diffmode"] == 'unified' and 'btn-active')} tooltip"
945 title="${h.tooltip(_('View diff as unified'))}" href="${h.current_route_path(request, diffmode='unified')}">
945 title="${h.tooltip(_('View diff as unified'))}" href="${h.current_route_path(request, diffmode='unified')}">
946 <span>${_('Unified')}</span>
946 <span>${_('Unified')}</span>
947 </a>
947 </a>
948
948
949 % if range_diff_on is True:
949 % if range_diff_on is True:
950 <a
950 <a
951 title="${_('Turn off: Show the diff as commit range')}"
951 title="${_('Turn off: Show the diff as commit range')}"
952 class="btn btn-primary"
952 class="btn btn-primary"
953 href="${h.current_route_path(request, **{"range-diff":"0"})}">
953 href="${h.current_route_path(request, **{"range-diff":"0"})}">
954 <span>${_('Range Diff')}</span>
954 <span>${_('Range Diff')}</span>
955 </a>
955 </a>
956 % elif range_diff_on is False:
956 % elif range_diff_on is False:
957 <a
957 <a
958 title="${_('Show the diff as commit range')}"
958 title="${_('Show the diff as commit range')}"
959 class="btn"
959 class="btn"
960 href="${h.current_route_path(request, **{"range-diff":"1"})}">
960 href="${h.current_route_path(request, **{"range-diff":"1"})}">
961 <span>${_('Range Diff')}</span>
961 <span>${_('Range Diff')}</span>
962 </a>
962 </a>
963 % endif
963 % endif
964 </div>
964 </div>
965 <div class="btn-group">
965 <div class="btn-group">
966
966
967 <details class="details-reset details-inline-block">
967 <details class="details-reset details-inline-block">
968 <summary class="noselect btn">
968 <summary class="noselect btn">
969 <i class="icon-options cursor-pointer" op="options"></i>
969 <i class="icon-options cursor-pointer" op="options"></i>
970 </summary>
970 </summary>
971
971
972 <div>
972 <div>
973 <details-menu class="details-dropdown" style="top: 35px;">
973 <details-menu class="details-dropdown" style="top: 35px;">
974
974
975 <div class="dropdown-item">
975 <div class="dropdown-item">
976 <div style="padding: 2px 0px">
976 <div style="padding: 2px 0px">
977 % if request.GET.get('ignorews', '') == '1':
977 % if request.GET.get('ignorews', '') == '1':
978 <a href="${h.current_route_path(request, ignorews=0)}">${_('Show whitespace changes')}</a>
978 <a href="${h.current_route_path(request, ignorews=0)}">${_('Show whitespace changes')}</a>
979 % else:
979 % else:
980 <a href="${h.current_route_path(request, ignorews=1)}">${_('Hide whitespace changes')}</a>
980 <a href="${h.current_route_path(request, ignorews=1)}">${_('Hide whitespace changes')}</a>
981 % endif
981 % endif
982 </div>
982 </div>
983 </div>
983 </div>
984
984
985 <div class="dropdown-item">
985 <div class="dropdown-item">
986 <div style="padding: 2px 0px">
986 <div style="padding: 2px 0px">
987 % if request.GET.get('fullcontext', '') == '1':
987 % if request.GET.get('fullcontext', '') == '1':
988 <a href="${h.current_route_path(request, fullcontext=0)}">${_('Hide full context diff')}</a>
988 <a href="${h.current_route_path(request, fullcontext=0)}">${_('Hide full context diff')}</a>
989 % else:
989 % else:
990 <a href="${h.current_route_path(request, fullcontext=1)}">${_('Show full context diff')}</a>
990 <a href="${h.current_route_path(request, fullcontext=1)}">${_('Show full context diff')}</a>
991 % endif
991 % endif
992 </div>
992 </div>
993 </div>
993 </div>
994
994
995 </details-menu>
995 </details-menu>
996 </div>
996 </div>
997 </details>
997 </details>
998
998
999 </div>
999 </div>
1000 </div>
1000 </div>
1001 <div class="pull-left">
1001 <div class="pull-left">
1002 <div class="btn-group">
1002 <div class="btn-group">
1003 <div class="pull-left">
1003 <div class="pull-left">
1004 ${h.hidden('file_filter_{}'.format(diffset_container_id))}
1004 ${h.hidden('file_filter_{}'.format(diffset_container_id))}
1005 </div>
1005 </div>
1006
1006
1007 </div>
1007 </div>
1008 </div>
1008 </div>
1009 </div>
1009 </div>
1010 <div class="fpath-placeholder pull-left">
1010 <div class="fpath-placeholder pull-left">
1011 <i class="icon-file-text"></i>
1011 <i class="icon-file-text"></i>
1012 <strong class="fpath-placeholder-text">
1012 <strong class="fpath-placeholder-text">
1013 Context file:
1013 Context file:
1014 </strong>
1014 </strong>
1015 </div>
1015 </div>
1016 <div class="pull-right noselect">
1016 <div class="pull-right noselect">
1017 %if commit:
1017 %if commit:
1018 <span>
1018 <span>
1019 <code>${h.show_id(commit)}</code>
1019 <code>${h.show_id(commit)}</code>
1020 </span>
1020 </span>
1021 %elif pull_request_menu and pull_request_menu.get('pull_request'):
1021 %elif pull_request_menu and pull_request_menu.get('pull_request'):
1022 <span>
1022 <span>
1023 <code>!${pull_request_menu['pull_request'].pull_request_id}</code>
1023 <code>!${pull_request_menu['pull_request'].pull_request_id}</code>
1024 </span>
1024 </span>
1025 %endif
1025 %endif
1026 % if commit or pull_request_menu:
1026 % if commit or pull_request_menu:
1027 <span class="tooltip" title="Navigate to previous or next change inside files." id="diff_nav">Loading diff...:</span>
1027 <span class="tooltip" title="Navigate to previous or next change inside files." id="diff_nav">Loading diff...:</span>
1028 <span class="cursor-pointer" onclick="scrollToPrevChunk(); return false">
1028 <span class="cursor-pointer" onclick="scrollToPrevChunk(); return false">
1029 <i class="icon-angle-up"></i>
1029 <i class="icon-angle-up"></i>
1030 </span>
1030 </span>
1031 <span class="cursor-pointer" onclick="scrollToNextChunk(); return false">
1031 <span class="cursor-pointer" onclick="scrollToNextChunk(); return false">
1032 <i class="icon-angle-down"></i>
1032 <i class="icon-angle-down"></i>
1033 </span>
1033 </span>
1034 % endif
1034 % endif
1035 </div>
1035 </div>
1036 <div class="sidebar_inner_shadow"></div>
1036 <div class="sidebar_inner_shadow"></div>
1037 </div>
1037 </div>
1038 </div>
1038 </div>
1039
1039
1040 % if diffset:
1040 % if diffset:
1041 %if diffset.limited_diff:
1041 %if diffset.limited_diff:
1042 <% file_placeholder = _ungettext('%(num)s file changed', '%(num)s files changed', diffset.changed_files) % {'num': diffset.changed_files} %>
1042 <% file_placeholder = _ungettext('%(num)s file changed', '%(num)s files changed', diffset.changed_files) % {'num': diffset.changed_files} %>
1043 %else:
1043 %else:
1044 <% file_placeholder = h.literal(_ungettext('%(num)s file changed: <span class="op-added">%(linesadd)s inserted</span>, <span class="op-deleted">%(linesdel)s deleted</span>', '%(num)s files changed: <span class="op-added">%(linesadd)s inserted</span>, <span class="op-deleted">%(linesdel)s deleted</span>',
1044 <% file_placeholder = h.literal(_ungettext('%(num)s file changed: <span class="op-added">%(linesadd)s inserted</span>, <span class="op-deleted">%(linesdel)s deleted</span>', '%(num)s files changed: <span class="op-added">%(linesadd)s inserted</span>, <span class="op-deleted">%(linesdel)s deleted</span>',
1045 diffset.changed_files) % {'num': diffset.changed_files, 'linesadd': diffset.lines_added, 'linesdel': diffset.lines_deleted}) %>
1045 diffset.changed_files) % {'num': diffset.changed_files, 'linesadd': diffset.lines_added, 'linesdel': diffset.lines_deleted}) %>
1046
1046
1047 %endif
1047 %endif
1048 ## case on range-diff placeholder needs to be updated
1048 ## case on range-diff placeholder needs to be updated
1049 % if range_diff_on is True:
1049 % if range_diff_on is True:
1050 <% file_placeholder = _('Disabled on range diff') %>
1050 <% file_placeholder = _('Disabled on range diff') %>
1051 % endif
1051 % endif
1052
1052
1053 <script type="text/javascript">
1053 <script type="text/javascript">
1054 var feedFilesOptions = function (query, initialData) {
1054 var feedFilesOptions = function (query, initialData) {
1055 var data = {results: []};
1055 var data = {results: []};
1056 var isQuery = typeof query.term !== 'undefined';
1056 var isQuery = typeof query.term !== 'undefined';
1057
1057
1058 var section = _gettext('Changed files');
1058 var section = _gettext('Changed files');
1059 var filteredData = [];
1059 var filteredData = [];
1060
1060
1061 //filter results
1061 //filter results
1062 $.each(initialData.results, function (idx, value) {
1062 $.each(initialData.results, function (idx, value) {
1063
1063
1064 if (!isQuery || query.term.length === 0 || value.text.toUpperCase().indexOf(query.term.toUpperCase()) >= 0) {
1064 if (!isQuery || query.term.length === 0 || value.text.toUpperCase().indexOf(query.term.toUpperCase()) >= 0) {
1065 filteredData.push({
1065 filteredData.push({
1066 'id': this.id,
1066 'id': this.id,
1067 'text': this.text,
1067 'text': this.text,
1068 "ops": this.ops,
1068 "ops": this.ops,
1069 })
1069 })
1070 }
1070 }
1071
1071
1072 });
1072 });
1073
1073
1074 data.results = filteredData;
1074 data.results = filteredData;
1075
1075
1076 query.callback(data);
1076 query.callback(data);
1077 };
1077 };
1078
1078
1079 var selectionFormatter = function(data, escapeMarkup) {
1079 var selectionFormatter = function(data, escapeMarkup) {
1080 var container = '<div class="filelist" style="padding-right:100px">{0}</div>';
1080 var container = '<div class="filelist" style="padding-right:100px">{0}</div>';
1081 var tmpl = '<div><strong>{0}</strong></div>'.format(escapeMarkup(data['text']));
1081 var tmpl = '<div><strong>{0}</strong></div>'.format(escapeMarkup(data['text']));
1082 var pill = '<div class="pill-group" style="position: absolute; top:7px; right: 0">' +
1082 var pill = '<div class="pill-group" style="position: absolute; top:7px; right: 0">' +
1083 '<span class="pill" op="added">{0}</span>' +
1083 '<span class="pill" op="added">{0}</span>' +
1084 '<span class="pill" op="deleted">{1}</span>' +
1084 '<span class="pill" op="deleted">{1}</span>' +
1085 '</div>'
1085 '</div>'
1086 ;
1086 ;
1087 var added = data['ops']['added'];
1087 var added = data['ops']['added'];
1088 if (added === 0) {
1088 if (added === 0) {
1089 // don't show +0
1089 // don't show +0
1090 added = 0;
1090 added = 0;
1091 } else {
1091 } else {
1092 added = '+' + added;
1092 added = '+' + added;
1093 }
1093 }
1094
1094
1095 var deleted = -1*data['ops']['deleted'];
1095 var deleted = -1*data['ops']['deleted'];
1096
1096
1097 tmpl += pill.format(added, deleted);
1097 tmpl += pill.format(added, deleted);
1098 return container.format(tmpl);
1098 return container.format(tmpl);
1099 };
1099 };
1100 var formatFileResult = function(result, container, query, escapeMarkup) {
1100 var formatFileResult = function(result, container, query, escapeMarkup) {
1101 return selectionFormatter(result, escapeMarkup);
1101 return selectionFormatter(result, escapeMarkup);
1102 };
1102 };
1103
1103
1104 var formatSelection = function (data, container) {
1104 var formatSelection = function (data, container) {
1105 return '${file_placeholder}'
1105 return '${file_placeholder}'
1106 };
1106 };
1107
1107
1108 if (window.preloadFileFilterData === undefined) {
1108 if (window.preloadFileFilterData === undefined) {
1109 window.preloadFileFilterData = {}
1109 window.preloadFileFilterData = {}
1110 }
1110 }
1111
1111
1112 preloadFileFilterData["${diffset_container_id}"] = {
1112 preloadFileFilterData["${diffset_container_id}"] = {
1113 results: [
1113 results: [
1114 % for filediff in diffset.files:
1114 % for filediff in diffset.files:
1115 {id:"a_${h.FID(filediff.raw_id, filediff.patch['filename'])}",
1115 {id:"a_${h.FID(filediff.raw_id, filediff.patch['filename'])}",
1116 text:"${filediff.patch['filename']}",
1116 text:"${filediff.patch['filename']}",
1117 ops:${h.json.dumps(filediff.patch['stats'])|n}}${('' if loop.last else ',')}
1117 ops:${h.json.dumps(filediff.patch['stats'])|n}}${('' if loop.last else ',')}
1118 % endfor
1118 % endfor
1119 ]
1119 ]
1120 };
1120 };
1121
1121
1122 var diffFileFilterId = "#file_filter_" + "${diffset_container_id}";
1122 var diffFileFilterId = "#file_filter_" + "${diffset_container_id}";
1123 var diffFileFilter = $(diffFileFilterId).select2({
1123 var diffFileFilter = $(diffFileFilterId).select2({
1124 'dropdownAutoWidth': true,
1124 'dropdownAutoWidth': true,
1125 'width': 'auto',
1125 'width': 'auto',
1126
1126
1127 containerCssClass: "drop-menu",
1127 containerCssClass: "drop-menu",
1128 dropdownCssClass: "drop-menu-dropdown",
1128 dropdownCssClass: "drop-menu-dropdown",
1129 data: preloadFileFilterData["${diffset_container_id}"],
1129 data: preloadFileFilterData["${diffset_container_id}"],
1130 query: function(query) {
1130 query: function(query) {
1131 feedFilesOptions(query, preloadFileFilterData["${diffset_container_id}"]);
1131 feedFilesOptions(query, preloadFileFilterData["${diffset_container_id}"]);
1132 },
1132 },
1133 initSelection: function(element, callback) {
1133 initSelection: function(element, callback) {
1134 callback({'init': true});
1134 callback({'init': true});
1135 },
1135 },
1136 formatResult: formatFileResult,
1136 formatResult: formatFileResult,
1137 formatSelection: formatSelection
1137 formatSelection: formatSelection
1138 });
1138 });
1139
1139
1140 % if range_diff_on is True:
1140 % if range_diff_on is True:
1141 diffFileFilter.select2("enable", false);
1141 diffFileFilter.select2("enable", false);
1142 % endif
1142 % endif
1143
1143
1144 $(diffFileFilterId).on('select2-selecting', function (e) {
1144 $(diffFileFilterId).on('select2-selecting', function (e) {
1145 var idSelector = e.choice.id;
1145 var idSelector = e.choice.id;
1146
1146
1147 // expand the container if we quick-select the field
1147 // expand the container if we quick-select the field
1148 $('#'+idSelector).next().prop('checked', false);
1148 $('#'+idSelector).next().prop('checked', false);
1149 // hide the mast as we later do preventDefault()
1149 // hide the mast as we later do preventDefault()
1150 $("#select2-drop-mask").click();
1150 $("#select2-drop-mask").click();
1151
1151
1152 window.location.hash = '#'+idSelector;
1152 window.location.hash = '#'+idSelector;
1153 updateSticky();
1153 updateSticky();
1154
1154
1155 e.preventDefault();
1155 e.preventDefault();
1156 });
1156 });
1157
1157
1158 diffNavText = 'diff navigation:'
1158 diffNavText = 'diff navigation:'
1159
1159
1160 getCurrentChunk = function () {
1160 getCurrentChunk = function () {
1161
1161
1162 var chunksAll = $('.nav-chunk').filter(function () {
1162 var chunksAll = $('.nav-chunk').filter(function () {
1163 return $(this).parents('.filediff').prev().get(0).checked !== true
1163 return $(this).parents('.filediff').prev().get(0).checked !== true
1164 })
1164 })
1165 var chunkSelected = $('.nav-chunk.selected');
1165 var chunkSelected = $('.nav-chunk.selected');
1166 var initial = false;
1166 var initial = false;
1167
1167
1168 if (chunkSelected.length === 0) {
1168 if (chunkSelected.length === 0) {
1169 // no initial chunk selected, we pick first
1169 // no initial chunk selected, we pick first
1170 chunkSelected = $(chunksAll.get(0));
1170 chunkSelected = $(chunksAll.get(0));
1171 var initial = true;
1171 var initial = true;
1172 }
1172 }
1173
1173
1174 return {
1174 return {
1175 'all': chunksAll,
1175 'all': chunksAll,
1176 'selected': chunkSelected,
1176 'selected': chunkSelected,
1177 'initial': initial,
1177 'initial': initial,
1178 }
1178 }
1179 }
1179 }
1180
1180
1181 animateDiffNavText = function () {
1181 animateDiffNavText = function () {
1182 var $diffNav = $('#diff_nav')
1182 var $diffNav = $('#diff_nav')
1183
1183
1184 var callback = function () {
1184 var callback = function () {
1185 $diffNav.animate({'opacity': 1.00}, 200)
1185 $diffNav.animate({'opacity': 1.00}, 200)
1186 };
1186 };
1187 $diffNav.animate({'opacity': 0.15}, 200, callback);
1187 $diffNav.animate({'opacity': 0.15}, 200, callback);
1188 }
1188 }
1189
1189
1190 scrollToChunk = function (moveBy) {
1190 scrollToChunk = function (moveBy) {
1191 var chunk = getCurrentChunk();
1191 var chunk = getCurrentChunk();
1192 var all = chunk.all
1192 var all = chunk.all
1193 var selected = chunk.selected
1193 var selected = chunk.selected
1194
1194
1195 var curPos = all.index(selected);
1195 var curPos = all.index(selected);
1196 var newPos = curPos;
1196 var newPos = curPos;
1197 if (!chunk.initial) {
1197 if (!chunk.initial) {
1198 var newPos = curPos + moveBy;
1198 var newPos = curPos + moveBy;
1199 }
1199 }
1200
1200
1201 var curElem = all.get(newPos);
1201 var curElem = all.get(newPos);
1202
1202
1203 if (curElem === undefined) {
1203 if (curElem === undefined) {
1204 // end or back
1204 // end or back
1205 $('#diff_nav').html('no next diff element:')
1205 $('#diff_nav').html('no next diff element:')
1206 animateDiffNavText()
1206 animateDiffNavText()
1207 return
1207 return
1208 } else if (newPos < 0) {
1208 } else if (newPos < 0) {
1209 $('#diff_nav').html('no previous diff element:')
1209 $('#diff_nav').html('no previous diff element:')
1210 animateDiffNavText()
1210 animateDiffNavText()
1211 return
1211 return
1212 } else {
1212 } else {
1213 $('#diff_nav').html(diffNavText)
1213 $('#diff_nav').html(diffNavText)
1214 }
1214 }
1215
1215
1216 curElem = $(curElem)
1216 curElem = $(curElem)
1217 var offset = 100;
1217 var offset = 100;
1218 $(window).scrollTop(curElem.position().top - offset);
1218 $(window).scrollTop(curElem.position().top - offset);
1219
1219
1220 //clear selection
1220 //clear selection
1221 all.removeClass('selected')
1221 all.removeClass('selected')
1222 curElem.addClass('selected')
1222 curElem.addClass('selected')
1223 }
1223 }
1224
1224
1225 scrollToPrevChunk = function () {
1225 scrollToPrevChunk = function () {
1226 scrollToChunk(-1)
1226 scrollToChunk(-1)
1227 }
1227 }
1228 scrollToNextChunk = function () {
1228 scrollToNextChunk = function () {
1229 scrollToChunk(1)
1229 scrollToChunk(1)
1230 }
1230 }
1231
1231
1232 </script>
1232 </script>
1233 % endif
1233 % endif
1234
1234
1235 <script type="text/javascript">
1235 <script type="text/javascript">
1236 $('#diff_nav').html('loading diff...') // wait until whole page is loaded
1236 $('#diff_nav').html('loading diff...') // wait until whole page is loaded
1237
1237
1238 $(document).ready(function () {
1238 $(document).ready(function () {
1239
1239
1240 var contextPrefix = _gettext('Context file: ');
1240 var contextPrefix = _gettext('Context file: ');
1241 ## sticky sidebar
1241 ## sticky sidebar
1242 var sidebarElement = document.getElementById('diff-file-sticky');
1242 var sidebarElement = document.getElementById('diff-file-sticky');
1243 sidebar = new StickySidebar(sidebarElement, {
1243 sidebar = new StickySidebar(sidebarElement, {
1244 topSpacing: 0,
1244 topSpacing: 0,
1245 bottomSpacing: 0,
1245 bottomSpacing: 0,
1246 innerWrapperSelector: '.sidebar__inner'
1246 innerWrapperSelector: '.sidebar__inner'
1247 });
1247 });
1248 sidebarElement.addEventListener('affixed.static.stickySidebar', function () {
1248 sidebarElement.addEventListener('affixed.static.stickySidebar', function () {
1249 // reset our file so it's not holding new value
1249 // reset our file so it's not holding new value
1250 $('.fpath-placeholder-text').html(contextPrefix + ' - ')
1250 $('.fpath-placeholder-text').html(contextPrefix + ' - ')
1251 });
1251 });
1252
1252
1253 updateSticky = function () {
1253 updateSticky = function () {
1254 sidebar.updateSticky();
1254 sidebar.updateSticky();
1255 Waypoint.refreshAll();
1255 Waypoint.refreshAll();
1256 };
1256 };
1257
1257
1258 var animateText = function (fPath, anchorId) {
1258 var animateText = function (fPath, anchorId) {
1259 fPath = Select2.util.escapeMarkup(fPath);
1259 fPath = Select2.util.escapeMarkup(fPath);
1260 $('.fpath-placeholder-text').html(contextPrefix + '<a href="#a_' + anchorId + '">' + fPath + '</a>')
1260 $('.fpath-placeholder-text').html(contextPrefix + '<a href="#a_' + anchorId + '">' + fPath + '</a>')
1261 };
1261 };
1262
1262
1263 ## dynamic file waypoints
1263 ## dynamic file waypoints
1264 var setFPathInfo = function(fPath, anchorId){
1264 var setFPathInfo = function(fPath, anchorId){
1265 animateText(fPath, anchorId)
1265 animateText(fPath, anchorId)
1266 };
1266 };
1267
1267
1268 var codeBlock = $('.filediff');
1268 var codeBlock = $('.filediff');
1269
1269
1270 // forward waypoint
1270 // forward waypoint
1271 codeBlock.waypoint(
1271 codeBlock.waypoint(
1272 function(direction) {
1272 function(direction) {
1273 if (direction === "down"){
1273 if (direction === "down"){
1274 setFPathInfo($(this.element).data('fPath'), $(this.element).data('anchorId'))
1274 setFPathInfo($(this.element).data('fPath'), $(this.element).data('anchorId'))
1275 }
1275 }
1276 }, {
1276 }, {
1277 offset: function () {
1277 offset: function () {
1278 return 70;
1278 return 70;
1279 },
1279 },
1280 context: '.fpath-placeholder'
1280 context: '.fpath-placeholder'
1281 }
1281 }
1282 );
1282 );
1283
1283
1284 // backward waypoint
1284 // backward waypoint
1285 codeBlock.waypoint(
1285 codeBlock.waypoint(
1286 function(direction) {
1286 function(direction) {
1287 if (direction === "up"){
1287 if (direction === "up"){
1288 setFPathInfo($(this.element).data('fPath'), $(this.element).data('anchorId'))
1288 setFPathInfo($(this.element).data('fPath'), $(this.element).data('anchorId'))
1289 }
1289 }
1290 }, {
1290 }, {
1291 offset: function () {
1291 offset: function () {
1292 return -this.element.clientHeight + 90;
1292 return -this.element.clientHeight + 90;
1293 },
1293 },
1294 context: '.fpath-placeholder'
1294 context: '.fpath-placeholder'
1295 }
1295 }
1296 );
1296 );
1297
1297
1298 toggleWideDiff = function (el) {
1298 toggleWideDiff = function (el) {
1299 updateSticky();
1299 updateSticky();
1300 var wide = Rhodecode.comments.toggleWideMode(this);
1300 var wide = Rhodecode.comments.toggleWideMode(this);
1301 storeUserSessionAttr('rc_user_session_attr.wide_diff_mode', wide);
1301 storeUserSessionAttr('rc_user_session_attr.wide_diff_mode', wide);
1302 if (wide === true) {
1302 if (wide === true) {
1303 $(el).addClass('btn-active');
1303 $(el).addClass('btn-active');
1304 } else {
1304 } else {
1305 $(el).removeClass('btn-active');
1305 $(el).removeClass('btn-active');
1306 }
1306 }
1307 return null;
1307 return null;
1308 };
1308 };
1309
1309
1310 toggleExpand = function (el, diffsetEl) {
1310 toggleExpand = function (el, diffsetEl) {
1311 var el = $(el);
1311 var el = $(el);
1312 if (el.hasClass('collapsed')) {
1312 if (el.hasClass('collapsed')) {
1313 $('.filediff-collapse-state.collapse-{0}'.format(diffsetEl)).prop('checked', false);
1313 $('.filediff-collapse-state.collapse-{0}'.format(diffsetEl)).prop('checked', false);
1314 el.removeClass('collapsed');
1314 el.removeClass('collapsed');
1315 el.html(
1315 el.html(
1316 '<i class="icon-minus-squared-alt icon-no-margin"></i>' +
1316 '<i class="icon-minus-squared-alt icon-no-margin"></i>' +
1317 _gettext('Collapse all files'));
1317 _gettext('Collapse all files'));
1318 }
1318 }
1319 else {
1319 else {
1320 $('.filediff-collapse-state.collapse-{0}'.format(diffsetEl)).prop('checked', true);
1320 $('.filediff-collapse-state.collapse-{0}'.format(diffsetEl)).prop('checked', true);
1321 el.addClass('collapsed');
1321 el.addClass('collapsed');
1322 el.html(
1322 el.html(
1323 '<i class="icon-plus-squared-alt icon-no-margin"></i>' +
1323 '<i class="icon-plus-squared-alt icon-no-margin"></i>' +
1324 _gettext('Expand all files'));
1324 _gettext('Expand all files'));
1325 }
1325 }
1326 updateSticky()
1326 updateSticky()
1327 };
1327 };
1328
1328
1329 toggleCommitExpand = function (el) {
1329 toggleCommitExpand = function (el) {
1330 var $el = $(el);
1330 var $el = $(el);
1331 var commits = $el.data('toggleCommitsCnt');
1331 var commits = $el.data('toggleCommitsCnt');
1332 var collapseMsg = _ngettext('Collapse {0} commit', 'Collapse {0} commits', commits).format(commits);
1332 var collapseMsg = _ngettext('Collapse {0} commit', 'Collapse {0} commits', commits).format(commits);
1333 var expandMsg = _ngettext('Expand {0} commit', 'Expand {0} commits', commits).format(commits);
1333 var expandMsg = _ngettext('Expand {0} commit', 'Expand {0} commits', commits).format(commits);
1334
1334
1335 if ($el.hasClass('collapsed')) {
1335 if ($el.hasClass('collapsed')) {
1336 $('.compare_select').show();
1336 $('.compare_select').show();
1337 $('.compare_select_hidden').hide();
1337 $('.compare_select_hidden').hide();
1338
1338
1339 $el.removeClass('collapsed');
1339 $el.removeClass('collapsed');
1340 $el.html(
1340 $el.html(
1341 '<i class="icon-minus-squared-alt icon-no-margin"></i>' +
1341 '<i class="icon-minus-squared-alt icon-no-margin"></i>' +
1342 collapseMsg);
1342 collapseMsg);
1343 }
1343 }
1344 else {
1344 else {
1345 $('.compare_select').hide();
1345 $('.compare_select').hide();
1346 $('.compare_select_hidden').show();
1346 $('.compare_select_hidden').show();
1347 $el.addClass('collapsed');
1347 $el.addClass('collapsed');
1348 $el.html(
1348 $el.html(
1349 '<i class="icon-plus-squared-alt icon-no-margin"></i>' +
1349 '<i class="icon-plus-squared-alt icon-no-margin"></i>' +
1350 expandMsg);
1350 expandMsg);
1351 }
1351 }
1352 updateSticky();
1352 updateSticky();
1353 };
1353 };
1354
1354
1355 // get stored diff mode and pre-enable it
1355 // get stored diff mode and pre-enable it
1356 if (templateContext.session_attrs.wide_diff_mode === "true") {
1356 if (templateContext.session_attrs.wide_diff_mode === "true") {
1357 Rhodecode.comments.toggleWideMode(null);
1357 Rhodecode.comments.toggleWideMode(null);
1358 $('.toggle-wide-diff').addClass('btn-active');
1358 $('.toggle-wide-diff').addClass('btn-active');
1359 updateSticky();
1359 updateSticky();
1360 }
1360 }
1361
1361
1362 // DIFF NAV //
1362 // DIFF NAV //
1363
1363
1364 // element to detect scroll direction of
1364 // element to detect scroll direction of
1365 var $window = $(window);
1365 var $window = $(window);
1366
1366
1367 // initialize last scroll position
1367 // initialize last scroll position
1368 var lastScrollY = $window.scrollTop();
1368 var lastScrollY = $window.scrollTop();
1369
1369
1370 $window.on('resize scrollstop', {latency: 350}, function () {
1370 $window.on('resize scrollstop', {latency: 350}, function () {
1371 var visibleChunks = $('.nav-chunk').withinviewport({top: 75});
1371 var visibleChunks = $('.nav-chunk').withinviewport({top: 75});
1372
1372
1373 // get current scroll position
1373 // get current scroll position
1374 var currentScrollY = $window.scrollTop();
1374 var currentScrollY = $window.scrollTop();
1375
1375
1376 // determine current scroll direction
1376 // determine current scroll direction
1377 if (currentScrollY > lastScrollY) {
1377 if (currentScrollY > lastScrollY) {
1378 var y = 'down'
1378 var y = 'down'
1379 } else if (currentScrollY !== lastScrollY) {
1379 } else if (currentScrollY !== lastScrollY) {
1380 var y = 'up';
1380 var y = 'up';
1381 }
1381 }
1382
1382
1383 var pos = -1; // by default we use last element in viewport
1383 var pos = -1; // by default we use last element in viewport
1384 if (y === 'down') {
1384 if (y === 'down') {
1385 pos = -1;
1385 pos = -1;
1386 } else if (y === 'up') {
1386 } else if (y === 'up') {
1387 pos = 0;
1387 pos = 0;
1388 }
1388 }
1389
1389
1390 if (visibleChunks.length > 0) {
1390 if (visibleChunks.length > 0) {
1391 $('.nav-chunk').removeClass('selected');
1391 $('.nav-chunk').removeClass('selected');
1392 $(visibleChunks.get(pos)).addClass('selected');
1392 $(visibleChunks.get(pos)).addClass('selected');
1393 }
1393 }
1394
1394
1395 // update last scroll position to current position
1395 // update last scroll position to current position
1396 lastScrollY = currentScrollY;
1396 lastScrollY = currentScrollY;
1397
1397
1398 });
1398 });
1399 $('#diff_nav').html(diffNavText);
1399 $('#diff_nav').html(diffNavText);
1400
1400
1401 });
1401 });
1402 </script>
1402 </script>
1403
1403
1404 </%def>
1404 </%def>
@@ -1,209 +1,208 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20 import io
21 from io import StringIO
22
21
23 import pytest
22 import pytest
24 from mock import patch, Mock
23 from mock import patch, Mock
25
24
26 from rhodecode.lib.middleware.simplesvn import SimpleSvn, SimpleSvnApp
25 from rhodecode.lib.middleware.simplesvn import SimpleSvn, SimpleSvnApp
27 from rhodecode.lib.utils import get_rhodecode_base_path
26 from rhodecode.lib.utils import get_rhodecode_base_path
28
27
29
28
30 class TestSimpleSvn(object):
29 class TestSimpleSvn(object):
31 @pytest.fixture(autouse=True)
30 @pytest.fixture(autouse=True)
32 def simple_svn(self, baseapp, request_stub):
31 def simple_svn(self, baseapp, request_stub):
33 base_path = get_rhodecode_base_path()
32 base_path = get_rhodecode_base_path()
34 self.app = SimpleSvn(
33 self.app = SimpleSvn(
35 config={'auth_ret_code': '', 'base_path': base_path},
34 config={'auth_ret_code': '', 'base_path': base_path},
36 registry=request_stub.registry)
35 registry=request_stub.registry)
37
36
38 def test_get_config(self):
37 def test_get_config(self):
39 extras = {'foo': 'FOO', 'bar': 'BAR'}
38 extras = {'foo': 'FOO', 'bar': 'BAR'}
40 config = self.app._create_config(extras, repo_name='test-repo')
39 config = self.app._create_config(extras, repo_name='test-repo')
41 assert config == extras
40 assert config == extras
42
41
43 @pytest.mark.parametrize(
42 @pytest.mark.parametrize(
44 'method', ['OPTIONS', 'PROPFIND', 'GET', 'REPORT'])
43 'method', ['OPTIONS', 'PROPFIND', 'GET', 'REPORT'])
45 def test_get_action_returns_pull(self, method):
44 def test_get_action_returns_pull(self, method):
46 environment = {'REQUEST_METHOD': method}
45 environment = {'REQUEST_METHOD': method}
47 action = self.app._get_action(environment)
46 action = self.app._get_action(environment)
48 assert action == 'pull'
47 assert action == 'pull'
49
48
50 @pytest.mark.parametrize(
49 @pytest.mark.parametrize(
51 'method', [
50 'method', [
52 'MKACTIVITY', 'PROPPATCH', 'PUT', 'CHECKOUT', 'MKCOL', 'MOVE',
51 'MKACTIVITY', 'PROPPATCH', 'PUT', 'CHECKOUT', 'MKCOL', 'MOVE',
53 'COPY', 'DELETE', 'LOCK', 'UNLOCK', 'MERGE'
52 'COPY', 'DELETE', 'LOCK', 'UNLOCK', 'MERGE'
54 ])
53 ])
55 def test_get_action_returns_push(self, method):
54 def test_get_action_returns_push(self, method):
56 environment = {'REQUEST_METHOD': method}
55 environment = {'REQUEST_METHOD': method}
57 action = self.app._get_action(environment)
56 action = self.app._get_action(environment)
58 assert action == 'push'
57 assert action == 'push'
59
58
60 @pytest.mark.parametrize(
59 @pytest.mark.parametrize(
61 'path, expected_name', [
60 'path, expected_name', [
62 ('/hello-svn', 'hello-svn'),
61 ('/hello-svn', 'hello-svn'),
63 ('/hello-svn/', 'hello-svn'),
62 ('/hello-svn/', 'hello-svn'),
64 ('/group/hello-svn/', 'group/hello-svn'),
63 ('/group/hello-svn/', 'group/hello-svn'),
65 ('/group/hello-svn/!svn/vcc/default', 'group/hello-svn'),
64 ('/group/hello-svn/!svn/vcc/default', 'group/hello-svn'),
66 ])
65 ])
67 def test_get_repository_name(self, path, expected_name):
66 def test_get_repository_name(self, path, expected_name):
68 environment = {'PATH_INFO': path}
67 environment = {'PATH_INFO': path}
69 name = self.app._get_repository_name(environment)
68 name = self.app._get_repository_name(environment)
70 assert name == expected_name
69 assert name == expected_name
71
70
72 def test_get_repository_name_subfolder(self, backend_svn):
71 def test_get_repository_name_subfolder(self, backend_svn):
73 repo = backend_svn.repo
72 repo = backend_svn.repo
74 environment = {
73 environment = {
75 'PATH_INFO': '/{}/path/with/subfolders'.format(repo.repo_name)}
74 'PATH_INFO': '/{}/path/with/subfolders'.format(repo.repo_name)}
76 name = self.app._get_repository_name(environment)
75 name = self.app._get_repository_name(environment)
77 assert name == repo.repo_name
76 assert name == repo.repo_name
78
77
79 def test_create_wsgi_app(self):
78 def test_create_wsgi_app(self):
80 with patch.object(SimpleSvn, '_is_svn_enabled') as mock_method:
79 with patch.object(SimpleSvn, '_is_svn_enabled') as mock_method:
81 mock_method.return_value = False
80 mock_method.return_value = False
82 with patch('rhodecode.lib.middleware.simplesvn.DisabledSimpleSvnApp') as (
81 with patch('rhodecode.lib.middleware.simplesvn.DisabledSimpleSvnApp') as (
83 wsgi_app_mock):
82 wsgi_app_mock):
84 config = Mock()
83 config = Mock()
85 wsgi_app = self.app._create_wsgi_app(
84 wsgi_app = self.app._create_wsgi_app(
86 repo_path='', repo_name='', config=config)
85 repo_path='', repo_name='', config=config)
87
86
88 wsgi_app_mock.assert_called_once_with(config)
87 wsgi_app_mock.assert_called_once_with(config)
89 assert wsgi_app == wsgi_app_mock()
88 assert wsgi_app == wsgi_app_mock()
90
89
91 def test_create_wsgi_app_when_enabled(self):
90 def test_create_wsgi_app_when_enabled(self):
92 with patch.object(SimpleSvn, '_is_svn_enabled') as mock_method:
91 with patch.object(SimpleSvn, '_is_svn_enabled') as mock_method:
93 mock_method.return_value = True
92 mock_method.return_value = True
94 with patch('rhodecode.lib.middleware.simplesvn.SimpleSvnApp') as (
93 with patch('rhodecode.lib.middleware.simplesvn.SimpleSvnApp') as (
95 wsgi_app_mock):
94 wsgi_app_mock):
96 config = Mock()
95 config = Mock()
97 wsgi_app = self.app._create_wsgi_app(
96 wsgi_app = self.app._create_wsgi_app(
98 repo_path='', repo_name='', config=config)
97 repo_path='', repo_name='', config=config)
99
98
100 wsgi_app_mock.assert_called_once_with(config)
99 wsgi_app_mock.assert_called_once_with(config)
101 assert wsgi_app == wsgi_app_mock()
100 assert wsgi_app == wsgi_app_mock()
102
101
103
102
104 class TestSimpleSvnApp(object):
103 class TestSimpleSvnApp(object):
105 data = '<xml></xml>'
104 data = '<xml></xml>'
106 path = '/group/my-repo'
105 path = '/group/my-repo'
107 wsgi_input = StringIO(data)
106 wsgi_input = io.StringIO(data)
108 environment = {
107 environment = {
109 'HTTP_DAV': (
108 'HTTP_DAV': (
110 'http://subversion.tigris.org/xmlns/dav/svn/depth,'
109 'http://subversion.tigris.org/xmlns/dav/svn/depth,'
111 ' http://subversion.tigris.org/xmlns/dav/svn/mergeinfo'),
110 ' http://subversion.tigris.org/xmlns/dav/svn/mergeinfo'),
112 'HTTP_USER_AGENT': 'SVN/1.8.11 (x86_64-linux) serf/1.3.8',
111 'HTTP_USER_AGENT': 'SVN/1.8.11 (x86_64-linux) serf/1.3.8',
113 'REQUEST_METHOD': 'OPTIONS',
112 'REQUEST_METHOD': 'OPTIONS',
114 'PATH_INFO': path,
113 'PATH_INFO': path,
115 'wsgi.input': wsgi_input,
114 'wsgi.input': wsgi_input,
116 'CONTENT_TYPE': 'text/xml',
115 'CONTENT_TYPE': 'text/xml',
117 'CONTENT_LENGTH': '130'
116 'CONTENT_LENGTH': '130'
118 }
117 }
119
118
120 def setup_method(self, method):
119 def setup_method(self, method):
121 self.host = 'http://localhost/'
120 self.host = 'http://localhost/'
122 base_path = get_rhodecode_base_path()
121 base_path = get_rhodecode_base_path()
123 self.app = SimpleSvnApp(
122 self.app = SimpleSvnApp(
124 config={'subversion_http_server_url': self.host,
123 config={'subversion_http_server_url': self.host,
125 'base_path': base_path})
124 'base_path': base_path})
126
125
127 def test_get_request_headers_with_content_type(self):
126 def test_get_request_headers_with_content_type(self):
128 expected_headers = {
127 expected_headers = {
129 'Dav': self.environment['HTTP_DAV'],
128 'Dav': self.environment['HTTP_DAV'],
130 'User-Agent': self.environment['HTTP_USER_AGENT'],
129 'User-Agent': self.environment['HTTP_USER_AGENT'],
131 'Content-Type': self.environment['CONTENT_TYPE'],
130 'Content-Type': self.environment['CONTENT_TYPE'],
132 'Content-Length': self.environment['CONTENT_LENGTH']
131 'Content-Length': self.environment['CONTENT_LENGTH']
133 }
132 }
134 headers = self.app._get_request_headers(self.environment)
133 headers = self.app._get_request_headers(self.environment)
135 assert headers == expected_headers
134 assert headers == expected_headers
136
135
137 def test_get_request_headers_without_content_type(self):
136 def test_get_request_headers_without_content_type(self):
138 environment = self.environment.copy()
137 environment = self.environment.copy()
139 environment.pop('CONTENT_TYPE')
138 environment.pop('CONTENT_TYPE')
140 expected_headers = {
139 expected_headers = {
141 'Dav': environment['HTTP_DAV'],
140 'Dav': environment['HTTP_DAV'],
142 'Content-Length': self.environment['CONTENT_LENGTH'],
141 'Content-Length': self.environment['CONTENT_LENGTH'],
143 'User-Agent': environment['HTTP_USER_AGENT'],
142 'User-Agent': environment['HTTP_USER_AGENT'],
144 }
143 }
145 request_headers = self.app._get_request_headers(environment)
144 request_headers = self.app._get_request_headers(environment)
146 assert request_headers == expected_headers
145 assert request_headers == expected_headers
147
146
148 def test_get_response_headers(self):
147 def test_get_response_headers(self):
149 headers = {
148 headers = {
150 'Connection': 'keep-alive',
149 'Connection': 'keep-alive',
151 'Keep-Alive': 'timeout=5, max=100',
150 'Keep-Alive': 'timeout=5, max=100',
152 'Transfer-Encoding': 'chunked',
151 'Transfer-Encoding': 'chunked',
153 'Content-Encoding': 'gzip',
152 'Content-Encoding': 'gzip',
154 'MS-Author-Via': 'DAV',
153 'MS-Author-Via': 'DAV',
155 'SVN-Supported-Posts': 'create-txn-with-props'
154 'SVN-Supported-Posts': 'create-txn-with-props'
156 }
155 }
157 expected_headers = [
156 expected_headers = [
158 ('MS-Author-Via', 'DAV'),
157 ('MS-Author-Via', 'DAV'),
159 ('SVN-Supported-Posts', 'create-txn-with-props'),
158 ('SVN-Supported-Posts', 'create-txn-with-props'),
160 ]
159 ]
161 response_headers = self.app._get_response_headers(headers)
160 response_headers = self.app._get_response_headers(headers)
162 assert sorted(response_headers) == sorted(expected_headers)
161 assert sorted(response_headers) == sorted(expected_headers)
163
162
164 @pytest.mark.parametrize('svn_http_url, path_info, expected_url', [
163 @pytest.mark.parametrize('svn_http_url, path_info, expected_url', [
165 ('http://localhost:8200', '/repo_name', 'http://localhost:8200/repo_name'),
164 ('http://localhost:8200', '/repo_name', 'http://localhost:8200/repo_name'),
166 ('http://localhost:8200///', '/repo_name', 'http://localhost:8200/repo_name'),
165 ('http://localhost:8200///', '/repo_name', 'http://localhost:8200/repo_name'),
167 ('http://localhost:8200', '/group/repo_name', 'http://localhost:8200/group/repo_name'),
166 ('http://localhost:8200', '/group/repo_name', 'http://localhost:8200/group/repo_name'),
168 ('http://localhost:8200/', '/group/repo_name', 'http://localhost:8200/group/repo_name'),
167 ('http://localhost:8200/', '/group/repo_name', 'http://localhost:8200/group/repo_name'),
169 ('http://localhost:8200/prefix', '/repo_name', 'http://localhost:8200/prefix/repo_name'),
168 ('http://localhost:8200/prefix', '/repo_name', 'http://localhost:8200/prefix/repo_name'),
170 ('http://localhost:8200/prefix', 'repo_name', 'http://localhost:8200/prefix/repo_name'),
169 ('http://localhost:8200/prefix', 'repo_name', 'http://localhost:8200/prefix/repo_name'),
171 ('http://localhost:8200/prefix', '/group/repo_name', 'http://localhost:8200/prefix/group/repo_name')
170 ('http://localhost:8200/prefix', '/group/repo_name', 'http://localhost:8200/prefix/group/repo_name')
172 ])
171 ])
173 def test_get_url(self, svn_http_url, path_info, expected_url):
172 def test_get_url(self, svn_http_url, path_info, expected_url):
174 url = self.app._get_url(svn_http_url, path_info)
173 url = self.app._get_url(svn_http_url, path_info)
175 assert url == expected_url
174 assert url == expected_url
176
175
177 def test_call(self):
176 def test_call(self):
178 start_response = Mock()
177 start_response = Mock()
179 response_mock = Mock()
178 response_mock = Mock()
180 response_mock.headers = {
179 response_mock.headers = {
181 'Content-Encoding': 'gzip',
180 'Content-Encoding': 'gzip',
182 'MS-Author-Via': 'DAV',
181 'MS-Author-Via': 'DAV',
183 'SVN-Supported-Posts': 'create-txn-with-props'
182 'SVN-Supported-Posts': 'create-txn-with-props'
184 }
183 }
185 response_mock.status_code = 200
184 response_mock.status_code = 200
186 response_mock.reason = 'OK'
185 response_mock.reason = 'OK'
187 with patch('rhodecode.lib.middleware.simplesvn.requests.request') as (
186 with patch('rhodecode.lib.middleware.simplesvn.requests.request') as (
188 request_mock):
187 request_mock):
189 request_mock.return_value = response_mock
188 request_mock.return_value = response_mock
190 self.app(self.environment, start_response)
189 self.app(self.environment, start_response)
191
190
192 expected_url = '{}{}'.format(self.host.strip('/'), self.path)
191 expected_url = '{}{}'.format(self.host.strip('/'), self.path)
193 expected_request_headers = {
192 expected_request_headers = {
194 'Dav': self.environment['HTTP_DAV'],
193 'Dav': self.environment['HTTP_DAV'],
195 'User-Agent': self.environment['HTTP_USER_AGENT'],
194 'User-Agent': self.environment['HTTP_USER_AGENT'],
196 'Content-Type': self.environment['CONTENT_TYPE'],
195 'Content-Type': self.environment['CONTENT_TYPE'],
197 'Content-Length': self.environment['CONTENT_LENGTH']
196 'Content-Length': self.environment['CONTENT_LENGTH']
198 }
197 }
199 expected_response_headers = [
198 expected_response_headers = [
200 ('SVN-Supported-Posts', 'create-txn-with-props'),
199 ('SVN-Supported-Posts', 'create-txn-with-props'),
201 ('MS-Author-Via', 'DAV'),
200 ('MS-Author-Via', 'DAV'),
202 ]
201 ]
203 request_mock.assert_called_once_with(
202 request_mock.assert_called_once_with(
204 self.environment['REQUEST_METHOD'], expected_url,
203 self.environment['REQUEST_METHOD'], expected_url,
205 data=self.data, headers=expected_request_headers, stream=False)
204 data=self.data, headers=expected_request_headers, stream=False)
206 response_mock.iter_content.assert_called_once_with(chunk_size=1024)
205 response_mock.iter_content.assert_called_once_with(chunk_size=1024)
207 args, _ = start_response.call_args
206 args, _ = start_response.call_args
208 assert args[0] == '200 OK'
207 assert args[0] == '200 OK'
209 assert sorted(args[1]) == sorted(expected_response_headers)
208 assert sorted(args[1]) == sorted(expected_response_headers)
@@ -1,342 +1,342 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import json
21 import json
22 import logging
22 import logging
23 from io import StringIO
23 import io
24
24
25 import mock
25 import mock
26 import pytest
26 import pytest
27
27
28 from rhodecode.lib import hooks_daemon
28 from rhodecode.lib import hooks_daemon
29 from rhodecode.tests.utils import assert_message_in_log
29 from rhodecode.tests.utils import assert_message_in_log
30
30
31
31
32 class TestDummyHooksCallbackDaemon(object):
32 class TestDummyHooksCallbackDaemon(object):
33 def test_hooks_module_path_set_properly(self):
33 def test_hooks_module_path_set_properly(self):
34 daemon = hooks_daemon.DummyHooksCallbackDaemon()
34 daemon = hooks_daemon.DummyHooksCallbackDaemon()
35 assert daemon.hooks_module == 'rhodecode.lib.hooks_daemon'
35 assert daemon.hooks_module == 'rhodecode.lib.hooks_daemon'
36
36
37 def test_logs_entering_the_hook(self):
37 def test_logs_entering_the_hook(self):
38 daemon = hooks_daemon.DummyHooksCallbackDaemon()
38 daemon = hooks_daemon.DummyHooksCallbackDaemon()
39 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
39 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
40 with daemon as return_value:
40 with daemon as return_value:
41 log_mock.assert_called_once_with(
41 log_mock.assert_called_once_with(
42 'Running `%s` callback daemon', 'DummyHooksCallbackDaemon')
42 'Running `%s` callback daemon', 'DummyHooksCallbackDaemon')
43 assert return_value == daemon
43 assert return_value == daemon
44
44
45 def test_logs_exiting_the_hook(self):
45 def test_logs_exiting_the_hook(self):
46 daemon = hooks_daemon.DummyHooksCallbackDaemon()
46 daemon = hooks_daemon.DummyHooksCallbackDaemon()
47 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
47 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
48 with daemon:
48 with daemon:
49 pass
49 pass
50 log_mock.assert_called_with(
50 log_mock.assert_called_with(
51 'Exiting `%s` callback daemon', 'DummyHooksCallbackDaemon')
51 'Exiting `%s` callback daemon', 'DummyHooksCallbackDaemon')
52
52
53
53
54 class TestHooks(object):
54 class TestHooks(object):
55 def test_hooks_can_be_used_as_a_context_processor(self):
55 def test_hooks_can_be_used_as_a_context_processor(self):
56 hooks = hooks_daemon.Hooks()
56 hooks = hooks_daemon.Hooks()
57 with hooks as return_value:
57 with hooks as return_value:
58 pass
58 pass
59 assert hooks == return_value
59 assert hooks == return_value
60
60
61
61
62 class TestHooksHttpHandler(object):
62 class TestHooksHttpHandler(object):
63 def test_read_request_parses_method_name_and_arguments(self):
63 def test_read_request_parses_method_name_and_arguments(self):
64 data = {
64 data = {
65 'method': 'test',
65 'method': 'test',
66 'extras': {
66 'extras': {
67 'param1': 1,
67 'param1': 1,
68 'param2': 'a'
68 'param2': 'a'
69 }
69 }
70 }
70 }
71 request = self._generate_post_request(data)
71 request = self._generate_post_request(data)
72 hooks_patcher = mock.patch.object(
72 hooks_patcher = mock.patch.object(
73 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
73 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
74
74
75 with hooks_patcher as hooks_mock:
75 with hooks_patcher as hooks_mock:
76 MockServer(hooks_daemon.HooksHttpHandler, request)
76 MockServer(hooks_daemon.HooksHttpHandler, request)
77
77
78 hooks_mock.assert_called_once_with(data['extras'])
78 hooks_mock.assert_called_once_with(data['extras'])
79
79
80 def test_hooks_serialized_result_is_returned(self):
80 def test_hooks_serialized_result_is_returned(self):
81 request = self._generate_post_request({})
81 request = self._generate_post_request({})
82 rpc_method = 'test'
82 rpc_method = 'test'
83 hook_result = {
83 hook_result = {
84 'first': 'one',
84 'first': 'one',
85 'second': 2
85 'second': 2
86 }
86 }
87 read_patcher = mock.patch.object(
87 read_patcher = mock.patch.object(
88 hooks_daemon.HooksHttpHandler, '_read_request',
88 hooks_daemon.HooksHttpHandler, '_read_request',
89 return_value=(rpc_method, {}))
89 return_value=(rpc_method, {}))
90 hooks_patcher = mock.patch.object(
90 hooks_patcher = mock.patch.object(
91 hooks_daemon.Hooks, rpc_method, create=True,
91 hooks_daemon.Hooks, rpc_method, create=True,
92 return_value=hook_result)
92 return_value=hook_result)
93
93
94 with read_patcher, hooks_patcher:
94 with read_patcher, hooks_patcher:
95 server = MockServer(hooks_daemon.HooksHttpHandler, request)
95 server = MockServer(hooks_daemon.HooksHttpHandler, request)
96
96
97 expected_result = json.dumps(hook_result)
97 expected_result = json.dumps(hook_result)
98 assert server.request.output_stream.buflist[-1] == expected_result
98 assert server.request.output_stream.buflist[-1] == expected_result
99
99
100 def test_exception_is_returned_in_response(self):
100 def test_exception_is_returned_in_response(self):
101 request = self._generate_post_request({})
101 request = self._generate_post_request({})
102 rpc_method = 'test'
102 rpc_method = 'test'
103 read_patcher = mock.patch.object(
103 read_patcher = mock.patch.object(
104 hooks_daemon.HooksHttpHandler, '_read_request',
104 hooks_daemon.HooksHttpHandler, '_read_request',
105 return_value=(rpc_method, {}))
105 return_value=(rpc_method, {}))
106 hooks_patcher = mock.patch.object(
106 hooks_patcher = mock.patch.object(
107 hooks_daemon.Hooks, rpc_method, create=True,
107 hooks_daemon.Hooks, rpc_method, create=True,
108 side_effect=Exception('Test exception'))
108 side_effect=Exception('Test exception'))
109
109
110 with read_patcher, hooks_patcher:
110 with read_patcher, hooks_patcher:
111 server = MockServer(hooks_daemon.HooksHttpHandler, request)
111 server = MockServer(hooks_daemon.HooksHttpHandler, request)
112
112
113 org_exc = json.loads(server.request.output_stream.buflist[-1])
113 org_exc = json.loads(server.request.output_stream.buflist[-1])
114 expected_result = {
114 expected_result = {
115 'exception': 'Exception',
115 'exception': 'Exception',
116 'exception_traceback': org_exc['exception_traceback'],
116 'exception_traceback': org_exc['exception_traceback'],
117 'exception_args': ['Test exception']
117 'exception_args': ['Test exception']
118 }
118 }
119 assert org_exc == expected_result
119 assert org_exc == expected_result
120
120
121 def test_log_message_writes_to_debug_log(self, caplog):
121 def test_log_message_writes_to_debug_log(self, caplog):
122 ip_port = ('0.0.0.0', 8888)
122 ip_port = ('0.0.0.0', 8888)
123 handler = hooks_daemon.HooksHttpHandler(
123 handler = hooks_daemon.HooksHttpHandler(
124 MockRequest('POST /'), ip_port, mock.Mock())
124 MockRequest('POST /'), ip_port, mock.Mock())
125 fake_date = '1/Nov/2015 00:00:00'
125 fake_date = '1/Nov/2015 00:00:00'
126 date_patcher = mock.patch.object(
126 date_patcher = mock.patch.object(
127 handler, 'log_date_time_string', return_value=fake_date)
127 handler, 'log_date_time_string', return_value=fake_date)
128 with date_patcher, caplog.at_level(logging.DEBUG):
128 with date_patcher, caplog.at_level(logging.DEBUG):
129 handler.log_message('Some message %d, %s', 123, 'string')
129 handler.log_message('Some message %d, %s', 123, 'string')
130
130
131 expected_message = "HOOKS: {} - - [{}] Some message 123, string".format(ip_port, fake_date)
131 expected_message = "HOOKS: {} - - [{}] Some message 123, string".format(ip_port, fake_date)
132 assert_message_in_log(
132 assert_message_in_log(
133 caplog.records, expected_message,
133 caplog.records, expected_message,
134 levelno=logging.DEBUG, module='hooks_daemon')
134 levelno=logging.DEBUG, module='hooks_daemon')
135
135
136 def _generate_post_request(self, data):
136 def _generate_post_request(self, data):
137 payload = json.dumps(data)
137 payload = json.dumps(data)
138 return 'POST / HTTP/1.0\nContent-Length: {}\n\n{}'.format(
138 return 'POST / HTTP/1.0\nContent-Length: {}\n\n{}'.format(
139 len(payload), payload)
139 len(payload), payload)
140
140
141
141
142 class ThreadedHookCallbackDaemon(object):
142 class ThreadedHookCallbackDaemon(object):
143 def test_constructor_calls_prepare(self):
143 def test_constructor_calls_prepare(self):
144 prepare_daemon_patcher = mock.patch.object(
144 prepare_daemon_patcher = mock.patch.object(
145 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
145 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
146 with prepare_daemon_patcher as prepare_daemon_mock:
146 with prepare_daemon_patcher as prepare_daemon_mock:
147 hooks_daemon.ThreadedHookCallbackDaemon()
147 hooks_daemon.ThreadedHookCallbackDaemon()
148 prepare_daemon_mock.assert_called_once_with()
148 prepare_daemon_mock.assert_called_once_with()
149
149
150 def test_run_is_called_on_context_start(self):
150 def test_run_is_called_on_context_start(self):
151 patchers = mock.patch.multiple(
151 patchers = mock.patch.multiple(
152 hooks_daemon.ThreadedHookCallbackDaemon,
152 hooks_daemon.ThreadedHookCallbackDaemon,
153 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
153 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
154
154
155 with patchers as mocks:
155 with patchers as mocks:
156 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
156 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
157 with daemon as daemon_context:
157 with daemon as daemon_context:
158 pass
158 pass
159 mocks['_run'].assert_called_once_with()
159 mocks['_run'].assert_called_once_with()
160 assert daemon_context == daemon
160 assert daemon_context == daemon
161
161
162 def test_stop_is_called_on_context_exit(self):
162 def test_stop_is_called_on_context_exit(self):
163 patchers = mock.patch.multiple(
163 patchers = mock.patch.multiple(
164 hooks_daemon.ThreadedHookCallbackDaemon,
164 hooks_daemon.ThreadedHookCallbackDaemon,
165 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
165 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
166
166
167 with patchers as mocks:
167 with patchers as mocks:
168 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
168 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
169 with daemon as daemon_context:
169 with daemon as daemon_context:
170 assert mocks['_stop'].call_count == 0
170 assert mocks['_stop'].call_count == 0
171
171
172 mocks['_stop'].assert_called_once_with()
172 mocks['_stop'].assert_called_once_with()
173 assert daemon_context == daemon
173 assert daemon_context == daemon
174
174
175
175
176 class TestHttpHooksCallbackDaemon(object):
176 class TestHttpHooksCallbackDaemon(object):
177 def test_hooks_callback_generates_new_port(self, caplog):
177 def test_hooks_callback_generates_new_port(self, caplog):
178 with caplog.at_level(logging.DEBUG):
178 with caplog.at_level(logging.DEBUG):
179 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
179 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
180 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
180 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
181
181
182 with caplog.at_level(logging.DEBUG):
182 with caplog.at_level(logging.DEBUG):
183 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
183 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
184 assert daemon._daemon.server_address[1] in range(0, 66000)
184 assert daemon._daemon.server_address[1] in range(0, 66000)
185 assert daemon._daemon.server_address[0] != '127.0.0.1'
185 assert daemon._daemon.server_address[0] != '127.0.0.1'
186
186
187 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
187 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
188 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
188 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
189 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
189 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
190 assert daemon._daemon == tcp_server
190 assert daemon._daemon == tcp_server
191
191
192 _, port = tcp_server.server_address
192 _, port = tcp_server.server_address
193 expected_uri = '{}:{}'.format('127.0.0.1', port)
193 expected_uri = '{}:{}'.format('127.0.0.1', port)
194 msg = 'HOOKS: {} Preparing HTTP callback daemon registering ' \
194 msg = 'HOOKS: {} Preparing HTTP callback daemon registering ' \
195 'hook object: rhodecode.lib.hooks_daemon.HooksHttpHandler'.format(expected_uri)
195 'hook object: rhodecode.lib.hooks_daemon.HooksHttpHandler'.format(expected_uri)
196 assert_message_in_log(
196 assert_message_in_log(
197 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
197 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
198
198
199 def test_prepare_inits_hooks_uri_and_logs_it(
199 def test_prepare_inits_hooks_uri_and_logs_it(
200 self, tcp_server, caplog):
200 self, tcp_server, caplog):
201 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
201 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
202 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
202 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
203
203
204 _, port = tcp_server.server_address
204 _, port = tcp_server.server_address
205 expected_uri = '{}:{}'.format('127.0.0.1', port)
205 expected_uri = '{}:{}'.format('127.0.0.1', port)
206 assert daemon.hooks_uri == expected_uri
206 assert daemon.hooks_uri == expected_uri
207
207
208 msg = 'HOOKS: {} Preparing HTTP callback daemon registering ' \
208 msg = 'HOOKS: {} Preparing HTTP callback daemon registering ' \
209 'hook object: rhodecode.lib.hooks_daemon.HooksHttpHandler'.format(expected_uri)
209 'hook object: rhodecode.lib.hooks_daemon.HooksHttpHandler'.format(expected_uri)
210 assert_message_in_log(
210 assert_message_in_log(
211 caplog.records, msg,
211 caplog.records, msg,
212 levelno=logging.DEBUG, module='hooks_daemon')
212 levelno=logging.DEBUG, module='hooks_daemon')
213
213
214 def test_run_creates_a_thread(self, tcp_server):
214 def test_run_creates_a_thread(self, tcp_server):
215 thread = mock.Mock()
215 thread = mock.Mock()
216
216
217 with self._tcp_patcher(tcp_server):
217 with self._tcp_patcher(tcp_server):
218 daemon = hooks_daemon.HttpHooksCallbackDaemon()
218 daemon = hooks_daemon.HttpHooksCallbackDaemon()
219
219
220 with self._thread_patcher(thread) as thread_mock:
220 with self._thread_patcher(thread) as thread_mock:
221 daemon._run()
221 daemon._run()
222
222
223 thread_mock.assert_called_once_with(
223 thread_mock.assert_called_once_with(
224 target=tcp_server.serve_forever,
224 target=tcp_server.serve_forever,
225 kwargs={'poll_interval': daemon.POLL_INTERVAL})
225 kwargs={'poll_interval': daemon.POLL_INTERVAL})
226 assert thread.daemon is True
226 assert thread.daemon is True
227 thread.start.assert_called_once_with()
227 thread.start.assert_called_once_with()
228
228
229 def test_run_logs(self, tcp_server, caplog):
229 def test_run_logs(self, tcp_server, caplog):
230
230
231 with self._tcp_patcher(tcp_server):
231 with self._tcp_patcher(tcp_server):
232 daemon = hooks_daemon.HttpHooksCallbackDaemon()
232 daemon = hooks_daemon.HttpHooksCallbackDaemon()
233
233
234 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
234 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
235 daemon._run()
235 daemon._run()
236
236
237 assert_message_in_log(
237 assert_message_in_log(
238 caplog.records,
238 caplog.records,
239 'Running event loop of callback daemon in background thread',
239 'Running event loop of callback daemon in background thread',
240 levelno=logging.DEBUG, module='hooks_daemon')
240 levelno=logging.DEBUG, module='hooks_daemon')
241
241
242 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
242 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
243 thread = mock.Mock()
243 thread = mock.Mock()
244
244
245 with self._tcp_patcher(tcp_server):
245 with self._tcp_patcher(tcp_server):
246 daemon = hooks_daemon.HttpHooksCallbackDaemon()
246 daemon = hooks_daemon.HttpHooksCallbackDaemon()
247
247
248 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
248 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
249 with daemon:
249 with daemon:
250 assert daemon._daemon == tcp_server
250 assert daemon._daemon == tcp_server
251 assert daemon._callback_thread == thread
251 assert daemon._callback_thread == thread
252
252
253 assert daemon._daemon is None
253 assert daemon._daemon is None
254 assert daemon._callback_thread is None
254 assert daemon._callback_thread is None
255 tcp_server.shutdown.assert_called_with()
255 tcp_server.shutdown.assert_called_with()
256 thread.join.assert_called_once_with()
256 thread.join.assert_called_once_with()
257
257
258 assert_message_in_log(
258 assert_message_in_log(
259 caplog.records, 'Waiting for background thread to finish.',
259 caplog.records, 'Waiting for background thread to finish.',
260 levelno=logging.DEBUG, module='hooks_daemon')
260 levelno=logging.DEBUG, module='hooks_daemon')
261
261
262 def _tcp_patcher(self, tcp_server):
262 def _tcp_patcher(self, tcp_server):
263 return mock.patch.object(
263 return mock.patch.object(
264 hooks_daemon, 'TCPServer', return_value=tcp_server)
264 hooks_daemon, 'TCPServer', return_value=tcp_server)
265
265
266 def _thread_patcher(self, thread):
266 def _thread_patcher(self, thread):
267 return mock.patch.object(
267 return mock.patch.object(
268 hooks_daemon.threading, 'Thread', return_value=thread)
268 hooks_daemon.threading, 'Thread', return_value=thread)
269
269
270
270
271 class TestPrepareHooksDaemon(object):
271 class TestPrepareHooksDaemon(object):
272 @pytest.mark.parametrize('protocol', ('http',))
272 @pytest.mark.parametrize('protocol', ('http',))
273 def test_returns_dummy_hooks_callback_daemon_when_using_direct_calls(
273 def test_returns_dummy_hooks_callback_daemon_when_using_direct_calls(
274 self, protocol):
274 self, protocol):
275 expected_extras = {'extra1': 'value1'}
275 expected_extras = {'extra1': 'value1'}
276 callback, extras = hooks_daemon.prepare_callback_daemon(
276 callback, extras = hooks_daemon.prepare_callback_daemon(
277 expected_extras.copy(), protocol=protocol,
277 expected_extras.copy(), protocol=protocol,
278 host='127.0.0.1', use_direct_calls=True)
278 host='127.0.0.1', use_direct_calls=True)
279 assert isinstance(callback, hooks_daemon.DummyHooksCallbackDaemon)
279 assert isinstance(callback, hooks_daemon.DummyHooksCallbackDaemon)
280 expected_extras['hooks_module'] = 'rhodecode.lib.hooks_daemon'
280 expected_extras['hooks_module'] = 'rhodecode.lib.hooks_daemon'
281 expected_extras['time'] = extras['time']
281 expected_extras['time'] = extras['time']
282 assert 'extra1' in extras
282 assert 'extra1' in extras
283
283
284 @pytest.mark.parametrize('protocol, expected_class', (
284 @pytest.mark.parametrize('protocol, expected_class', (
285 ('http', hooks_daemon.HttpHooksCallbackDaemon),
285 ('http', hooks_daemon.HttpHooksCallbackDaemon),
286 ))
286 ))
287 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
287 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
288 self, protocol, expected_class):
288 self, protocol, expected_class):
289 expected_extras = {
289 expected_extras = {
290 'extra1': 'value1',
290 'extra1': 'value1',
291 'txn_id': 'txnid2',
291 'txn_id': 'txnid2',
292 'hooks_protocol': protocol.lower()
292 'hooks_protocol': protocol.lower()
293 }
293 }
294 callback, extras = hooks_daemon.prepare_callback_daemon(
294 callback, extras = hooks_daemon.prepare_callback_daemon(
295 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
295 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
296 use_direct_calls=False,
296 use_direct_calls=False,
297 txn_id='txnid2')
297 txn_id='txnid2')
298 assert isinstance(callback, expected_class)
298 assert isinstance(callback, expected_class)
299 extras.pop('hooks_uri')
299 extras.pop('hooks_uri')
300 expected_extras['time'] = extras['time']
300 expected_extras['time'] = extras['time']
301 assert extras == expected_extras
301 assert extras == expected_extras
302
302
303 @pytest.mark.parametrize('protocol', (
303 @pytest.mark.parametrize('protocol', (
304 'invalid',
304 'invalid',
305 'Http',
305 'Http',
306 'HTTP',
306 'HTTP',
307 ))
307 ))
308 def test_raises_on_invalid_protocol(self, protocol):
308 def test_raises_on_invalid_protocol(self, protocol):
309 expected_extras = {
309 expected_extras = {
310 'extra1': 'value1',
310 'extra1': 'value1',
311 'hooks_protocol': protocol.lower()
311 'hooks_protocol': protocol.lower()
312 }
312 }
313 with pytest.raises(Exception):
313 with pytest.raises(Exception):
314 callback, extras = hooks_daemon.prepare_callback_daemon(
314 callback, extras = hooks_daemon.prepare_callback_daemon(
315 expected_extras.copy(),
315 expected_extras.copy(),
316 protocol=protocol, host='127.0.0.1',
316 protocol=protocol, host='127.0.0.1',
317 use_direct_calls=False)
317 use_direct_calls=False)
318
318
319
319
320 class MockRequest(object):
320 class MockRequest(object):
321 def __init__(self, request):
321 def __init__(self, request):
322 self.request = request
322 self.request = request
323 self.input_stream = StringIO(b'{}'.format(self.request))
323 self.input_stream = io.StringIO(b'{}'.format(self.request))
324 self.output_stream = StringIO()
324 self.output_stream = io.StringIO()
325
325
326 def makefile(self, mode, *args, **kwargs):
326 def makefile(self, mode, *args, **kwargs):
327 return self.output_stream if mode == 'wb' else self.input_stream
327 return self.output_stream if mode == 'wb' else self.input_stream
328
328
329
329
330 class MockServer(object):
330 class MockServer(object):
331 def __init__(self, handler_cls, request):
331 def __init__(self, handler_cls, request):
332 ip_port = ('0.0.0.0', 8888)
332 ip_port = ('0.0.0.0', 8888)
333 self.request = MockRequest(request)
333 self.request = MockRequest(request)
334 self.server_address = ip_port
334 self.server_address = ip_port
335 self.handler = handler_cls(self.request, ip_port, self)
335 self.handler = handler_cls(self.request, ip_port, self)
336
336
337
337
338 @pytest.fixture()
338 @pytest.fixture()
339 def tcp_server():
339 def tcp_server():
340 server = mock.Mock()
340 server = mock.Mock()
341 server.server_address = ('127.0.0.1', 8881)
341 server.server_address = ('127.0.0.1', 8881)
342 return server
342 return server
@@ -1,176 +1,176 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import datetime
21 import datetime
22 import os
22 import os
23 import shutil
23 import shutil
24 import tarfile
24 import tarfile
25 import tempfile
25 import tempfile
26 import zipfile
26 import zipfile
27 from io import StringIO
27 import io
28
28
29 import mock
29 import mock
30 import pytest
30 import pytest
31
31
32 from rhodecode.lib.vcs.backends import base
32 from rhodecode.lib.vcs.backends import base
33 from rhodecode.lib.vcs.exceptions import ImproperArchiveTypeError, VCSError
33 from rhodecode.lib.vcs.exceptions import ImproperArchiveTypeError, VCSError
34 from rhodecode.lib.vcs.nodes import FileNode
34 from rhodecode.lib.vcs.nodes import FileNode
35 from rhodecode.tests.vcs.conftest import BackendTestMixin
35 from rhodecode.tests.vcs.conftest import BackendTestMixin
36
36
37
37
38 @pytest.mark.usefixtures("vcs_repository_support")
38 @pytest.mark.usefixtures("vcs_repository_support")
39 class TestArchives(BackendTestMixin):
39 class TestArchives(BackendTestMixin):
40
40
41 @pytest.fixture(autouse=True)
41 @pytest.fixture(autouse=True)
42 def tempfile(self, request):
42 def tempfile(self, request):
43 self.temp_file = tempfile.mkstemp()[1]
43 self.temp_file = tempfile.mkstemp()[1]
44
44
45 @request.addfinalizer
45 @request.addfinalizer
46 def cleanup():
46 def cleanup():
47 os.remove(self.temp_file)
47 os.remove(self.temp_file)
48
48
49 @classmethod
49 @classmethod
50 def _get_commits(cls):
50 def _get_commits(cls):
51 start_date = datetime.datetime(2010, 1, 1, 20)
51 start_date = datetime.datetime(2010, 1, 1, 20)
52 yield {
52 yield {
53 'message': 'Initial Commit',
53 'message': 'Initial Commit',
54 'author': 'Joe Doe <joe.doe@example.com>',
54 'author': 'Joe Doe <joe.doe@example.com>',
55 'date': start_date + datetime.timedelta(hours=12),
55 'date': start_date + datetime.timedelta(hours=12),
56 'added': [
56 'added': [
57 FileNode('executable_0o100755', '...', mode=0o100755),
57 FileNode('executable_0o100755', '...', mode=0o100755),
58 FileNode('executable_0o100500', '...', mode=0o100500),
58 FileNode('executable_0o100500', '...', mode=0o100500),
59 FileNode('not_executable', '...', mode=0o100644),
59 FileNode('not_executable', '...', mode=0o100644),
60 ],
60 ],
61 }
61 }
62 for x in range(5):
62 for x in range(5):
63 yield {
63 yield {
64 'message': 'Commit %d' % x,
64 'message': 'Commit %d' % x,
65 'author': 'Joe Doe <joe.doe@example.com>',
65 'author': 'Joe Doe <joe.doe@example.com>',
66 'date': start_date + datetime.timedelta(hours=12 * x),
66 'date': start_date + datetime.timedelta(hours=12 * x),
67 'added': [
67 'added': [
68 FileNode('%d/file_%d.txt' % (x, x), content='Foobar %d' % x),
68 FileNode('%d/file_%d.txt' % (x, x), content='Foobar %d' % x),
69 ],
69 ],
70 }
70 }
71
71
72 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
72 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
73 def test_archive_tar(self, compressor):
73 def test_archive_tar(self, compressor):
74 self.tip.archive_repo(
74 self.tip.archive_repo(
75 self.temp_file, kind='t{}'.format(compressor), archive_dir_name='repo')
75 self.temp_file, kind='t{}'.format(compressor), archive_dir_name='repo')
76 out_dir = tempfile.mkdtemp()
76 out_dir = tempfile.mkdtemp()
77 out_file = tarfile.open(self.temp_file, 'r|{}'.format(compressor))
77 out_file = tarfile.open(self.temp_file, 'r|{}'.format(compressor))
78 out_file.extractall(out_dir)
78 out_file.extractall(out_dir)
79 out_file.close()
79 out_file.close()
80
80
81 for x in range(5):
81 for x in range(5):
82 node_path = '%d/file_%d.txt' % (x, x)
82 node_path = '%d/file_%d.txt' % (x, x)
83 with open(os.path.join(out_dir, 'repo/' + node_path)) as f:
83 with open(os.path.join(out_dir, 'repo/' + node_path)) as f:
84 file_content = f.read()
84 file_content = f.read()
85 assert file_content == self.tip.get_node(node_path).content
85 assert file_content == self.tip.get_node(node_path).content
86
86
87 shutil.rmtree(out_dir)
87 shutil.rmtree(out_dir)
88
88
89 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
89 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
90 def test_archive_tar_symlink(self, compressor):
90 def test_archive_tar_symlink(self, compressor):
91 return False
91 return False
92
92
93 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
93 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
94 def test_archive_tar_file_modes(self, compressor):
94 def test_archive_tar_file_modes(self, compressor):
95 self.tip.archive_repo(
95 self.tip.archive_repo(
96 self.temp_file, kind='t{}'.format(compressor), archive_dir_name='repo')
96 self.temp_file, kind='t{}'.format(compressor), archive_dir_name='repo')
97 out_dir = tempfile.mkdtemp()
97 out_dir = tempfile.mkdtemp()
98 out_file = tarfile.open(self.temp_file, 'r|{}'.format(compressor))
98 out_file = tarfile.open(self.temp_file, 'r|{}'.format(compressor))
99 out_file.extractall(out_dir)
99 out_file.extractall(out_dir)
100 out_file.close()
100 out_file.close()
101 dest = lambda inp: os.path.join(out_dir, 'repo/' + inp)
101 dest = lambda inp: os.path.join(out_dir, 'repo/' + inp)
102
102
103 assert oct(os.stat(dest('not_executable')).st_mode) == '0100644'
103 assert oct(os.stat(dest('not_executable')).st_mode) == '0100644'
104
104
105 def test_archive_zip(self):
105 def test_archive_zip(self):
106 self.tip.archive_repo(self.temp_file, kind='zip', archive_dir_name='repo')
106 self.tip.archive_repo(self.temp_file, kind='zip', archive_dir_name='repo')
107 out = zipfile.ZipFile(self.temp_file)
107 out = zipfile.ZipFile(self.temp_file)
108
108
109 for x in range(5):
109 for x in range(5):
110 node_path = '%d/file_%d.txt' % (x, x)
110 node_path = '%d/file_%d.txt' % (x, x)
111 decompressed = StringIO.StringIO()
111 decompressed = io.StringIO()
112 decompressed.write(out.read('repo/' + node_path))
112 decompressed.write(out.read('repo/' + node_path))
113 assert decompressed.getvalue() == \
113 assert decompressed.getvalue() == \
114 self.tip.get_node(node_path).content
114 self.tip.get_node(node_path).content
115 decompressed.close()
115 decompressed.close()
116
116
117 def test_archive_zip_with_metadata(self):
117 def test_archive_zip_with_metadata(self):
118 self.tip.archive_repo(self.temp_file, kind='zip',
118 self.tip.archive_repo(self.temp_file, kind='zip',
119 archive_dir_name='repo', write_metadata=True)
119 archive_dir_name='repo', write_metadata=True)
120
120
121 out = zipfile.ZipFile(self.temp_file)
121 out = zipfile.ZipFile(self.temp_file)
122 metafile = out.read('repo/.archival.txt')
122 metafile = out.read('repo/.archival.txt')
123
123
124 raw_id = self.tip.raw_id
124 raw_id = self.tip.raw_id
125 assert 'commit_id:%s' % raw_id in metafile
125 assert 'commit_id:%s' % raw_id in metafile
126
126
127 for x in range(5):
127 for x in range(5):
128 node_path = '%d/file_%d.txt' % (x, x)
128 node_path = '%d/file_%d.txt' % (x, x)
129 decompressed = StringIO.StringIO()
129 decompressed = io.StringIO()
130 decompressed.write(out.read('repo/' + node_path))
130 decompressed.write(out.read('repo/' + node_path))
131 assert decompressed.getvalue() == \
131 assert decompressed.getvalue() == \
132 self.tip.get_node(node_path).content
132 self.tip.get_node(node_path).content
133 decompressed.close()
133 decompressed.close()
134
134
135 def test_archive_wrong_kind(self):
135 def test_archive_wrong_kind(self):
136 with pytest.raises(ImproperArchiveTypeError):
136 with pytest.raises(ImproperArchiveTypeError):
137 self.tip.archive_repo(self.temp_file, kind='wrong kind')
137 self.tip.archive_repo(self.temp_file, kind='wrong kind')
138
138
139
139
140 @pytest.fixture()
140 @pytest.fixture()
141 def base_commit():
141 def base_commit():
142 """
142 """
143 Prepare a `base.BaseCommit` just enough for `_validate_archive_prefix`.
143 Prepare a `base.BaseCommit` just enough for `_validate_archive_prefix`.
144 """
144 """
145 commit = base.BaseCommit()
145 commit = base.BaseCommit()
146 commit.repository = mock.Mock()
146 commit.repository = mock.Mock()
147 commit.repository.name = u'fake_repo'
147 commit.repository.name = u'fake_repo'
148 commit.short_id = 'fake_id'
148 commit.short_id = 'fake_id'
149 return commit
149 return commit
150
150
151
151
152 @pytest.mark.parametrize("prefix", [u"unicode-prefix", u"Ünïcödë"])
152 @pytest.mark.parametrize("prefix", [u"unicode-prefix", u"Ünïcödë"])
153 def test_validate_archive_prefix_enforces_bytes_as_prefix(prefix, base_commit):
153 def test_validate_archive_prefix_enforces_bytes_as_prefix(prefix, base_commit):
154 with pytest.raises(ValueError):
154 with pytest.raises(ValueError):
155 base_commit._validate_archive_prefix(prefix)
155 base_commit._validate_archive_prefix(prefix)
156
156
157
157
158 def test_validate_archive_prefix_empty_prefix(base_commit):
158 def test_validate_archive_prefix_empty_prefix(base_commit):
159 # TODO: johbo: Should raise a ValueError here.
159 # TODO: johbo: Should raise a ValueError here.
160 with pytest.raises(VCSError):
160 with pytest.raises(VCSError):
161 base_commit._validate_archive_prefix('')
161 base_commit._validate_archive_prefix('')
162
162
163
163
164 def test_validate_archive_prefix_with_leading_slash(base_commit):
164 def test_validate_archive_prefix_with_leading_slash(base_commit):
165 # TODO: johbo: Should raise a ValueError here.
165 # TODO: johbo: Should raise a ValueError here.
166 with pytest.raises(VCSError):
166 with pytest.raises(VCSError):
167 base_commit._validate_archive_prefix('/any')
167 base_commit._validate_archive_prefix('/any')
168
168
169
169
170 def test_validate_archive_prefix_falls_back_to_repository_name(base_commit):
170 def test_validate_archive_prefix_falls_back_to_repository_name(base_commit):
171 prefix = base_commit._validate_archive_prefix(None)
171 prefix = base_commit._validate_archive_prefix(None)
172 expected_prefix = base_commit._ARCHIVE_PREFIX_TEMPLATE.format(
172 expected_prefix = base_commit._ARCHIVE_PREFIX_TEMPLATE.format(
173 repo_name='fake_repo',
173 repo_name='fake_repo',
174 short_id='fake_id')
174 short_id='fake_id')
175 assert isinstance(prefix, str)
175 assert isinstance(prefix, str)
176 assert prefix == expected_prefix
176 assert prefix == expected_prefix
General Comments 0
You need to be logged in to leave comments. Login now